database_dup_check.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import time
  2. import pickle
  3. import requests
  4. from fuzzywuzzy import fuzz
  5. from sentence_transformers import util
  6. import config
  7. from mysql_operate import mysql_operate
  8. class DataBase_Dup_Check():
  9. def __init__(self):
  10. self.mongo_coll = config.mongo_coll_cloud
  11. # hnsw_app检索接口
  12. self.hnsw_retrieve_url = config.hnsw_retrieve_url
  13. # 获取初始group_id值
  14. self.group_count = 1
  15. # 主函数
  16. def __call__(self):
  17. start0 = time.time()
  18. # 将对应学科mongodb的列group_id数据初始化
  19. cond_reset = {"is_stop": 0}
  20. set_reset = {"$set": {"group_id": 0}}
  21. self.mongo_coll.update_many(cond_reset, set_reset)
  22. # 日志采集
  23. print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  24. type="cloud_mongodb_group_id初始化耗时:",
  25. message=time.time()-start0))
  26. # 清空数据表
  27. delete_sql = "truncate table topic_duplicate_check;"
  28. mysql_operate(delete_sql, operate_flag=1)
  29. # 遍历数据查重
  30. origin_dataset = self.mongo_coll.find({"is_stop": 0}, no_cursor_timeout=True, batch_size=5)
  31. self.hnsw_retrieve(origin_dataset)
  32. # 遍历数据查重
  33. def hnsw_retrieve(self, origin_dataset):
  34. for data in origin_dataset:
  35. # 若"group_id"不为0表示为相似题,则跳过该题
  36. # 若不是相似题,则判断"sentence_vec"是否在数据中
  37. if data["group_id"] != 0 or "sentence_vec" not in data:
  38. continue
  39. # 获取学科id和句向量
  40. topic_id = data["topic_id"]
  41. subject_id = data["subject_id"]
  42. query_vec = pickle.loads(data["sentence_vec"])
  43. query_cont = data["content_clear"]
  44. # 句向量不符合条件则跳过
  45. if query_vec.size != config.vector_dim:
  46. continue
  47. # 调用hnsw接口检索数据
  48. post_dict = dict(query_vec=query_vec.tolist(), hnsw_index=0)
  49. try:
  50. query_labels = requests.post(self.hnsw_retrieve_url, json=post_dict, timeout=10).json()
  51. except Exception as e:
  52. continue
  53. # 去除自身topic_id
  54. query_labels = [idx for idx in query_labels if idx != topic_id]
  55. # 批量读取数据库
  56. mongo_find_dict = {"topic_id":{'$in':query_labels}, "subject_id":subject_id, "is_stop": 0}
  57. # 增加题型判断,准确定位重题
  58. if "topic_type_id" in data:
  59. mongo_find_dict["topic_type_id"] = data["topic_type_id"]
  60. query_dataset = self.mongo_coll.find(mongo_find_dict)
  61. # 重题组判定阈值
  62. threshold = 0.95
  63. retrieve_value_list = []
  64. for label_data in query_dataset:
  65. # 若group_id不为0表示为相似题,则跳过该题
  66. if label_data["group_id"] != 0:
  67. continue
  68. # 计算余弦相似度得分
  69. # 句向量不符合条件则跳过(若数据发生改变导致向量变化)
  70. try:
  71. cosine_score = util.cos_sim(query_vec, pickle.loads(label_data["sentence_vec"]))[0][0]
  72. except Exception as e:
  73. print("topic_id: {} -> error_id: {}".format(topic_id, label_data["topic_id"]))
  74. continue
  75. if cosine_score < threshold:
  76. continue
  77. # 进行编辑距离(fuzzywuzzy-200字符)验证,若小于设定分则过滤
  78. if fuzz.ratio(query_cont[:200], label_data["content_clear"][:200]) / 100 < 0.8:
  79. continue
  80. retrieve_value_list.append([label_data["topic_id"], int(cosine_score*100)/100, label_data])
  81. # 若当前题目找不到满足阈值的相似题则跳过
  82. if len(retrieve_value_list) == 0:
  83. continue
  84. else:
  85. retrieve_value_list.append([topic_id, 1, data])
  86. # 将结果存入数据库
  87. self.mysql_update(retrieve_value_list)
  88. # 修改group_id
  89. self.group_count += 1
  90. # 将数据存入数据库
  91. def mysql_update(self, mongo_update_list):
  92. insert_list = []
  93. last_index = len(mongo_update_list) - 1
  94. for i,ele in enumerate(mongo_update_list):
  95. is_motif = 1 if i == last_index else 0
  96. # 将重题group_id存入mongodb对应topic_id
  97. condition = {"topic_id": ele[0]}
  98. update_elements = {"$set": {"group_id": self.group_count, "sim_score": ele[1]}}
  99. self.mongo_coll.update_one(condition, update_elements)
  100. # 获取插入数据值
  101. insert_value = self.get_insert_value(ele, is_motif)
  102. # 将重题数据存入insert_list
  103. insert_list.append(insert_value)
  104. # 对compute_final_draft和sim_score进行降序排序
  105. insert_list.sort(key=lambda x: (-x[-2], -x[2]))
  106. # sql数据插入语句
  107. insert_sql = "insert into topic_duplicate_check(topic_id,group_id,sim_score,subject_id,topic_type_id," \
  108. "unit_id,is_audit,is_mark,compute_final_draft,create_time) values(%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"
  109. mysql_operate(insert_sql, insert_list, operate_flag=1)
  110. # 获取插入数据值
  111. def get_insert_value(self, ele, compute_final_draft):
  112. sim_score, data = ele[1], ele[2]
  113. create_time = int(time.time())
  114. # mysql查询获取unit_id
  115. sql = "SELECT kp.unit_id FROM tn_kp, kp WHERE tn_kp.kp_id=kp.kp_id AND tn_kp.topic_id=%d" % (ele[0])
  116. fetch_dict = mysql_operate(sql)
  117. unit_id = fetch_dict["unit_id"] if fetch_dict is not None and len(fetch_dict) > 0 else 0
  118. # 获取insert_value
  119. insert_value = (data["topic_id"], self.group_count, sim_score, data["subject_id"], data["topic_type_id"],
  120. unit_id, data["is_audit"], data["is_mark"], compute_final_draft, create_time)
  121. return insert_value
  122. if __name__ == "__main__":
  123. ddc = DataBase_Dup_Check()
  124. start = time.time()
  125. # 日志采集
  126. print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  127. type="定时自查重功能",
  128. message="即将开启题库自查重功能"))
  129. ddc()
  130. # 日志采集
  131. print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  132. type="云题库自查重总计耗时:",
  133. message=time.time()-start))