tujintao 1 year ago
parent
commit
15e42df364

BIN
__pycache__/comprehensive_score.cpython-38.pyc


BIN
__pycache__/config.cpython-38.pyc


BIN
__pycache__/data_preprocessing.cpython-38.pyc


BIN
__pycache__/formula_process.cpython-38.pyc


BIN
__pycache__/heap_sort.cpython-38.pyc


BIN
__pycache__/hnsw_model_train.cpython-38.pyc


BIN
__pycache__/hnsw_retrieval.cpython-38.pyc


BIN
__pycache__/info_retrieval.cpython-38.pyc


BIN
__pycache__/ir_db_establish.cpython-38.pyc


BIN
__pycache__/log_config.cpython-38.pyc


BIN
__pycache__/word_segment.cpython-38.pyc


+ 68 - 0
comparison.py

@@ -0,0 +1,68 @@
+import json
+import pandas as pd
+
+keyword2id_dict = dict()
+# 求解类型
+solution_type_list = ["概念辨析","规律理解","现象解释","物理学史","计算分析","实验操作","连线作图","实验读数"]
+
+solving_type2id = dict()
+for i, ele in enumerate(solution_type_list):
+    solving_type2id[ele] = 1 + i
+keyword2id_dict["solving_type2id"] = solving_type2id
+
+# 物理量
+excel_path = r"data/物理量.xlsx"
+df = pd.read_excel(excel_path)
+quantity2id = dict()
+count_index = 0
+for i in range(len(df)):
+    if not pd.isna(df['类别'][i]):
+        count_index += 1
+        sign_index = count_index * 100
+    knowledge = df['物理量'][i]
+    if not pd.isna(knowledge):
+        sign_index += 1
+        quantity2id[knowledge] = sign_index
+keyword2id_dict["quantity2id"] = quantity2id
+
+# 物理场景
+excel_path = r"data/物理情景.xlsx"
+df = pd.read_excel(excel_path)
+scene2id = dict()
+count_index = 0
+for i in range(len(df)):
+    if not pd.isna(df['知识点'][i]):
+        count_index += 1
+        sign_index = 10000 + count_index * 10
+    knowledge = df['情景'][i]
+    if not pd.isna(knowledge):
+        sign_index += 1
+        scene2id[knowledge] = sign_index
+keyword2id_dict["scene2id"] = scene2id
+
+# 知识点
+excel_path = r"data/物理知识点.xlsx"
+df = pd.read_excel(excel_path)
+knowledge2id = dict()
+init_id2max_id = dict()
+count_index = 0
+for i in range(len(df)):
+    if not pd.isna(df['2级知识点'][i]):
+        count_index += 1
+    if not pd.isna(df['3级知识点'][i]):
+        sign = df['3级知识点'][i].split(' ')[0].split('.')
+        # sign_index = 10000 + int(sign[0]) * 100 + int(sign[1]) * 10
+        sign_index = 10000 + count_index * 100 + int(sign[1]) * 10
+        init_id = sign_index
+        init_id2max_id[init_id] = sign_index
+    knowledge = df['4级知识点'][i]
+    if not pd.isna(knowledge):
+        sign_index += 1
+        knowledge2id[knowledge] = sign_index
+        init_id2max_id[init_id] = sign_index
+keyword2id_dict["knowledge2id"] = knowledge2id
+keyword2id_dict["init_id2max_id"] = init_id2max_id
+
+# 映射转换
+with open("data/keyword_mapping.json", 'w', encoding="utf8") as f:
+    json.dump(keyword2id_dict, f, ensure_ascii=False, indent=2)

+ 1 - 1
config.py

@@ -4,7 +4,7 @@ import pymongo
 # 建立mongodb连接
 myclient = pymongo.MongoClient("mongodb://192.168.1.140:27017/")
 mongo_info_db = myclient["ksy"]
-mongo_coll = mongo_info_db['topic']
+mongo_coll = mongo_info_db['test_topic']
 
 # mongodb句向量训练标志
 sent_train_flag = 1

+ 22 - 3
db_train_app.py

@@ -1,16 +1,31 @@
 import sys
 import time
