retrieval_app.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # from gevent import monkey; monkey.patch_all()
  2. import requests
  3. from gevent.pywsgi import WSGIServer
  4. from flask import Flask, request, jsonify
  5. from flask_cors import CORS
  6. import config
  7. from hnsw_retrieval import HNSW
  8. from info_retrieval import Info_Retrieval
  9. from data_preprocessing import DataPreProcessing
  10. from log_config import LogConfig
  11. app = Flask(__name__)
  12. CORS(app, supports_credentials=True)
  13. # 日志采集初始化
  14. retrieval_LogConfig = LogConfig(config.retrieval_path, "retrieval")
  15. retrieval_logger = retrieval_LogConfig.get_log()
  16. # 数据处理初始化
  17. data_process = DataPreProcessing(logger=retrieval_logger)
  18. # HNSW模型初始化
  19. hnsw_model = HNSW(data_process, retrieval_logger)
  20. # 信息检索模型初始化
  21. ir_model = Info_Retrieval(data_process, logger=retrieval_logger, n_grams_flag=True)
  22. # 文档查重
  23. @app.route('/hnsw_retrieve', methods=['GET', 'POST'])
  24. def hnsw_retrieve():
  25. if request.method == 'POST':
  26. # 获取post数据
  27. retrieve_dict = request.get_json()
  28. if not retrieve_dict:
  29. return "请输入查重数据"
  30. retrieve_list = retrieve_dict["content"]
  31. similar = retrieve_dict["similar"] / 100
  32. scale = retrieve_dict["scale"]
  33. doc_flag = True if retrieve_dict["doc_flag"] == 1 else False
  34. # 接收日志采集
  35. id_name = "文档查重" if doc_flag is True else "整题图片查重"
  36. retrieval_logger.info(config.log_msg.format(id=id_name,
  37. type="hnsw_retrieve接收",
  38. message=retrieve_dict))
  39. # hnsw模型查重
  40. post_url = config.illustration_url
  41. res_list = hnsw_model.retrieve(retrieve_list, post_url, similar, scale, doc_flag)
  42. # 返回日志采集
  43. retrieval_logger.info(config.log_msg.format(id=id_name,
  44. type="hnsw_retrieve返回",
  45. message=res_list))
  46. return jsonify(res_list)
  47. # 图片查重
  48. @app.route('/image_retrieve', methods=['GET', 'POST'])
  49. def image_retrieve():
  50. if request.method == 'POST':
  51. # 获取post数据
  52. retrieve_dict = request.get_json()
  53. if not retrieve_dict:
  54. return "请输入查重数据"
  55. retrieve_img = retrieve_dict["content"]
  56. similar = retrieve_dict["similar"] / 100
  57. # 图片查重链接
  58. post_url = config.image_url
  59. img_dict = dict(img_url=retrieve_img, img_threshold=similar, img_max_num=30)
  60. try:
  61. res_list = requests.post(post_url, json=img_dict, timeout=30).json()
  62. except Exception as e:
  63. res_list = []
  64. # 返回日志采集
  65. retrieval_logger.info(config.log_msg.format(id="图片查重",
  66. type="image_retrieve返回",
  67. message=res_list))
  68. return jsonify(res_list)
  69. # 公式查重
  70. @app.route('/formula_retrieve', methods=['GET', 'POST'])
  71. def formula_retrieve():
  72. if request.method == 'POST':
  73. # 获取post数据
  74. retrieve_dict = request.get_json()
  75. # 接收日志采集
  76. retrieval_logger.info(config.log_msg.format(id="公式查重",
  77. type="formula_retrieve接收",
  78. message=retrieve_dict))
  79. if not retrieve_dict:
  80. return "请输入查重数据"
  81. formula_string = retrieve_dict["content"]
  82. similar = retrieve_dict["similar"] / 100
  83. # 公式图片查重链接
  84. res_list = hnsw_model.formula_retrieve(formula_string, similar)
  85. # 返回日志采集
  86. retrieval_logger.info(config.log_msg.format(id="公式查重",
  87. type="formula_retrieve返回",
  88. message=res_list))
  89. return jsonify(res_list)
  90. # 文本查重
  91. @app.route('/info_retrieve', methods=['GET', 'POST'])
  92. def info_retrieve():
  93. if request.method == 'POST':
  94. # 获取post数据
  95. retrieve_dict = request.get_json()
  96. # 接收日志采集
  97. retrieval_logger.info(config.log_msg.format(id="文本查重",
  98. type="info_retrieve接收",
  99. message=retrieve_dict))
  100. if not retrieve_dict:
  101. return "请输入检索数据"
  102. sentence = retrieve_dict["content"]
  103. similar = retrieve_dict["similar"] / 100
  104. scale = retrieve_dict["scale"]
  105. # 文本关键词检索
  106. id_list, seg_list = ir_model(sentence)
  107. id_list = [int(idx) for idx in id_list]
  108. res_dict = dict(info=[id_list, seg_list])
  109. # 返回日志采集
  110. retrieval_logger.info(config.log_msg.format(id="文本查重",
  111. type="info_retrieve返回",
  112. message=res_dict))
  113. return jsonify(res_dict)
  114. if __name__ == "__main__":
  115. # app.run(host='0.0.0.0',port='8835')
  116. server = WSGIServer(('0.0.0.0', 8835), app)
  117. server.serve_forever()