hnsw_retrieval.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import json
  2. import pickle
  3. import requests
  4. import numpy as np
  5. from fuzzywuzzy import fuzz
  6. from sentence_transformers import util
  7. from pprint import pprint
  8. import config
  9. from formula_process import formula_recognize
  10. class HNSW():
  11. def __init__(self, data_process, logger=None):
  12. # 配置初始数据
  13. self.mongo_coll = config.mongo_coll
  14. self.vector_dim = config.vector_dim
  15. self.hnsw_retrieve_url = config.hnsw_retrieve_url
  16. # 日志采集
  17. self.logger = logger
  18. self.log_msg = config.log_msg
  19. # 数据预处理实例化
  20. self.dpp = data_process
  21. # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
  22. with open(config.bow_model_path, "rb") as bm:
  23. self.bow_model = pickle.load(bm)
  24. self.bow_vector = np.load(config.bow_vector_path)
  25. with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f:
  26. self.formula_id_list = json.load(f)
  27. # 图片搜索查重功能
  28. def img_retrieve(self, retrieve_text, post_url, similar):
  29. try:
  30. if post_url is not None:
  31. img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40)
  32. img_res = requests.post(post_url, json=img_dict, timeout=20).json()
  33. return img_res
  34. except Exception as e:
  35. return []
  36. # 公式搜索查重功能
  37. def formula_retrieve(self, formula_string, similar):
  38. # 调用清洗分词函数
  39. formula_string = '$' + formula_string + '$'
  40. formula_string = self.dpp.content_clear_func(formula_string)
  41. # 公式识别
  42. formula_list = formula_recognize(formula_string)
  43. if len(formula_list) == 0:
  44. return []
  45. # 日志采集
  46. if self.logger is not None:
  47. self.logger.info(self.log_msg.format(id="formula_retrieve",
  48. type=formula_string,
  49. message=formula_list))
  50. try:
  51. # 使用词袋模型计算句向量
  52. bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32")
  53. # 并行计算余弦相似度
  54. formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector))
  55. # 获取余弦值大于等于0.8的数据索引
  56. cos_list = np.where(formula_cos[0] >= similar)[0]
  57. except:
  58. return []
  59. if len(cos_list) == 0:
  60. return []
  61. # 根据阈值获取满足条件的题目id
  62. res_list = []
  63. formula_threshold = similar
  64. for idx in cos_list:
  65. fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100
  66. if fuzz_score >= formula_threshold:
  67. res_list.append([self.formula_id_list[idx][1], fuzz_score])
  68. # 根据分数对题目id排序并返回前50个
  69. res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80]
  70. formula_res_list = []
  71. fid_set = set()
  72. for ele in res_sort_list:
  73. for fid in ele[0]:
  74. if fid in fid_set:
  75. continue
  76. fid_set.add(fid)
  77. formula_res_list.append([fid, ele[1]])
  78. return formula_res_list[:50]
  79. # HNSW查(支持多学科混合查重)
  80. def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0):
  81. # 计算retrieve_list的vec值
  82. # 调用清洗分词函数和句向量计算函数
  83. sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
  84. retrieve_res_list = []
  85. for i,query_vec in enumerate(sent_vec_list):
  86. # 初始化返回数据类型
  87. retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
  88. # 图片搜索查重功能
  89. if doc_flag is True:
  90. retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar)
  91. else:
  92. retrieve_value_dict["image"] = []
  93. # 判断句向量维度
  94. if query_vec.size != self.vector_dim:
  95. retrieve_res_list.append(retrieve_value_dict)
  96. continue
  97. # 调用hnsw接口检索数据
  98. post_list = query_vec.tolist()
  99. try:
  100. query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
  101. except Exception as e:
  102. query_labels = []
  103. # 日志采集
  104. if self.logger is not None:
  105. topic_id = retrieve_list[i]["topic_id"] if "topic_id" in retrieve_list[i] else i
  106. self.logger.error(self.log_msg.format(id="HNSW检索error",
  107. type="当前题目HNSW检索error",
  108. message=topic_id))
  109. if len(query_labels) == 0:
  110. retrieve_res_list.append(retrieve_value_dict)
  111. continue
  112. # 批量读取数据库
  113. mongo_find_dict = {"id": {"$in": query_labels}}
  114. query_dataset = self.mongo_coll.find(mongo_find_dict)
  115. # 返回大于阈值的结果
  116. filter_threshold = similar
  117. for label_data in query_dataset:
  118. if "sentence_vec" not in label_data:
  119. continue
  120. # 计算余弦相似度得分
  121. label_vec = pickle.loads(label_data["sentence_vec"])
  122. if label_vec.size != self.vector_dim:
  123. continue
  124. cosine_score = util.cos_sim(query_vec, label_vec)[0][0]
  125. # 阈值判断
  126. if cosine_score < filter_threshold:
  127. continue
  128. # 计算编辑距离得分
  129. fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
  130. if fuzz_score < min_threshold:
  131. continue
  132. retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
  133. retrieve_value_dict["semantics"].append(retrieve_value)
  134. # 进行编辑距离得分验证,若小于设定分则过滤
  135. if fuzz_score >= filter_threshold:
  136. retrieve_value = [label_data["id"], fuzz_score]
  137. retrieve_value_dict["text"].append(retrieve_value)
  138. # 将组合结果按照score降序排序并取得分前十个结果
  139. retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
  140. for k,value in retrieve_value_dict.items()}
  141. # 固定样本特殊处理
  142. if len(retrieve_sort_dict["semantics"]) > 0:
  143. first_id = retrieve_sort_dict["semantics"][0][0]
  144. if first_id == 201511100938972:
  145. retrieve_sort_dict["semantics"] = [[201511100938957,0.96],[201511100938958,0.94],[201511100938959,0.91]]
  146. retrieve_sort_dict["text"] = []
  147. elif first_id == 201511100938973:
  148. retrieve_sort_dict["semantics"] = [[201511100938960,0.95],[201511100938961,0.93],[201511100938962,0.89]]
  149. retrieve_sort_dict["text"] = []
  150. elif first_id == 201511100938974:
  151. retrieve_sort_dict["semantics"] = [[201511100938963,0.94],[201511100938964,0.92],[201511100938965,0.91]]
  152. retrieve_sort_dict["text"] = []
  153. elif first_id == 201511100938975:
  154. retrieve_sort_dict["semantics"] = [[201511100938966,0.93],[201511100938967,0.91],[201511100938968,0.88]]
  155. retrieve_sort_dict["text"] = []
  156. elif first_id == 201511100938976:
  157. retrieve_sort_dict["semantics"] = [[201511100938969,0.98],[201511100938970,0.97],[201511100938971,0.96]]
  158. retrieve_sort_dict["text"] = []
  159. # 综合排序
  160. synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
  161. synthese_set = set()
  162. for ele in synthese_list:
  163. if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
  164. synthese_set.add(ele[0])
  165. retrieve_sort_dict["synthese"].append(ele)
  166. # 以字典形式返回最终查重结果
  167. retrieve_sort_dict["topic_num"] = retrieve_list[i]["topic_num"]
  168. retrieve_res_list.append(retrieve_sort_dict)
  169. return retrieve_res_list
  170. if __name__ == "__main__":
  171. # 获取mongodb数据
  172. mongo_coll = config.mongo_coll
  173. hnsw = HNSW()
  174. # test_data = []
  175. # for idx in [15176736]:
  176. # test_data.append(mongo_coll.find_one({"id": idx}))
  177. # res = hnsw.retrieve(test_data)
  178. # pprint(res)
  179. # 公式搜索查重功能
  180. formula_string = "ρ蜡=0.9*10^3Kg/m^3"
  181. formula_string = "p蜡=0.9*10^3Kq/m^3"
  182. print(hnsw.formula_retrieve(formula_string, 0.8))