api.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from fastapi import FastAPI, Request
  2. from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
  3. import uvicorn, json, datetime
  4. import torch
  5. DEVICE = "cuda"
  6. DEVICE_ID = "0"
  7. CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
  8. def torch_gc():
  9. if torch.cuda.is_available():
  10. with torch.cuda.device(CUDA_DEVICE):
  11. torch.cuda.empty_cache()
  12. torch.cuda.ipc_collect()
  13. app = FastAPI()
  14. @app.post("/")
  15. async def create_item(request: Request):
  16. global model, tokenizer
  17. json_post_raw = await request.json()
  18. json_post = json.dumps(json_post_raw)
  19. json_post_list = json.loads(json_post)
  20. prompt = json_post_list.get('prompt')
  21. history = json_post_list.get('history')
  22. max_length = json_post_list.get('max_length')
  23. top_p = json_post_list.get('top_p')
  24. temperature = json_post_list.get('temperature')
  25. response, history = model.chat(tokenizer,
  26. prompt,
  27. history=history,
  28. max_length=max_length if max_length else 2048,
  29. top_p=top_p if top_p else 0.7,
  30. temperature=temperature if temperature else 0.95)
  31. now = datetime.datetime.now()
  32. time = now.strftime("%Y-%m-%d %H:%M:%S")
  33. answer = {
  34. "response": response,
  35. "history": history,
  36. "status": 200,
  37. "time": time
  38. }
  39. log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
  40. print(log)
  41. torch_gc()
  42. return answer
  43. if __name__ == '__main__':
  44. #model_file = "chatglm-6b"
  45. #model_file = "THUDM/chatglm-6b-int4-qe"
  46. model_file = "chatglm-6b-int4-qe"
  47. tokenizer = AutoTokenizer.from_pretrained(model_file, trust_remote_code=True)
  48. #quantization_config= BitsAndBytesConfig(load_in_8bit=True)
  49. model = AutoModel.from_pretrained(model_file, trust_remote_code=True,max_memory=torch.cuda.get_device_properties(0).total_memory).quantize(4).half().cuda()
  50. #model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
  51. model.eval()
  52. uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)