+import json
 import config
 from data_preprocessing import DataPreProcessing
 
 # 数据清洗与句向量计算
-def clear_embedding_train(mongo_coll, sup, sub):
-    origin_dataset = mongo_coll.find(no_cursor_timeout=True, batch_size=5)
+def clear_embedding_train(mongo_coll, mongo_find_dict, sup, sub):
+    origin_dataset = mongo_coll.find(mongo_find_dict, no_cursor_timeout=True, batch_size=5)
     dpp = DataPreProcessing(mongo_coll, is_train=True)
     start = time.time()
     dpp(origin_dataset[sup:sub])
     print("耗时:", time.time()-start)
 
+# 知识点转换成id用于mongodb检索
+def convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub):
+    with open("model_data/keyword_mapping.json", 'r', encoding="utf8") as f:
+        knowledge2id = json.load(f)["knowledge2id"]
+    origin_dataset = mongo_coll.find(mongo_find_dict, no_cursor_timeout=True, batch_size=5)
+    start = time.time()
+    for data in origin_dataset[sup:sub]:
+        print(data["knowledge"])
+        condition = {"id": data["id"]}
+        # 需要新增train_flag,防止机器奔溃重复训练
+        knowledge_list = [knowledge2id[ele] for ele in data["knowledge"] if knowledge2id.get(ele, 0)]
+        update_elements = {"$set": {"knowledge_id": knowledge_list}}
+        mongo_coll.update_one(condition, update_elements)
+    print("耗时:", time.time()-start)
 
 if __name__ == "__main__":
     # 获取shell输入参数
@@ -23,5 +38,9 @@ if __name__ == "__main__":
         sub = None if sub == '' else int(sub)
     # 获取mongodb数据
     mongo_coll = config.mongo_coll
+    # mongo_find_dict = {"sent_train_flag": {"$exists": 0}}
+    mongo_find_dict = dict()
     # 清洗文本与计算句向量(train_mode=1表示需要进行文本清洗与句向量计算)
-    clear_embedding_train(mongo_coll, sup, sub)
+    clear_embedding_train(mongo_coll, mongo_find_dict, sup, sub)
+    # 知识点转换成id用于mongodb检索
+    convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub)

+ 2 - 2
hnsw_app.py

@@ -17,8 +17,8 @@ scheduler = APScheduler()
 # 每周定时刷新日志
 @scheduler.task('cron', id='log_reset', week='*', day_of_week='sun', hour='05', minute='00', second='00', timezone='Asia/Shanghai')
 def log_reset_schedule():
-    hm_LogConfig.log_reset()
-    os.popen("nohup python restart_server.py > logs/temp_app.log 2>&1 &")
+    # hm_LogConfig.log_reset()
+    os.popen("nohup python restart_server.py 0 > logs/temp_app.log 2>&1 &")
 
 # hnsw模型数据检索
 @app.route('/retrieve', methods=['GET', 'POST'])

+ 62 - 27
hnsw_retrieval.py

@@ -9,6 +9,7 @@ from pprint import pprint
 
 import config
 from formula_process import formula_recognize
+from comprehensive_score import Comprehensive_Score
 
 class HNSW():
     def __init__(self, data_process, logger=None):
@@ -21,6 +22,8 @@ class HNSW():
         self.log_msg = config.log_msg
         # 数据预处理实例化
         self.dpp = data_process
+        # 语义相似度实例化
+        self.cph_score = Comprehensive_Score()
         # 加载公式处理数据模型(词袋模型/原始向量/原始数据)
         with open(config.bow_model_path, "rb") as bm:
             self.bow_model = pickle.load(bm)
@@ -203,7 +206,7 @@ class HNSW():
     #     return retrieve_res_list
 
     # HNSW查(支持多学科混合查重)
-    def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0.56):
+    def retrieve(self, retrieve_list, post_url, similar, scale, doc_flag):
         # 计算retrieve_list的vec值
         # 调用清洗分词函数和句向量计算函数
         sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
@@ -211,7 +214,8 @@ class HNSW():
         retrieve_res_list = []
         for i,sent_vec in enumerate(sent_vec_list):
             # 初始化返回数据类型
