Browse Source

新增多维度分类

tujintao 10 months ago
parent
commit
13a1d22664

BIN
__pycache__/config.cpython-38.pyc


BIN
__pycache__/model.cpython-38.pyc


BIN
__pycache__/physical_quantity_extract.cpython-38.pyc


+ 65 - 25
comparison.py

@@ -1,14 +1,8 @@
 import json
+import numpy as np
 import pandas as pd
 
 keyword2id_dict = dict()
-# 求解类型
-solution_type_list = ["概念辨析","规律理解","现象解释","物理学史","计算分析","实验操作","连线作图","实验读数"]
-
-solving_type2id = dict()
-for i, ele in enumerate(solution_type_list):
-    solving_type2id[ele] = 1 + i
-keyword2id_dict["solving_type2id"] = solving_type2id
 
 # 物理量
 excel_path = r"data/物理量.xlsx"
@@ -25,22 +19,22 @@ for i in range(len(df)):
         quantity2id[knowledge] = sign_index
 keyword2id_dict["quantity2id"] = quantity2id
 
-# 物理场景
-excel_path = r"data/物理情景.xlsx"
-df = pd.read_excel(excel_path)
-scene2id = dict()
-count_index = 0
-for i in range(len(df)):
-    if not pd.isna(df['知识点'][i]):
-        count_index += 1
-        sign_index = 10000 + count_index * 10
-    knowledge = df['情景'][i]
-    if not pd.isna(knowledge):
-        sign_index += 1
-        scene2id[knowledge] = sign_index
-keyword2id_dict["scene2id"] = scene2id
+# # 物理场景
+# excel_path = r"data/物理情景.xlsx"
+# df = pd.read_excel(excel_path)
+# scene2id = dict()
+# count_index = 0
+# for i in range(len(df)):
+#     if not pd.isna(df['知识点'][i]):
+#         count_index += 1
+#         sign_index = 10000 + count_index * 10
+#     knowledge = df['情景'][i]
+#     if not pd.isna(knowledge):
+#         sign_index += 1
+#         scene2id[knowledge] = sign_index
+# keyword2id_dict["scene2id"] = scene2id
 
-# 知识点
+# 风向标-知识点
 excel_path = r"data/物理知识点.xlsx"
 df = pd.read_excel(excel_path)
 knowledge2id = dict()
@@ -54,15 +48,61 @@ for i in range(len(df)):
         # sign_index = 10000 + int(sign[0]) * 100 + int(sign[1]) * 10
         sign_index = 10000 + count_index * 100 + int(sign[1]) * 10
         init_id = sign_index
-        init_id2max_id[init_id] = sign_index
+        if init_id not in init_id2max_id:
+            init_id2max_id[init_id] = []
+        else:
+            init_id2max_id[init_id].append(sign_index)
     knowledge = df['4级知识点'][i]
     if not pd.isna(knowledge):
         sign_index += 1
         knowledge2id[knowledge] = sign_index
-        init_id2max_id[init_id] = sign_index
+        if init_id not in init_id2max_id:
+            init_id2max_id[init_id] = []
+        else:
+            init_id2max_id[init_id].append(sign_index)
 keyword2id_dict["knowledge2id"] = knowledge2id
 keyword2id_dict["init_id2max_id"] = init_id2max_id
 
+# # 考试院-知识点
+# excel_path = r"data/初中物理知识对应关系.xlsx"
+# df = pd.read_excel(excel_path)
+# knowledge2id = dict()
+# init_id2max_id = dict()
+# count_index = 0
+# for i in range(len(df)):
+#     if not pd.isna(df.iloc[i][2]):
+#         count_index += 1
+#         sign_index = 100000000 + count_index * 1000000
+#         if  pd.isna(df.iloc[i+1][3]):
+#             knowledge = df.iloc[i][2].split(' ')[1]
+#             knowledge2id[knowledge] = sign_index
+#             continue
+#     if not pd.isna(df.iloc[i][3]):
+#         sign_index = int(str(sign_index)[:-4]) * 10000
+#         sign_index += 10000
+#         relate_index = sign_index
+#         init_id2max_id[relate_index] = []
+#         if pd.isna(df.iloc[i+1][4]):
+#             knowledge = df.iloc[i][3].split(' ')[1]
+#             knowledge2id[knowledge] = sign_index
+#             continue
+#     if not pd.isna(df.iloc[i][4]):
+#         sign_index = int(str(sign_index)[:-2]) * 100
+#         sign_index += 100
+#         if pd.isna(df.iloc[i+1][5]):
+#             knowledge = df.iloc[i][4].split(' ')[1]
+#             knowledge2id[knowledge] = sign_index
+#             init_id2max_id[relate_index].append(sign_index)
+#             continue
+#     if not pd.isna(df.iloc[i][5]):
+#         sign_index += 1
+#         knowledge = df.iloc[i][5].split(' ')[1]
+#         knowledge2id[knowledge] = sign_index
+#         init_id2max_id[relate_index].append(sign_index)
+
+# keyword2id_dict["knowledge2id"] = knowledge2id
+# keyword2id_dict["init_id2max_id"] = init_id2max_id
+
 # 映射转换
