123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313 |
- import json
- import pickle
- import requests
- import numpy as np
- from fuzzywuzzy import fuzz
- from sentence_transformers import util
- from concurrent.futures import ThreadPoolExecutor
- from pprint import pprint
- import config
- from formula_process import formula_recognize
- class HNSW():
- def __init__(self, data_process, logger=None):
- # 配置初始数据
- self.mongo_coll = config.mongo_coll
- self.vector_dim = config.vector_dim
- self.hnsw_retrieve_url = config.hnsw_retrieve_url
- # 日志采集
- self.logger = logger
- self.log_msg = config.log_msg
- # 数据预处理实例化
- self.dpp = data_process
- # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
- with open(config.bow_model_path, "rb") as bm:
- self.bow_model = pickle.load(bm)
- self.bow_vector = np.load(config.bow_vector_path)
- with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f:
- self.formula_id_list = json.load(f)
- # 图片搜索查重功能
- def img_retrieve(self, retrieve_text, post_url, similar, topic_num):
- try:
- if post_url is not None:
- # 日志采集
- if self.logger is not None:
- self.logger.info(self.log_msg.format(id="图片搜索查重",
- type="{}图片搜索查重post".format(topic_num),
- message=retrieve_text))
- img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40)
- img_res = requests.post(post_url, json=img_dict, timeout=30).json()
- # 日志采集
- if self.logger is not None:
- self.logger.info(self.log_msg.format(id="图片搜索查重",
- type="{}图片搜索查重success".format(topic_num),
- message=img_res))
- return img_res
- except Exception as e:
- # 日志采集
- if self.logger is not None:
- self.logger.error(self.log_msg.format(id="图片搜索查重",
- type="{}图片搜索查重error".format(topic_num),
- message=retrieve_text))
- return []
-
- # 公式搜索查重功能
- def formula_retrieve(self, formula_string, similar):
- # 调用清洗分词函数
- formula_string = '$' + formula_string + '$'
- formula_string = self.dpp.content_clear_func(formula_string)
- # 公式识别
- formula_list = formula_recognize(formula_string)
- if len(formula_list) == 0:
- return []
- # 日志采集
- if self.logger is not None:
- self.logger.info(self.log_msg.format(id="formula_retrieve",
- type=formula_string,
- message=formula_list))
- try:
- # 使用词袋模型计算句向量
- bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32")
- # 并行计算余弦相似度
- formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector))
- # 获取余弦值大于等于0.8的数据索引
- cos_list = np.where(formula_cos[0] >= similar)[0]
- except:
- return []
- if len(cos_list) == 0:
- return []
- # 根据阈值获取满足条件的题目id
- res_list = []
- # formula_threshold = similar
- formula_threshold = 0.7
- for idx in cos_list:
- fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100
- if fuzz_score >= formula_threshold:
- # res_list.append([self.formula_id_list[idx][1], fuzz_score])
- # 对余弦相似度进行折算
- cosine_score = formula_cos[0][idx]
- if 0.95 <= cosine_score < 0.98:
- cosine_score = cosine_score * 0.98
- elif cosine_score < 0.95:
- cosine_score = cosine_score * 0.95
- # 余弦相似度折算后阈值判断
- if cosine_score < similar:
- continue
- res_list.append([self.formula_id_list[idx][1], int(cosine_score * 100) / 100])
- # 根据分数对题目id排序并返回前50个
- res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80]
- formula_res_list = []
- fid_set = set()
- for ele in res_sort_list:
- for fid in ele[0]:
- if fid in fid_set:
- continue
- fid_set.add(fid)
- formula_res_list.append([fid, ele[1]])
- return formula_res_list[:50]
- # # HNSW查(支持多学科混合查重)
- # def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
- # # 计算retrieve_list的vec值
- # # 调用清洗分词函数和句向量计算函数
- # sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
- # # HNSW查重
- # def dup_search(retrieve_data, sent_vec, cont_clear):
- # # 初始化返回数据类型
- # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
- # # 获取题目序号
- # topic_num = retrieve_data["topic_num"] if "topic_num" in retrieve_data else 1
- # # 图片搜索查重功能
- # if doc_flag is True:
- # retrieve_value_dict["image"] = self.img_retrieve(retrieve_data["stem"], post_url, similar, topic_num)
- # else:
- # retrieve_value_dict["image"] = []
- # # 判断句向量维度
- # if sent_vec.size != self.vector_dim:
- # return retrieve_value_dict
- # # 调用hnsw接口检索数据
- # post_list = sent_vec.tolist()
- # try:
- # query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
- # except Exception as e:
- # query_labels = []
- # # 日志采集
- # if self.logger is not None:
- # self.logger.error(self.log_msg.format(id="HNSW检索error",
- # type="当前题目HNSW检索error",
- # message=cont_clear))
- # if len(query_labels) == 0:
- # return retrieve_value_dict
- # # 批量读取数据库
- # mongo_find_dict = {"id": {"$in": query_labels}}
- # query_dataset = self.mongo_coll.find(mongo_find_dict)
- # # 返回大于阈值的结果
- # filter_threshold = similar
- # for label_data in query_dataset:
- # if "sentence_vec" not in label_data:
- # continue
- # # 计算余弦相似度得分
- # label_vec = pickle.loads(label_data["sentence_vec"])
- # if label_vec.size != self.vector_dim:
- # continue
- # cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
- # # 阈值判断
- # if cosine_score < filter_threshold:
- # continue
- # # 计算编辑距离得分
- # fuzz_score = fuzz.ratio(cont_clear, label_data["content_clear"]) / 100
- # if fuzz_score < min_threshold:
- # continue
- # # 对余弦相似度进行折算
- # if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
- # cosine_score = cosine_score * 0.95
- # elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
- # cosine_score = cosine_score * 0.94
- # # 余弦相似度折算后阈值判断
- # if cosine_score < filter_threshold:
- # continue
- # retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
- # retrieve_value_dict["semantics"].append(retrieve_value)
- # # 进行编辑距离得分验证,若小于设定分则过滤
- # if fuzz_score >= filter_threshold:
- # retrieve_value = [label_data["id"], fuzz_score]
- # retrieve_value_dict["text"].append(retrieve_value)
-
- # # 将组合结果按照score降序排序并取得分前十个结果
- # retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
- # for k,value in retrieve_value_dict.items()}
- # # 综合排序
- # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
- # synthese_set = set()
- # for ele in synthese_list:
- # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
- # synthese_set.add(ele[0])
- # retrieve_sort_dict["synthese"].append(ele)
- # # 加入题目序号
- # retrieve_sort_dict["topic_num"] = topic_num
-
- # # 以字典形式返回最终查重结果
- # return retrieve_sort_dict
- # # 多线程HNSW查重
- # with ThreadPoolExecutor(max_workers=5) as executor:
- # retrieve_res_list = list(executor.map(dup_search, retrieve_list, sent_vec_list, cont_clear_list))
- # return retrieve_res_list
- # HNSW查(支持多学科混合查重)
- def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
- # 计算retrieve_list的vec值
- # 调用清洗分词函数和句向量计算函数
- sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
- # HNSW查重
- retrieve_res_list = []
- for i,sent_vec in enumerate(sent_vec_list):
- # 初始化返回数据类型
- retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
- # 获取题目序号
- topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
- # 图片搜索查重功能
- if doc_flag is True:
- retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
- else:
- retrieve_value_dict["image"] = []
- # 判断句向量维度
- if sent_vec.size != self.vector_dim:
- retrieve_res_list.append(retrieve_value_dict)
- continue
- # 调用hnsw接口检索数据
- post_list = sent_vec.tolist()
- try:
- query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
- except Exception as e:
- query_labels = []
- # 日志采集
- if self.logger is not None:
- self.logger.error(self.log_msg.format(id="HNSW检索error",
- type="当前题目HNSW检索error",
- message=cont_clear_list[i]))
- if len(query_labels) == 0:
- retrieve_res_list.append(retrieve_value_dict)
- continue
- # 批量读取数据库
- mongo_find_dict = {"id": {"$in": query_labels}}
- query_dataset = self.mongo_coll.find(mongo_find_dict)
- # 返回大于阈值的结果
- filter_threshold = similar
- for label_data in query_dataset:
- if "sentence_vec" not in label_data:
- continue
- # 计算余弦相似度得分
- label_vec = pickle.loads(label_data["sentence_vec"])
- if label_vec.size != self.vector_dim:
- continue
- cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
- # 阈值判断
- if cosine_score < filter_threshold:
- continue
- # 计算编辑距离得分
- fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
- if fuzz_score < min_threshold:
- continue
- # 对余弦相似度进行折算
- if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
- cosine_score = cosine_score * 0.95
- elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
- cosine_score = cosine_score * 0.94
- # 余弦相似度折算后阈值判断
- if cosine_score < filter_threshold:
- continue
- retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
- retrieve_value_dict["semantics"].append(retrieve_value)
- # 进行编辑距离得分验证,若小于设定分则过滤
- if fuzz_score >= filter_threshold:
- retrieve_value = [label_data["id"], fuzz_score]
- retrieve_value_dict["text"].append(retrieve_value)
-
- # 将组合结果按照score降序排序并取得分前十个结果
- retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
- for k,value in retrieve_value_dict.items()}
- # 综合排序
- synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
- synthese_set = set()
- for ele in synthese_list:
- if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
- synthese_set.add(ele[0])
- retrieve_sort_dict["synthese"].append(ele)
- # 加入题目序号
- retrieve_sort_dict["topic_num"] = topic_num
-
- # 以字典形式返回最终查重结果
- retrieve_res_list.append(retrieve_sort_dict)
- return retrieve_res_list
- if __name__ == "__main__":
- # 获取mongodb数据
- mongo_coll = config.mongo_coll
- from data_preprocessing import DataPreProcessing
- hnsw = HNSW(DataPreProcessing())
- test_data = []
- for idx in [201511100736265]:
- test_data.append(mongo_coll.find_one({"id": idx}))
- res = hnsw.retrieve(test_data, '', 0.8, False)
- pprint(res[0]["semantics"])
- # # 公式搜索查重功能
- # formula_string = "ρ蜡=0.9*10^3Kg/m^3"
- # formula_string = "p蜡=0.9*10^3Kq/m^3"
- # print(hnsw.formula_retrieve(formula_string, 0.8))
|