123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- from gevent import monkey; monkey.patch_all()
- import requests
- from gevent.pywsgi import WSGIServer
- from flask import Flask, request, jsonify
- from flask_cors import CORS
- import config
- from hnsw_retrieval import HNSW
- from info_retrieval import Info_Retrieval
- from data_preprocessing import DataPreProcessing
- from log_config import LogConfig
- app = Flask(__name__)
- CORS(app, supports_credentials=True)
- # 日志采集初始化
- retrieval_LogConfig = LogConfig(config.retrieval_path, "retrieval")
- retrieval_logger = retrieval_LogConfig.get_log()
- # 数据处理初始化
- data_process = DataPreProcessing(logger=retrieval_logger)
- # HNSW模型初始化
- hnsw_model = HNSW(data_process, retrieval_logger)
- # 信息检索模型初始化
- ir_model = Info_Retrieval(data_process, logger=retrieval_logger, n_grams_flag=True)
- # 文档查重
- @app.route('/hnsw_retrieve', methods=['GET', 'POST'])
- def hnsw_retrieve():
- if request.method == 'POST':
- # 获取post数据
- retrieve_dict = request.get_json()
- if not retrieve_dict:
- return "请输入查重数据"
- retrieve_list = retrieve_dict["content"]
- similar = retrieve_dict["similar"] / 100
- doc_flag = True if retrieve_dict["doc_flag"] == 1 else False
- # 接收日志采集
- id_name = "文档查重" if doc_flag is True else "整题图片查重"
- retrieval_logger.info(config.log_msg.format(id=id_name,
- type="hnsw_retrieve接收",
- message=retrieve_dict))
- # hnsw模型查重
- post_url = r"http://192.168.1.209:8068/topic_retrieval_http"
- res_list = hnsw_model.retrieve(retrieve_list, post_url, similar, doc_flag)
- # 返回日志采集
- retrieval_logger.info(config.log_msg.format(id=id_name,
- type="hnsw_retrieve返回",
- message=res_list))
- return jsonify(res_list)
- # 图片查重
- @app.route('/image_retrieve', methods=['GET', 'POST'])
- def image_retrieve():
- if request.method == 'POST':
- # 获取post数据
- retrieve_dict = request.get_json()
- if not retrieve_dict:
- return "请输入查重数据"
- retrieve_img = retrieve_dict["content"]
- similar = retrieve_dict["similar"] / 100
- # 图片查重链接
- post_url = r"http://192.168.1.209:8068/img_retrieval_http"
- img_dict = dict(img_url=retrieve_img, img_threshold=similar, img_max_num=30)
- try:
- res_list = requests.post(post_url, json=img_dict, timeout=30).json()
- except Exception as e:
- res_list = []
- # 返回日志采集
- retrieval_logger.info(config.log_msg.format(id="图片查重",
- type="image_retrieve返回",
- message=res_list))
- return jsonify(res_list)
- # 公式查重
- @app.route('/formula_retrieve', methods=['GET', 'POST'])
- def formula_retrieve():
- if request.method == 'POST':
- # 获取post数据
- retrieve_dict = request.get_json()
- # 接收日志采集
- retrieval_logger.info(config.log_msg.format(id="公式查重",
- type="formula_retrieve接收",
- message=retrieve_dict))
- if not retrieve_dict:
- return "请输入查重数据"
- formula_string = retrieve_dict["content"]
- similar = retrieve_dict["similar"] / 100
- # 公式图片查重链接
- res_list = hnsw_model.formula_retrieve(formula_string, similar)
- # 返回日志采集
- retrieval_logger.info(config.log_msg.format(id="公式查重",
- type="formula_retrieve返回",
- message=res_list))
- return jsonify(res_list)
- # 文本查重
- @app.route('/info_retrieve', methods=['GET', 'POST'])
- def info_retrieve():
- if request.method == 'POST':
- # 获取post数据
- retrieve_dict = request.get_json()
- # 接收日志采集
- retrieval_logger.info(config.log_msg.format(id="文本查重",
- type="info_retrieve接收",
- message=retrieve_dict))
- if not retrieve_dict:
- return "请输入检索数据"
- sentence = retrieve_dict["content"]
- similar = retrieve_dict["similar"] / 100
- # 文本关键词检索
- id_list, seg_list = ir_model(sentence)
- id_list = [int(idx) for idx in id_list]
- # 语义相似度查重
- retrieve_list = [dict(stem=sentence)]
- if len(sentence) > 30:
- doc_list = hnsw_model.retrieve(retrieve_list, '', similar, False)[0]["semantics"]
- else:
- doc_list = hnsw_model.retrieve(retrieve_list, '', similar, False, 0.6)[0]["semantics"]
- res_dict = dict(info=[id_list, seg_list], doc=doc_list)
- # 返回日志采集
- retrieval_logger.info(config.log_msg.format(id="文本查重",
- type="info_retrieve返回",
- message=res_dict))
- return jsonify(res_dict)
- if __name__ == "__main__":
- # app.run(host='0.0.0.0',port='8835')
- server = WSGIServer(('0.0.0.0', 8835), app)
- server.serve_forever()
|