-with open("data/keyword_mapping.json", 'w', encoding="utf8") as f:
+with open("model_data/keyword_mapping.json", 'w', encoding="utf8") as f:
     json.dump(keyword2id_dict, f, ensure_ascii=False, indent=2)

+ 23 - 32
comprehensive_score.py

@@ -3,10 +3,11 @@ from fuzzywuzzy import fuzz
 
 
 class Comprehensive_Score():
-    def __init__(self):
+    def __init__(self, dev_mode):
         with open("model_data/keyword_mapping.json", 'r', encoding="utf8") as f:
             keyword_mapping = json.load(f)
-        self.scene2id = keyword_mapping["scene2id"]
+        # 根据ksy和fxb判断值的次方
+        self.power = 2 if dev_mode == "ksy" else 1
         self.knowledge2id = keyword_mapping["knowledge2id"]
         self.quantity2id = keyword_mapping["quantity2id"]
         self.init_id2max_id = keyword_mapping["init_id2max_id"]
@@ -15,21 +16,19 @@ class Comprehensive_Score():
         score_dict = dict()
         quesType = self.compute_quesType(query["quesType"], refer["quesType"]["quesType"])
         knowledge = self.compute_knowledge(query["knowledge"], refer["knowledge"])
-        physical_scene = self.compute_physical_scene(query["physical_scene"], refer["physical_scene"])
         solving_type = self.compute_solving_type(query["solving_type"], refer["solving_type"])
         difficulty = self.compute_difficulty(query["difficulty"], refer["difficulty"])
         physical_quantity = self.compute_physical_quantity(query["physical_quantity"], refer["physical_quantity"])
         # image_semantics = self.compute_image_semantics(query["image_semantics"], refer["image_semantics"])
 
-        sum_score = quesType * scale["quesType"] + knowledge * scale["knowledge"] + physical_scene * scale["physical_scene"] + \
+        sum_score = quesType * scale["quesType"] + knowledge * scale["knowledge"] + \
                     solving_type * scale["solving_type"] + difficulty * scale["difficulty"] + \
-                    physical_quantity * scale["physical_quantity"]# + image_semantics * scale["image_semantics"]
+                    physical_quantity * scale["physical_quantity"]
         sum_score = int(sum_score * 100) / 100
         sum_score = min(sum_score, 1.0)
 
         score_dict["quesType"] = quesType
         score_dict["knowledge"] = knowledge
-        score_dict["physical_scene"] = physical_scene
         score_dict["solving_type"] = solving_type
         score_dict["difficulty"] = difficulty
         score_dict["physical_quantity"] = physical_quantity
@@ -38,7 +37,7 @@ class Comprehensive_Score():
         return sum_score, score_dict
 
     # 知识点/物理场景/物理量相互关联得分计算
-    def compute_relate_score(self, query_list, refer_list, keyword2id, mode=0):
+    def compute_relate_score(self, query_list, refer_list, keyword2id, mode):
         query_set, refer_set = set(query_list), set(refer_list)
         if query_set == refer_set:
             return 1.0
@@ -56,14 +55,11 @@ class Comprehensive_Score():
                     query_score += 1
                     continue
                 # 知识点
-                if mode == 0:
-                    if abs(query_id - refer_id) < 10: query_score += 0.3
-                    elif abs(query_id - refer_id) < 100: query_score += 0.2
+                if mode == 1:
+                    if abs(query_id - refer_id) < 10 ** self.power: query_score += 0.3
+                    elif abs(query_id - refer_id) < 100 ** self.power: query_score += 0.2
                     else: continue
-                elif mode == 1:
-                    if abs(query_id - refer_id) < 10: query_score += 0.5
-                    else: continue
-                elif mode == 3:
+                elif mode == 2:
                     if abs(query_id - refer_id) < 100: query_score += 0.2
                     else: continue
                 fuzz_score = fuzz.ratio(query, refer)
@@ -90,12 +86,7 @@ class Comprehensive_Score():
 
     # 知识点相似度评分
     def compute_knowledge(self, query_list, refer_list):
-        score = self.compute_relate_score(query_list, refer_list, self.knowledge2id, mode=0)
-        return int(score * 100) / 100
-
-    # 物理场景相似度评分
-    def compute_physical_scene(self, query_list, refer_list):
-        score = self.compute_relate_score(query_list, refer_list, self.scene2id, mode=1)
+        score = self.compute_relate_score(query_list, refer_list, self.knowledge2id, mode=1)
         return int(score * 100) / 100
 
     # 试题求解类型相似度评分
@@ -117,15 +108,15 @@ class Comprehensive_Score():
         score = self.compute_relate_score(query_list, refer_list, self.quantity2id, mode=2)
         return int(score * 100) / 100
 
