tjt 1 ano atrás
pai
commit
1ad3d21baa
4 arquivos alterados com 13 adições e 16 exclusões
  1. 2 2
      hnsw_app.py
  2. 5 9
      hnsw_retrieval.py
  3. 5 4
      retrieval_app.py
  4. 1 1
      retrieval_monitor.py

+ 2 - 2
hnsw_app.py

@@ -17,13 +17,13 @@ class APS_Config(object):
 scheduler = APScheduler()
 
 # 定时重启retrieval_app服务, 防止内存累加
-@scheduler.task('cron', id='restart_retrieval_app', day='*', hour='00', minute='30', second='00', timezone='Asia/Shanghai')
+@scheduler.task('cron', id='restart_retrieval_app', day='*', hour='02', minute='00', second='00', timezone='Asia/Shanghai')
 def retrieval_app_schedule():
     # 重启retrieval_app服务
     restart_retrieval_app()
 
 # 定时训练HNSW模型并重启服务(0-'mon',...2-'wed',...,6-'sun')
-@scheduler.task('cron', id='hm_train', week='*', day_of_week='0', hour='01', minute='30', second='00', timezone='Asia/Shanghai')
+@scheduler.task('cron', id='hm_train', week='*', day_of_week='0', hour='02', minute='30', second='00', timezone='Asia/Shanghai')
 def hm_train_schedule():
     hm_train = HNSW_Model_Train(hm_logger)
     hm_train()

+ 5 - 9
hnsw_retrieval.py

@@ -83,7 +83,7 @@ class HNSW():
         return formula_res_list[:50]
 
     # HNSW查(支持多学科混合查重)
-    def retrieve(self, retrieve_list, post_url, similar, doc_flag):
+    def retrieve(self, retrieve_list, post_url, similar, doc_flag, min_threshold=0):
         # 计算retrieve_list的vec值
         # 调用清洗分词函数和句向量计算函数
         sent_vec_list, cont_clear_list = self.dpp(retrieve_list, is_retrieve=True)
@@ -123,7 +123,6 @@ class HNSW():
 
             # 返回大于阈值的结果
             filter_threshold = similar
-            # min_threshold = 0.25
             for label_data in query_dataset:
                 if "sentence_vec" not in label_data:
                     continue
@@ -135,15 +134,12 @@ class HNSW():
                 # 阈值判断
                 if cosine_score < filter_threshold:
                     continue
-                retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
-                retrieve_value_dict["semantics"].append(retrieve_value)
                 # 计算编辑距离得分
                 fuzz_score = fuzz.ratio(cont_clear_list[i], label_data["content_clear"]) / 100
-                # if fuzz_score < min_threshold:
-                #     continue
-                # if cosine_score >= filter_threshold:
-                #     retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
-                #     retrieve_value_dict["semantics"].append(retrieve_value)
+                if fuzz_score < min_threshold:
+                    continue
+                retrieve_value = [label_data["id"], int(cosine_score * 100) / 100]
+                retrieve_value_dict["semantics"].append(retrieve_value)
                 # 进行编辑距离得分验证,若小于设定分则过滤
                 if fuzz_score >= filter_threshold:
                     retrieve_value = [label_data["id"], fuzz_score]

+ 5 - 4
retrieval_app.py

@@ -40,7 +40,7 @@ def hnsw_retrieve():
                                                     type="hnsw_retrieve接收",
                                                     message=retrieve_dict))
         # hnsw模型查重
-        post_url = r"http://127.0.0.1:8068/topic_retrieval_http"
+        post_url = r"http://localhost:8068/topic_retrieval_http"
         res_list = hnsw_model.retrieve(retrieve_list, post_url, similar, doc_flag)
         # 返回日志采集
         retrieval_logger.info(config.log_msg.format(id=id_name,
@@ -59,7 +59,7 @@ def image_retrieve():
         retrieve_img = retrieve_dict["content"]
         similar = retrieve_dict["similar"] / 100
         # 图片查重链接
-        post_url = r"http://127.0.0.1:8068/img_retrieval_http"
+        post_url = r"http://localhost:8068/img_retrieval_http"
         img_dict = dict(img_url=retrieve_img, img_threshold=similar, img_max_num=30)
         try:
             res_list = requests.post(post_url, json=img_dict, timeout=20).json()
@@ -111,11 +111,12 @@ def info_retrieve():
         id_list, seg_list = ir_model(sentence)
         id_list = [int(idx) for idx in id_list]
         # 语义相似度查重
-        if len(sentence) > 15:
+        retrieve_list = [dict(stem=sentence, topic_num=1)]
+        if len(sentence) > 30:
             retrieve_list = [dict(stem=sentence, topic_num=1)]
             doc_list = hnsw_model.retrieve(retrieve_list, '', similar, False)[0]["semantics"]
         else:
-            doc_list = []
+            doc_list = hnsw_model.retrieve(retrieve_list, '', similar, False, 0.3)[0]["semantics"]
         res_dict = dict(info=[id_list, seg_list], doc=doc_list)
         # 返回日志采集
         retrieval_logger.info(config.log_msg.format(id="文本查重",

+ 1 - 1
retrieval_monitor.py

@@ -3,7 +3,7 @@ import time
 
 def server_run(port, command):
     # 设置服务缓存时间(防止更新服务产生冲突)
-    time.sleep(20)
+    time.sleep(12) if port == 8836 else time.sleep(8)
     server = os.popen("lsof -i:{}".format(port)).readlines()
     if not server:
         print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),