import json import pickle import requests import numpy as np from fuzzywuzzy import fuzz from sentence_transformers import util from concurrent.futures import ThreadPoolExecutor from pprint import pprint import config from formula_process import formula_recognize class HNSW(): def __init__(self, data_process, logger=None): # 配置初始数据 self.mongo_coll = config.mongo_coll self.vector_dim = config.vector_dim self.hnsw_retrieve_url = config.hnsw_retrieve_url # 日志采集 self.logger = logger self.log_msg = config.log_msg # 数据预处理实例化 self.dpp = data_process # 加载公式处理数据模型(词袋模型/原始向量/原始数据) with open(config.bow_model_path, "rb") as bm: self.bow_model = pickle.load(bm) self.bow_vector = np.load(config.bow_vector_path) with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f: self.formula_id_list = json.load(f) # 图片搜索查重功能 def img_retrieve(self, retrieve_text, post_url, similar, topic_num): try: if post_url is not None: # 日志采集 if self.logger is not None: self.logger.info(self.log_msg.format(id="图片搜索查重", type="{}图片搜索查重post".format(topic_num), message=retrieve_text)) img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40) img_res = requests.post(post_url, json=img_dict, timeout=30).json() # 日志采集 if self.logger is not None: self.logger.info(self.log_msg.format(id="图片搜索查重", type="{}图片搜索查重success".format(topic_num), message=img_res)) return img_res except Exception as e: # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="图片搜索查重", type="{}图片搜索查重error".format(topic_num), message=retrieve_text)) return [] # 公式搜索查重功能 def formula_retrieve(self, formula_string, similar): # 调用清洗分词函数 formula_string = '$' + formula_string + '$' formula_string = self.dpp.content_clear_func(formula_string) # 公式识别 formula_list = formula_recognize(formula_string) if len(formula_list) == 0: return [] # 日志采集 if self.logger is not None: self.logger.info(self.log_msg.format(id="formula_retrieve", type=formula_string, message=formula_list)) try: # 使用词袋模型计算句向量 bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32") # 并行计算余弦相似度 formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector)) # 获取余弦值大于等于0.8的数据索引 cos_list = np.where(formula_cos[0] >= similar)[0] except: return [] if len(cos_list) == 0: return [] # 根据阈值获取满足条件的题目id res_list = [] # formula_threshold = similar formula_threshold = 0.7 for idx in cos_list: fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100 if fuzz_score >= formula_threshold: # res_list.append([self.formula_id_list[idx][1], fuzz_score]) # 对余弦相似度进行折算 cosine_score = formula_cos[0][idx] if 0.95 <= cosine_score < 0.98: cosine_score = cosine_score * 0.98 elif cosine_score < 0.95: cosine_score = cosine_score * 0.95 # 余弦相似度折算后阈值判断 if cosine_score < similar: continue res_list.append([self.formula_id_list[idx][1], int(cosine_score * 100) / 100]) # 根据分数对题目id排序并返回前50个 res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80] formula_res_list = [] fid_set = set() for ele in res_sort_list: for fid in ele[0]: if fid in fid_set: continue fid_set.add(fid) formula_res_list.append([fid, ele[1]]) return formula_res_list[:50] # # HNSW查(支持多学科混合查重) # def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56): # # 计算retrieve_list的vec值 # # 调用清洗分词函数和句向量计算函数 # sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True) # # HNSW查重 # def dup_search(retrieve_data, sent_vec, cont_clear): # # 初始化返回数据类型 # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[]) # # 获取题目序号 # topic_num = retrieve_data["topic_num"] if "topic_num" in retrieve_data else 1 # # 图片搜索查重功能 # if doc_flag is True: # retrieve_value_dict["image"] = self.img_retrieve(retrieve_data["stem"], post_url, similar, topic_num) # else: # retrieve_value_dict["image"] = [] # # 判断句向量维度 # if sent_vec.size != self.vector_dim: # return retrieve_value_dict # # 调用hnsw接口检索数据 # post_list = sent_vec.tolist() # try: # query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json() # except Exception as e: # query_labels = [] # # 日志采集 # if self.logger is not None: # self.logger.error(self.log_msg.format(id="HNSW检索error", # type="当前题目HNSW检索error", # message=cont_clear)) # if len(query_labels) == 0: # return retrieve_value_dict # # 批量读取数据库 # mongo_find_dict = {"id": {"$in": query_labels}} # query_dataset = self.mongo_coll.find(mongo_find_dict) # # 返回大于阈值的结果 # filter_threshold = similar # for label_data in query_dataset: # if "sentence_vec" not in label_data: # continue # # 计算余弦相似度得分 # label_vec = pickle.loads(label_data["sentence_vec"]) # if label_vec.size != self.vector_dim: # continue # cosine_score = util.cos_sim(sent_vec, label_vec)[0][0] # # 阈值判断 # if cosine_score < filter_threshold: # continue # # 计算编辑距离得分 # fuzz_score = fuzz.ratio(cont_clear, label_data["content_clear"]) / 100 # if fuzz_score < min_threshold: # continue # # 对余弦相似度进行折算 # if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06: # cosine_score = cosine_score * 0.95 # elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06: # cosine_score = cosine_score * 0.94 # # 余弦相似度折算后阈值判断 # if cosine_score < filter_threshold: # continue # retrieve_value = [label_data["id"], int(cosine_score * 100) / 100] # retrieve_value_dict["semantics"].append(retrieve_value) # # 进行编辑距离得分验证,若小于设定分则过滤 # if fuzz_score >= filter_threshold: # retrieve_value = [label_data["id"], fuzz_score] # retrieve_value_dict["text"].append(retrieve_value) # # 将组合结果按照score降序排序并取得分前十个结果 # retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True) # for k,value in retrieve_value_dict.items()} # # 综合排序 # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True) # synthese_set = set() # for ele in synthese_list: # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50: # synthese_set.add(ele[0]) # retrieve_sort_dict["synthese"].append(ele) # # 加入题目序号 # retrieve_sort_dict["topic_num"] = topic_num # # 以字典形式返回最终查重结果 # return retrieve_sort_dict # # 多线程HNSW查重 # with ThreadPoolExecutor(max_workers=5) as executor: # retrieve_res_list = list(executor.map(dup_search, retrieve_list, sent_vec_list, cont_clear_list)) # return retrieve_res_list # HNSW查(支持多学科混合查重) def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56): # 计算retrieve_list的vec值 # 调用清洗分词函数和句向量计算函数 sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True) # HNSW查重 retrieve_res_list = [] for i,sent_vec in enumerate(sent_vec_list): # 初始化返回数据类型 retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[]) # 获取题目序号 topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1 # 图片搜索查重功能 if doc_flag is True: retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num) else: retrieve_value_dict["image"] = [] # 判断句向量维度 if sent_vec.size != self.vector_dim: retrieve_res_list.append(retrieve_value_dict) continue # 调用hnsw接口检索数据 post_list = sent_vec.tolist() try: query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json() except Exception as e: query_labels = [] # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="HNSW检索error", type="当前题目HNSW检索error", message=cont_clear_list[i])) if len(query_labels) == 0: retrieve_res_list.append(retrieve_value_dict) continue # 批量读取数据库 mongo_find_dict = {"id": {"$in": query_labels}} query_dataset = self.mongo_coll.find(mongo_find_dict) # 返回大于阈值的结果 filter_threshold = similar for label_data in query_dataset: if "sentence_vec" not in label_data: continue # 计算余弦相似度得分 label_vec = pickle.loads(label_data["sentence_vec"]) if label_vec.size != self.vector_dim: continue cosine_score = util.cos_sim(sent_vec, label_vec)[0][0] # 阈值判断 if cosine_score < filter_threshold: continue # 计算编辑距离得分 fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100 if fuzz_score < min_threshold: continue # 对余弦相似度进行折算 if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06: cosine_score = cosine_score * 0.95 elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06: cosine_score = cosine_score * 0.94 # 余弦相似度折算后阈值判断 if cosine_score < filter_threshold: continue retrieve_value = [label_data["id"], int(cosine_score * 100) / 100] retrieve_value_dict["semantics"].append(retrieve_value) # 进行编辑距离得分验证,若小于设定分则过滤 if fuzz_score >= filter_threshold: retrieve_value = [label_data["id"], fuzz_score] retrieve_value_dict["text"].append(retrieve_value) # 将组合结果按照score降序排序并取得分前十个结果 retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True) for k,value in retrieve_value_dict.items()} # 综合排序 synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True) synthese_set = set() for ele in synthese_list: if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50: synthese_set.add(ele[0]) retrieve_sort_dict["synthese"].append(ele) # 加入题目序号 retrieve_sort_dict["topic_num"] = topic_num # 以字典形式返回最终查重结果 retrieve_res_list.append(retrieve_sort_dict) return retrieve_res_list if __name__ == "__main__": # 获取mongodb数据 mongo_coll = config.mongo_coll from data_preprocessing import DataPreProcessing hnsw = HNSW(DataPreProcessing()) test_data = [] for idx in [201511100736265]: test_data.append(mongo_coll.find_one({"id": idx})) res = hnsw.retrieve(test_data, '', 0.8, False) pprint(res[0]["semantics"]) # # 公式搜索查重功能 # formula_string = "ρ蜡=0.9*10^3Kg/m^3" # formula_string = "p蜡=0.9*10^3Kq/m^3" # print(hnsw.formula_retrieve(formula_string, 0.8))