-    # 图片语义相似度评分
-    def compute_image_semantics(self, query_list, refer_list):
-        query_set, refer_set = set(query_list), set(refer_list)
-        if len(query_set) == 0 and len(refer_set) == 0:
-            return 1
-        elif len(query_set) == 0 or len(refer_set) == 0:
-            return 0
-        elif len(query_set) > len(refer_set):
-            query_set, refer_set = refer_set, query_set
-        same_count = sum([1 for ele in query_set if ele in refer_set])
-        score = same_count / len(refer_set)
-        return int(score * 100) / 100
+    # # 图片语义相似度评分
+    # def compute_image_semantics(self, query_list, refer_list):
+    #     query_set, refer_set = set(query_list), set(refer_list)
+    #     if len(query_set) == 0 and len(refer_set) == 0:
+    #         return 1
+    #     elif len(query_set) == 0 or len(refer_set) == 0:
+    #         return 0
+    #     elif len(query_set) > len(refer_set):
+    #         query_set, refer_set = refer_set, query_set
+    #     same_count = sum([1 for ele in query_set if ele in refer_set])
+    #     score = same_count / len(refer_set)
+    #     return int(score * 100) / 100

+ 19 - 2
config.py

@@ -1,8 +1,13 @@
 import os
 import pymongo
 
+# 开发模式: 0-fxb, 1-ksy
+dev_mode_list = ["fxb", "ksy"]
+dev_mode = dev_mode_list[0]
+
 # 建立mongodb连接
-myclient = pymongo.MongoClient("mongodb://192.168.1.140:27017/")
+client_url = dict(fxb="mongodb://192.168.1.140:27017/", ksy="mongodb://127.0.0.1:27017/")[dev_mode]
+myclient = pymongo.MongoClient(client_url)
 mongo_info_db = myclient["ksy"]
 mongo_coll = mongo_info_db['test_topic']
 
@@ -25,8 +30,15 @@ num_elements = 1000000
 # hnsw召回数量参数
 hnsw_set_ef = 150
 
+# 调用api链接
+# 配图查重
+illustration_url = dict(fxb="http://192.168.1.204:8068/topic_retrieval_http", ksy="http://127.0.0.1:8068/topic_retrieval_http")[dev_mode]
+# 图片查重
+image_url = dict(fxb="http://192.168.1.204:8068/img_retrieval_http", ksy="http://127.0.0.1:8068/img_retrieval_http")[dev_mode]
 # hnsw模型检索链接
-hnsw_retrieve_url = r"http://localhost:8836/retrieve"
+hnsw_retrieve_url = r"http://127.0.0.1:8836/retrieve"
+# 多维度分类链接
+dim_classify_url = r"http://127.0.0.1:8837/dim_classify"
 
 # 根地址
 root_path = os.getcwd()
@@ -46,6 +58,11 @@ hnsw_path = "hnsw_model.bin"
 bow_model_path = os.path.join(data_root_path, "bow_model.pkl")
 bow_vector_path = os.path.join(data_root_path, "bow_vector.npy")
 formula_data_path = os.path.join(data_root_path, "formula_data.json")
+# 分词器地址
+bert_path = "bert-base-chinese"
+# 多维度分类模型地址
+solution_model_path = os.path.join(data_root_path, "solution_classify.pt")
+difficulty_model_path = os.path.join(data_root_path, "difficulty_classify.pt")
 
 # 日志地址
 log_root_path = os.path.join(root_path, "logs")

+ 2 - 2
data_preprocessing.py

@@ -196,8 +196,8 @@ class DataPreProcessing():
         content_clear = re.sub(r'\[题文\]', '', content_clear)
         content_clear = re.sub(r'(\([单多]选\)|\[[单多]选\])', '', content_clear)
         content_clear = re.sub(r'(\(\d{1,2}分\)|\[\d{1,2}分\])', '', content_clear)
-        # 将文本中的选项"A.B.C.D."改为";"
-        content_clear = re.sub(r'[ABCD]\.', ';', content_clear)
+        # # 将文本中的选项"A.B.C.D."改为";"
+        # content_clear = re.sub(r'[ABCD]\.', ';', content_clear)
         # # 若选项中只有1,2,3或(1)(2)(3),或者只有标点符号,则过滤该选项
         # content_clear = re.sub(r'(\(\d\)[、,;\.]?)+\(\d\)|\d[、,;]+\d', '', content_clear)
         # 去除题目开头(...年...[中模月]考)文本

+ 1 - 1
db_train_app.py

@@ -22,7 +22,7 @@ def convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub):
         print(data["knowledge"])
         condition = {"id": data["id"]}
         # 需要新增train_flag,防止机器奔溃重复训练
-        knowledge_list = [knowledge2id[ele] for ele in data["knowledge"] if knowledge2id.get(ele, 0)]
+        knowledge_list = [knowledge2id[ele] for ele in data["knowledge"] if ele in knowledge2id]
         update_elements = {"$set": {"knowledge_id": knowledge_list}}
         mongo_coll.update_one(condition, update_elements)
     print("耗时:", time.time()-start)

+ 122 - 0
dim_classify.py

