hnsw_app.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from gevent import monkey; monkey.patch_all()
  2. from flask import Flask, request, jsonify
  3. from flask_apscheduler import APScheduler
  4. from gevent.pywsgi import WSGIServer
  5. import requests
  6. from concurrent.futures import ThreadPoolExecutor
  7. import config
  8. from hnsw_model import HNSW
  9. from restart_server import restart_math_dup_app
  10. app = Flask(__name__)
  11. class APS_Config(object):
  12. SCHEDULER_API_ENABLED = True
  13. scheduler = APScheduler()
  14. # 定时更新HNSW模型增/改/删变化
  15. @scheduler.task('cron', id='hnsw_update', day='*', hour='01', minute='10', second='00', timezone='Asia/Shanghai')
  16. def hnsw_update_schedule():
  17. if sum(hnsw.hnsw_update_flag_list) > 0:
  18. # 保存HNSW模型
  19. hnsw.save_hnsw()
  20. # 重启math_dup_app服务
  21. restart_math_dup_app()
  22. # 日志采集
  23. hnsw_retrieve_logger.info(config.log_msg.format(id="数据更新",
  24. type="HNSW Update",
  25. message="HNSW模型定时更新完毕"))
  26. # 手动HNSW数据更新
  27. @app.route('/hnsw/update', methods=['GET', 'POST'])
  28. def hnsw_update():
  29. if request.method == 'POST':
  30. # 获取post数据
  31. update_command = request.get_json()
  32. # 日志采集
  33. if update_command == "save" and sum(hnsw.hnsw_update_flag_list) > 0:
  34. # 保存HNSW模型
  35. hnsw.save_hnsw()
  36. # 日志采集
  37. hnsw_retrieve_logger.info(config.log_msg.format(id="数据更新",
  38. type="hnsw/update",
  39. message="手动HNSW数据更新完毕"))
  40. return jsonify("手动HNSW数据更新完毕")
  41. elif update_command != "save" or sum(hnsw.hnsw_update_flag_list) == 0:
  42. return jsonify("不符合手动更新条件")
  43. # hnsw模型数据检索
  44. @app.route('/retrieve', methods=['GET', 'POST'])
  45. def retrieve():
  46. if request.method == 'POST':
  47. # 获取post数据
  48. query_dict = request.get_json()
  49. # HNSW检索
  50. query_labels = hnsw.retrieve(query_dict["query_vec"], query_dict["hnsw_index"])
  51. return jsonify(query_labels)
  52. # hnsw模型数据更新
  53. @app.route('/update', methods=['GET', 'POST'])
  54. def update():
  55. if request.method == 'POST':
  56. # 获取post数据
  57. update_dict = request.get_json()
  58. # 更新HNSW模型
  59. hnsw.update(update_dict["id"], update_dict["hnsw_index"])
  60. # 追加保存hnsw_update_data.txt中更新数据,等待定时或手动更新
  61. with open("hnsw_update_data.txt", 'a', encoding='utf8') as f:
  62. f.write(str(update_dict)+"\n")
  63. return jsonify("数据更新完毕")
  64. # hnsw模型数据更新
  65. @app.route('/chc/transfer', methods=['GET', 'POST'])
  66. def chc_transfer():
  67. if request.method == 'POST':
  68. # 获取post数据
  69. topics_dict = request.get_json()
  70. # 异步查重
  71. executor = ThreadPoolExecutor(max_workers=1)
  72. def async_chc():
  73. nonlocal topics_dict
  74. # 将结果post给callback_url
  75. requests.post(r"http://localhost:8855/chc/process", json=topics_dict).json()
  76. # 线程内存回收
  77. nonlocal executor
  78. executor.shutdown()
  79. executor.submit(async_chc)
  80. return jsonify("")
  81. if __name__ == '__main__':
  82. # 日志采集初始化
  83. hnsw_retrieve_LogConfig = config.LogConfig(config.math_dup_path, "hnsw_retrieve")
  84. hnsw_retrieve_logger = hnsw_retrieve_LogConfig.get_log()
  85. # HNSW模型初始化
  86. hnsw = HNSW(hnsw_retrieve_logger)
  87. # 定时更新HNSW模型增/改/删变化
  88. app.config.from_object(APS_Config())
  89. scheduler.init_app(app)
  90. scheduler.start()
  91. # app.run(host='0.0.0.0',port='8858')
  92. server = WSGIServer(('0.0.0.0', 8858), app)
  93. server.serve_forever()