hnsw_retrieval.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import json
  2. import pickle
  3. import requests
  4. import numpy as np
  5. from fuzzywuzzy import fuzz
  6. from sentence_transformers import util
  7. from pprint import pprint
  8. import config
  9. from formula_process import formula_recognize
  10. from comprehensive_score import Comprehensive_Score
  11. from physical_quantity_extract import physical_quantity_extract
  12. class HNSW():
  13. def __init__(self, data_process, logger=None):
  14. # 配置初始数据
  15. self.mongo_coll = config.mongo_coll
  16. self.vector_dim = config.vector_dim
  17. self.hnsw_retrieve_url = config.hnsw_retrieve_url
  18. self.dim_classify_url = config.dim_classify_url
  19. # 日志采集
  20. self.logger = logger
  21. self.log_msg = config.log_msg
  22. # 数据预处理实例化
  23. self.dpp = data_process
  24. # 语义相似度实例化
  25. self.cph_score = Comprehensive_Score(config.dev_mode)
  26. # 难度数值化定义
  27. self.difficulty_transfer = {"容易": 0.2, "较易": 0.4, "一般": 0.6, "较难": 0.8, "困难": 1.0}
  28. # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
  29. with open(config.bow_model_path, "rb") as bm:
  30. self.bow_model = pickle.load(bm)
  31. self.bow_vector = np.load(config.bow_vector_path)
  32. with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f:
  33. self.formula_id_list = json.load(f)
  34. # 图片搜索查重功能
  35. def img_retrieve(self, retrieve_text, post_url, similar, topic_num):
  36. try:
  37. if post_url is not None:
  38. # 日志采集
  39. if self.logger is not None:
  40. self.logger.info(self.log_msg.format(id="图片搜索查重",
  41. type="{}图片搜索查重post".format(topic_num),
  42. message=retrieve_text))
  43. img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40)
  44. img_res = requests.post(post_url, json=img_dict, timeout=30).json()
  45. # 日志采集
  46. if self.logger is not None:
  47. self.logger.info(self.log_msg.format(id="图片搜索查重",
  48. type="{}图片搜索查重success".format(topic_num),
  49. message=img_res))
  50. return img_res
  51. except Exception as e:
  52. # 日志采集
  53. if self.logger is not None:
  54. self.logger.error(self.log_msg.format(id="图片搜索查重",
  55. type="{}图片搜索查重error".format(topic_num),
  56. message=retrieve_text))
  57. return []
  58. # 公式搜索查重功能
  59. def formula_retrieve(self, formula_string, similar):
  60. # 调用清洗分词函数
  61. formula_string = '$' + formula_string + '$'
  62. formula_string = self.dpp.content_clear_func(formula_string)
  63. # 公式识别
  64. formula_list = formula_recognize(formula_string)
  65. if len(formula_list) == 0:
  66. return []
  67. # 日志采集
  68. if self.logger is not None:
  69. self.logger.info(self.log_msg.format(id="formula_retrieve",
  70. type=formula_string,
  71. message=formula_list))
  72. try:
  73. # 使用词袋模型计算句向量
  74. bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32")
  75. # 并行计算余弦相似度
  76. formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector))
  77. # 获取余弦值大于等于0.8的数据索引
  78. cos_list = np.where(formula_cos[0] >= similar)[0]
  79. except:
  80. return []
  81. if len(cos_list) == 0:
  82. return []
  83. # 根据阈值获取满足条件的题目id
  84. res_list = []
  85. # formula_threshold = similar
  86. formula_threshold = 0.7
  87. for idx in cos_list:
  88. fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100
  89. if fuzz_score >= formula_threshold:
  90. # res_list.append([self.formula_id_list[idx][1], fuzz_score])
  91. # 对余弦相似度进行折算
  92. cosine_score = formula_cos[0][idx]
  93. if 0.95 <= cosine_score < 0.98:
  94. cosine_score = cosine_score * 0.98
  95. elif cosine_score < 0.95:
  96. cosine_score = cosine_score * 0.95
  97. # 余弦相似度折算后阈值判断
  98. if cosine_score < similar:
  99. continue
  100. res_list.append([self.formula_id_list[idx][1], int(cosine_score * 100) / 100])
  101. # 根据分数对题目id排序并返回前50个
  102. res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80]
  103. formula_res_list = []
  104. fid_set = set()
  105. for ele in res_sort_list:
  106. for fid in ele[0]:
  107. if fid in fid_set:
  108. continue
  109. fid_set.add(fid)
  110. formula_res_list.append([fid, ele[1]])
  111. return formula_res_list[:50]
  112. # HNSW查(支持多学科混合查重)
  113. def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag):
  114. """
  115. return:
  116. [
  117. {
  118. 'semantics': [[20232015, 1.0, {'quesType': 1.0, 'knowledge': 1.0, 'physical_scene': 1.0, 'solving_type': 1.0, 'difficulty': 1.0, 'physical_quantity': 1.0}]],
  119. 'text': [[20232015, 0.97]],
  120. 'image': [],
  121. 'label': {'knowledge': ['串并联电路的辨别'], 'physical_scene': ['串并联电路的辨别'], 'solving_type': ['规律理解'], 'difficulty': 0.6, 'physical_quantity': ['电流']},
  122. 'topic_num': 1
  123. },
  124. ...
  125. ]
  126. """
  127. # 计算retrieve_list的vec值
  128. # 调用清洗分词函数和句向量计算函数
  129. sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
  130. # HNSW查重
  131. retrieve_res_list = []
  132. for i,sent_vec in enumerate(sent_vec_list):
  133. # 初始化返回数据类型
  134. # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
  135. retrieve_value_dict = dict(semantics=[], text=[], image=[])
  136. # 获取题目序号
  137. topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
  138. # 图片搜索查重功能
  139. if doc_flag is True:
  140. retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
  141. else:
  142. retrieve_value_dict["image"] = []
  143. """
  144. 文本相似度特殊处理
  145. """
  146. # 判断句向量维度
  147. if sent_vec.size != self.vector_dim:
  148. retrieve_res_list.append(retrieve_value_dict)
  149. continue
  150. # 调用hnsw接口检索数据
  151. try:
  152. hnsw_post_list = sent_vec.tolist()
  153. query_labels = requests.post(self.hnsw_retrieve_url, json=hnsw_post_list, timeout=10).json()
  154. except Exception as e:
  155. query_labels = []
  156. # 日志采集
  157. if self.logger is not None:
  158. self.logger.error(self.log_msg.format(id="HNSW检索error",
  159. type="当前题目HNSW检索error",
  160. message=cont_clear_list[i]))
  161. if len(query_labels) == 0:
  162. retrieve_res_list.append(retrieve_value_dict)
  163. continue
  164. # 批量读取数据库
  165. mongo_find_dict = {"id": {"$in": query_labels}}
  166. query_dataset = self.mongo_coll.find(mongo_find_dict)
  167. ####################################### 语义相似度借靠 #######################################
  168. query_data = dict()
  169. ####################################### 语义相似度借靠 #######################################
  170. # 返回大于阈值的结果
  171. for label_data in query_dataset:
  172. if "sentence_vec" not in label_data:
  173. continue
  174. # 计算余弦相似度得分
  175. label_vec = pickle.loads(label_data["sentence_vec"])
  176. if label_vec.size != self.vector_dim:
  177. continue
  178. cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
  179. # 阈值判断
  180. if cosine_score < similar:
  181. continue
  182. # 计算编辑距离得分
  183. fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
  184. # 进行编辑距离得分验证,若小于设定分则过滤
  185. if fuzz_score >= similar:
  186. retrieve_value = [label_data["id"], fuzz_score]
  187. retrieve_value_dict["text"].append(retrieve_value)
  188. ####################################### 语义相似度借靠 #######################################
  189. max_mark_score = 0
  190. if cosine_score >= 0.9 and cosine_score > max_mark_score:
  191. max_mark_score = cosine_score
  192. query_data = label_data
  193. ####################################### 语义相似度借靠 #######################################
  194. """
  195. 语义相似度特殊处理
  196. """
  197. # 标签字典初始化
  198. label_dict = dict()
  199. # 知识点LLM标注
  200. # label_dict["knowledge"] = query_data["knowledge"] if query_data else []
  201. label_dict["knowledge"] = query_data["knowledge"] if query_data else []
  202. tagging_id_list = [self.cph_score.knowledge2id[ele] for ele in label_dict["knowledge"] \
  203. if ele in self.cph_score.knowledge2id]
  204. # 题型数据获取
  205. label_dict["quesType"] = retrieve_list[i].get("quesType", "选择题")
  206. # 多维分类api调用
  207. try:
  208. dim_post_list = {"sentence": cont_clear_list[i], "quesType": label_dict["quesType"]}
  209. dim_classify_dict = requests.post(self.dim_classify_url, json=dim_post_list, timeout=10).json()
  210. except Exception as e:
  211. dim_classify_dict = {"solving_type": ["规律理解"], "difficulty": 0.6}
  212. # 日志采集
  213. if self.logger is not None:
  214. self.logger.error(self.log_msg.format(id="多维分类error",
  215. type="当前题目多维分类error",
  216. message=cont_clear_list[i]))
  217. # 求解类型模型分类
  218. label_dict["solving_type"] = dim_classify_dict["solving_type"]
  219. # 难度模型分类
  220. label_dict["difficulty"] = dim_classify_dict["difficulty"]
  221. # 物理量规则提取
  222. label_dict["physical_quantity"] = physical_quantity_extract(cont_clear_list[i])
  223. # LLM标注知识点题目获取题库对应相似知识点题目数据
  224. knowledge_id_list = []
  225. if len(tagging_id_list) > 0:
  226. ####################################### encode_base_value设置 #######################################
  227. # 考试院: 10000, 风向标: 10
  228. encode_base_value = 10000 if config.dev_mode == "ksy" else 10
  229. ####################################### encode_base_value设置 #######################################
  230. for ele in tagging_id_list:
  231. init_id = int(ele / encode_base_value) * encode_base_value
  232. init_id_list = self.cph_score.init_id2max_id.get(str(init_id), [])
  233. knowledge_id_list.extend(init_id_list)
  234. knowledge_query_dataset = None
  235. if len(knowledge_id_list) > 0:
  236. mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
  237. knowledge_query_dataset = self.mongo_coll.find(mongo_find_dict)
  238. # 返回大于阈值的结果
  239. if knowledge_query_dataset:
  240. for refer_data in knowledge_query_dataset:
  241. # 难度数值转换
  242. if refer_data["difficulty"] in self.difficulty_transfer:
  243. refer_data["difficulty"] = self.difficulty_transfer[refer_data["difficulty"]]
  244. sum_score, score_dict = self.cph_score(label_dict, refer_data, scale)
  245. if sum_score < similar:
  246. continue
  247. retrieve_value = [refer_data["id"], sum_score, score_dict]
  248. retrieve_value_dict["semantics"].append(retrieve_value)
  249. # 将组合结果按照score降序排序
  250. retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
  251. for k,value in retrieve_value_dict.items()}
  252. # # 综合排序
  253. # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
  254. # synthese_set = set()
  255. # for ele in synthese_list:
  256. # # 综合排序返回前50个
  257. # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
  258. # synthese_set.add(ele[0])
  259. # retrieve_sort_dict["synthese"].append(ele[:2])
  260. # 加入题目序号
  261. retrieve_sort_dict["label"] = label_dict
  262. retrieve_sort_dict["topic_num"] = topic_num
  263. # 以字典形式返回最终查重结果
  264. retrieve_res_list.append(retrieve_sort_dict)
  265. return retrieve_res_list
  266. if __name__ == "__main__":
  267. # 获取mongodb数据
  268. mongo_coll = config.mongo_coll
  269. from data_preprocessing import DataPreProcessing
  270. hnsw = HNSW(DataPreProcessing())
  271. test_data = []
  272. for idx in [201511100736265]:
  273. test_data.append(mongo_coll.find_one({"id": idx}))
  274. res = hnsw.retrieve(test_data, '', 0.8, False)
  275. pprint(res[0]["semantics"])
  276. # # 公式搜索查重功能
  277. # formula_string = "ρ蜡=0.9*10^3Kg/m^3"
  278. # formula_string = "p蜡=0.9*10^3Kq/m^3"
  279. # print(hnsw.formula_retrieve(formula_string, 0.8))