-            retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
+            # retrieve_value_dict = dict(synthese=[], semantics=[], text=[], image=[])
+            retrieve_value_dict = dict(semantics=[], text=[], image=[])
             # 获取题目序号
             topic_num = retrieve_list[i]["topic_num"] if "topic_num" in retrieve_list[i] else 1
             # 图片搜索查重功能
@@ -219,6 +223,9 @@ class HNSW():
                 retrieve_value_dict["image"] = self.img_retrieve(retrieve_list[i]["stem"], post_url, similar, topic_num)
             else:
                 retrieve_value_dict["image"] = []
+            """
+            文本相似度特殊处理
+            """
             # 判断句向量维度
             if sent_vec.size != self.vector_dim:
                 retrieve_res_list.append(retrieve_value_dict)
@@ -237,13 +244,13 @@ class HNSW():
             if len(query_labels) == 0:
                 retrieve_res_list.append(retrieve_value_dict)
                 continue
-
             # 批量读取数据库
             mongo_find_dict = {"id": {"$in": query_labels}}
             query_dataset = self.mongo_coll.find(mongo_find_dict)
-
+            ####################################### 语义相似度借靠 ####################################### 
+            query_data = dict()
+            ####################################### 语义相似度借靠 ####################################### 
             # 返回大于阈值的结果
-            filter_threshold = similar
             for label_data in query_dataset:
                 if "sentence_vec" not in label_data:
                     continue
@@ -253,39 +260,67 @@ class HNSW():
                     continue
                 cosine_score = util.cos_sim(sent_vec, label_vec)[0][0]
                 # 阈值判断
-                if cosine_score < filter_threshold:
+                if cosine_score < similar:
                     continue
                 # 计算编辑距离得分
                 fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
-                if fuzz_score < min_threshold:
-                    continue
-                # 对余弦相似度进行折算
-                if cosine_score >= 0.91 and fuzz_score < min_threshold + 0.06:
-                    cosine_score = cosine_score * 0.95
-                elif cosine_score < 0.91 and fuzz_score < min_threshold + 0.06:
-                    cosine_score = cosine_score * 0.94
-                # 余弦相似度折算后阈值判断
-                if cosine_score < filter_threshold:
-                    continue
-                retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
-                retrieve_value_dict["semantics"].append(retrieve_value)
                 # 进行编辑距离得分验证,若小于设定分则过滤
-                if fuzz_score >= filter_threshold:
+                if fuzz_score >= similar:
                     retrieve_value = [label_data["id"], fuzz_score]
                     retrieve_value_dict["text"].append(retrieve_value)
+                    ####################################### 语义相似度借靠 ####################################### 
+                    max_mark_score = 0
+                    if cosine_score >= 0.9 and cosine_score > max_mark_score:
+                        max_mark_score = cosine_score
+                        query_data = label_data
+                    ####################################### 语义相似度借靠 ####################################### 
             
-            # 将组合结果按照score降序排序并取得分前十个结果
+            """
+            语义相似度特殊处理
+            """
+            # 批量读取数据库
+            knowledge_id_list = query_data["knowledge_id"] if query_data else []
+            label_dict = dict()
+            # label_dict["quesType"] = retrieve_list[i]["quesType"] if query_data else []
+            label_dict["knowledge"] = query_data["knowledge"] if query_data else []
+            label_dict["physical_scene"] = query_data["physical_scene"] if query_data else []
+            label_dict["solving_type"] = query_data["solving_type"] if query_data else []
+            label_dict["difficulty"] = float(query_data["difficulty"]) if query_data else 0
+            label_dict["physical_quantity"] = query_data["physical_quantity"] if query_data else []
+            # label_dict["image_semantics"] = query_data["image_semantics"] if query_data else []
+            query_data["quesType"] = retrieve_list[i].get("quesType", '')
+
+            if len(knowledge_id_list) > 0:
+                relate_list = []
+                for ele in knowledge_id_list:
+                    init_id = int(ele / 10) * 10
+                    last_id = self.cph_score.init_id2max_id[str(init_id)]
+                    relate_list.extend(np.arange(init_id + 1, last_id + 1).tolist())
+                knowledge_id_list = relate_list
+            mongo_find_dict = {"knowledge_id": {"$in": knowledge_id_list}}
+            query_dataset = self.mongo_coll.find(mongo_find_dict)
+            # 返回大于阈值的结果
+            for refer_data in query_dataset:
+                sum_score, score_dict = self.cph_score(query_data, refer_data, scale)
+                if sum_score < similar:
+                    continue
+                retrieve_value = [refer_data["id"], sum_score, score_dict]
+                retrieve_value_dict["semantics"].append(retrieve_value)
+
+            # 将组合结果按照score降序排序
             retrieve_sort_dict = {k: sorted(value, key=lambda x: x[1], reverse=True)
                                   for k,value in retrieve_value_dict.items()}
 
