server.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from fastapi import FastAPI, Request, BackgroundTasks
  2. import uvicorn
  3. from fastapi.responses import JSONResponse, Response
  4. from dotenv import load_dotenv
  5. load_dotenv()
  6. from service import labeling
  7. from config.config import subject_id
  8. from common.valid_check import valid_params, valid_is_contained, LabelExceptionErrCode
  9. TIMEOUT_KEEP_ALIVE = 5 # seconds.
  10. from config.config import log
  11. app = FastAPI()
  12. from concurrent.futures import ThreadPoolExecutor
  13. import requests
  14. executor = ThreadPoolExecutor(max_workers=1)
  15. @app.post("/auto_label")
  16. async def auto_label(request: Request) -> Response:
  17. """
  18. 入参格式:
  19. {
  20. "subject_id": xx,
  21. "topic_list": [
  22. {
  23. "topic_id": 23,#题目id
  24. "topic_text": "题干",
  25. "parse": "解析",
  26. "option": []
  27. }
  28. ],
  29. "call_back_url": "回调url"
  30. }
  31. 回调响应函数格式:
  32. {
  33. "topic_list": [
  34. {
  35. "topic_id": 234,#题目id
  36. "labels": ["牛顿第二定律","牛顿第一定律"],#考点列表
  37. "knowsledge_state": -1,#考点标注状态,1,成功,-1失败
  38. "difficulty_state":1,#难度标注状态,1成功,-1失败
  39. "difficulty": 2#难度值
  40. }
  41. ]
  42. }
  43. """
  44. request_dict = await request.json()
  45. log.info("request: "+ str(request_dict))
  46. result = {"err_code": 0, "msg": "success"}
  47. #1. 校验参数是否合法
  48. if not valid_params(request_dict.get("subject_id", None),
  49. request_dict.get("topic_list", None),
  50. request_dict.get("call_back_url", "")):
  51. result = {"err_code": LabelExceptionErrCode.PARAM_NOT_NULL.value, "msg": "入参不能为空"}
  52. return result
  53. topic_list = request_dict["topic_list"]
  54. if len(topic_list) > 10:
  55. result = {"err_code": LabelExceptionErrCode.NUM_OVER_LIMIT.value, "msg": "一次标注题目个数不能超过10道题"}
  56. return result
  57. for topic in topic_list:
  58. if not valid_params(topic.get("topic_text", ""),topic.get("parse", None), topic.get("topic_id", None)):
  59. result = {"err_code": LabelExceptionErrCode.PARAM_NOT_NULL.value, "msg": "入参不能为空"}
  60. return result
  61. if not valid_is_contained(request_dict["subject_id"], subject_id):
  62. result = {"err_code": LabelExceptionErrCode.PARAM_NOT_NULL.value, "msg": "学科id不合法"}
  63. return result
  64. #2. 启动线程标注
  65. def async_chc():
  66. nonlocal request_dict
  67. result = labeling.auto_label(request_dict)
  68. if result["err_code"] == -1:
  69. requests.post(request_dict["call_back_url"], json=result)
  70. else:
  71. #冗余一个难度字段,用于后续支持难度标注
  72. for topic in result["topic_list"]:
  73. topic["difficulty"] = 2
  74. topic["difficulty_state"] = 1
  75. result["topic_list"].append(topic)
  76. # 将结果post给callback_url
  77. requests.post(request_dict["call_back_url"], json=result)
  78. executor.submit(async_chc)
  79. return JSONResponse(result)
  80. if __name__ == "__main__":
  81. uvicorn.run('server:app',
  82. host="0.0.0.0",
  83. port=8840,
  84. log_level="debug",
  85. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  86. workers=1,
  87. )