余弦相似度算法

概述:

用向量空间中两个向量夹角的余弦值作为衡量两个个体间差异的大小的度量。余弦值越接近1,就表明夹角越接近0度,也就是两个向量越相似,这就叫"余弦相似性"。

实现思路:

  1. 分词:对需要进比较的文本进行分词,获取出现的词和词频
  2. 汇总:对出现的词进行汇总去重,作为向量的维度
  3. 统计:每个文本以出现的词作为维度,词频作为维度值,没有的以0填充,构建向量
  4. 计算:
cos(\theta) = \frac{\sum_{i=1}^n(X_i × Y_i)}{\sqrt{\sum_{i=1}^n(X_i)^2} × {\sqrt{\sum_{i=1}^n(Y_i)^2}}}

举个简单的例子:

  1. 准备比较文本

    文本一:黑灰化肥灰会挥发发灰黑讳为黑灰花会飞

    文本二:灰黑化肥会会挥发发黑灰为讳飞花化为灰

  2. 分词并统计词频

    文本一:“灰”:4,“挥发”:1,“黑”:3,“发”:1,“讳”:1,“化肥”:1,“花会”:1,“会”:1,“为”:1,“飞”:1

    文本二:“灰”:3,“挥发”:1,“飞花”:1,“黑”:2,“发”:1,“讳”:1,“化肥”:1,“会”:2,“为”:1,“化为”:1

  3. 统计所有词

    所有词:“灰”,“挥发”,“飞花”,“黑”,“发”,“讳”,“化肥”,“花会”,“会”,“为”,“飞”,“化为”

    文本一:(4,1,0,3,1,0,1,1,1,1,0,1,0)

    文本二:(3,1,1,2,1,1,1,2,1,0,1,1,1)

  4. 计算

    x = \frac{4*3+1.1+0*1+...+0*1+1*1+0*1}{\sqrt{4^2+1^2+0^2+...+0^2+1^2+0^2}×\sqrt{3^2+1^2+1^2+...+1^2+1^2+1^2}} = 0.7795794428691074

引入相关依赖

<!--汉语言包,主要用于分词-->
<dependency>
    <groupId>org.ansj</groupId>
    <artifactId>ansj_seg</artifactId>
    <version>5.1.6</version>
</dependency>
<!-- json序列化 -->
<dependency>
    <groupId>com.google.code.gson</groupId>
    <artifactId>gson</artifactId>
    <version>2.8.9</version>
</dependency>

代码部分

import com.google.gson.Gson;
import org.ansj.recognition.impl.StopRecognition;
import org.ansj.splitWord.analysis.ToAnalysis;

import java.util.*;

public static void main(String[] args) {
    String str1 = "黑灰化肥灰会挥发发灰黑讳为黑灰花会飞";
    String str2 = "灰黑化肥会会挥发发黑灰为讳飞花化为灰";
    StopRecognition filter = new StopRecognition();
    // 过滤掉标点
    filter.insertStopNatures("w");
    // 分词-统计词频
    Map<String,Integer> map1= new HashMap<>();
    ToAnalysis.parse(str1).recognition(filter).forEach(item -> {
        // 没有则赋初始值,有则+1
        map1.put(item.getName(),map1.getOrDefault(item.getName(),0)+1);
    });
    Map<String,Integer> map2 = new HashMap<>();
    ToAnalysis.parse(str2).recognition(filter).forEach(item -> {
        // 没有则赋初始值,有则+1
        map2.put(item.getName(),map2.getOrDefault(item.getName(),0)+1);
    });
    Gson gson = new Gson();
    System.out.println("map1="+ gson.toJson(map1));
    System.out.println("map1="+ gson.toJson(map2));
    // 将分词存放到集合中进行汇总并去重
    Set<String> set1 = map1.keySet();
    Set<String> set2 = map2.keySet();
    Set<String> setAll = new HashSet<>();
    setAll.addAll(set1);
    setAll.addAll(set2);
    System.out.println("all="+gson.toJson(setAll));
    // 创建两个列表用于存储每个字符串中分词出现的次数
    List<Integer> list1 = new ArrayList<>(setAll.size());
    List<Integer> list2 = new ArrayList<>(setAll.size());
    //构建向量
    setAll.forEach(item ->{
        if (set1.contains(item)){
            list1.add(map1.get(item));
        }else {
            list1.add(0);
        }

        if (set2.contains(item)){
            list2.add(map2.get(item));
        }else {
            list2.add(0);
        }
    });
    //计算余弦相似度
    int sum =0;
    long sq1 = 0;
    long sq2 = 0;
    double result = 0;
    for (int i =0;i<setAll.size();i++){
        sum +=list1.get(i)*list2.get(i);
        sq1 += list1.get(i)*list1.get(i);
        sq2 += list2.get(i)*list2.get(i);
    }
    result = sum/(Math.sqrt(sq1)*Math.sqrt(sq2));
    System.out.println("余弦相似度="+result);
}