|
@@ -9,6 +9,7 @@ from pprint import pprint
|
|
|
|
|
|
import config
|
|
import config
|
|
from formula_process import formula_recognize
|
|
from formula_process import formula_recognize
|
|
|
|
+from comprehensive_score import Comprehensive_Score
|
|
|
|
|
|
class HNSW():
|
|
class HNSW():
|
|
def __init__(self, data_process, logger=None):
|
|
def __init__(self, data_process, logger=None):
|
|
@@ -21,6 +22,8 @@ class HNSW():
|
|
self.log_msg = config.log_msg
|
|
self.log_msg = config.log_msg
|
|
# 数据预处理实例化
|
|
# 数据预处理实例化
|
|
self.dpp = data_process
|
|
self.dpp = data_process
|
|
|
|
+ # 语义相似度实例化
|
|
|
|
+ self.cph_score = Comprehensive_Score()
|
|
# 加载公式处理数据模型(词袋模型/原始向量/原始数据)
|
|
# 加载公式处理数据模型(词袋模型/原始向量/原始数据)
|
|
with open(config.bow_model_path, "rb") as bm:
|
|
with open(config.bow_model_path, "rb") as bm:
|
|
self.bow_model = pickle.load(bm)
|
|
self.bow_model = pickle.load(bm)
|
|
@@ -203,7 +206,7 @@ class HNSW():
|
|
# return retrieve_res_list
|
|
# return retrieve_res_list
|
|
|
|
|
|
# HNSW查(支持多学科混合查重)
|
|
# HNSW查(支持多学科混合查重)
|
|
- def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
|
|
|
|
|
|
+ def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag):
|
|
# 计算retrieve_list的vec值
|
|
# 计算retrieve_list的vec值
|
|
# 调用清洗分词函数和句向量计算函数
|
|
# 调用清洗分词函数和句向量计算函数
|
|
sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
|
|
sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
|
|
@@ -211,7 +214,8 @@ class HNSW():
|
|
retrieve_res_list = []
|
|
retrieve_res_list = []
|
|
for i,sent_vec in enumerate(sent_vec_list):
|
|
for i,sent_vec in enumerate(sent_vec_list):
|
|
# 初始化返回数据类型
|
|
# 初始化返回数据类型
|
|
- retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
|
|
|
|
|
|
+ # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
|
|
|
|
+ retrieve_value_dict = dict(semantics=[], text=[], image=[])
|
|
# 获取题目序号
|
|
# 获取题目序号
|
|
topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
|
|
topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
|
|
# 图片搜索查重功能
|
|
# 图片搜索查重功能
|
|
@@ -219,6 +223,9 @@ class HNSW():
|
|
retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
|
|
retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
|
|
else:
|
|
else:
|
|
retrieve_value_dict["image"] = []
|
|
retrieve_value_dict["image"] = []
|
|
|
|
+ """
|
|
|
|
+ 文本相似度特殊处理
|
|
|
|
+ """
|
|
# 判断句向量维度
|
|
# 判断句向量维度
|
|
if sent_vec.size != self.vector_dim:
|
|
if sent_vec.size != self.vector_dim:
|
|
retrieve_res_list.append(retrieve_value_dict)
|
|
retrieve_res_list.append(retrieve_value_dict)
|
|
@@ -237,13 +244,13 @@ class HNSW():
|
|
if len(query_labels) == 0:
|
|
if len(query_labels) == 0:
|
|
retrieve_res_list.append(retrieve_value_dict)
|
|
retrieve_res_list.append(retrieve_value_dict)
|
|
continue
|
|
continue
|
|
-
|
|
|
|
# 批量读取数据库
|
|
# 批量读取数据库
|
|
mongo_find_dict = {"id": {"$in": query_labels}}
|
|
mongo_find_dict = {"id": {"$in": query_labels}}
|
|
query_dataset = self.mongo_coll.find(mongo_find_dict)
|
|
query_dataset = self.mongo_coll.find(mongo_find_dict)
|
|
-
|
|
|
|
|
|
+ ####################################### 语义相似度借靠 #######################################
|
|
|
|
+ query_data = dict()
|
|
|
|
+ ####################################### 语义相似度借靠 #######################################
|
|
# 返回大于阈值的结果
|
|
# 返回大于阈值的结果
|
|
- filter_threshold = similar
|
|
|
|
for label_data in query_dataset:
|
|
for label_data in query_dataset:
|
|
if "sentence_vec" not in label_data:
|
|
if "sentence_vec" not in label_data:
|
|
continue
|
|
continue
|
|
@@ -253,39 +260,67 @@ class HNSW():
|
|
continue
|
|
continue
|
|
cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
|
|
cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
|
|
# 阈值判断
|
|
# 阈值判断
|
|
- if cosine_score < filter_threshold:
|
|
|
|
|
|
+ if cosine_score < similar:
|
|
continue
|
|
continue
|
|
# 计算编辑距离得分
|
|
# 计算编辑距离得分
|
|
fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
|
|
fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
|
|
- if fuzz_score < min_threshold:
|
|
|
|
- continue
|
|
|
|
- # 对余弦相似度进行折算
|
|
|
|
- if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
|
|
|
|
- cosine_score = cosine_score * 0.95
|
|
|
|
- elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
|
|
|
|
- cosine_score = cosine_score * 0.94
|
|
|
|
- # 余弦相似度折算后阈值判断
|
|
|
|
- if cosine_score < filter_threshold:
|
|
|
|
- continue
|
|
|
|
- retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
|
|
|
|
- retrieve_value_dict["semantics"].append(retrieve_value)
|
|
|
|
# 进行编辑距离得分验证,若小于设定分则过滤
|
|
# 进行编辑距离得分验证,若小于设定分则过滤
|
|
- if fuzz_score >= filter_threshold:
|
|
|
|
|
|
+ if fuzz_score >= similar:
|
|
retrieve_value = [label_data["id"], fuzz_score]
|
|
retrieve_value = [label_data["id"], fuzz_score]
|
|
retrieve_value_dict["text"].append(retrieve_value)
|
|
retrieve_value_dict["text"].append(retrieve_value)
|
|
|
|
+ ####################################### 语义相似度借靠 #######################################
|
|
|
|
+ max_mark_score = 0
|
|
|
|
+ if cosine_score >= 0.9 and cosine_score > max_mark_score:
|
|
|
|
+ max_mark_score = cosine_score
|
|
|
|
+ query_data = label_data
|
|
|
|
+ ####################################### 语义相似度借靠 #######################################
|
|
|
|
|
|
- # 将组合结果按照score降序排序并取得分前十个结果
|
|
|
|
|
|
+ """
|
|
|
|
+ 语义相似度特殊处理
|
|
|
|
+ """
|
|
|
|
+ # 批量读取数据库
|
|
|
|
+ knowledge_id_list = query_data["knowledge_id"] if query_data else []
|
|
|
|
+ label_dict = dict()
|
|
|
|
+ # label_dict["quesType"] = retrieve_list[i]["quesType"] if query_data else []
|
|
|
|
+ label_dict["knowledge"] = query_data["knowledge"] if query_data else []
|
|
|
|
+ label_dict["physical_scene"] = query_data["physical_scene"] if query_data else []
|
|
|
|
+ label_dict["solving_type"] = query_data["solving_type"] if query_data else []
|
|
|
|
+ label_dict["difficulty"] = float(query_data["difficulty"]) if query_data else 0
|
|
|
|
+ label_dict["physical_quantity"] = query_data["physical_quantity"] if query_data else []
|
|
|
|
+ # label_dict["image_semantics"] = query_data["image_semantics"] if query_data else []
|
|
|
|
+ query_data["quesType"] = retrieve_list[i].get("quesType", '')
|
|
|
|
+
|
|
|
|
+ if len(knowledge_id_list) > 0:
|
|
|
|
+ relate_list = []
|
|
|
|
+ for ele in knowledge_id_list:
|
|
|
|
+ init_id = int(ele / 10) * 10
|
|
|
|
+ last_id = self.cph_score.init_id2max_id[str(init_id)]
|
|
|
|
+ relate_list.extend(np.arange(init_id + 1, last_id + 1).tolist())
|
|
|
|
+ knowledge_id_list = relate_list
|
|
|
|
+ mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
|
|
|
|
+ query_dataset = self.mongo_coll.find(mongo_find_dict)
|
|
|
|
+ # 返回大于阈值的结果
|
|
|
|
+ for refer_data in query_dataset:
|
|
|
|
+ sum_score, score_dict = self.cph_score(query_data, refer_data, scale)
|
|
|
|
+ if sum_score < similar:
|
|
|
|
+ continue
|
|
|
|
+ retrieve_value = [refer_data["id"], sum_score, score_dict]
|
|
|
|
+ retrieve_value_dict["semantics"].append(retrieve_value)
|
|
|
|
+
|
|
|
|
+ # 将组合结果按照score降序排序
|
|
retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
|
|
retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
|
|
for k,value in retrieve_value_dict.items()}
|
|
for k,value in retrieve_value_dict.items()}
|
|
|
|
|
|
- # 综合排序
|
|
|
|
- synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
|
|
|
|
- synthese_set = set()
|
|
|
|
- for ele in synthese_list:
|
|
|
|
- if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
|
|
|
|
- synthese_set.add(ele[0])
|
|
|
|
- retrieve_sort_dict["synthese"].append(ele)
|
|
|
|
|
|
+ # # 综合排序
|
|
|
|
+ # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
|
|
|
|
+ # synthese_set = set()
|
|
|
|
+ # for ele in synthese_list:
|
|
|
|
+ # # 综合排序返回前50个
|
|
|
|
+ # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
|
|
|
|
+ # synthese_set.add(ele[0])
|
|
|
|
+ # retrieve_sort_dict["synthese"].append(ele[:2])
|
|
# 加入题目序号
|
|
# 加入题目序号
|
|
|
|
+ retrieve_sort_dict["label"] = label_dict
|
|
retrieve_sort_dict["topic_num"] = topic_num
|
|
retrieve_sort_dict["topic_num"] = topic_num
|
|
|
|
|
|
# 以字典形式返回最终查重结果
|
|
# 以字典形式返回最终查重结果
|