@@ -0,0 +1,122 @@
+import torch
+import torch.nn as nn
+from transformers import AutoConfig, BertTokenizer, AutoModel
+
+import config
+
+class Solution_Model(nn.Module):
+    def __init__(self):
+        super(Solution_Model, self).__init__()
+        self.bert_config = AutoConfig.from_pretrained(config.bert_path)
+        self.bert = AutoModel.from_pretrained(config.bert_path)
+        self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8)
+
+    def forward(self, input_ids, attention_mask):
+        x = self.bert(input_ids, attention_mask)[0][:, 0, :]
+        x = self.fc(x)
+
+        return x
+
+class Difficulty_Model(nn.Module):
+    def __init__(self):
+        super(Difficulty_Model, self).__init__()
+        self.bert_config = AutoConfig.from_pretrained(config.bert_path)
+        self.bert = AutoModel.from_pretrained(config.bert_path)
+        self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8)
+
+    def forward(self, input_ids, attention_mask):
+        x = self.bert(input_ids, attention_mask)[0][:, 0, :]
+        x = self.fc(x)
+
+        return x
+
+class Dimension_Classification():
+    def __init__(self, logger=None):
+        self.tokenizer = BertTokenizer.from_pretrained(config.bert_path)
+        self.solution_model = torch.load(config.solution_model_path)
+        self.difficulty_model = torch.load(config.difficulty_model_path)
+        self.max_squence_length = 500
+        self.solving_type_dict = {
+            0: "实验操作", 
+            1: "计算分析", 
+            2: "连线作图", 
+            3: "实验读数", 
+            4: "现象解释", 
+            5: "概念辨析", 
+            6: "规律理解", 
+            7: "物理学史"
+        }
+        # 日志采集
+        self.logger = logger
+
+    def __call__(self, sentence, quesType):
+        solution_list = self.solution_classify(sentence, quesType)
+        difficulty_value = self.difficulty_classify(sentence)
+        res_dict = {
+            "solving_type": solution_list, 
+            "difficulty": difficulty_value, 
+        }
+
+        return res_dict
+
+    def solution_classify(self, sentence, quesType):
+        solution_tensor = self.model_calculate(self.solution_model, sentence)
+        solution_tensor[solution_tensor >= 0.5] = 1
+        solution_tensor[solution_tensor < 0.5] = 0
+        solution_list = [self.solving_type_dict[idx] for idx in solution_tensor[0].int().tolist() if idx == 1]
+        # 题型判断
+        if quesType == "计算题":
+            solution_list.append("计算分析")
+        elif quesType == "作图题":
+            solution_list.append("连线作图")
+        if len(solution_list) == 0:
+            solution_list.append("规律理解")
+            
+        return list(set(solution_list))
+
+    def difficulty_classify(self, sentence):
+        difficulty_tensor = self.model_calculate(self.difficulty_model, sentence).item()
+        difficulty_value = 0.6
+        if difficulty_tensor >= 0.8:
+            difficulty_value = 0.8
+        elif difficulty_tensor <= 0.2:
+            difficulty_value = 0.4
+        else:
+            difficulty_value = 0.6
+        
+        return difficulty_value
+
+    def model_calculate(self, model, sentence):
+        model.eval()
+        with torch.no_grad():
+            token_list = self.sentence_tokenize(sentence)
+            mask_list = self.attention_mask(token_list)
+            output_tensor = model(torch.tensor(token_list), attention_mask=torch.tensor(mask_list))
+            output_tensor = torch.sigmoid(output_tensor)
+
+        return output_tensor
+
+    def sentence_tokenize(self, sentence):
+        # 直接截断
+        # 编码时: 开头添加[LCS]->101, 结尾添加[SEP]->102, 未知的字或单词变为[UNK]->100
+        token_list = self.tokenizer.encode(sentence[:self.max_squence_length])
+        # 补齐(pad的索引号就是0)
+        if len(token_list) < self.max_squence_length + 2:
+            token_list.extend([0] * (self.max_squence_length + 2 - len(token_list)))
+        
+        return [token_list]
+
+    def attention_mask(self, tokens_list):
+        # 在一个文本中,如果是PAD符号则是0,否则就是1
+        mask_list = []
+        for tokens in tokens_list:
+            mask = [float(token > 0) for token in tokens]
+            mask_list.append(mask)
+
+        return mask_list
+
+if __name__ == "__main__":
+    dc = Dimension_Classification()
+    sentence = "荆门市是国家循环经济试点市,目前正在沙洋建设全国最大的秸秆气化发电厂.电厂建成后每年可消化秸秆13万吨,发电9*10^7*kW*h.同时电厂所产生的灰渣将生成肥料返还农民,焦油用于精细化工,实现“农业--工业--农业”循环.(1)若秸秆电厂正常工作时,每小时可发电2.5*10^5*kW*h,按每户居民每天使用5只20*W的节能灯、1个800*W的电饭锅、1台100*W的电视机计算,该发电厂同时可供多少户居民正常用电?(2)与同等规模的火电厂相比,该电厂每年可减少6.4万吨二氧化碳的排放量,若火电厂煤燃烧的热利用率为20%,秸秆电厂每年可节约多少吨标准煤?(标准煤的热值按3.6*10^7J/k*g计算)"
+    res = dc(sentence, "")
+    print(res)

+ 41 - 0
dim_classify_app.py

