|
@@ -29,13 +29,28 @@ class HNSW():
|
|
|
self.formula_id_list = json.load(f)
|
|
|
|
|
|
# 图片搜索查重功能
|
|
|
- def img_retrieve(self, retrieve_text, post_url, similar):
|
|
|
+ def img_retrieve(self, retrieve_text, post_url, similar, topic_num):
|
|
|
try:
|
|
|
if post_url is not None:
|
|
|
+ # 日志采集
|
|
|
+ if self.logger is not None:
|
|
|
+ self.logger.info(self.log_msg.format(id="图片搜索查重",
|
|
|
+ type="{}图片搜索查重post".format(topic_num),
|
|
|
+ message=retrieve_text))
|
|
|
img_dict = dict(img_url=retrieve_text, img_threshold=similar, img_max_num=40)
|
|
|
- img_res = requests.post(post_url, json=img_dict, timeout=20).json()
|
|
|
+ img_res = requests.post(post_url, json=img_dict, timeout=30).json()
|
|
|
+ # 日志采集
|
|
|
+ if self.logger is not None:
|
|
|
+ self.logger.info(self.log_msg.format(id="图片搜索查重",
|
|
|
+ type="{}图片搜索查重success".format(topic_num),
|
|
|
+ message=img_res))
|
|
|
return img_res
|
|
|
except Exception as e:
|
|
|
+ # 日志采集
|
|
|
+ if self.logger is not None:
|
|
|
+ self.logger.error(self.log_msg.format(id="图片搜索查重",
|
|
|
+ type="{}图片搜索查重error".format(topic_num),
|
|
|
+ message=retrieve_text))
|
|
|
return []
|
|
|
|
|
|
# 公式搜索查重功能
|
|
@@ -74,11 +89,12 @@ class HNSW():
|
|
|
# 对余弦相似度进行折算
|
|
|
cosine_score = formula_cos[0][idx]
|
|
|
if 0.95 <= cosine_score < 0.98:
|
|
|
+ cosine_score = cosine_score * 0.98
|
|
|
+ elif cosine_score < 0.95:
|
|
|
cosine_score = cosine_score * 0.95
|
|
|
- elif 0.9 <= cosine_score < 0.95:
|
|
|
- cosine_score = cosine_score * 0.93
|
|
|
- elif cosine_score < 0.9:
|
|
|
- cosine_score = cosine_score * 0.91
|
|
|
+ # 余弦相似度折算后阈值判断
|
|
|
+ if cosine_score < similar:
|
|
|
+ continue
|
|
|
res_list.append([self.formula_id_list[idx][1], int(cosine_score * 100) / 100])
|
|
|
# 根据分数对题目id排序并返回前50个
|
|
|
res_sort_list = sorted(res_list, key=lambda x: x[1], reverse=True)[:80]
|
|
@@ -93,24 +109,120 @@ class HNSW():
|
|
|
|
|
|
return formula_res_list[:50]
|
|
|
|
|
|
+ # # HNSW查(支持多学科混合查重)
|
|
|
+ # def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
|
|
|
+ # # 计算retrieve_list的vec值
|
|
|
+ # # 调用清洗分词函数和句向量计算函数
|
|
|
+ # sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
|
|
|
+
|
|
|
+ # # HNSW查重
|
|
|
+ # def dup_search(retrieve_data, sent_vec, cont_clear):
|
|
|
+ # # 初始化返回数据类型
|
|
|
+ # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
|
|
|
+ # # 获取题目序号
|
|
|
+ # topic_num = retrieve_data["topic_num"] if "topic_num" in retrieve_data else 1
|
|
|
+ # # 图片搜索查重功能
|
|
|
+ # if doc_flag is True:
|
|
|
+ # retrieve_value_dict["image"] = self.img_retrieve(retrieve_data["stem"], post_url, similar, topic_num)
|
|
|
+ # else:
|
|
|
+ # retrieve_value_dict["image"] = []
|
|
|
+ # # 判断句向量维度
|
|
|
+ # if sent_vec.size != self.vector_dim:
|
|
|
+ # return retrieve_value_dict
|
|
|
+ # # 调用hnsw接口检索数据
|
|
|
+ # post_list = sent_vec.tolist()
|
|
|
+ # try:
|
|
|
+ # query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
|
|
|
+ # except Exception as e:
|
|
|
+ # query_labels = []
|
|
|
+ # # 日志采集
|
|
|
+ # if self.logger is not None:
|
|
|
+ # self.logger.error(self.log_msg.format(id="HNSW检索error",
|
|
|
+ # type="当前题目HNSW检索error",
|
|
|
+ # message=cont_clear))
|
|
|
+ # if len(query_labels) == 0:
|
|
|
+ # return retrieve_value_dict
|
|
|
+
|
|
|
+ # # 批量读取数据库
|
|
|
+ # mongo_find_dict = {"id": {"$in": query_labels}}
|
|
|
+ # query_dataset = self.mongo_coll.find(mongo_find_dict)
|
|
|
+
|
|
|
+ # # 返回大于阈值的结果
|
|
|
+ # filter_threshold = similar
|
|
|
+ # for label_data in query_dataset:
|
|
|
+ # if "sentence_vec" not in label_data:
|
|
|
+ # continue
|
|
|
+ # # 计算余弦相似度得分
|
|
|
+ # label_vec = pickle.loads(label_data["sentence_vec"])
|
|
|
+ # if label_vec.size != self.vector_dim:
|
|
|
+ # continue
|
|
|
+ # cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
|
|
|
+ # # 阈值判断
|
|
|
+ # if cosine_score < filter_threshold:
|
|
|
+ # continue
|
|
|
+ # # 计算编辑距离得分
|
|
|
+ # fuzz_score = fuzz.ratio(cont_clear, label_data["content_clear"]) / 100
|
|
|
+ # if fuzz_score < min_threshold:
|
|
|
+ # continue
|
|
|
+ # # 对余弦相似度进行折算
|
|
|
+ # if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
|
|
|
+ # cosine_score = cosine_score * 0.95
|
|
|
+ # elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
|
|
|
+ # cosine_score = cosine_score * 0.94
|
|
|
+ # # 余弦相似度折算后阈值判断
|
|
|
+ # if cosine_score < filter_threshold:
|
|
|
+ # continue
|
|
|
+ # retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
|
|
|
+ # retrieve_value_dict["semantics"].append(retrieve_value)
|
|
|
+ # # 进行编辑距离得分验证,若小于设定分则过滤
|
|
|
+ # if fuzz_score >= filter_threshold:
|
|
|
+ # retrieve_value = [label_data["id"], fuzz_score]
|
|
|
+ # retrieve_value_dict["text"].append(retrieve_value)
|
|
|
+
|
|
|
+ # # 将组合结果按照score降序排序并取得分前十个结果
|
|
|
+ # retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
|
|
|
+ # for k,value in retrieve_value_dict.items()}
|
|
|
+
|
|
|
+ # # 综合排序
|
|
|
+ # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
|
|
|
+ # synthese_set = set()
|
|
|
+ # for ele in synthese_list:
|
|
|
+ # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
|
|
|
+ # synthese_set.add(ele[0])
|
|
|
+ # retrieve_sort_dict["synthese"].append(ele)
|
|
|
+ # # 加入题目序号
|
|
|
+ # retrieve_sort_dict["topic_num"] = topic_num
|
|
|
+
|
|
|
+ # # 以字典形式返回最终查重结果
|
|
|
+ # return retrieve_sort_dict
|
|
|
+
|
|
|
+ # # 多线程HNSW查重
|
|
|
+ # with ThreadPoolExecutor(max_workers=5) as executor:
|
|
|
+ # retrieve_res_list = list(executor.map(dup_search, retrieve_list, sent_vec_list, cont_clear_list))
|
|
|
+
|
|
|
+ # return retrieve_res_list
|
|
|
+
|
|
|
# HNSW查(支持多学科混合查重)
|
|
|
- def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.4):
|
|
|
+ def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
|
|
|
# 计算retrieve_list的vec值
|
|
|
# 调用清洗分词函数和句向量计算函数
|
|
|
sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
|
|
|
-
|
|
|
# HNSW查重
|
|
|
- def dup_search(retrieve_data, sent_vec, cont_clear):
|
|
|
+ retrieve_res_list = []
|
|
|
+ for i,sent_vec in enumerate(sent_vec_list):
|
|
|
# 初始化返回数据类型
|
|
|
retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
|
|
|
+ # 获取题目序号
|
|
|
+ topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
|
|
|
# 图片搜索查重功能
|
|
|
if doc_flag is True:
|
|
|
- retrieve_value_dict["image"] = self.img_retrieve(retrieve_data["stem"], post_url, similar)
|
|
|
+ retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
|
|
|
else:
|
|
|
retrieve_value_dict["image"] = []
|
|
|
# 判断句向量维度
|
|
|
if sent_vec.size != self.vector_dim:
|
|
|
- return retrieve_value_dict
|
|
|
+ retrieve_res_list.append(retrieve_value_dict)
|
|
|
+ continue
|
|
|
# 调用hnsw接口检索数据
|
|
|
post_list = sent_vec.tolist()
|
|
|
try:
|
|
@@ -121,9 +233,10 @@ class HNSW():
|
|
|
if self.logger is not None:
|
|
|
self.logger.error(self.log_msg.format(id="HNSW检索error",
|
|
|
type="当前题目HNSW检索error",
|
|
|
- message=cont_clear))
|
|
|
+ message=cont_clear_list[i]))
|
|
|
if len(query_labels) == 0:
|
|
|
- return retrieve_value_dict
|
|
|
+ retrieve_res_list.append(retrieve_value_dict)
|
|
|
+ continue
|
|
|
|
|
|
# 批量读取数据库
|
|
|
mongo_find_dict = {"id": {"$in": query_labels}}
|
|
@@ -143,16 +256,17 @@ class HNSW():
|
|
|
if cosine_score < filter_threshold:
|
|
|
continue
|
|
|
# 计算编辑距离得分
|
|
|
- fuzz_score = fuzz.ratio(cont_clear, label_data["content_clear"]) / 100
|
|
|
+ fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
|
|
|
if fuzz_score < min_threshold:
|
|
|
continue
|
|
|
# 对余弦相似度进行折算
|
|
|
- if 0.95 <= cosine_score < 0.98:
|
|
|
- cosine_score = cosine_score * 0.93
|
|
|
- elif 0.9 <= cosine_score < 0.95:
|
|
|
- cosine_score = cosine_score * 0.87
|
|
|
- elif cosine_score < 0.9:
|
|
|
- cosine_score = cosine_score * 0.81
|
|
|
+ if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
|
|
|
+ cosine_score = cosine_score * 0.95
|
|
|
+ elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
|
|
|
+ cosine_score = cosine_score * 0.94
|
|
|
+ # 余弦相似度折算后阈值判断
|
|
|
+ if cosine_score < filter_threshold:
|
|
|
+ continue
|
|
|
retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
|
|
|
retrieve_value_dict["semantics"].append(retrieve_value)
|
|
|
# 进行编辑距离得分验证,若小于设定分则过滤
|
|
@@ -171,31 +285,29 @@ class HNSW():
|
|
|
if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
|
|
|
synthese_set.add(ele[0])
|
|
|
retrieve_sort_dict["synthese"].append(ele)
|
|
|
-
|
|
|
+ # 加入题目序号
|
|
|
+ retrieve_sort_dict["topic_num"] = topic_num
|
|
|
+
|
|
|
# 以字典形式返回最终查重结果
|
|
|
- retrieve_sort_dict["topic_num"] = retrieve_data["topic_num"]
|
|
|
- return retrieve_sort_dict
|
|
|
+ retrieve_res_list.append(retrieve_sort_dict)
|
|
|
|
|
|
- # 多线程HNSW查重
|
|
|
- with ThreadPoolExecutor(max_workers=5) as executor:
|
|
|
- retrieve_res_list = list(executor.map(dup_search, retrieve_list, sent_vec_list, cont_clear_list))
|
|
|
-
|
|
|
return retrieve_res_list
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
# 获取mongodb数据
|
|
|
mongo_coll = config.mongo_coll
|
|
|
- hnsw = HNSW()
|
|
|
+ from data_preprocessing import DataPreProcessing
|
|
|
+ hnsw = HNSW(DataPreProcessing())
|
|
|
|
|
|
- # test_data = []
|
|
|
- # for idx in [15176736]:
|
|
|
- # test_data.append(mongo_coll.find_one({"id": idx}))
|
|
|
+ test_data = []
|
|
|
+ for idx in [201511100736265]:
|
|
|
+ test_data.append(mongo_coll.find_one({"id": idx}))
|
|
|
|
|
|
- # res = hnsw.retrieve(test_data)
|
|
|
- # pprint(res)
|
|
|
+ res = hnsw.retrieve(test_data, '', 0.8, False)
|
|
|
+ pprint(res[0]["semantics"])
|
|
|
|
|
|
- # 公式搜索查重功能
|
|
|
- formula_string = "ρ蜡=0.9*10^3Kg/m^3"
|
|
|
- formula_string = "p蜡=0.9*10^3Kq/m^3"
|
|
|
- print(hnsw.formula_retrieve(formula_string, 0.8))
|
|
|
+ # # 公式搜索查重功能
|
|
|
+ # formula_string = "ρ蜡=0.9*10^3Kg/m^3"
|
|
|
+ # formula_string = "p蜡=0.9*10^3Kq/m^3"
|
|
|
+ # print(hnsw.formula_retrieve(formula_string, 0.8))
|