123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- import json
- import pickle
- import requests
- import numpy as np
- from fuzzywuzzy import fuzz
- from sentence_transformers import util
- from pprint import pprint
- import config
- from formula_process import formula_recognize
- from comprehensive_score import Comprehensive_Score
- from physical_quantity_extract import physical_quantity_extract
- class HNSW():
- def __init__(self, data_process, logger=None):
- # 配置初始数据
- self.mongo_coll = config.mongo_coll
- self.vector_dim = config.vector_dim
- self.hnsw_retrieve_url = config.hnsw_retrieve_url
- self.dim_classify_url = config.dim_classify_url
- self.knowledge_tagging_url = config.knowledge_tagging_url
- # 日志采集
- self.logger = logger
- self.log_msg = config.log_msg
- # 数据预处理实例化
- self.dpp = data_process
- # 语义相似度实例化
- self.cph_score = Comprehensive_Score(config.dev_mode)
- # 难度数值化定义
- self.difficulty_transfer = {"容易": 0.2, "较易": 0.4, "一般": 0.6, "较难": 0.8, "困难": 1.0}
- # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
- with open(config.bow_model_path, "rb") as bm:
- self.bow_model = pickle.load(bm)
- self.bow_vector = np.load(config.bow_vector_path)
- with open(config.formula_data_path, 'r', encoding='utf8', errors='ignore') as f:
- self.formula_id_list = json.load(f)
- # 图片搜索查重功能
- def img_retrieve(self, retrieve_text, post_url, similar, topic_num):
- try:
- if post_url is not None:
- # 日志采集
- if self.logger is not None:
- self.logger.info(self.log_msg.format(id="图片搜索查重",
- type="{}图片搜索查重post".format(topic_num),
- message=retrieve_text))
- img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40)
- img_res = requests.post(post_url, json=img_dict, timeout=30).json()
- # 日志采集
- if self.logger is not None:
- self.logger.info(self.log_msg.format(id="图片搜索查重",
- type="{}图片搜索查重success".format(topic_num),
- message=img_res))
- return img_res
- except Exception as e:
- # 日志采集
- if self.logger is not None:
- self.logger.error(self.log_msg.format(id="图片搜索查重",
- type="{}图片搜索查重error".format(topic_num),
- message=retrieve_text))
- return []
-
- # 公式搜索查重功能
- def formula_retrieve(self, formula_string, similar):
- # 调用清洗分词函数
- formula_string = '$' + formula_string + '$'
- formula_string = self.dpp.content_clear_func(formula_string)
- # 公式识别
- formula_list = formula_recognize(formula_string)
- if len(formula_list) == 0:
- return []
- # 日志采集
- if self.logger is not None:
- self.logger.info(self.log_msg.format(id="formula_retrieve",
- type=formula_string,
- message=formula_list))
- try:
- # 使用词袋模型计算句向量
- bow_vec = self.bow_model.transform([formula_list[0]]).toarray().astype("float32")
- # 并行计算余弦相似度
- formula_cos = np.array(util.cos_sim(bow_vec, self.bow_vector))
- # 获取余弦值大于等于0.8的数据索引
- cos_list = np.where(formula_cos[0] >= similar)[0]
- except:
- return []
- if len(cos_list) == 0:
- return []
- # 根据阈值获取满足条件的题目id
- res_list = []
- # formula_threshold = similar
- formula_threshold = 0.7
- for idx in cos_list:
- fuzz_score = fuzz.ratio(formula_list[0], self.formula_id_list[idx][0]) / 100
- if fuzz_score >= formula_threshold:
- # res_list.append([self.formula_id_list[idx][1], fuzz_score])
- # 对余弦相似度进行折算
- cosine_score = formula_cos[0][idx]
- if 0.95 <= cosine_score < 0.98:
- cosine_score = cosine_score * 0.98
- elif cosine_score < 0.95:
- cosine_score = cosine_score * 0.95
- # 余弦相似度折算后阈值判断
- if cosine_score < similar:
- continue
- res_list.append([self.formula_id_list[idx][1], int(cosine_score * 100) / 100])
- # 根据分数对题目id排序并返回前50个
- res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80]
- formula_res_list = []
- fid_set = set()
- for ele in res_sort_list:
- for fid in ele[0]:
- if fid in fid_set:
- continue
- fid_set.add(fid)
- formula_res_list.append([fid, ele[1]])
- return formula_res_list[:50]
- def api_post(self, post_url, post_data, log_info):
- try:
- post_result = requests.post(post_url, json=post_data, timeout=10).json()
- except Exception as e:
- post_result = []
- # 日志采集
- if self.logger is not None:
- self.logger.error(self.log_msg.format(id="{}error".format(log_info[0]),
- type="当前题目{}error".format(log_info[0]),
- message=log_info[1]))
-
- return post_result
- # HNSW查(支持多学科混合查重)
- def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag):
- """
- return:
- [
- {
- 'semantics': [[20232015, 1.0, {'quesType': 1.0, 'knowledge': 1.0, 'physical_scene': 1.0, 'solving_type': 1.0, 'difficulty': 1.0, 'physical_quantity': 1.0}]],
- 'text': [[20232015, 0.97]],
- 'image': [],
- 'label': {'knowledge': ['串并联电路的辨别'], 'physical_scene': ['串并联电路的辨别'], 'solving_type': ['规律理解'], 'difficulty': 0.6, 'physical_quantity': ['电流']},
- 'topic_num': 1
- },
- ...
- ]
- """
- # 计算retrieve_list的vec值
- # 调用清洗分词函数和句向量计算函数
- sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
- # HNSW查重
- retrieve_res_list = []
- for i,sent_vec in enumerate(sent_vec_list):
- # 初始化返回数据类型
- # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
- retrieve_value_dict = dict(semantics=[], text=[], image=[])
- # 获取题目序号
- topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
- # 图片搜索查重功能
- if doc_flag is True:
- retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
- else:
- retrieve_value_dict["image"] = []
- """
- 文本相似度特殊处理
- """
- # 判断句向量维度
- if sent_vec.size != self.vector_dim:
- retrieve_res_list.append(retrieve_value_dict)
- continue
- # 调用hnsw接口检索数据
- query_labels = self.api_post(self.hnsw_retrieve_url, sent_vec.tolist(), ["HNSW检索", cont_clear_list[i]])
- if len(query_labels) == 0:
- retrieve_res_list.append(retrieve_value_dict)
- continue
- # 批量读取数据库
- mongo_find_dict = {"id": {"$in": query_labels}}
- query_dataset = self.mongo_coll.find(mongo_find_dict)
- ####################################### 语义相似度借靠 #######################################
- query_data = dict()
- ####################################### 语义相似度借靠 #######################################
- # 返回大于阈值的结果
- for label_data in query_dataset:
- if "sentence_vec" not in label_data:
- continue
- # 计算余弦相似度得分
- label_vec = pickle.loads(label_data["sentence_vec"])
- if label_vec.size != self.vector_dim:
- continue
- cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
- # 阈值判断
- if cosine_score < similar:
- continue
- # 计算编辑距离得分
- fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
- # 进行编辑距离得分验证,若小于设定分则过滤
- if fuzz_score >= similar:
- retrieve_value = [label_data["id"], fuzz_score]
- retrieve_value_dict["text"].append(retrieve_value)
- ####################################### 语义相似度借靠 #######################################
- max_mark_score = 0
- if cosine_score >= 0.9 and cosine_score > max_mark_score:
- max_mark_score = cosine_score
- query_data = label_data
- ####################################### 语义相似度借靠 #######################################
-
- """
- 语义相似度特殊处理
- """
- # 标签字典初始化
- label_dict = dict() # todo: 可以加一个高相似的替换, 不走模型预测
- # 知识点LLM标注
- label_dict["knowledge"] = query_data["knowledge"] if query_data else []
- ####################################### 知识点标注接口调用 #######################################
- # knowledge_post_data = {"sentence": cont_clear_list[i]}
- # label_dict["knowledge"] = self.api_post(self.knowledge_tagging_url, knowledge_post_data, ["知识点标注", cont_clear_list[i]])
- ####################################### 知识点标注接口调用 #######################################
- tagging_id_list = [self.cph_score.knowledge2id[ele] for ele in label_dict["knowledge"] \
- if ele in self.cph_score.knowledge2id]
- # 题型数据获取
- label_dict["quesType"] = retrieve_list[i].get("quesType", "选择题")
- # 多维分类api调用
- dim_post_data = {"sentence": cont_clear_list[i], "quesType": label_dict["quesType"]}
- dim_classify_dict = self.api_post(self.dim_classify_url, dim_post_data, ["多维分类", cont_clear_list[i]])
- if len(dim_classify_dict) == 0:
- dim_classify_dict = {"solving_type": ["规律理解"], "difficulty": 0.6}
- # 求解类型模型分类
- label_dict["solving_type"] = dim_classify_dict["solving_type"]
- # 难度模型分类
- label_dict["difficulty"] = dim_classify_dict["difficulty"]
- # 物理量规则提取
- label_dict["physical_quantity"] = physical_quantity_extract(cont_clear_list[i])
- # LLM标注知识点题目获取题库对应相似知识点题目数据
- knowledge_id_list = []
- if len(tagging_id_list) > 0:
- ####################################### encode_base_value设置 #######################################
- # 考试院: 10000, 风向标: 10
- encode_base_value = 10000 if config.dev_mode == "ksy" else 10
- ####################################### encode_base_value设置 #######################################
- for ele in tagging_id_list:
- init_id = int(ele / encode_base_value) * encode_base_value
- init_id_list = self.cph_score.init_id2max_id.get(str(init_id), [])
- knowledge_id_list.extend(init_id_list)
- knowledge_query_dataset = None
- if len(knowledge_id_list) > 0:
- mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
- knowledge_query_dataset = self.mongo_coll.find(mongo_find_dict)
- # 返回大于阈值的结果
- if knowledge_query_dataset:
- # 待查重难度数值转换
- if label_dict["difficulty"] in self.difficulty_transfer:
- label_dict["difficulty"] = self.difficulty_transfer[label_dict["difficulty"]]
- for refer_data in knowledge_query_dataset:
- # 题库数据难度数值转换
- if refer_data["difficulty"] in self.difficulty_transfer:
- refer_data["difficulty"] = self.difficulty_transfer[refer_data["difficulty"]]
- sum_score, score_dict = self.cph_score(label_dict, refer_data, scale)
- if sum_score < similar:
- continue
- retrieve_value = [refer_data["id"], sum_score, score_dict]
- retrieve_value_dict["semantics"].append(retrieve_value)
- # 将组合结果按照score降序排序
- retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
- for k,value in retrieve_value_dict.items()}
- # 加入题目序号
- retrieve_sort_dict["label"] = label_dict
- retrieve_sort_dict["topic_num"] = topic_num
-
- # 以字典形式返回最终查重结果
- retrieve_res_list.append(retrieve_sort_dict)
- return retrieve_res_list
- if __name__ == "__main__":
- # 获取mongodb数据
- mongo_coll = config.mongo_coll
- from data_preprocessing import DataPreProcessing
- hnsw = HNSW(DataPreProcessing())
- test_data = []
- for idx in [201511100736265]:
- test_data.append(mongo_coll.find_one({"id": idx}))
- res = hnsw.retrieve(test_data, '', 0.8, False)
- pprint(res[0]["semantics"])
- # # 公式搜索查重功能
- # formula_string = "ρ蜡=0.9*10^3Kg/m^3"
- # formula_string = "p蜡=0.9*10^3Kq/m^3"
- # print(hnsw.formula_retrieve(formula_string, 0.8))
|