@@ -0,0 +1,41 @@
+from gevent import monkey; monkey.patch_all()
+from flask import Flask, request, jsonify
+from gevent.pywsgi import WSGIServer
+
+import config
+from log_config import LogConfig
+from dim_classify import Dimension_Classification, Solution_Model, Difficulty_Model
+
+app = Flask(__name__)
+
+# hnsw模型数据检索
+@app.route('/dim_classify', methods=['GET', 'POST'])
+def retrieve():
+    if request.method == 'POST':
+        # 获取post数据
+        retrieve_dict = request.get_json()
+        # 接收日志采集
+        dc_logger.info(config.log_msg.format(id="多维分类",
+                                             type="dim_classify接收",
+                                             message=retrieve_dict))
+        sentence = retrieve_dict["sentence"]
+        questype = retrieve_dict["quesType"]
+        # 多维分类
+        res_list = dim_classify(sentence, questype)
+        # 返回日志采集
+        dc_logger.info(config.log_msg.format(id="多维分类",
+                                             type="dim_classify返回",
+                                             message=res_list))
+        return jsonify(res_list)
+
+
+if __name__ == '__main__':
+    # 日志采集初始化
+    dc_LogConfig = LogConfig(config.retrieval_path, "dim_classify")
+    dc_logger = dc_LogConfig.get_log()
+    # 多维分类模型初始化
+    dim_classify = Dimension_Classification()
+
+    # app.run(host='0.0.0.0',port='8837')
+    server = WSGIServer(('0.0.0.0', 8837), app)
+    server.serve_forever()

+ 6 - 2
hnsw_app.py

@@ -27,8 +27,12 @@ def retrieve():
         # 获取post数据
         query_vec = request.get_json()
         # HNSW检索
-        query_labels = hnsw.retrieve(query_vec)
-        return jsonify(query_labels)
+        res_list = hnsw.retrieve(query_vec)
+        # 返回日志采集
+        hm_logger.info(config.log_msg.format(id="HNSW检索",
+                                             type="retrieve返回",
+                                             message=res_list))
+        return jsonify(res_list)
 
 
 if __name__ == '__main__':

+ 68 - 120
hnsw_retrieval.py

@@ -4,12 +4,12 @@ import requests
 import numpy as np
 from fuzzywuzzy import fuzz
 from sentence_transformers import util
-from concurrent.futures import ThreadPoolExecutor
 from pprint import pprint
 
 import config
 from formula_process import formula_recognize
 from comprehensive_score import Comprehensive_Score
+from physical_quantity_extract import physical_quantity_extract
 
 class HNSW():
     def __init__(self, data_process, logger=None):
@@ -17,13 +17,16 @@ class HNSW():
         self.mongo_coll = config.mongo_coll
         self.vector_dim = config.vector_dim
         self.hnsw_retrieve_url = config.hnsw_retrieve_url
+        self.dim_classify_url = config.dim_classify_url
         # 日志采集
         self.logger = logger
         self.log_msg = config.log_msg
         # 数据预处理实例化
         self.dpp = data_process
         # 语义相似度实例化
-        self.cph_score = Comprehensive_Score()
+        self.cph_score = Comprehensive_Score(config.dev_mode)
+        # 难度数值化定义
+        self.difficulty_transfer = {"容易": 0.2, "较易": 0.4, "一般": 0.6, "较难": 0.8, "困难": 1.0}
         # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
         with open(config.bow_model_path, "rb") as bm:
             self.bow_model = pickle.load(bm)
@@ -112,101 +115,21 @@ class HNSW():
 
         return formula_res_list[:50]
 
