|
@@ -1,5 +1,5 @@
|
|
|
from fastapi import FastAPI, Request
|
|
|
-from transformers import AutoTokenizer, AutoModel
|
|
|
+from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
|
|
|
import uvicorn, json, datetime
|
|
|
import torch
|
|
|
|
|
@@ -50,8 +50,11 @@ async def create_item(request: Request):
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
|
|
- model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(4).half().cuda()
|
|
|
+ model_file = "chatglm-6b"
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(model_file, trust_remote_code=True)
|
|
|
+ #quantization_config= BitsAndBytesConfig(load_in_8bit=True)
|
|
|
+
|
|
|
+ model = AutoModel.from_pretrained(model_file, trust_remote_code=True,max_memory=torch.cuda.get_device_properties(0).total_memory).quantize(4).half().cuda()
|
|
|
#model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
|
|
|
model.eval()
|
|
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|