Skip to content

Commit

Permalink
Merge pull request #85 from weaviate/trust-remote-code-env-var
Browse files Browse the repository at this point in the history
Add`TRUST_REMOTE_CODE` env var
  • Loading branch information
cdpierse authored Jun 21, 2024
2 parents 1453612 + 7e237ca commit 9b84e55
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 26 deletions.
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", False)
19 changes: 13 additions & 6 deletions meta.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
from transformers import AutoConfig

from config import TRUST_REMOTE_CODE


class Meta:
config: AutoConfig

def __init__(self, model_path: str, model_name: str, use_sentence_transformer_vectorizer: bool):
def __init__(
self,
model_path: str,
model_name: str,
use_sentence_transformer_vectorizer: bool,
):
if use_sentence_transformer_vectorizer:
self.config = {"model_name": model_name, "model_type": None}
else:
self.config = AutoConfig.from_pretrained(model_path).to_dict()
self.config = AutoConfig.from_pretrained(
model_path, trust_remote_code=TRUST_REMOTE_CODE
).to_dict()

def get(self):
return {
'model': self.config
}
return {"model": self.config}

def get_model_type(self):
return self.config['model_type']
return self.config["model_type"]

def get_architecture(self):
architecture = None
Expand Down
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ sentence-transformers==2.6.1
optimum==1.17.1
onnxruntime==1.17.1
onnx==1.15.0
numpy==1.26.4
pytest
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ sentence-transformers==2.6.1
optimum==1.17.1
onnxruntime==1.17.1
onnx==1.15.0
numpy==1.26.4
38 changes: 18 additions & 20 deletions vectorizer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import math
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Optional

import nltk
import torch
import torch.nn.functional as F
from pathlib import Path
import nltk
from nltk.tokenize import sent_tokenize
from optimum.onnxruntime import ORTModelForFeatureExtraction
from pydantic import BaseModel
from transformers import (
AutoModel,
AutoTokenizer,
T5ForConditionalGeneration,
T5Tokenizer,
DPRContextEncoder,
DPRQuestionEncoder,
)
from sentence_transformers import SentenceTransformer
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import (AutoModel, AutoTokenizer, DPRContextEncoder,
DPRQuestionEncoder, T5ForConditionalGeneration,
T5Tokenizer)

from config import TRUST_REMOTE_CODE

# limit transformer batch size to limit parallel inference, otherwise we run
# into memory problems
Expand Down Expand Up @@ -78,8 +75,9 @@ class ONNXVectorizer:

def __init__(self, model_path) -> None:
onnx_path = Path(model_path)
self.model = ORTModelForFeatureExtraction.from_pretrained(onnx_path, file_name="model_quantized.onnx")
self.tokenizer = AutoTokenizer.from_pretrained(onnx_path)
self.model = ORTModelForFeatureExtraction.from_pretrained(onnx_path, file_name="model_quantized.onnx",
trust_remote_code=TRUST_REMOTE_CODE)
self.tokenizer = AutoTokenizer.from_pretrained(onnx_path, trust_remote_code=TRUST_REMOTE_CODE)

def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
Expand Down Expand Up @@ -179,11 +177,11 @@ def __init__(self, cuda_support: bool, cuda_core: str):
self.cuda_core = cuda_core

def create_tokenizer(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
return self.tokenizer

def create_model(self, model_path):
self.model = AutoModel.from_pretrained(model_path)
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
return self.model

def get_embeddings(self, batch_results):
Expand Down Expand Up @@ -236,9 +234,9 @@ def __init__(self, architecture: str, cuda_support: bool, cuda_core: str):

def create_model(self, model_path):
if self.architecture == "DPRQuestionEncoder":
self.model = DPRQuestionEncoder.from_pretrained(model_path)
self.model = DPRQuestionEncoder.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
else:
self.model = DPRContextEncoder.from_pretrained(model_path)
self.model = DPRContextEncoder.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
return self.model

def get_batch_results(self, tokens, text):
Expand All @@ -259,11 +257,11 @@ def __init__(self, cuda_support: bool, cuda_core: str):
self.cuda_core = cuda_core

def create_model(self, model_path):
self.model = T5ForConditionalGeneration.from_pretrained(model_path)
self.model = T5ForConditionalGeneration.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
return self.model

def create_tokenizer(self, model_path):
self.tokenizer = T5Tokenizer.from_pretrained(model_path)
self.tokenizer = T5Tokenizer.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
return self.tokenizer

def get_embeddings(self, batch_results):
Expand Down

0 comments on commit 9b84e55

Please sign in to comment.