-    # # HNSW查(支持多学科混合查重)
-    # def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
-    #     # 计算retrieve_list的vec值
-    #     # 调用清洗分词函数和句向量计算函数
-    #     sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
-
-    #     # HNSW查重
-    #     def dup_search(retrieve_data, sent_vec, cont_clear):
-    #         # 初始化返回数据类型
-    #         retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
-    #         # 获取题目序号
-    #         topic_num = retrieve_data["topic_num"] if "topic_num" in retrieve_data else 1
-    #         # 图片搜索查重功能
-    #         if doc_flag is True:
-    #             retrieve_value_dict["image"] = self.img_retrieve(retrieve_data["stem"], post_url, similar, topic_num)
-    #         else:
-    #             retrieve_value_dict["image"] = []
-    #         # 判断句向量维度
-    #         if sent_vec.size != self.vector_dim:
-    #             return retrieve_value_dict
-    #         # 调用hnsw接口检索数据
-    #         post_list = sent_vec.tolist()
-    #         try:
-    #             query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
-    #         except Exception as e:
-    #             query_labels = []
-    #             # 日志采集
-    #             if self.logger is not None:
-    #                 self.logger.error(self.log_msg.format(id="HNSW检索error",
-    #                                                       type="当前题目HNSW检索error",
-    #                                                       message=cont_clear))
-    #         if len(query_labels) == 0:
-    #             return retrieve_value_dict
-
-    #         # 批量读取数据库
-    #         mongo_find_dict = {"id": {"$in": query_labels}}
-    #         query_dataset = self.mongo_coll.find(mongo_find_dict)
-
-    #         # 返回大于阈值的结果
-    #         filter_threshold = similar
-    #         for label_data in query_dataset:
-    #             if "sentence_vec" not in label_data:
-    #                 continue
-    #             # 计算余弦相似度得分
-    #             label_vec = pickle.loads(label_data["sentence_vec"])
-    #             if label_vec.size != self.vector_dim:
-    #                 continue
-    #             cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
-    #             # 阈值判断
-    #             if cosine_score < filter_threshold:
-    #                 continue
-    #             # 计算编辑距离得分
-    #             fuzz_score = fuzz.ratio(cont_clear, label_data["content_clear"]) / 100
-    #             if fuzz_score < min_threshold:
-    #                 continue
-    #             # 对余弦相似度进行折算
-    #             if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
-    #                 cosine_score = cosine_score * 0.95
-    #             elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
-    #                 cosine_score = cosine_score * 0.94
-    #             # 余弦相似度折算后阈值判断
-    #             if cosine_score < filter_threshold:
-    #                 continue
-    #             retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
-    #             retrieve_value_dict["semantics"].append(retrieve_value)
-    #             # 进行编辑距离得分验证,若小于设定分则过滤
-    #             if fuzz_score >= filter_threshold:
-    #                 retrieve_value = [label_data["id"], fuzz_score]
-    #                 retrieve_value_dict["text"].append(retrieve_value)
-            
-    #         # 将组合结果按照score降序排序并取得分前十个结果
-    #         retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
-    #                               for k,value in retrieve_value_dict.items()}
-
-    #         # 综合排序
-    #         synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
-    #         synthese_set = set()
-    #         for ele in synthese_list:
-    #             if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
-    #                 synthese_set.add(ele[0])
-    #                 retrieve_sort_dict["synthese"].append(ele)
-    #         # 加入题目序号
-    #         retrieve_sort_dict["topic_num"] = topic_num
-            
-    #         # 以字典形式返回最终查重结果
-    #         return retrieve_sort_dict
-
-    #     # 多线程HNSW查重
-    #     with ThreadPoolExecutor(max_workers=5) as executor:
-    #         retrieve_res_list = list(executor.map(dup_search, retrieve_list, sent_vec_list, cont_clear_list))
-
-    #     return retrieve_res_list
-
     # HNSW查(支持多学科混合查重)
     def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag):
+        """
+        return: 
+        [
+            {
+                'semantics': [[20232015, 1.0, {'quesType': 1.0, 'knowledge': 1.0, 'physical_scene': 1.0, 'solving_type': 1.0, 'difficulty': 1.0, 'physical_quantity': 1.0}]], 
+                'text': [[20232015, 0.97]], 
+                'image': [], 
+                'label': {'knowledge': ['串并联电路的辨别'], 'physical_scene': ['串并联电路的辨别'], 'solving_type': ['规律理解'], 'difficulty': 0.6, 'physical_quantity': ['电流']}, 
+                'topic_num': 1
+            }, 
+            ...
+        ]
+        """
         # 计算retrieve_list的vec值
         # 调用清洗分词函数和句向量计算函数
         sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
@@ -231,9 +154,9 @@ class HNSW():
                 retrieve_res_list.append(retrieve_value_dict)
                 continue
             # 调用hnsw接口检索数据
-            post_list = sent_vec.tolist()
             try:
-                query_labels = requests.post(self.hnsw_retrieve_url, json=post_list, timeout=10).json()
+                hnsw_post_list = sent_vec.tolist()
+                query_labels = requests.post(self.hnsw_retrieve_url, json=hnsw_post_list, timeout=10).json()
             except Exception as e:
                 query_labels = []
                 # 日志采集
@@ -278,34 +201,59 @@ class HNSW():
             """
             语义相似度特殊处理
             """
-            # 批量读取数据库
-            knowledge_id_list = query_data["knowledge_id"] if query_data else []
+            # 标签字典初始化
             label_dict = dict()
-            # label_dict["quesType"] = retrieve_list[i]["quesType"] if query_data else []
+            # 知识点LLM标注
+            # label_dict["knowledge"] = query_data["knowledge"] if query_data else []
             label_dict["knowledge"] = query_data["knowledge"] if query_data else []
-            label_dict["physical_scene"] = query_data["physical_scene"] if query_data else []
-            label_dict["solving_type"] = query_data["solving_type"] if query_data else []
-            label_dict["difficulty"] = float(query_data["difficulty"]) if query_data else 0
-            label_dict["physical_quantity"] = query_data["physical_quantity"] if query_data else []
-            # label_dict["image_semantics"] = query_data["image_semantics"] if query_data else []
-            query_data["quesType"] = retrieve_list[i].get("quesType", '')
+            tagging_id_list = [self.cph_score.knowledge2id[ele] for ele in label_dict["knowledge"] \
+                               if ele in self.cph_score.knowledge2id]
+            # 题型数据获取
+            label_dict["quesType"] = retrieve_list[i].get("quesType", "选择题")
+            # 多维分类api调用
+            try:
+                dim_post_list = {"sentence": cont_clear_list[i], "quesType": label_dict["quesType"]}
+                dim_classify_dict = requests.post(self.dim_classify_url, json=dim_post_list, timeout=10).json()
+            except Exception as e:
+                dim_classify_dict = {"solving_type": ["规律理解"], "difficulty": 0.6}
+                # 日志采集
+                if self.logger is not None:
+                    self.logger.error(self.log_msg.format(id="多维分类error",
+                                                          type="当前题目多维分类error",
+                                                          message=cont_clear_list[i]))
+            # 求解类型模型分类
+            label_dict["solving_type"] = dim_classify_dict["solving_type"]
+            # 难度模型分类
+            label_dict["difficulty"] = dim_classify_dict["difficulty"]
+            # 物理量规则提取
+            label_dict["physical_quantity"] = physical_quantity_extract(cont_clear_list[i])
 
