retrieval_app.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from gevent import monkey; monkey.patch_all()
  2. import requests
  3. from gevent.pywsgi import WSGIServer
  4. from flask import Flask, request, jsonify
  5. from flask_cors import CORS
  6. import config
  7. from hnsw_retrieval import HNSW
  8. from info_retrieval import Info_Retrieval
  9. from data_preprocessing import DataPreProcessing
  10. from log_config import LogConfig
  11. app = Flask(__name__)
  12. CORS(app, supports_credentials=True)
  13. # 日志采集初始化
  14. retrieval_LogConfig = LogConfig(config.retrieval_path, "retrieval")
  15. retrieval_logger = retrieval_LogConfig.get_log()
  16. # 数据处理初始化
  17. data_process = DataPreProcessing(logger=retrieval_logger)
  18. # HNSW模型初始化
  19. hnsw_model = HNSW(data_process, retrieval_logger)
  20. # 信息检索模型初始化
  21. ir_model = Info_Retrieval(data_process, logger=retrieval_logger, n_grams_flag=True)
  22. # 文档查重
  23. @app.route('/hnsw_retrieve', methods=['GET', 'POST'])
  24. def hnsw_retrieve():
  25. if request.method == 'POST':
  26. # 获取post数据
  27. retrieve_dict = request.get_json()
  28. if not retrieve_dict:
  29. return "请输入查重数据"
  30. retrieve_list = retrieve_dict["content"]
  31. similar = retrieve_dict["similar"] / 100
  32. doc_flag = True if retrieve_dict["doc_flag"] == 1 else False
  33. # 接收日志采集
  34. id_name = "文档查重" if doc_flag is True else "整题图片查重"
  35. retrieval_logger.info(config.log_msg.format(id=id_name,
  36. type="hnsw_retrieve接收",
  37. message=retrieve_dict))
  38. # hnsw模型查重
  39. post_url = r"http://localhost:8068/topic_retrieval_http"
  40. res_list = hnsw_model.retrieve(retrieve_list, post_url, similar, doc_flag)
  41. # 返回日志采集
  42. retrieval_logger.info(config.log_msg.format(id=id_name,
  43. type="hnsw_retrieve返回",
  44. message=res_list))
  45. return jsonify(res_list)
  46. # 图片查重
  47. @app.route('/image_retrieve', methods=['GET', 'POST'])
  48. def image_retrieve():
  49. if request.method == 'POST':
  50. # 获取post数据
  51. retrieve_dict = request.get_json()
  52. if not retrieve_dict:
  53. return "请输入查重数据"
  54. retrieve_img = retrieve_dict["content"]
  55. similar = retrieve_dict["similar"] / 100
  56. # 图片查重链接
  57. post_url = r"http://localhost:8068/img_retrieval_http"
  58. img_dict = dict(img_url=retrieve_img, img_threshold=similar, img_max_num=30)
  59. try:
  60. res_list = requests.post(post_url, json=img_dict, timeout=20).json()
  61. except Exception as e:
  62. res_list = []
  63. # 返回日志采集
  64. retrieval_logger.info(config.log_msg.format(id="图片查重",
  65. type="image_retrieve返回",
  66. message=res_list))
  67. return jsonify(res_list)
  68. # 公式查重
  69. @app.route('/formula_retrieve', methods=['GET', 'POST'])
  70. def formula_retrieve():
  71. if request.method == 'POST':
  72. # 获取post数据
  73. retrieve_dict = request.get_json()
  74. # 接收日志采集
  75. retrieval_logger.info(config.log_msg.format(id="公式查重",
  76. type="formula_retrieve接收",
  77. message=retrieve_dict))
  78. if not retrieve_dict:
  79. return "请输入查重数据"
  80. formula_string = retrieve_dict["content"]
  81. similar = retrieve_dict["similar"] / 100
  82. # 公式图片查重链接
  83. res_list = hnsw_model.formula_retrieve(formula_string, similar)
  84. # 返回日志采集
  85. retrieval_logger.info(config.log_msg.format(id="公式查重",
  86. type="formula_retrieve返回",
  87. message=res_list))
  88. return jsonify(res_list)
  89. # 文本查重
  90. @app.route('/info_retrieve', methods=['GET', 'POST'])
  91. def info_retrieve():
  92. if request.method == 'POST':
  93. # 获取post数据
  94. retrieve_dict = request.get_json()
  95. # 接收日志采集
  96. retrieval_logger.info(config.log_msg.format(id="文本查重",
  97. type="info_retrieve接收",
  98. message=retrieve_dict))
  99. if not retrieve_dict:
  100. return "请输入检索数据"
  101. sentence = retrieve_dict["content"]
  102. similar = retrieve_dict["similar"] / 100
  103. # 文本关键词检索
  104. id_list, seg_list = ir_model(sentence)
  105. id_list = [int(idx) for idx in id_list]
  106. # 语义相似度查重
  107. retrieve_list = [dict(stem=sentence, topic_num=1)]
  108. if len(sentence) > 30:
  109. retrieve_list = [dict(stem=sentence, topic_num=1)]
  110. doc_list = hnsw_model.retrieve(retrieve_list, '', similar, False)[0]["semantics"]
  111. else:
  112. doc_list = hnsw_model.retrieve(retrieve_list, '', similar, False, 0.3)[0]["semantics"]
  113. res_dict = dict(info=[id_list, seg_list], doc=doc_list)
  114. # 返回日志采集
  115. retrieval_logger.info(config.log_msg.format(id="文本查重",
  116. type="info_retrieve返回",
  117. message=res_dict))
  118. return jsonify(res_dict)
  119. if __name__ == "__main__":
  120. # app.run(host='0.0.0.0',port='8835')
  121. server = WSGIServer(('0.0.0.0', 8835), app)
  122. server.serve_forever()