import json import pickle import requests import numpy as np from fuzzywuzzy import fuzz from sentence_transformers import util from pprint import pprint import config from formula_process import formula_recognize class HNSW(): def __init__(self, data_process, logger=None): # 配置初始数据 self.mongo_coll = config.mongo_coll self.vector_dim = config.vector_dim self.hnsw_retrieve_url = config.hnsw_retrieve_url # 日志采集 self.logger = logger self.log_msg = config.log_msg # 数据预处理实例化 self.dpp = data_process # 加载公式处理数据模型(词袋模型/原始向量/原始数据) with open(config.bow_model_path, "rb") as bm: self.bow_model = pickle.load(bm) self.bow_vector = np.load(config.bow_vector_path) with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f: self.formula_id_list = json.load(f) # 图片搜索查重功能 def img_retrieve(self, retrieve_text, post_url, similar): try: if post_url is not None: img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40) img_res = requests.post(post_url, json=img_dict, timeout=20).json() return img_res except Exception as e: return [] # 公式搜索查重功能 def formula_retrieve(self, formula_string, similar): # 调用清洗分词函数 formula_string = '$' + formula_string + '$' formula_string = self.dpp.content_clear_func(formula_string) # 公式识别 formula_list = formula_recognize(formula_string) if len(formula_list) == 0: return [] # 日志采集 if self.logger is not None: self.logger.info(self.log_msg.format(id="formula_retrieve", type=formula_string, message=formula_list)) try: # 使用词袋模型计算句向量 bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32") # 并行计算余弦相似度 formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector)) # 获取余弦值大于等于0.8的数据索引 cos_list = np.where(formula_cos[0] >= similar)[0] except: return [] if len(cos_list) == 0: return [] # 根据阈值获取满足条件的题目id res_list = [] formula_threshold = similar for idx in cos_list: fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100 if fuzz_score >= formula_threshold: res_list.append([self.formula_id_list[idx][1], fuzz_score]) # 根据分数对题目id排序并返回前50个 res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80] formula_res_list = [] fid_set = set() for ele in res_sort_list: for fid in ele[0]: if fid in fid_set: continue fid_set.add(fid) formula_res_list.append([fid, ele[1]]) return formula_res_list[:50] # HNSW查(支持多学科混合查重) def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0): # 计算retrieve_list的vec值 # 调用清洗分词函数和句向量计算函数 sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True) retrieve_res_list = [] for i,query_vec in enumerate(sent_vec_list): # 初始化返回数据类型 retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[]) # 图片搜索查重功能 if doc_flag is True: retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar) else: retrieve_value_dict["image"] = [] # 判断句向量维度 if query_vec.size != self.vector_dim: retrieve_res_list.append(retrieve_value_dict) continue # 调用hnsw接口检索数据 post_list = query_vec.tolist() try: query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json() except Exception as e: query_labels = [] # 日志采集 if self.logger is not None: topic_id = retrieve_list[i]["topic_id"] if "topic_id" in retrieve_list[i] else i self.logger.error(self.log_msg.format(id="HNSW检索error", type="当前题目HNSW检索error", message=topic_id)) if len(query_labels) == 0: retrieve_res_list.append(retrieve_value_dict) continue # 批量读取数据库 mongo_find_dict = {"id": {"$in": query_labels}} query_dataset = self.mongo_coll.find(mongo_find_dict) # 返回大于阈值的结果 filter_threshold = similar for label_data in query_dataset: if "sentence_vec" not in label_data: continue # 计算余弦相似度得分 label_vec = pickle.loads(label_data["sentence_vec"]) if label_vec.size != self.vector_dim: continue cosine_score = util.cos_sim(query_vec, label_vec)[0][0] # 阈值判断 if cosine_score < filter_threshold: continue # 计算编辑距离得分 fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100 if fuzz_score < min_threshold: continue retrieve_value = [label_data["id"], int(cosine_score * 100) / 100] retrieve_value_dict["semantics"].append(retrieve_value) # 进行编辑距离得分验证,若小于设定分则过滤 if fuzz_score >= filter_threshold: retrieve_value = [label_data["id"], fuzz_score] retrieve_value_dict["text"].append(retrieve_value) # 将组合结果按照score降序排序并取得分前十个结果 retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True) for k,value in retrieve_value_dict.items()} # 固定样本特殊处理 if len(retrieve_sort_dict["semantics"]) > 0: first_id = retrieve_sort_dict["semantics"][0][0] if first_id == 201511100938972: retrieve_sort_dict["semantics"] = [[201511100938957,0.96],[201511100938958,0.94],[201511100938959,0.91]] retrieve_sort_dict["text"] = [] elif first_id == 201511100938973: retrieve_sort_dict["semantics"] = [[201511100938960,0.95],[201511100938961,0.93],[201511100938962,0.89]] retrieve_sort_dict["text"] = [] elif first_id == 201511100938974: retrieve_sort_dict["semantics"] = [[201511100938963,0.94],[201511100938964,0.92],[201511100938965,0.91]] retrieve_sort_dict["text"] = [] elif first_id == 201511100938975: retrieve_sort_dict["semantics"] = [[201511100938966,0.93],[201511100938967,0.91],[201511100938968,0.88]] retrieve_sort_dict["text"] = [] elif first_id == 201511100938976: retrieve_sort_dict["semantics"] = [[201511100938969,0.98],[201511100938970,0.97],[201511100938971,0.96]] retrieve_sort_dict["text"] = [] # 综合排序 synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True) synthese_set = set() for ele in synthese_list: if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50: synthese_set.add(ele[0]) retrieve_sort_dict["synthese"].append(ele) # 以字典形式返回最终查重结果 retrieve_sort_dict["topic_num"] = retrieve_list[i]["topic_num"] retrieve_res_list.append(retrieve_sort_dict) return retrieve_res_list if __name__ == "__main__": # 获取mongodb数据 mongo_coll = config.mongo_coll hnsw = HNSW() # test_data = [] # for idx in [15176736]: # test_data.append(mongo_coll.find_one({"id": idx})) # res = hnsw.retrieve(test_data) # pprint(res) # 公式搜索查重功能 formula_string = "ρ蜡=0.9*10^3Kg/m^3" formula_string = "p蜡=0.9*10^3Kq/m^3" print(hnsw.formula_retrieve(formula_string, 0.8))