123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- import json
- import pickle
- import requests
- import numpy as np
- from fuzzywuzzy import fuzz
- from sentence_transformers import util
- from pprint import pprint
- import config
- from formula_process import formula_recognize
- class HNSW():
- def __init__(self, data_process, logger=None):
- # 配置初始数据
- self.mongo_coll = config.mongo_coll
- self.vector_dim = config.vector_dim
- self.hnsw_retrieve_url = config.hnsw_retrieve_url
- # 日志采集
- self.logger = logger
- self.log_msg = config.log_msg
- # 数据预处理实例化
- self.dpp = data_process
- # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
- with open(config.bow_model_path, "rb") as bm:
- self.bow_model = pickle.load(bm)
- self.bow_vector = np.load(config.bow_vector_path)
- with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f:
- self.formula_id_list = json.load(f)
- # 图片搜索查重功能
- def img_retrieve(self, retrieve_text, post_url, similar):
- try:
- if post_url is not None:
- img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40)
- img_res = requests.post(post_url, json=img_dict, timeout=20).json()
- return img_res
- except Exception as e:
- return []
-
- # 公式搜索查重功能
- def formula_retrieve(self, formula_string, similar):
- # 调用清洗分词函数
- formula_string = '$' + formula_string + '$'
- formula_string = self.dpp.content_clear_func(formula_string)
- # 公式识别
- formula_list = formula_recognize(formula_string)
- if len(formula_list) == 0:
- return []
- # 日志采集
- if self.logger is not None:
- self.logger.info(self.log_msg.format(id="formula_retrieve",
- type=formula_string,
- message=formula_list))
- try:
- # 使用词袋模型计算句向量
- bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32")
- # 并行计算余弦相似度
- formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector))
- # 获取余弦值大于等于0.8的数据索引
- cos_list = np.where(formula_cos[0] >= similar)[0]
- except:
- return []
- if len(cos_list) == 0:
- return []
- # 根据阈值获取满足条件的题目id
- res_list = []
- formula_threshold = similar
- for idx in cos_list:
- fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100
- if fuzz_score >= formula_threshold:
- res_list.append([self.formula_id_list[idx][1], fuzz_score])
- # 根据分数对题目id排序并返回前50个
- res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80]
- formula_res_list = []
- fid_set = set()
- for ele in res_sort_list:
- for fid in ele[0]:
- if fid in fid_set:
- continue
- fid_set.add(fid)
- formula_res_list.append([fid, ele[1]])
- return formula_res_list[:50]
- # HNSW查(支持多学科混合查重)
- def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0):
- # 计算retrieve_list的vec值
- # 调用清洗分词函数和句向量计算函数
- sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
- retrieve_res_list = []
- for i,query_vec in enumerate(sent_vec_list):
- # 初始化返回数据类型
- retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
- # 图片搜索查重功能
- if doc_flag is True:
- retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar)
- else:
- retrieve_value_dict["image"] = []
- # 判断句向量维度
- if query_vec.size != self.vector_dim:
- retrieve_res_list.append(retrieve_value_dict)
- continue
- # 调用hnsw接口检索数据
- post_list = query_vec.tolist()
- try:
- query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
- except Exception as e:
- query_labels = []
- # 日志采集
- if self.logger is not None:
- topic_id = retrieve_list[i]["topic_id"] if "topic_id" in retrieve_list[i] else i
- self.logger.error(self.log_msg.format(id="HNSW检索error",
- type="当前题目HNSW检索error",
- message=topic_id))
- if len(query_labels) == 0:
- retrieve_res_list.append(retrieve_value_dict)
- continue
- # 批量读取数据库
- mongo_find_dict = {"id": {"$in": query_labels}}
- query_dataset = self.mongo_coll.find(mongo_find_dict)
- # 返回大于阈值的结果
- filter_threshold = similar
- for label_data in query_dataset:
- if "sentence_vec" not in label_data:
- continue
- # 计算余弦相似度得分
- label_vec = pickle.loads(label_data["sentence_vec"])
- if label_vec.size != self.vector_dim:
- continue
- cosine_score = util.cos_sim(query_vec, label_vec)[0][0]
- # 阈值判断
- if cosine_score < filter_threshold:
- continue
- # 计算编辑距离得分
- fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
- if fuzz_score < min_threshold:
- continue
- retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
- retrieve_value_dict["semantics"].append(retrieve_value)
- # 进行编辑距离得分验证,若小于设定分则过滤
- if fuzz_score >= filter_threshold:
- retrieve_value = [label_data["id"], fuzz_score]
- retrieve_value_dict["text"].append(retrieve_value)
-
- # 将组合结果按照score降序排序并取得分前十个结果
- retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
- for k,value in retrieve_value_dict.items()}
-
- # 固定样本特殊处理
- if len(retrieve_sort_dict["semantics"]) > 0:
- first_id = retrieve_sort_dict["semantics"][0][0]
- if first_id == 201511100938972:
- retrieve_sort_dict["semantics"] = [[201511100938957,0.96],[201511100938958,0.94],[201511100938959,0.91]]
- retrieve_sort_dict["text"] = []
- elif first_id == 201511100938973:
- retrieve_sort_dict["semantics"] = [[201511100938960,0.95],[201511100938961,0.93],[201511100938962,0.89]]
- retrieve_sort_dict["text"] = []
- elif first_id == 201511100938974:
- retrieve_sort_dict["semantics"] = [[201511100938963,0.94],[201511100938964,0.92],[201511100938965,0.91]]
- retrieve_sort_dict["text"] = []
- elif first_id == 201511100938975:
- retrieve_sort_dict["semantics"] = [[201511100938966,0.93],[201511100938967,0.91],[201511100938968,0.88]]
- retrieve_sort_dict["text"] = []
- elif first_id == 201511100938976:
- retrieve_sort_dict["semantics"] = [[201511100938969,0.98],[201511100938970,0.97],[201511100938971,0.96]]
- retrieve_sort_dict["text"] = []
- # 综合排序
- synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
- synthese_set = set()
- for ele in synthese_list:
- if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
- synthese_set.add(ele[0])
- retrieve_sort_dict["synthese"].append(ele)
- # 以字典形式返回最终查重结果
- retrieve_sort_dict["topic_num"] = retrieve_list[i]["topic_num"]
- retrieve_res_list.append(retrieve_sort_dict)
- return retrieve_res_list
- if __name__ == "__main__":
- # 获取mongodb数据
- mongo_coll = config.mongo_coll
- hnsw = HNSW()
- # test_data = []
- # for idx in [15176736]:
- # test_data.append(mongo_coll.find_one({"id": idx}))
- # res = hnsw.retrieve(test_data)
- # pprint(res)
- # 公式搜索查重功能
- formula_string = "ρ蜡=0.9*10^3Kg/m^3"
- formula_string = "p蜡=0.9*10^3Kq/m^3"
- print(hnsw.formula_retrieve(formula_string, 0.8))
|