hnsw_retrieval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import json
  2. import pickle
  3. import requests
  4. import numpy as np
  5. from fuzzywuzzy import fuzz
  6. from sentence_transformers import util
  7. from concurrent.futures import ThreadPoolExecutor
  8. from pprint import pprint
  9. import config
  10. from formula_process import formula_recognize
  11. from comprehensive_score import Comprehensive_Score
  12. class HNSW():
  13. def __init__(self, data_process, logger=None):
  14. # 配置初始数据
  15. self.mongo_coll = config.mongo_coll
  16. self.vector_dim = config.vector_dim
  17. self.hnsw_retrieve_url = config.hnsw_retrieve_url
  18. # 日志采集
  19. self.logger = logger
  20. self.log_msg = config.log_msg
  21. # 数据预处理实例化
  22. self.dpp = data_process
  23. # 语义相似度实例化
  24. self.cph_score = Comprehensive_Score()
  25. # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
  26. with open(config.bow_model_path, "rb") as bm:
  27. self.bow_model = pickle.load(bm)
  28. self.bow_vector = np.load(config.bow_vector_path)
  29. with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f:
  30. self.formula_id_list = json.load(f)
  31. # 图片搜索查重功能
  32. def img_retrieve(self, retrieve_text, post_url, similar, topic_num):
  33. try:
  34. if post_url is not None:
  35. # 日志采集
  36. if self.logger is not None:
  37. self.logger.info(self.log_msg.format(id="图片搜索查重",
  38. type="{}图片搜索查重post".format(topic_num),
  39. message=retrieve_text))
  40. img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40)
  41. img_res = requests.post(post_url, json=img_dict, timeout=30).json()
  42. # 日志采集
  43. if self.logger is not None:
  44. self.logger.info(self.log_msg.format(id="图片搜索查重",
  45. type="{}图片搜索查重success".format(topic_num),
  46. message=img_res))
  47. return img_res
  48. except Exception as e:
  49. # 日志采集
  50. if self.logger is not None:
  51. self.logger.error(self.log_msg.format(id="图片搜索查重",
  52. type="{}图片搜索查重error".format(topic_num),
  53. message=retrieve_text))
  54. return []
  55. # 公式搜索查重功能
  56. def formula_retrieve(self, formula_string, similar):
  57. # 调用清洗分词函数
  58. formula_string = '$' + formula_string + '$'
  59. formula_string = self.dpp.content_clear_func(formula_string)
  60. # 公式识别
  61. formula_list = formula_recognize(formula_string)
  62. if len(formula_list) == 0:
  63. return []
  64. # 日志采集
  65. if self.logger is not None:
  66. self.logger.info(self.log_msg.format(id="formula_retrieve",
  67. type=formula_string,
  68. message=formula_list))
  69. try:
  70. # 使用词袋模型计算句向量
  71. bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32")
  72. # 并行计算余弦相似度
  73. formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector))
  74. # 获取余弦值大于等于0.8的数据索引
  75. cos_list = np.where(formula_cos[0] >= similar)[0]
  76. except:
  77. return []
  78. if len(cos_list) == 0:
  79. return []
  80. # 根据阈值获取满足条件的题目id
  81. res_list = []
  82. # formula_threshold = similar
  83. formula_threshold = 0.7
  84. for idx in cos_list:
  85. fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100
  86. if fuzz_score >= formula_threshold:
  87. # res_list.append([self.formula_id_list[idx][1], fuzz_score])
  88. # 对余弦相似度进行折算
  89. cosine_score = formula_cos[0][idx]
  90. if 0.95 <= cosine_score < 0.98:
  91. cosine_score = cosine_score * 0.98
  92. elif cosine_score < 0.95:
  93. cosine_score = cosine_score * 0.95
  94. # 余弦相似度折算后阈值判断
  95. if cosine_score < similar:
  96. continue
  97. res_list.append([self.formula_id_list[idx][1], int(cosine_score * 100) / 100])
  98. # 根据分数对题目id排序并返回前50个
  99. res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80]
  100. formula_res_list = []
  101. fid_set = set()
  102. for ele in res_sort_list:
  103. for fid in ele[0]:
  104. if fid in fid_set:
  105. continue
  106. fid_set.add(fid)
  107. formula_res_list.append([fid, ele[1]])
  108. return formula_res_list[:50]
  109. # # HNSW查(支持多学科混合查重)
  110. # def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
  111. # # 计算retrieve_list的vec值
  112. # # 调用清洗分词函数和句向量计算函数
  113. # sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
  114. # # HNSW查重
  115. # def dup_search(retrieve_data, sent_vec, cont_clear):
  116. # # 初始化返回数据类型
  117. # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
  118. # # 获取题目序号
  119. # topic_num = retrieve_data["topic_num"] if "topic_num" in retrieve_data else 1
  120. # # 图片搜索查重功能
  121. # if doc_flag is True:
  122. # retrieve_value_dict["image"] = self.img_retrieve(retrieve_data["stem"], post_url, similar, topic_num)
  123. # else:
  124. # retrieve_value_dict["image"] = []
  125. # # 判断句向量维度
  126. # if sent_vec.size != self.vector_dim:
  127. # return retrieve_value_dict
  128. # # 调用hnsw接口检索数据
  129. # post_list = sent_vec.tolist()
  130. # try:
  131. # query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
  132. # except Exception as e:
  133. # query_labels = []
  134. # # 日志采集
  135. # if self.logger is not None:
  136. # self.logger.error(self.log_msg.format(id="HNSW检索error",
  137. # type="当前题目HNSW检索error",
  138. # message=cont_clear))
  139. # if len(query_labels) == 0:
  140. # return retrieve_value_dict
  141. # # 批量读取数据库
  142. # mongo_find_dict = {"id": {"$in": query_labels}}
  143. # query_dataset = self.mongo_coll.find(mongo_find_dict)
  144. # # 返回大于阈值的结果
  145. # filter_threshold = similar
  146. # for label_data in query_dataset:
  147. # if "sentence_vec" not in label_data:
  148. # continue
  149. # # 计算余弦相似度得分
  150. # label_vec = pickle.loads(label_data["sentence_vec"])
  151. # if label_vec.size != self.vector_dim:
  152. # continue
  153. # cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
  154. # # 阈值判断
  155. # if cosine_score < filter_threshold:
  156. # continue
  157. # # 计算编辑距离得分
  158. # fuzz_score = fuzz.ratio(cont_clear, label_data["content_clear"]) / 100
  159. # if fuzz_score < min_threshold:
  160. # continue
  161. # # 对余弦相似度进行折算
  162. # if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
  163. # cosine_score = cosine_score * 0.95
  164. # elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
  165. # cosine_score = cosine_score * 0.94
  166. # # 余弦相似度折算后阈值判断
  167. # if cosine_score < filter_threshold:
  168. # continue
  169. # retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
  170. # retrieve_value_dict["semantics"].append(retrieve_value)
  171. # # 进行编辑距离得分验证,若小于设定分则过滤
  172. # if fuzz_score >= filter_threshold:
  173. # retrieve_value = [label_data["id"], fuzz_score]
  174. # retrieve_value_dict["text"].append(retrieve_value)
  175. # # 将组合结果按照score降序排序并取得分前十个结果
  176. # retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
  177. # for k,value in retrieve_value_dict.items()}
  178. # # 综合排序
  179. # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
  180. # synthese_set = set()
  181. # for ele in synthese_list:
  182. # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
  183. # synthese_set.add(ele[0])
  184. # retrieve_sort_dict["synthese"].append(ele)
  185. # # 加入题目序号
  186. # retrieve_sort_dict["topic_num"] = topic_num
  187. # # 以字典形式返回最终查重结果
  188. # return retrieve_sort_dict
  189. # # 多线程HNSW查重
  190. # with ThreadPoolExecutor(max_workers=5) as executor:
  191. # retrieve_res_list = list(executor.map(dup_search, retrieve_list, sent_vec_list, cont_clear_list))
  192. # return retrieve_res_list
  193. # HNSW查(支持多学科混合查重)
  194. def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag):
  195. # 计算retrieve_list的vec值
  196. # 调用清洗分词函数和句向量计算函数
  197. sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
  198. # HNSW查重
  199. retrieve_res_list = []
  200. for i,sent_vec in enumerate(sent_vec_list):
  201. # 初始化返回数据类型
  202. # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
  203. retrieve_value_dict = dict(semantics=[], text=[], image=[])
  204. # 获取题目序号
  205. topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
  206. # 图片搜索查重功能
  207. if doc_flag is True:
  208. retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
  209. else:
  210. retrieve_value_dict["image"] = []
  211. """
  212. 文本相似度特殊处理
  213. """
  214. # 判断句向量维度
  215. if sent_vec.size != self.vector_dim:
  216. retrieve_res_list.append(retrieve_value_dict)
  217. continue
  218. # 调用hnsw接口检索数据
  219. post_list = sent_vec.tolist()
  220. try:
  221. query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
  222. except Exception as e:
  223. query_labels = []
  224. # 日志采集
  225. if self.logger is not None:
  226. self.logger.error(self.log_msg.format(id="HNSW检索error",
  227. type="当前题目HNSW检索error",
  228. message=cont_clear_list[i]))
  229. if len(query_labels) == 0:
  230. retrieve_res_list.append(retrieve_value_dict)
  231. continue
  232. # 批量读取数据库
  233. mongo_find_dict = {"id": {"$in": query_labels}}
  234. query_dataset = self.mongo_coll.find(mongo_find_dict)
  235. ####################################### 语义相似度借靠 #######################################
  236. query_data = dict()
  237. ####################################### 语义相似度借靠 #######################################
  238. # 返回大于阈值的结果
  239. for label_data in query_dataset:
  240. if "sentence_vec" not in label_data:
  241. continue
  242. # 计算余弦相似度得分
  243. label_vec = pickle.loads(label_data["sentence_vec"])
  244. if label_vec.size != self.vector_dim:
  245. continue
  246. cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
  247. # 阈值判断
  248. if cosine_score < similar:
  249. continue
  250. # 计算编辑距离得分
  251. fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
  252. # 进行编辑距离得分验证,若小于设定分则过滤
  253. if fuzz_score >= similar:
  254. retrieve_value = [label_data["id"], fuzz_score]
  255. retrieve_value_dict["text"].append(retrieve_value)
  256. ####################################### 语义相似度借靠 #######################################
  257. max_mark_score = 0
  258. if cosine_score >= 0.9 and cosine_score > max_mark_score:
  259. max_mark_score = cosine_score
  260. query_data = label_data
  261. ####################################### 语义相似度借靠 #######################################
  262. """
  263. 语义相似度特殊处理
  264. """
  265. # 批量读取数据库
  266. knowledge_id_list = query_data["knowledge_id"] if query_data else []
  267. label_dict = dict()
  268. # label_dict["quesType"] = retrieve_list[i]["quesType"] if query_data else []
  269. label_dict["knowledge"] = query_data["knowledge"] if query_data else []
  270. label_dict["physical_scene"] = query_data["physical_scene"] if query_data else []
  271. label_dict["solving_type"] = query_data["solving_type"] if query_data else []
  272. label_dict["difficulty"] = float(query_data["difficulty"]) if query_data else 0
  273. label_dict["physical_quantity"] = query_data["physical_quantity"] if query_data else []
  274. # label_dict["image_semantics"] = query_data["image_semantics"] if query_data else []
  275. query_data["quesType"] = retrieve_list[i].get("quesType", '')
  276. if len(knowledge_id_list) > 0:
  277. relate_list = []
  278. for ele in knowledge_id_list:
  279. init_id = int(ele / 10) * 10
  280. last_id = self.cph_score.init_id2max_id[str(init_id)]
  281. relate_list.extend(np.arange(init_id + 1, last_id + 1).tolist())
  282. knowledge_id_list = relate_list
  283. mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
  284. query_dataset = self.mongo_coll.find(mongo_find_dict)
  285. # 返回大于阈值的结果
  286. for refer_data in query_dataset:
  287. sum_score, score_dict = self.cph_score(query_data, refer_data, scale)
  288. if sum_score < similar:
  289. continue
  290. retrieve_value = [refer_data["id"], sum_score, score_dict]
  291. retrieve_value_dict["semantics"].append(retrieve_value)
  292. # 将组合结果按照score降序排序
  293. retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
  294. for k,value in retrieve_value_dict.items()}
  295. # # 综合排序
  296. # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
  297. # synthese_set = set()
  298. # for ele in synthese_list:
  299. # # 综合排序返回前50个
  300. # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
  301. # synthese_set.add(ele[0])
  302. # retrieve_sort_dict["synthese"].append(ele[:2])
  303. # 加入题目序号
  304. retrieve_sort_dict["label"] = label_dict
  305. retrieve_sort_dict["topic_num"] = topic_num
  306. # 以字典形式返回最终查重结果
  307. retrieve_res_list.append(retrieve_sort_dict)
  308. return retrieve_res_list
  309. if __name__ == "__main__":
  310. # 获取mongodb数据
  311. mongo_coll = config.mongo_coll
  312. from data_preprocessing import DataPreProcessing
  313. hnsw = HNSW(DataPreProcessing())
  314. test_data = []
  315. for idx in [201511100736265]:
  316. test_data.append(mongo_coll.find_one({"id": idx}))
  317. res = hnsw.retrieve(test_data, '', 0.8, False)
  318. pprint(res[0]["semantics"])
  319. # # 公式搜索查重功能
  320. # formula_string = "ρ蜡=0.9*10^3Kg/m^3"
  321. # formula_string = "p蜡=0.9*10^3Kq/m^3"
  322. # print(hnsw.formula_retrieve(formula_string, 0.8))