+            # LLM标注知识点题目获取题库对应相似知识点题目数据
+            knowledge_id_list = []
+            if len(tagging_id_list) > 0:
+                ####################################### encode_base_value设置 ####################################### 
+                # 考试院: 10000, 风向标: 10
+                encode_base_value = 10000 if config.dev_mode == "ksy" else 10
+                ####################################### encode_base_value设置 ####################################### 
+                for ele in tagging_id_list:
+                    init_id = int(ele / encode_base_value) * encode_base_value
+                    init_id_list = self.cph_score.init_id2max_id.get(str(init_id), [])
+                    knowledge_id_list.extend(init_id_list)
+            knowledge_query_dataset = None
             if len(knowledge_id_list) > 0:
-                relate_list = []
-                for ele in knowledge_id_list:
-                    init_id = int(ele / 10) * 10
-                    last_id = self.cph_score.init_id2max_id[str(init_id)]
-                    relate_list.extend(np.arange(init_id + 1, last_id + 1).tolist())
-                knowledge_id_list = relate_list
-            mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
-            query_dataset = self.mongo_coll.find(mongo_find_dict)
+                mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
+                knowledge_query_dataset = self.mongo_coll.find(mongo_find_dict)
             # 返回大于阈值的结果
-            for refer_data in query_dataset:
-                sum_score, score_dict = self.cph_score(query_data, refer_data, scale)
-                if sum_score < similar:
-                    continue
-                retrieve_value = [refer_data["id"], sum_score, score_dict]
-                retrieve_value_dict["semantics"].append(retrieve_value)
+            if knowledge_query_dataset:
+                for refer_data in knowledge_query_dataset:
+                    # 难度数值转换
+                    if refer_data["difficulty"] in self.difficulty_transfer:
+                        refer_data["difficulty"] = self.difficulty_transfer[refer_data["difficulty"]]
+                    sum_score, score_dict = self.cph_score(label_dict, refer_data, scale)
+                    if sum_score < similar:
+                        continue
+                    retrieve_value = [refer_data["id"], sum_score, score_dict]
+                    retrieve_value_dict["semantics"].append(retrieve_value)
 
             # 将组合结果按照score降序排序
             retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)

+ 62 - 0
physical_quantity_extract.py

@@ -0,0 +1,62 @@
+import re
+
+physical_quantity_dict = {
+"长度": ["刻度尺", "米", "光年", "身高", "[长厚]度", "[厘分千]米"],
+"距离": ["距离", "相距"],
+"高度": ["高度", "高"],
+"时间": ["时间", "秒", "分钟", "秒表"],
+"质量": ["质量", "千克", "克", "惯性", "天平", "公?斤", "\dk?g"],
+"密度": ["密度", "鉴别", "g/cm\^?3", "kg/m\^?3"],
+"速度": ["速度", "米每秒", "平均速度", "运动图像", "追及", "相遇", "m/s", "km/h"],
+"重力": ["重力", "重心", "纬度"],
+"弹力": ["形变", "弹力", "胡克", "弹簧"],
+"拉力": ["拉力", "拉伸"],
+"摩擦力": ["摩擦力", "静摩擦", "滑动摩擦", "滚动摩擦", "接触面粗糙程度", "摩擦"],
+"压强": ["压强", "压力效果"],
+"液体压强": ["液体压强", "连通器"],
+"大气压强": ["大气压", "流体压强", "托里拆利实验"],
+"浮力": ["浮力", "漂浮", "沉浮", "悬浮", "阿基米德原理"],
+# (?!快)排除"快"
+"功": ["做功(?!快)", "功(?!快)"],
+"功率": ["功率", "做功快慢", "瓦特"],
+"动能": ["动能"],
+"重力势能": ["重力势能", "重力做功"],
+"弹性势能": ["弹性势能"],
+"机械能": ["机械能", "动能", "势能"],
+"海拔高度": ["海拔", "海拔高度"],
+"横截面积": ["横截面积", "底面积"],
+"阻力臂": ["阻力臂", "杠杆"],
+"动力臂": ["动力臂", "杠杆"],
+"电荷": ["静电", "摩擦起电", "正电荷", "负电荷", "元电荷", "电荷", "验电器", "带电"],
+"电流": ["定向移动", "安培", "电流", "电流热效应", "电流磁效应", "电流表", "安培表", "电笔"],
+"电压": ["电压", "电压表", "伏特"],
+"电阻": ["电阻", "电阻器", "电阻率", "欧姆", "半导体", "超导", "变阻器", "欧姆定律", "Ω"],
+"电功": ["电功", "焦耳定律", "电能"],
+"电功率": ["额定功率", "电功率", "电功快慢", "额定电压"],
+"温度": ["温度", "温度计", "摄氏度", "华氏度", "温标", "保温", "体温计", "寒暑表", "℃"],
+"内能": ["分子动理论", "热运动", "分子间作用力", "热传递"],
+"热量": ["热量", "热值"],
+"比热容": ["比热容", "吸热本领", "J/(kg·℃)"],
+"音调": ["音调", "频率", "高音", "低音", "Hz"],
+"响度": ["振幅", "分贝", "响度", "dB"],
+"音色": ["乐器", "音色", "闻其声知其人"],
+"入射角": ["入射", "入射角度"],
+"反射角": ["反射", "反射角度"],
+"光速": ["光速"],
+"像距": ["像距"],
+"物距": ["物距"],
+}
+
+def physical_quantity_extract(content):
+    physical_quantity = []
+    for k,v in physical_quantity_dict.items():
+        if re.findall('|'.join(v), content):
+            physical_quantity.append(k)
+    
+    return physical_quantity
+
+
+if __name__ == "__main__":
+    content = "如图,在老师的指导下,小军用测电笔(试电笔)试触某插座的插孔,用指尖抵住笔尾金属体,测电笔的氖管发光,此时(选填'有'或'没有')电流通过小军的身体,"
+    res = physical_quantity_extract(content)
+    print(res)

