hnsw_retrieval.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import json
  2. import pickle
  3. import requests
  4. import numpy as np
  5. from fuzzywuzzy import fuzz
  6. from sentence_transformers import util
  7. from concurrent.futures import ThreadPoolExecutor
  8. from pprint import pprint
  9. import config
  10. from formula_process import formula_recognize
  11. class HNSW():
  12. def __init__(self, data_process, logger=None):
  13. # 配置初始数据
  14. self.mongo_coll = config.mongo_coll
  15. self.vector_dim = config.vector_dim
  16. self.hnsw_retrieve_url = config.hnsw_retrieve_url
  17. # 日志采集
  18. self.logger = logger
  19. self.log_msg = config.log_msg
  20. # 数据预处理实例化
  21. self.dpp = data_process
  22. # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
  23. with open(config.bow_model_path, "rb") as bm:
  24. self.bow_model = pickle.load(bm)
  25. self.bow_vector = np.load(config.bow_vector_path)
  26. with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f:
  27. self.formula_id_list = json.load(f)
  28. # 图片搜索查重功能
  29. def img_retrieve(self, retrieve_text, post_url, similar, topic_num):
  30. try:
  31. if post_url is not None:
  32. # 日志采集
  33. if self.logger is not None:
  34. self.logger.info(self.log_msg.format(id="图片搜索查重",
  35. type="{}图片搜索查重post".format(topic_num),
  36. message=retrieve_text))
  37. img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40)
  38. img_res = requests.post(post_url, json=img_dict, timeout=30).json()
  39. # 日志采集
  40. if self.logger is not None:
  41. self.logger.info(self.log_msg.format(id="图片搜索查重",
  42. type="{}图片搜索查重success".format(topic_num),
  43. message=img_res))
  44. return img_res
  45. except Exception as e:
  46. # 日志采集
  47. if self.logger is not None:
  48. self.logger.error(self.log_msg.format(id="图片搜索查重",
  49. type="{}图片搜索查重error".format(topic_num),
  50. message=retrieve_text))
  51. return []
  52. # 公式搜索查重功能
  53. def formula_retrieve(self, formula_string, similar):
  54. # 调用清洗分词函数
  55. formula_string = '$' + formula_string + '$'
  56. formula_string = self.dpp.content_clear_func(formula_string)
  57. # 公式识别
  58. formula_list = formula_recognize(formula_string)
  59. if len(formula_list) == 0:
  60. return []
  61. # 日志采集
  62. if self.logger is not None:
  63. self.logger.info(self.log_msg.format(id="formula_retrieve",
  64. type=formula_string,
  65. message=formula_list))
  66. try:
  67. # 使用词袋模型计算句向量
  68. bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32")
  69. # 并行计算余弦相似度
  70. formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector))
  71. # 获取余弦值大于等于0.8的数据索引
  72. cos_list = np.where(formula_cos[0] >= similar)[0]
  73. except:
  74. return []
  75. if len(cos_list) == 0:
  76. return []
  77. # 根据阈值获取满足条件的题目id
  78. res_list = []
  79. # formula_threshold = similar
  80. formula_threshold = 0.7
  81. for idx in cos_list:
  82. fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100
  83. if fuzz_score >= formula_threshold:
  84. # res_list.append([self.formula_id_list[idx][1], fuzz_score])
  85. # 对余弦相似度进行折算
  86. cosine_score = formula_cos[0][idx]
  87. if 0.95 <= cosine_score < 0.98:
  88. cosine_score = cosine_score * 0.98
  89. elif cosine_score < 0.95:
  90. cosine_score = cosine_score * 0.95
  91. # 余弦相似度折算后阈值判断
  92. if cosine_score < similar:
  93. continue
  94. res_list.append([self.formula_id_list[idx][1], int(cosine_score * 100) / 100])
  95. # 根据分数对题目id排序并返回前50个
  96. res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80]
  97. formula_res_list = []
  98. fid_set = set()
  99. for ele in res_sort_list:
  100. for fid in ele[0]:
  101. if fid in fid_set:
  102. continue
  103. fid_set.add(fid)
  104. formula_res_list.append([fid, ele[1]])
  105. return formula_res_list[:50]
  106. # # HNSW查(支持多学科混合查重)
  107. # def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
  108. # # 计算retrieve_list的vec值
  109. # # 调用清洗分词函数和句向量计算函数
  110. # sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
  111. # # HNSW查重
  112. # def dup_search(retrieve_data, sent_vec, cont_clear):
  113. # # 初始化返回数据类型
  114. # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
  115. # # 获取题目序号
  116. # topic_num = retrieve_data["topic_num"] if "topic_num" in retrieve_data else 1
  117. # # 图片搜索查重功能
  118. # if doc_flag is True:
  119. # retrieve_value_dict["image"] = self.img_retrieve(retrieve_data["stem"], post_url, similar, topic_num)
  120. # else:
  121. # retrieve_value_dict["image"] = []
  122. # # 判断句向量维度
  123. # if sent_vec.size != self.vector_dim:
  124. # return retrieve_value_dict
  125. # # 调用hnsw接口检索数据
  126. # post_list = sent_vec.tolist()
  127. # try:
  128. # query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
  129. # except Exception as e:
  130. # query_labels = []
  131. # # 日志采集
  132. # if self.logger is not None:
  133. # self.logger.error(self.log_msg.format(id="HNSW检索error",
  134. # type="当前题目HNSW检索error",
  135. # message=cont_clear))
  136. # if len(query_labels) == 0:
  137. # return retrieve_value_dict
  138. # # 批量读取数据库
  139. # mongo_find_dict = {"id": {"$in": query_labels}}
  140. # query_dataset = self.mongo_coll.find(mongo_find_dict)
  141. # # 返回大于阈值的结果
  142. # filter_threshold = similar
  143. # for label_data in query_dataset:
  144. # if "sentence_vec" not in label_data:
  145. # continue
  146. # # 计算余弦相似度得分
  147. # label_vec = pickle.loads(label_data["sentence_vec"])
  148. # if label_vec.size != self.vector_dim:
  149. # continue
  150. # cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
  151. # # 阈值判断
  152. # if cosine_score < filter_threshold:
  153. # continue
  154. # # 计算编辑距离得分
  155. # fuzz_score = fuzz.ratio(cont_clear, label_data["content_clear"]) / 100
  156. # if fuzz_score < min_threshold:
  157. # continue
  158. # # 对余弦相似度进行折算
  159. # if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
  160. # cosine_score = cosine_score * 0.95
  161. # elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
  162. # cosine_score = cosine_score * 0.94
  163. # # 余弦相似度折算后阈值判断
  164. # if cosine_score < filter_threshold:
  165. # continue
  166. # retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
  167. # retrieve_value_dict["semantics"].append(retrieve_value)
  168. # # 进行编辑距离得分验证,若小于设定分则过滤
  169. # if fuzz_score >= filter_threshold:
  170. # retrieve_value = [label_data["id"], fuzz_score]
  171. # retrieve_value_dict["text"].append(retrieve_value)
  172. # # 将组合结果按照score降序排序并取得分前十个结果
  173. # retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
  174. # for k,value in retrieve_value_dict.items()}
  175. # # 综合排序
  176. # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
  177. # synthese_set = set()
  178. # for ele in synthese_list:
  179. # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
  180. # synthese_set.add(ele[0])
  181. # retrieve_sort_dict["synthese"].append(ele)
  182. # # 加入题目序号
  183. # retrieve_sort_dict["topic_num"] = topic_num
  184. # # 以字典形式返回最终查重结果
  185. # return retrieve_sort_dict
  186. # # 多线程HNSW查重
  187. # with ThreadPoolExecutor(max_workers=5) as executor:
  188. # retrieve_res_list = list(executor.map(dup_search, retrieve_list, sent_vec_list, cont_clear_list))
  189. # return retrieve_res_list
  190. # HNSW查(支持多学科混合查重)
  191. def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
  192. # 计算retrieve_list的vec值
  193. # 调用清洗分词函数和句向量计算函数
  194. sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
  195. # HNSW查重
  196. retrieve_res_list = []
  197. for i,sent_vec in enumerate(sent_vec_list):
  198. # 初始化返回数据类型
  199. retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
  200. # 获取题目序号
  201. topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
  202. # 图片搜索查重功能
  203. if doc_flag is True:
  204. retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
  205. else:
  206. retrieve_value_dict["image"] = []
  207. # 判断句向量维度
  208. if sent_vec.size != self.vector_dim:
  209. retrieve_res_list.append(retrieve_value_dict)
  210. continue
  211. # 调用hnsw接口检索数据
  212. post_list = sent_vec.tolist()
  213. try:
  214. query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
  215. except Exception as e:
  216. query_labels = []
  217. # 日志采集
  218. if self.logger is not None:
  219. self.logger.error(self.log_msg.format(id="HNSW检索error",
  220. type="当前题目HNSW检索error",
  221. message=cont_clear_list[i]))
  222. if len(query_labels) == 0:
  223. retrieve_res_list.append(retrieve_value_dict)
  224. continue
  225. # 批量读取数据库
  226. mongo_find_dict = {"id": {"$in": query_labels}}
  227. query_dataset = self.mongo_coll.find(mongo_find_dict)
  228. # 返回大于阈值的结果
  229. filter_threshold = similar
  230. for label_data in query_dataset:
  231. if "sentence_vec" not in label_data:
  232. continue
  233. # 计算余弦相似度得分
  234. label_vec = pickle.loads(label_data["sentence_vec"])
  235. if label_vec.size != self.vector_dim:
  236. continue
  237. cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
  238. # 阈值判断
  239. if cosine_score < filter_threshold:
  240. continue
  241. # 计算编辑距离得分
  242. fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
  243. if fuzz_score < min_threshold:
  244. continue
  245. # 对余弦相似度进行折算
  246. if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
  247. cosine_score = cosine_score * 0.95
  248. elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
  249. cosine_score = cosine_score * 0.94
  250. # 余弦相似度折算后阈值判断
  251. if cosine_score < filter_threshold:
  252. continue
  253. retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
  254. retrieve_value_dict["semantics"].append(retrieve_value)
  255. # 进行编辑距离得分验证,若小于设定分则过滤
  256. if fuzz_score >= filter_threshold:
  257. retrieve_value = [label_data["id"], fuzz_score]
  258. retrieve_value_dict["text"].append(retrieve_value)
  259. # 将组合结果按照score降序排序并取得分前十个结果
  260. retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
  261. for k,value in retrieve_value_dict.items()}
  262. # 综合排序
  263. synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
  264. synthese_set = set()
  265. for ele in synthese_list:
  266. if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
  267. synthese_set.add(ele[0])
  268. retrieve_sort_dict["synthese"].append(ele)
  269. # 加入题目序号
  270. retrieve_sort_dict["topic_num"] = topic_num
  271. # 以字典形式返回最终查重结果
  272. retrieve_res_list.append(retrieve_sort_dict)
  273. return retrieve_res_list
  274. if __name__ == "__main__":
  275. # 获取mongodb数据
  276. mongo_coll = config.mongo_coll
  277. from data_preprocessing import DataPreProcessing
  278. hnsw = HNSW(DataPreProcessing())
  279. test_data = []
  280. for idx in [201511100736265]:
  281. test_data.append(mongo_coll.find_one({"id": idx}))
  282. res = hnsw.retrieve(test_data, '', 0.8, False)
  283. pprint(res[0]["semantics"])
  284. # # 公式搜索查重功能
  285. # formula_string = "ρ蜡=0.9*10^3Kg/m^3"
  286. # formula_string = "p蜡=0.9*10^3Kq/m^3"
  287. # print(hnsw.formula_retrieve(formula_string, 0.8))