Skip to content

Commit

Permalink
feat(transformers): support also text generation (#1630)
Browse files Browse the repository at this point in the history
* feat(transformers): support also text generation

Signed-off-by: Ettore Di Giacinto <[email protected]>

* embedded: set seed -1

---------

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Jan 23, 2024
1 parent d5d82ba commit 5e335ea
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 8 deletions.
53 changes: 45 additions & 8 deletions backend/python/transformers/transformers_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import grpc
import torch

import torch.cuda
from transformers import AutoTokenizer, AutoModel

_ONE_DAY_IN_SECONDS = 60 * 60 * 24
Expand Down Expand Up @@ -70,14 +70,10 @@ def LoadModel(self, request, context):
try:
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

if request.CUDA:
if request.CUDA or torch.cuda.is_available():
try:
# TODO: also tensorflow, make configurable
import torch.cuda
if torch.cuda.is_available():
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
self.model = self.model.to("cuda")
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
self.model = self.model.to("cuda")
except Exception as err:
print("Not using CUDA:", err, file=sys.stderr)
except Exception as err:
Expand Down Expand Up @@ -113,6 +109,47 @@ def Embedding(self, request, context):
print("Embeddings:", sentence_embeddings, file=sys.stderr)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)

def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters.
Args:
request: The predict request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The predict result.
"""
if request.TopP == 0:
request.TopP = 0.9

max_tokens = 200
if request.Tokens > 0:
max_tokens = request.Tokens

inputs = self.tokenizer.tokenizer(request.Prompt, return_tensors="pt").input_ids
outputs = self.model.generate(inputs,max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)

generated_text = self.tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Remove prompt from response if present
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")

return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))

def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results.
Args:
request: The predict stream request.
context: The gRPC context.
Returns:
backend_pb2.Result: The predict stream result.
"""
yield self.Predict(request, context)


def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
Expand Down
1 change: 1 addition & 0 deletions embedded/models/dolphin-2.5-mixtral-8x7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ parameters:
temperature: 0.2
top_k: 40
top_p: 0.95
seed: -1
template:
chat_message: |
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}}
Expand Down
1 change: 1 addition & 0 deletions embedded/models/llava.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ parameters:
temperature: 0.2
top_k: 40
top_p: 0.95
seed: -1

template:
chat: |
Expand Down
1 change: 1 addition & 0 deletions embedded/models/mistral-openorca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ parameters:
temperature: 0.2
top_k: 40
top_p: 0.95
seed: -1
template:
chat_message: |
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}}
Expand Down
1 change: 1 addition & 0 deletions embedded/models/mixtral-instruct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ parameters:
model: huggingface://TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/mixtral-8x7b-instruct-v0.1.Q2_K.gguf
temperature: 0.2
top_k: 40
seed: -1
top_p: 0.95
template:
chat: &chat |
Expand Down
1 change: 1 addition & 0 deletions embedded/models/tinyllama-chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ parameters:
model: huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q8_0.gguf
temperature: 0.2
top_k: 40
seed: -1
top_p: 0.95
template:
chat_message: |
Expand Down
1 change: 1 addition & 0 deletions examples/configurations/phi-2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ parameters:
temperature: 0.2
top_k: 40
top_p: 0.95
seed: -1
template:
chat: &template |
Instruct: {{.Input}}
Expand Down

0 comments on commit 5e335ea

Please sign in to comment.