123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- import time
- import pickle
- import requests
- from bson.binary import Binary
- from fuzzywuzzy import fuzz
- from sentence_transformers import util
- from pprint import pprint
- import config
- from data_preprocessing import DataPreProcessing
- class Hnsw_Logic():
- def __init__(self, logger=None):
- # 配置初始数据
- self.mongo_coll = config.mongo_coll_list
- self.vector_dim = config.vector_dim
- self.database_threshold = config.database_threshold
- self.hnsw_update_url = config.hnsw_update_url
- self.hnsw_retrieve_url = config.hnsw_retrieve_url
- # 日志采集
- self.logger = logger
- self.log_msg = config.log_msg
- # 数据预处理初始化
- self.dpp = DataPreProcessing(self.mongo_coll, self.logger)
- # HNSW查询逻辑判断
- def logic_process(self, retrieve_list, hnsw_index):
- # 调用清洗分词函数和句向量计算函数
- sent_vec_list, cont_clear_list = self.dpp(retrieve_list, hnsw_index, is_retrieve=True)
- # 云题库HNSW查重
- cloud_list = self.retrieve(retrieve_list, sent_vec_list, cont_clear_list, hnsw_index=0)
- if hnsw_index == 0:
- return cloud_list
- # 校本题库HNSW查重
- school_list = self.retrieve(retrieve_list, sent_vec_list, cont_clear_list, hnsw_index=1)
- # 遍历retrieve_list, 将数据插入mongodb
- for i,data in enumerate(retrieve_list):
- topic_id = data["topic_id"]
- # 判断相似度并将符合要求的数据存入mongodb
- if len(school_list[i]) > 0 and school_list[i][0][1] > 0.97:
- continue
- # 判断清洗文本长度,若长度小于10,则表示清洗失败需要过滤
- if len(cont_clear_list[i]) < 10:
- continue
- # 防止出现重复的topic_id
- if self.mongo_coll[1].find_one({"topic_id": int(topic_id)}) is None:
- try:
- self.school_mongodb_insert(data, sent_vec_list[i], cont_clear_list[i])
- # 将数据实时更新至hnsw模型
- self.update(int(topic_id), hnsw_index=1)
- except Exception as e:
- # 日志采集
- self.logger.info(self.log_msg.format(id=int(topic_id),
- type="chc retrieval insert",
- message="chc查重数据插入失败-"+str(e)))
- return cloud_list, school_list
- # 将校本题库接收数据插入mongodb
- def school_mongodb_insert(self, data, sentence_vec, content_clear):
- sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128)
- insert_dict = {
- 'topic_id': int(data['topic_id']),
- 'content_raw': data['stem'] if 'stem' in data else '',
- 'content': content_clear,
- 'content_clear': content_clear,
- 'sentence_vec': sentence_vec_byte,
- 'sent_train_flag': config.sent_train_flag,
- 'topic_type_id': int(data['topic_type_id']) if 'topic_type_id' in data else 0,
- 'school_id': int(data['school_id']) if 'school_id' in data else 0,
- 'parse': data['parse'] if 'parse' in data else '',
- 'answer': data['answer'] if 'answer' in data else '',
- 'save_time': time.time(),
- 'subject_id': int(data['subject_id']) if 'subject_id' in data else 0
- }
- # 将数据插入mongodb
- self.mongo_coll[1].insert_one(insert_dict)
- # 日志采集
- self.logger.info(self.log_msg.format(id=int(data['topic_id']),
- type="chc retrieval insert",
- message="已将chc查重数据插入mongo_coll_school"))
- # HNSW增/改
- def update(self, update_id, hnsw_index):
- if hnsw_index == 0:
- # 数据清洗、分词与句向量化
- cloud_data = self.mongo_coll[hnsw_index].find_one({"topic_id": update_id})
- self.dpp([cloud_data], hnsw_index=0)
- # 调用hnsw接口更新数据
- update_dict = dict(id=update_id, hnsw_index=hnsw_index)
- try:
- requests.post(self.hnsw_update_url, json=update_dict, timeout=10)
- # 日志采集
- if self.logger is not None:
- db_name = "云题库" if hnsw_index == 0 else "校本题库"
- self.logger.info(config.log_msg.format(id=update_id,
- type="{}数据更新".format(db_name),
- message="数据更新完毕"))
- except Exception as e:
- # 日志采集
- if self.logger is not None:
- self.logger.error(self.log_msg.format(id="HNSW更新error",
- type="当前题目HNSW更新error",
- message=update_id))
-
- # HNSW查(支持多学科混合查重)
- def retrieve(self, retrieve_list, sent_vec_list, cont_clear_list, hnsw_index):
- retrieve_res_list = []
- # 遍历检索查重数据
- for i,query_vec in enumerate(sent_vec_list):
- # 判断句向量维度
- if "subject_id" not in retrieve_list[i] or query_vec.size != self.vector_dim:
- retrieve_res_list.append([])
- continue
- subject_id = int(retrieve_list[i]["subject_id"])
- # 调用hnsw接口检索数据
- post_dict = dict(query_vec=query_vec.tolist(), hnsw_index=hnsw_index)
- try:
- query_labels = requests.post(self.hnsw_retrieve_url, json=post_dict, timeout=10).json()
- except Exception as e:
- query_labels = []
- # 日志采集
- if self.logger is not None:
- self.logger.error(self.log_msg.format(id="HNSW检索error",
- type="当前题目HNSW检索error",
- message=retrieve_list[i]["topic_id"]))
- if len(query_labels) == 0:
- retrieve_res_list.append([])
- continue
- # 批量读取数据库
- mongo_find_dict = {"topic_id": {"$in": query_labels}, "subject_id": subject_id}
- # 通过{"is_stop": 0}来过滤被删除题目的topic_id
- if hnsw_index == 0:
- mongo_find_dict["is_stop"] = 0
- # 增加题型判断,准确定位重题
- if "topic_type_id" in retrieve_list[i]:
- mongo_find_dict["topic_type_id"] = int(retrieve_list[i]["topic_type_id"])
- query_dataset = self.mongo_coll[hnsw_index].find(mongo_find_dict)
- # 返回大于阈值的结果
- cos_threshold = self.database_threshold[hnsw_index][0]
- fuzz_threshold = self.database_threshold[hnsw_index][1]
- retrieve_value_list = []
- for label_data in query_dataset:
- # 防止出现重复topic_id, 预先进行过滤
- if "sentence_vec" not in label_data:
- continue
- # 计算余弦相似度得分
- label_vec = pickle.loads(label_data["sentence_vec"])
- if label_vec.size != self.vector_dim:
- continue
- cosine_score = util.cos_sim(query_vec, label_vec)[0][0]
- # 阈值判断
- if cosine_score < cos_threshold:
- continue
- # 对于学科进行编辑距离(fuzzywuzzy-200字符)验证,若小于设定分则过滤
- if fuzz.ratio(cont_clear_list[i][:200], label_data["content_clear"][:200]) / 100 < fuzz_threshold:
- continue
- retrieve_value = [label_data["topic_id"], int(cosine_score*100)/100]
- retrieve_value_list.append(retrieve_value)
- # 将组合结果按照score降序排序并取得分前十个结果
- score_sort_list = sorted(retrieve_value_list, key=lambda x: x[1], reverse=True)[:20]
-
- # 以列表形式返回最终查重结果
- retrieve_res_list.append(score_sort_list)
- # 日志采集
- if self.logger is not None:
- self.logger.info(self.log_msg.format(
- id="云题库查重" if hnsw_index == 0 else "校本题库查重",
- type="repeat检索" if hnsw_index == 0 else "chc检索",
- message=str({idx+1:ele for idx,ele in enumerate(retrieve_res_list)})))
- return retrieve_res_list
- if __name__ == "__main__":
- # 获取mongodb数据
- mongo_coll = config.mongo_coll
- hnsw_logic = Hnsw_Logic()
- test_data = []
- for topic_id in [201511100832270]:
- test_data.append(mongo_coll.find_one({"topic_id":topic_id}))
- res = hnsw_logic.retrieve(test_data, hnsw_index=0)
- pprint(res)
|