123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import time
- import pickle
- import requests
- from fuzzywuzzy import fuzz
- from sentence_transformers import util
- import config
- from mysql_operate import mysql_operate
- class DataBase_Dup_Check():
- def __init__(self):
- self.mongo_coll = config.mongo_coll_cloud
- # hnsw_app检索接口
- self.hnsw_retrieve_url = config.hnsw_retrieve_url
- # 获取初始group_id值
- self.group_count = 1
-
- # 主函数
- def __call__(self):
- start0 = time.time()
- # 将对应学科mongodb的列group_id数据初始化
- cond_reset = {"is_stop": 0}
- set_reset = {"$set": {"group_id": 0}}
- self.mongo_coll.update_many(cond_reset, set_reset)
- # 日志采集
- print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type="cloud_mongodb_group_id初始化耗时:",
- message=time.time()-start0))
- # 清空数据表
- delete_sql = "truncate table topic_duplicate_check;"
- mysql_operate(delete_sql, operate_flag=1)
- # 遍历数据查重
- origin_dataset = self.mongo_coll.find({"is_stop": 0}, no_cursor_timeout=True, batch_size=5)
- self.hnsw_retrieve(origin_dataset)
- # 遍历数据查重
- def hnsw_retrieve(self, origin_dataset):
- for data in origin_dataset:
- # 若"group_id"不为0表示为相似题,则跳过该题
- # 若不是相似题,则判断"sentence_vec"是否在数据中
- if data["group_id"] != 0 or "sentence_vec" not in data:
- continue
- # 获取学科id和句向量
- topic_id = data["topic_id"]
- subject_id = data["subject_id"]
- query_vec = pickle.loads(data["sentence_vec"])
- query_cont = data["content_clear"]
- # 句向量不符合条件则跳过
- if query_vec.size != config.vector_dim:
- continue
- # 调用hnsw接口检索数据
- post_dict = dict(query_vec=query_vec.tolist(), hnsw_index=0)
- try:
- query_labels = requests.post(self.hnsw_retrieve_url, json=post_dict, timeout=10).json()
- except Exception as e:
- continue
- # 去除自身topic_id
- query_labels = [idx for idx in query_labels if idx != topic_id]
- # 批量读取数据库
- mongo_find_dict = {"topic_id":{'$in':query_labels}, "subject_id":subject_id, "is_stop": 0}
- # 增加题型判断,准确定位重题
- if "topic_type_id" in data:
- mongo_find_dict["topic_type_id"] = data["topic_type_id"]
- query_dataset = self.mongo_coll.find(mongo_find_dict)
- # 重题组判定阈值
- threshold = 0.95
- retrieve_value_list = []
- for label_data in query_dataset:
- # 若group_id不为0表示为相似题,则跳过该题
- if label_data["group_id"] != 0:
- continue
- # 计算余弦相似度得分
- # 句向量不符合条件则跳过(若数据发生改变导致向量变化)
- try:
- cosine_score = util.cos_sim(query_vec, pickle.loads(label_data["sentence_vec"]))[0][0]
- except Exception as e:
- print("topic_id: {} -> error_id: {}".format(topic_id, label_data["topic_id"]))
- continue
- if cosine_score < threshold:
- continue
- # 进行编辑距离(fuzzywuzzy-200字符)验证,若小于设定分则过滤
- if fuzz.ratio(query_cont[:200], label_data["content_clear"][:200]) / 100 < 0.8:
- continue
- retrieve_value_list.append([label_data["topic_id"], int(cosine_score*100)/100, label_data])
- # 若当前题目找不到满足阈值的相似题则跳过
- if len(retrieve_value_list) == 0:
- continue
- else:
- retrieve_value_list.append([topic_id, 1, data])
- # 将结果存入数据库
- self.mysql_update(retrieve_value_list)
- # 修改group_id
- self.group_count += 1
- # 将数据存入数据库
- def mysql_update(self, mongo_update_list):
- insert_list = []
- last_index = len(mongo_update_list) - 1
- for i,ele in enumerate(mongo_update_list):
- is_motif = 1 if i == last_index else 0
- # 将重题group_id存入mongodb对应topic_id
- condition = {"topic_id": ele[0]}
- update_elements = {"$set": {"group_id": self.group_count, "sim_score": ele[1]}}
- self.mongo_coll.update_one(condition, update_elements)
- # 获取插入数据值
- insert_value = self.get_insert_value(ele, is_motif)
- # 将重题数据存入insert_list
- insert_list.append(insert_value)
- # 对compute_final_draft和sim_score进行降序排序
- insert_list.sort(key=lambda x: (-x[-2], -x[2]))
- # sql数据插入语句
- insert_sql = "insert into topic_duplicate_check(topic_id,group_id,sim_score,subject_id,topic_type_id," \
- "unit_id,is_audit,is_mark,compute_final_draft,create_time) values(%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"
- mysql_operate(insert_sql, insert_list, operate_flag=1)
-
- # 获取插入数据值
- def get_insert_value(self, ele, compute_final_draft):
- sim_score, data = ele[1], ele[2]
- create_time = int(time.time())
- # mysql查询获取unit_id
- sql = "SELECT kp.unit_id FROM tn_kp, kp WHERE tn_kp.kp_id=kp.kp_id AND tn_kp.topic_id=%d" % (ele[0])
- fetch_dict = mysql_operate(sql)
- unit_id = fetch_dict["unit_id"] if fetch_dict is not None and len(fetch_dict) > 0 else 0
- # 获取insert_value
- insert_value = (data["topic_id"], self.group_count, sim_score, data["subject_id"], data["topic_type_id"],
- unit_id, data["is_audit"], data["is_mark"], compute_final_draft, create_time)
- return insert_value
- if __name__ == "__main__":
- ddc = DataBase_Dup_Check()
- start = time.time()
- # 日志采集
- print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type="定时自查重功能",
- message="即将开启题库自查重功能"))
- ddc()
- # 日志采集
- print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type="云题库自查重总计耗时:",
- message=time.time()-start))
|