|
@@ -4,12 +4,12 @@ import requests
|
|
import numpy as np
|
|
import numpy as np
|
|
from fuzzywuzzy import fuzz
|
|
from fuzzywuzzy import fuzz
|
|
from sentence_transformers import util
|
|
from sentence_transformers import util
|
|
-from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from pprint import pprint
|
|
from pprint import pprint
|
|
|
|
|
|
import config
|
|
import config
|
|
from formula_process import formula_recognize
|
|
from formula_process import formula_recognize
|
|
from comprehensive_score import Comprehensive_Score
|
|
from comprehensive_score import Comprehensive_Score
|
|
|
|
+from physical_quantity_extract import physical_quantity_extract
|
|
|
|
|
|
class HNSW():
|
|
class HNSW():
|
|
def __init__(self, data_process, logger=None):
|
|
def __init__(self, data_process, logger=None):
|
|
@@ -17,13 +17,16 @@ class HNSW():
|
|
self.mongo_coll = config.mongo_coll
|
|
self.mongo_coll = config.mongo_coll
|
|
self.vector_dim = config.vector_dim
|
|
self.vector_dim = config.vector_dim
|
|
self.hnsw_retrieve_url = config.hnsw_retrieve_url
|
|
self.hnsw_retrieve_url = config.hnsw_retrieve_url
|
|
|
|
+ self.dim_classify_url = config.dim_classify_url
|
|
# 日志采集
|
|
# 日志采集
|
|
self.logger = logger
|
|
self.logger = logger
|
|
self.log_msg = config.log_msg
|
|
self.log_msg = config.log_msg
|
|
# 数据预处理实例化
|
|
# 数据预处理实例化
|
|
self.dpp = data_process
|
|
self.dpp = data_process
|
|
# 语义相似度实例化
|
|
# 语义相似度实例化
|
|
- self.cph_score = Comprehensive_Score()
|
|
|
|
|
|
+ self.cph_score = Comprehensive_Score(config.dev_mode)
|
|
|
|
+ # 难度数值化定义
|
|
|
|
+ self.difficulty_transfer = {"容易": 0.2, "较易": 0.4, "一般": 0.6, "较难": 0.8, "困难": 1.0}
|
|
# 加载公式处理数据模型(词袋模型/原始向量/原始数据)
|
|
# 加载公式处理数据模型(词袋模型/原始向量/原始数据)
|
|
with open(config.bow_model_path, "rb") as bm:
|
|
with open(config.bow_model_path, "rb") as bm:
|
|
self.bow_model = pickle.load(bm)
|
|
self.bow_model = pickle.load(bm)
|
|
@@ -112,101 +115,21 @@ class HNSW():
|
|
|
|
|
|
return formula_res_list[:50]
|
|
return formula_res_list[:50]
|
|
|
|
|
|
- # # HNSW查(支持多学科混合查重)
|
|
|
|
- # def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
|
|
|
|
- # # 计算retrieve_list的vec值
|
|
|
|
- # # 调用清洗分词函数和句向量计算函数
|
|
|
|
- # sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
|
|
|
|
-
|
|
|
|
- # # HNSW查重
|
|
|
|
- # def dup_search(retrieve_data, sent_vec, cont_clear):
|
|
|
|
- # # 初始化返回数据类型
|
|
|
|
- # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
|
|
|
|
- # # 获取题目序号
|
|
|
|
- # topic_num = retrieve_data["topic_num"] if "topic_num" in retrieve_data else 1
|
|
|
|
- # # 图片搜索查重功能
|
|
|
|
- # if doc_flag is True:
|
|
|
|
- # retrieve_value_dict["image"] = self.img_retrieve(retrieve_data["stem"], post_url, similar, topic_num)
|
|
|
|
- # else:
|
|
|
|
- # retrieve_value_dict["image"] = []
|
|
|
|
- # # 判断句向量维度
|
|
|
|
- # if sent_vec.size != self.vector_dim:
|
|
|
|
- # return retrieve_value_dict
|
|
|
|
- # # 调用hnsw接口检索数据
|
|
|
|
- # post_list = sent_vec.tolist()
|
|
|
|
- # try:
|
|
|
|
- # query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
|
|
|
|
- # except Exception as e:
|
|
|
|
- # query_labels = []
|
|
|
|
- # # 日志采集
|
|
|
|
- # if self.logger is not None:
|
|
|
|
- # self.logger.error(self.log_msg.format(id="HNSW检索error",
|
|
|
|
- # type="当前题目HNSW检索error",
|
|
|
|
- # message=cont_clear))
|
|
|
|
- # if len(query_labels) == 0:
|
|
|
|
- # return retrieve_value_dict
|
|
|
|
-
|
|
|
|
- # # 批量读取数据库
|
|
|
|
- # mongo_find_dict = {"id": {"$in": query_labels}}
|
|
|
|
- # query_dataset = self.mongo_coll.find(mongo_find_dict)
|
|
|
|
-
|
|
|
|
- # # 返回大于阈值的结果
|
|
|
|
- # filter_threshold = similar
|
|
|
|
- # for label_data in query_dataset:
|
|
|
|
- # if "sentence_vec" not in label_data:
|
|
|
|
- # continue
|
|
|
|
- # # 计算余弦相似度得分
|
|
|
|
- # label_vec = pickle.loads(label_data["sentence_vec"])
|
|
|
|
- # if label_vec.size != self.vector_dim:
|
|
|
|
- # continue
|
|
|
|
- # cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
|
|
|
|
- # # 阈值判断
|
|
|
|
- # if cosine_score < filter_threshold:
|
|
|
|
- # continue
|
|
|
|
- # # 计算编辑距离得分
|
|
|
|
- # fuzz_score = fuzz.ratio(cont_clear, label_data["content_clear"]) / 100
|
|
|
|
- # if fuzz_score < min_threshold:
|
|
|
|
- # continue
|
|
|
|
- # # 对余弦相似度进行折算
|
|
|
|
- # if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
|
|
|
|
- # cosine_score = cosine_score * 0.95
|
|
|
|
- # elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
|
|
|
|
- # cosine_score = cosine_score * 0.94
|
|
|
|
- # # 余弦相似度折算后阈值判断
|
|
|
|
- # if cosine_score < filter_threshold:
|
|
|
|
- # continue
|
|
|
|
- # retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
|
|
|
|
- # retrieve_value_dict["semantics"].append(retrieve_value)
|
|
|
|
- # # 进行编辑距离得分验证,若小于设定分则过滤
|
|
|
|
- # if fuzz_score >= filter_threshold:
|
|
|
|
- # retrieve_value = [label_data["id"], fuzz_score]
|
|
|
|
- # retrieve_value_dict["text"].append(retrieve_value)
|
|
|
|
-
|
|
|
|
- # # 将组合结果按照score降序排序并取得分前十个结果
|
|
|
|
- # retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
|
|
|
|
- # for k,value in retrieve_value_dict.items()}
|
|
|
|
-
|
|
|
|
- # # 综合排序
|
|
|
|
- # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
|
|
|
|
- # synthese_set = set()
|
|
|
|
- # for ele in synthese_list:
|
|
|
|
- # if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
|
|
|
|
- # synthese_set.add(ele[0])
|
|
|
|
- # retrieve_sort_dict["synthese"].append(ele)
|
|
|
|
- # # 加入题目序号
|
|
|
|
- # retrieve_sort_dict["topic_num"] = topic_num
|
|
|
|
-
|
|
|
|
- # # 以字典形式返回最终查重结果
|
|
|
|
- # return retrieve_sort_dict
|
|
|
|
-
|
|
|
|
- # # 多线程HNSW查重
|
|
|
|
- # with ThreadPoolExecutor(max_workers=5) as executor:
|
|
|
|
- # retrieve_res_list = list(executor.map(dup_search, retrieve_list, sent_vec_list, cont_clear_list))
|
|
|
|
-
|
|
|
|
- # return retrieve_res_list
|
|
|
|
-
|
|
|
|
# HNSW查(支持多学科混合查重)
|
|
# HNSW查(支持多学科混合查重)
|
|
def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag):
|
|
def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag):
|
|
|
|
+ """
|
|
|
|
+ return:
|
|
|
|
+ [
|
|
|
|
+ {
|
|
|
|
+ 'semantics': [[20232015, 1.0, {'quesType': 1.0, 'knowledge': 1.0, 'physical_scene': 1.0, 'solving_type': 1.0, 'difficulty': 1.0, 'physical_quantity': 1.0}]],
|
|
|
|
+ 'text': [[20232015, 0.97]],
|
|
|
|
+ 'image': [],
|
|
|
|
+ 'label': {'knowledge': ['串并联电路的辨别'], 'physical_scene': ['串并联电路的辨别'], 'solving_type': ['规律理解'], 'difficulty': 0.6, 'physical_quantity': ['电流']},
|
|
|
|
+ 'topic_num': 1
|
|
|
|
+ },
|
|
|
|
+ ...
|
|
|
|
+ ]
|
|
|
|
+ """
|
|
# 计算retrieve_list的vec值
|
|
# 计算retrieve_list的vec值
|
|
# 调用清洗分词函数和句向量计算函数
|
|
# 调用清洗分词函数和句向量计算函数
|
|
sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
|
|
sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
|
|
@@ -231,9 +154,9 @@ class HNSW():
|
|
retrieve_res_list.append(retrieve_value_dict)
|
|
retrieve_res_list.append(retrieve_value_dict)
|
|
continue
|
|
continue
|
|
# 调用hnsw接口检索数据
|
|
# 调用hnsw接口检索数据
|
|
- post_list = sent_vec.tolist()
|
|
|
|
try:
|
|
try:
|
|
- query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
|
|
|
|
|
|
+ hnsw_post_list = sent_vec.tolist()
|
|
|
|
+ query_labels = requests.post(self.hnsw_retrieve_url, json=hnsw_post_list, timeout=10).json()
|
|
except Exception as e:
|
|
except Exception as e:
|
|
query_labels = []
|
|
query_labels = []
|
|
# 日志采集
|
|
# 日志采集
|
|
@@ -278,34 +201,59 @@ class HNSW():
|
|
"""
|
|
"""
|
|
语义相似度特殊处理
|
|
语义相似度特殊处理
|
|
"""
|
|
"""
|
|
- # 批量读取数据库
|
|
|
|
- knowledge_id_list = query_data["knowledge_id"] if query_data else []
|
|
|
|
|
|
+ # 标签字典初始化
|
|
label_dict = dict()
|
|
label_dict = dict()
|
|
- # label_dict["quesType"] = retrieve_list[i]["quesType"] if query_data else []
|
|
|
|
|
|
+ # 知识点LLM标注
|
|
|
|
+ # label_dict["knowledge"] = query_data["knowledge"] if query_data else []
|
|
label_dict["knowledge"] = query_data["knowledge"] if query_data else []
|
|
label_dict["knowledge"] = query_data["knowledge"] if query_data else []
|
|
- label_dict["physical_scene"] = query_data["physical_scene"] if query_data else []
|
|
|
|
- label_dict["solving_type"] = query_data["solving_type"] if query_data else []
|
|
|
|
- label_dict["difficulty"] = float(query_data["difficulty"]) if query_data else 0
|
|
|
|
- label_dict["physical_quantity"] = query_data["physical_quantity"] if query_data else []
|
|
|
|
- # label_dict["image_semantics"] = query_data["image_semantics"] if query_data else []
|
|
|
|
- query_data["quesType"] = retrieve_list[i].get("quesType", '')
|
|
|
|
|
|
+ tagging_id_list = [self.cph_score.knowledge2id[ele] for ele in label_dict["knowledge"] \
|
|
|
|
+ if ele in self.cph_score.knowledge2id]
|
|
|
|
+ # 题型数据获取
|
|
|
|
+ label_dict["quesType"] = retrieve_list[i].get("quesType", "选择题")
|
|
|
|
+ # 多维分类api调用
|
|
|
|
+ try:
|
|
|
|
+ dim_post_list = {"sentence": cont_clear_list[i], "quesType": label_dict["quesType"]}
|
|
|
|
+ dim_classify_dict = requests.post(self.dim_classify_url, json=dim_post_list, timeout=10).json()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ dim_classify_dict = {"solving_type": ["规律理解"], "difficulty": 0.6}
|
|
|
|
+ # 日志采集
|
|
|
|
+ if self.logger is not None:
|
|
|
|
+ self.logger.error(self.log_msg.format(id="多维分类error",
|
|
|
|
+ type="当前题目多维分类error",
|
|
|
|
+ message=cont_clear_list[i]))
|
|
|
|
+ # 求解类型模型分类
|
|
|
|
+ label_dict["solving_type"] = dim_classify_dict["solving_type"]
|
|
|
|
+ # 难度模型分类
|
|
|
|
+ label_dict["difficulty"] = dim_classify_dict["difficulty"]
|
|
|
|
+ # 物理量规则提取
|
|
|
|
+ label_dict["physical_quantity"] = physical_quantity_extract(cont_clear_list[i])
|
|
|
|
|
|
|
|
+ # LLM标注知识点题目获取题库对应相似知识点题目数据
|
|
|
|
+ knowledge_id_list = []
|
|
|
|
+ if len(tagging_id_list) > 0:
|
|
|
|
+ ####################################### encode_base_value设置 #######################################
|
|
|
|
+ # 考试院: 10000, 风向标: 10
|
|
|
|
+ encode_base_value = 10000 if config.dev_mode == "ksy" else 10
|
|
|
|
+ ####################################### encode_base_value设置 #######################################
|
|
|
|
+ for ele in tagging_id_list:
|
|
|
|
+ init_id = int(ele / encode_base_value) * encode_base_value
|
|
|
|
+ init_id_list = self.cph_score.init_id2max_id.get(str(init_id), [])
|
|
|
|
+ knowledge_id_list.extend(init_id_list)
|
|
|
|
+ knowledge_query_dataset = None
|
|
if len(knowledge_id_list) > 0:
|
|
if len(knowledge_id_list) > 0:
|
|
- relate_list = []
|
|
|
|
- for ele in knowledge_id_list:
|
|
|
|
- init_id = int(ele / 10) * 10
|
|
|
|
- last_id = self.cph_score.init_id2max_id[str(init_id)]
|
|
|
|
- relate_list.extend(np.arange(init_id + 1, last_id + 1).tolist())
|
|
|
|
- knowledge_id_list = relate_list
|
|
|
|
- mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
|
|
|
|
- query_dataset = self.mongo_coll.find(mongo_find_dict)
|
|
|
|
|
|
+ mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
|
|
|
|
+ knowledge_query_dataset = self.mongo_coll.find(mongo_find_dict)
|
|
# 返回大于阈值的结果
|
|
# 返回大于阈值的结果
|
|
- for refer_data in query_dataset:
|
|
|
|
- sum_score, score_dict = self.cph_score(query_data, refer_data, scale)
|
|
|
|
- if sum_score < similar:
|
|
|
|
- continue
|
|
|
|
- retrieve_value = [refer_data["id"], sum_score, score_dict]
|
|
|
|
- retrieve_value_dict["semantics"].append(retrieve_value)
|
|
|
|
|
|
+ if knowledge_query_dataset:
|
|
|
|
+ for refer_data in knowledge_query_dataset:
|
|
|
|
+ # 难度数值转换
|
|
|
|
+ if refer_data["difficulty"] in self.difficulty_transfer:
|
|
|
|
+ refer_data["difficulty"] = self.difficulty_transfer[refer_data["difficulty"]]
|
|
|
|
+ sum_score, score_dict = self.cph_score(label_dict, refer_data, scale)
|
|
|
|
+ if sum_score < similar:
|
|
|
|
+ continue
|
|
|
|
+ retrieve_value = [refer_data["id"], sum_score, score_dict]
|
|
|
|
+ retrieve_value_dict["semantics"].append(retrieve_value)
|
|
|
|
|
|
# 将组合结果按照score降序排序
|
|
# 将组合结果按照score降序排序
|
|
retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
|
|
retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
|