import json import pickle import requests import numpy as np from fuzzywuzzy import fuzz from sentence_transformers import util from pprint import pprint import config from formula_process import formula_recognize from comprehensive_score import Comprehensive_Score from physical_quantity_extract import physical_quantity_extract class HNSW(): def __init__(self, data_process, logger=None): # 配置初始数据 self.mongo_coll = config.mongo_coll self.vector_dim = config.vector_dim self.hnsw_retrieve_url = config.hnsw_retrieve_url self.dim_classify_url = config.dim_classify_url # 日志采集 self.logger = logger self.log_msg = config.log_msg # 数据预处理实例化 self.dpp = data_process # 语义相似度实例化 self.cph_score = Comprehensive_Score(config.dev_mode) # 难度数值化定义 self.difficulty_transfer = {"容易": 0.2, "较易": 0.4, "一般": 0.6, "较难": 0.8, "困难": 1.0} # 加载公式处理数据模型(词袋模型/原始向量/原始数据) with open(config.bow_model_path, "rb") as bm: self.bow_model = pickle.load(bm) self.bow_vector = np.load(config.bow_vector_path) with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f: self.formula_id_list = json.load(f) # 图片搜索查重功能 def img_retrieve(self, retrieve_text, post_url, similar, topic_num): try: if post_url is not None: # 日志采集 if self.logger is not None: self.logger.info(self.log_msg.format(id="图片搜索查重", type="{}图片搜索查重post".format(topic_num), message=retrieve_text)) img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40) img_res = requests.post(post_url, json=img_dict, timeout=30).json() # 日志采集 if self.logger is not None: self.logger.info(self.log_msg.format(id="图片搜索查重", type="{}图片搜索查重success".format(topic_num), message=img_res)) return img_res except Exception as e: # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="图片搜索查重", type="{}图片搜索查重error".format(topic_num), message=retrieve_text)) return [] # 公式搜索查重功能 def formula_retrieve(self, formula_string, similar): # 调用清洗分词函数 formula_string = '$' + formula_string + '$' formula_string = self.dpp.content_clear_func(formula_string) # 公式识别 formula_list = formula_recognize(formula_string) if len(formula_list) == 0: return [] # 日志采集 if self.logger is not None: self.logger.info(self.log_msg.format(id="formula_retrieve", type=formula_string, message=formula_list)) try: # 使用词袋模型计算句向量 bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32") # 并行计算余弦相似度 formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector)) # 获取余弦值大于等于0.8的数据索引 cos_list = np.where(formula_cos[0] >= similar)[0] except: return [] if len(cos_list) == 0: return [] # 根据阈值获取满足条件的题目id res_list = [] # formula_threshold = similar formula_threshold = 0.7 for idx in cos_list: fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100 if fuzz_score >= formula_threshold: # res_list.append([self.formula_id_list[idx][1], fuzz_score]) # 对余弦相似度进行折算 cosine_score = formula_cos[0][idx] if 0.95 <= cosine_score < 0.98: cosine_score = cosine_score * 0.98 elif cosine_score < 0.95: cosine_score = cosine_score * 0.95 # 余弦相似度折算后阈值判断 if cosine_score < similar: continue res_list.append([self.formula_id_list[idx][1], int(cosine_score * 100) / 100]) # 根据分数对题目id排序并返回前50个 res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80] formula_res_list = [] fid_set = set() for ele in res_sort_list: for fid in ele[0]: if fid in fid_set: continue fid_set.add(fid) formula_res_list.append([fid, ele[1]]) return formula_res_list[:50] # HNSW查(支持多学科混合查重) def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag): """ return: [ { 'semantics': [[20232015, 1.0, {'quesType': 1.0, 'knowledge': 1.0, 'physical_scene': 1.0, 'solving_type': 1.0, 'difficulty': 1.0, 'physical_quantity': 1.0}]], 'text': [[20232015, 0.97]], 'image': [], 'label': {'knowledge': ['串并联电路的辨别'], 'physical_scene': ['串并联电路的辨别'], 'solving_type': ['规律理解'], 'difficulty': 0.6, 'physical_quantity': ['电流']}, 'topic_num': 1 }, ... ] """ # 计算retrieve_list的vec值 # 调用清洗分词函数和句向量计算函数 sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True) # HNSW查重 retrieve_res_list = [] for i,sent_vec in enumerate(sent_vec_list): # 初始化返回数据类型 # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[]) retrieve_value_dict = dict(semantics=[], text=[], image=[]) # 获取题目序号 topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1 # 图片搜索查重功能 if doc_flag is True: retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num) else: retrieve_value_dict["image"] = [] """ 文本相似度特殊处理 """ # 判断句向量维度 if sent_vec.size != self.vector_dim: retrieve_res_list.append(retrieve_value_dict) continue # 调用hnsw接口检索数据 try: hnsw_post_list = sent_vec.tolist() query_labels = requests.post(self.hnsw_retrieve_url, json=hnsw_post_list, timeout=10).json() except Exception as e: query_labels = [] # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="HNSW检索error", type="当前题目HNSW检索error", message=cont_clear_list[i])) if len(query_labels) == 0: retrieve_res_list.append(retrieve_value_dict) continue # 批量读取数据库 mongo_find_dict = {"id": {"$in": query_labels}} query_dataset = self.mongo_coll.find(mongo_find_dict) ####################################### 语义相似度借靠 ####################################### query_data = dict() ####################################### 语义相似度借靠 ####################################### # 返回大于阈值的结果 for label_data in query_dataset: if "sentence_vec" not in label_data: continue # 计算余弦相似度得分 label_vec = pickle.loads(label_data["sentence_vec"]) if label_vec.size != self.vector_dim: continue cosine_score = util.cos_sim(sent_vec, label_vec)[0][0] # 阈值判断 if cosine_score < similar: continue # 计算编辑距离得分 fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100 # 进行编辑距离得分验证,若小于设定分则过滤 if fuzz_score >= similar: retrieve_value = [label_data["id"], fuzz_score] retrieve_value_dict["text"].append(retrieve_value) ####################################### 语义相似度借靠 ####################################### max_mark_score = 0 if cosine_score >= 0.9 and cosine_score > max_mark_score: max_mark_score = cosine_score query_data = label_data ####################################### 语义相似度借靠 ####################################### """ 语义相似度特殊处理 """ # 标签字典初始化 label_dict = dict() # 知识点LLM标注 # label_dict["knowledge"] = query_data["knowledge"] if query_data else [] label_dict["knowledge"] = query_data["knowledge"] if query_data else [] tagging_id_list = [self.cph_score.knowledge2id[ele] for ele in label_dict["knowledge"] \ if ele in self.cph_score.knowledge2id] # 题型数据获取 label_dict["quesType"] = retrieve_list[i].get("quesType", "选择题") # 多维分类api调用 try: dim_post_list = {"sentence": cont_clear_list[i], "quesType": label_dict["quesType"]} dim_classify_dict = requests.post(self.dim_classify_url, json=dim_post_list, timeout=10).json() except Exception as e: dim_classify_dict = {"solving_type": ["规律理解"], "difficulty": 0.6} # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="多维分类error", type="当前题目多维分类error", message=cont_clear_list[i])) # 求解类型模型分类 label_dict["solving_type"] = dim_classify_dict["solving_type"] # 难度模型分类 label_dict["difficulty"] = dim_classify_dict["difficulty"] # 物理量规则提取 label_dict["physical_quantity"] = physical_quantity_extract(cont_clear_list[i]) # LLM标注知识点题目获取题库对应相似知识点题目数据 knowledge_id_list = [] if len(tagging_id_list) > 0: ####################################### encode_base_value设置 ####################################### # 考试院: 10000, 风向标: 10 encode_base_value = 10000 if config.dev_mode == "ksy" else 10 ####################################### encode_base_value设置 ####################################### for ele in tagging_id_list: init_id = int(ele / encode_base_value) * encode_base_value init_id_list = self.cph_score.init_id2max_id.get(str(init_id), []) knowledge_id_list.extend(init_id_list) knowledge_query_dataset = None if len(knowledge_id_list) > 0: mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}} knowledge_query_dataset = self.mongo_coll.find(mongo_find_dict) # 返回大于阈值的结果 if knowledge_query_dataset: for refer_data in knowledge_query_dataset: # 难度数值转换 if refer_data["difficulty"] in self.difficulty_transfer: refer_data["difficulty"] = self.difficulty_transfer[refer_data["difficulty"]] sum_score, score_dict = self.cph_score(label_dict, refer_data, scale) if sum_score < similar: continue retrieve_value = [refer_data["id"], sum_score, score_dict] retrieve_value_dict["semantics"].append(retrieve_value) # 将组合结果按照score降序排序 retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True) for k,value in retrieve_value_dict.items()} # # 综合排序 # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True) # synthese_set = set() # for ele in synthese_list: # # 综合排序返回前50个 # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50: # synthese_set.add(ele[0]) # retrieve_sort_dict["synthese"].append(ele[:2]) # 加入题目序号 retrieve_sort_dict["label"] = label_dict retrieve_sort_dict["topic_num"] = topic_num # 以字典形式返回最终查重结果 retrieve_res_list.append(retrieve_sort_dict) return retrieve_res_list if __name__ == "__main__": # 获取mongodb数据 mongo_coll = config.mongo_coll from data_preprocessing import DataPreProcessing hnsw = HNSW(DataPreProcessing()) test_data = [] for idx in [201511100736265]: test_data.append(mongo_coll.find_one({"id": idx})) res = hnsw.retrieve(test_data, '', 0.8, False) pprint(res[0]["semantics"]) # # 公式搜索查重功能 # formula_string = "ρ蜡=0.9*10^3Kg/m^3" # formula_string = "p蜡=0.9*10^3Kq/m^3" # print(hnsw.formula_retrieve(formula_string, 0.8))