+ 9 - 1
restart_server.py

@@ -33,6 +33,14 @@ def restart_hnsw_app():
     # 启动服务
     os.popen("nohup python hnsw_app.py > logs/temp_app.log 2>&1 &")
 
+# 重启dim_classify_app服务
+def restart_dim_classify_app():
+    # 关闭服务进程
+    server_killer(port=8837)
+    print("即将启动dim_classify_app服务")
+    # 启动服务
+    os.popen("nohup python dim_classify_app.py > logs/temp_app.log 2>&1 &")
+
 # 重启retrieval_monitor服务
 def restart_retrieval_monitor():
     # 关闭服务进程
@@ -44,7 +52,7 @@ def restart_retrieval_monitor():
 
 if __name__ == "__main__":
     # 重启服务
-    server_list = [restart_retrieval_app, restart_hnsw_app, restart_retrieval_monitor]
+    server_list = [restart_retrieval_app, restart_hnsw_app, restart_dim_classify_app, restart_retrieval_monitor]
     argv_list = sys.argv
     if len(argv_list) == 1:
         [server() for server in server_list]

+ 3 - 9
retrieval_app.py

@@ -41,7 +41,7 @@ def hnsw_retrieve():
                                                     type="hnsw_retrieve接收",
                                                     message=retrieve_dict))
         # hnsw模型查重
-        post_url = r"http://192.168.1.209:8068/topic_retrieval_http"
+        post_url = config.illustration_url
         res_list = hnsw_model.retrieve(retrieve_list, post_url, similar, scale, doc_flag)
         # 返回日志采集
         retrieval_logger.info(config.log_msg.format(id=id_name,
@@ -60,7 +60,7 @@ def image_retrieve():
         retrieve_img = retrieve_dict["content"]
         similar = retrieve_dict["similar"] / 100
         # 图片查重链接
-        post_url = r"http://192.168.1.209:8068/img_retrieval_http"
+        post_url = config.image_url
         img_dict = dict(img_url=retrieve_img, img_threshold=similar, img_max_num=30)
         try:
             res_list = requests.post(post_url, json=img_dict, timeout=30).json()
@@ -112,13 +112,7 @@ def info_retrieve():
         # 文本关键词检索
         id_list, seg_list = ir_model(sentence)
         id_list = [int(idx) for idx in id_list]
-        # 语义相似度查重
-        retrieve_list = [dict(stem=sentence)]
-        if len(sentence) > 30:
-            doc_list = hnsw_model.retrieve(retrieve_list, '', similar, scale, False)[0]["semantics"]
-        else:
-            doc_list = hnsw_model.retrieve(retrieve_list, '', similar, scale, False)[0]["semantics"]
-        res_dict = dict(info=[id_list, seg_list], doc=doc_list)
+        res_dict = dict(info=[id_list, seg_list])
         # 返回日志采集
         retrieval_logger.info(config.log_msg.format(id="文本查重",
                                                     type="info_retrieve返回",

+ 6 - 2
retrieval_monitor.py

@@ -3,7 +3,7 @@ import time
 
 def server_run(port, command):
     # 设置服务缓存时间(防止更新服务产生冲突)
-    time.sleep(12) if port == 8836 else time.sleep(8)
+    time.sleep(12)
     server = os.popen("lsof -i:{}".format(port)).readlines()
     if not server:
         print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), 
@@ -19,4 +19,8 @@ while True:
 
     hnsw_app_server = os.popen("lsof -i:8836").readlines()
     if not hnsw_app_server:
-        server_run(8836, "nohup python hnsw_app.py > logs/temp_app.log 2>&1 &")
+        server_run(8836, "nohup python hnsw_app.py > logs/temp_app.log 2>&1 &")
+
+    dim_classify_app_server = os.popen("lsof -i:8837").readlines()
+    if not dim_classify_app_server:
+        server_run(8837, "nohup python dim_classify_app.py > logs/temp_app.log 2>&1 &")