123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- from fastapi import FastAPI, Request, BackgroundTasks
- import uvicorn
- from fastapi.responses import JSONResponse, Response
- from dotenv import load_dotenv
- load_dotenv()
- from service import labeling
- from config.config import subject_id
- from common.valid_check import valid_params, valid_is_contained, LabelExceptionErrCode
- TIMEOUT_KEEP_ALIVE = 5 # seconds.
- from config.config import log
- app = FastAPI()
- from concurrent.futures import ThreadPoolExecutor
- import requests
- executor = ThreadPoolExecutor(max_workers=1)
- @app.post("/auto_label")
- async def auto_label(request: Request) -> Response:
- """
- 入参格式:
- {
- "subject_id": xx,
- "topic_list": [
- {
- "topic_id": 23,#题目id
- "topic_text": "题干",
- "parse": "解析",
- "option": []
- }
- ],
- "call_back_url": "回调url"
- }
- 回调响应函数格式:
- {
- "topic_list": [
- {
- "topic_id": 234,#题目id
- "labels": ["牛顿第二定律","牛顿第一定律"],#考点列表
- "knowsledge_state": -1,#考点标注状态,1,成功,-1失败
- "difficulty_state":1,#难度标注状态,1成功,-1失败
- "difficulty": 2#难度值
- }
- ]
- }
- """
- request_dict = await request.json()
- log.info("request: "+ str(request_dict))
- result = {"err_code": 0, "msg": "success"}
- #1. 校验参数是否合法
- if not valid_params(request_dict.get("subject_id", None),
- request_dict.get("topic_list", None),
- request_dict.get("call_back_url", "")):
- result = {"err_code": LabelExceptionErrCode.PARAM_NOT_NULL.value, "msg": "入参不能为空"}
- return result
- topic_list = request_dict["topic_list"]
- if len(topic_list) > 10:
- result = {"err_code": LabelExceptionErrCode.NUM_OVER_LIMIT.value, "msg": "一次标注题目个数不能超过10道题"}
- return result
- for topic in topic_list:
- if not valid_params(topic.get("topic_text", ""),topic.get("parse", None), topic.get("topic_id", None)):
- result = {"err_code": LabelExceptionErrCode.PARAM_NOT_NULL.value, "msg": "入参不能为空"}
- return result
- if not valid_is_contained(request_dict["subject_id"], subject_id):
- result = {"err_code": LabelExceptionErrCode.PARAM_NOT_NULL.value, "msg": "学科id不合法"}
- return result
- #2. 启动线程标注
- def async_chc():
- nonlocal request_dict
- result = labeling.auto_label(request_dict)
- if result["err_code"] == -1:
- requests.post(request_dict["call_back_url"], json=result)
- else:
- #冗余一个难度字段,用于后续支持难度标注
- for topic in result["topic_list"]:
- topic["difficulty"] = 2
- topic["difficulty_state"] = 1
- result["topic_list"].append(topic)
- # 将结果post给callback_url
- requests.post(request_dict["call_back_url"], json=result)
- executor.submit(async_chc)
- return JSONResponse(result)
- if __name__ == "__main__":
- uvicorn.run('server:app',
- host="0.0.0.0",
- port=8840,
- log_level="debug",
- timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
- workers=1,
- )
|