-            # 综合排序
-            synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
-            synthese_set = set()
-            for ele in synthese_list:
-                if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
-                    synthese_set.add(ele[0])
-                    retrieve_sort_dict["synthese"].append(ele)
+            # # 综合排序
+            # synthese_list = sorted(sum(retrieve_sort_dict.values(), []), key=lambda x: x[1], reverse=True)
+            # synthese_set = set()
+            # for ele in synthese_list:
+            #     # 综合排序返回前50个
+            #     if ele[0] not in synthese_set and len(retrieve_sort_dict["synthese"]) < 50:
+            #         synthese_set.add(ele[0])
+            #         retrieve_sort_dict["synthese"].append(ele[:2])
             # 加入题目序号
+            retrieve_sort_dict["label"] = label_dict
             retrieve_sort_dict["topic_num"] = topic_num
             
             # 以字典形式返回最终查重结果

BIN
main_clear/__pycache__/sci_clear.cpython-38.pyc


File diff suppressed because it is too large
+ 0 - 0
main_clear/sci_clear.py


+ 6 - 4
retrieval_app.py

@@ -1,4 +1,4 @@
-from gevent import monkey; monkey.patch_all()
+# from gevent import monkey; monkey.patch_all()
 import requests
 from gevent.pywsgi import WSGIServer
 from flask import Flask, request, jsonify
@@ -33,6 +33,7 @@ def hnsw_retrieve():
             return "请输入查重数据"
         retrieve_list = retrieve_dict["content"]
         similar = retrieve_dict["similar"] / 100
+        scale = retrieve_dict["scale"]
         doc_flag = True if retrieve_dict["doc_flag"] == 1 else False
         # 接收日志采集
         id_name = "文档查重" if doc_flag is True else "整题图片查重"
@@ -41,7 +42,7 @@ def hnsw_retrieve():
                                                     message=retrieve_dict))
         # hnsw模型查重
         post_url = r"http://192.168.1.209:8068/topic_retrieval_http"
-        res_list = hnsw_model.retrieve(retrieve_list, post_url, similar, doc_flag)
+        res_list = hnsw_model.retrieve(retrieve_list, post_url, similar, scale, doc_flag)
         # 返回日志采集
         retrieval_logger.info(config.log_msg.format(id=id_name,
                                                     type="hnsw_retrieve返回",
@@ -107,15 +108,16 @@ def info_retrieve():
             return "请输入检索数据"
         sentence = retrieve_dict["content"]
         similar = retrieve_dict["similar"] / 100
+        scale = retrieve_dict["scale"]
         # 文本关键词检索
         id_list, seg_list = ir_model(sentence)
         id_list = [int(idx) for idx in id_list]
         # 语义相似度查重
         retrieve_list = [dict(stem=sentence)]
         if len(sentence) > 30:
-            doc_list = hnsw_model.retrieve(retrieve_list, '', similar, False)[0]["semantics"]
+            doc_list = hnsw_model.retrieve(retrieve_list, '', similar, scale, False)[0]["semantics"]
         else:
-            doc_list = hnsw_model.retrieve(retrieve_list, '', similar, False, 0.6)[0]["semantics"]
+            doc_list = hnsw_model.retrieve(retrieve_list, '', similar, scale, False)[0]["semantics"]
         res_dict = dict(info=[id_list, seg_list], doc=doc_list)
         # 返回日志采集
         retrieval_logger.info(config.log_msg.format(id="文本查重",

Some files were not shown because too many files changed in this diff