import time import pickle import requests from fuzzywuzzy import fuzz from sentence_transformers import util import config from mysql_operate import mysql_operate class DataBase_Dup_Check(): def __init__(self): self.mongo_coll = config.mongo_coll_cloud # hnsw_app检索接口 self.hnsw_retrieve_url = config.hnsw_retrieve_url # 获取初始group_id值 self.group_count = 1 # 主函数 def __call__(self): start0 = time.time() # 将对应学科mongodb的列group_id数据初始化 cond_reset = {"is_stop": 0} set_reset = {"$set": {"group_id": 0}} self.mongo_coll.update_many(cond_reset, set_reset) # 日志采集 print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type="cloud_mongodb_group_id初始化耗时:", message=time.time()-start0)) # 清空数据表 delete_sql = "truncate table topic_duplicate_check;" mysql_operate(delete_sql, operate_flag=1) # 遍历数据查重 origin_dataset = self.mongo_coll.find({"is_stop": 0}, no_cursor_timeout=True, batch_size=5) self.hnsw_retrieve(origin_dataset) # 遍历数据查重 def hnsw_retrieve(self, origin_dataset): for data in origin_dataset: # 若"group_id"不为0表示为相似题,则跳过该题 # 若不是相似题,则判断"sentence_vec"是否在数据中 if data["group_id"] != 0 or "sentence_vec" not in data: continue # 获取学科id和句向量 topic_id = data["topic_id"] subject_id = data["subject_id"] query_vec = pickle.loads(data["sentence_vec"]) query_cont = data["content_clear"] # 句向量不符合条件则跳过 if query_vec.size != config.vector_dim: continue # 调用hnsw接口检索数据 post_dict = dict(query_vec=query_vec.tolist(), hnsw_index=0) try: query_labels = requests.post(self.hnsw_retrieve_url, json=post_dict, timeout=10).json() except Exception as e: continue # 去除自身topic_id query_labels = [idx for idx in query_labels if idx != topic_id] # 批量读取数据库 mongo_find_dict = {"topic_id":{'$in':query_labels}, "subject_id":subject_id, "is_stop": 0} # 增加题型判断,准确定位重题 if "topic_type_id" in data: mongo_find_dict["topic_type_id"] = data["topic_type_id"] query_dataset = self.mongo_coll.find(mongo_find_dict) # 重题组判定阈值 threshold = 0.95 retrieve_value_list = [] for label_data in query_dataset: # 若group_id不为0表示为相似题,则跳过该题 if label_data["group_id"] != 0: continue # 计算余弦相似度得分 # 句向量不符合条件则跳过(若数据发生改变导致向量变化) try: cosine_score = util.cos_sim(query_vec, pickle.loads(label_data["sentence_vec"]))[0][0] except Exception as e: print("topic_id: {} -> error_id: {}".format(topic_id, label_data["topic_id"])) continue if cosine_score < threshold: continue # 进行编辑距离(fuzzywuzzy-200字符)验证,若小于设定分则过滤 if fuzz.ratio(query_cont[:200], label_data["content_clear"][:200]) / 100 < 0.8: continue retrieve_value_list.append([label_data["topic_id"], int(cosine_score*100)/100, label_data]) # 若当前题目找不到满足阈值的相似题则跳过 if len(retrieve_value_list) == 0: continue else: retrieve_value_list.append([topic_id, 1, data]) # 将结果存入数据库 self.mysql_update(retrieve_value_list) # 修改group_id self.group_count += 1 # 将数据存入数据库 def mysql_update(self, mongo_update_list): insert_list = [] last_index = len(mongo_update_list) - 1 for i,ele in enumerate(mongo_update_list): is_motif = 1 if i == last_index else 0 # 将重题group_id存入mongodb对应topic_id condition = {"topic_id": ele[0]} update_elements = {"$set": {"group_id": self.group_count, "sim_score": ele[1]}} self.mongo_coll.update_one(condition, update_elements) # 获取插入数据值 insert_value = self.get_insert_value(ele, is_motif) # 将重题数据存入insert_list insert_list.append(insert_value) # 对compute_final_draft和sim_score进行降序排序 insert_list.sort(key=lambda x: (-x[-2], -x[2])) # sql数据插入语句 insert_sql = "insert into topic_duplicate_check(topic_id,group_id,sim_score,subject_id,topic_type_id," \ "unit_id,is_audit,is_mark,compute_final_draft,create_time) values(%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)" mysql_operate(insert_sql, insert_list, operate_flag=1) # 获取插入数据值 def get_insert_value(self, ele, compute_final_draft): sim_score, data = ele[1], ele[2] create_time = int(time.time()) # mysql查询获取unit_id sql = "SELECT kp.unit_id FROM tn_kp, kp WHERE tn_kp.kp_id=kp.kp_id AND tn_kp.topic_id=%d" % (ele[0]) fetch_dict = mysql_operate(sql) unit_id = fetch_dict["unit_id"] if fetch_dict is not None and len(fetch_dict) > 0 else 0 # 获取insert_value insert_value = (data["topic_id"], self.group_count, sim_score, data["subject_id"], data["topic_type_id"], unit_id, data["is_audit"], data["is_mark"], compute_final_draft, create_time) return insert_value if __name__ == "__main__": ddc = DataBase_Dup_Check() start = time.time() # 日志采集 print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type="定时自查重功能", message="即将开启题库自查重功能")) ddc() # 日志采集 print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type="云题库自查重总计耗时:", message=time.time()-start))