rabin há 1 ano atrás
pai
commit
5fa62657c3
1 ficheiros alterados com 6 adições e 3 exclusões
  1. 6 3
      api.py

+ 6 - 3
api.py

@@ -1,5 +1,5 @@
 from fastapi import FastAPI, Request
-from transformers import AutoTokenizer, AutoModel
+from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
 import uvicorn, json, datetime
 import torch
 
@@ -50,8 +50,11 @@ async def create_item(request: Request):
 
 
 if __name__ == '__main__':
-    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
-    model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(4).half().cuda()
+    model_file = "chatglm-6b"
+    tokenizer = AutoTokenizer.from_pretrained(model_file, trust_remote_code=True)
+    #quantization_config= BitsAndBytesConfig(load_in_8bit=True)
+
+    model = AutoModel.from_pretrained(model_file, trust_remote_code=True,max_memory=torch.cuda.get_device_properties(0).total_memory).quantize(4).half().cuda()
     #model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
     model.eval()
     uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)