Skip to content

Commit

Permalink
Merge pull request #97 from weaviate/add-support-for-trust-remote-cod…
Browse files Browse the repository at this point in the history
…e-setting

Add support for trust remote code setting
  • Loading branch information
antas-marcin authored Nov 14, 2024
2 parents 51e5ad0 + 51ae05f commit 1427253
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 39 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ jobs:
MODEL_TAG_NAME: ${{matrix.model_tag_name}}
ONNX_RUNTIME: ${{matrix.onnx_runtime}}
USE_SENTENCE_TRANSFORMERS_VECTORIZER: ${{matrix.use_sentence_transformers_vectorizer}}
TRUST_REMOTE_CODE: ${{matrix.trust_remote_code}}
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ ARG TARGETARCH
ARG MODEL_NAME
ARG ONNX_RUNTIME
ENV ONNX_CPU=${TARGETARCH}
ARG TRUST_REMOTE_CODE
ARG USE_SENTENCE_TRANSFORMERS_VECTORIZER
RUN mkdir nltk_data
COPY download.py .
RUN ./download.py
Expand Down
24 changes: 20 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from logging import getLogger
from fastapi import FastAPI, Response, status
from typing import Union
from config import TRUST_REMOTE_CODE
from vectorizer import Vectorizer, VectorInput
from meta import Meta

Expand Down Expand Up @@ -55,7 +57,7 @@ def startup_event():

model_dir = "./models/model"

def get_model_directory() -> (str, bool):
def get_model_name() -> Union[str, bool]:
if os.path.exists(f"{model_dir}/model_name"):
with open(f"{model_dir}/model_name", "r") as f:
model_name = f.read()
Expand All @@ -70,6 +72,13 @@ def get_onnx_runtime() -> bool:
return onnx_runtime == "true"
return False

def get_trust_remote_code() -> bool:
if os.path.exists(f"{model_dir}/trust_remote_code"):
with open(f"{model_dir}/trust_remote_code", "r") as f:
trust_remote_code = f.read()
return trust_remote_code == "true"
return TRUST_REMOTE_CODE

def log_info_about_onnx(onnx_runtime: bool):
if onnx_runtime:
onnx_quantization_info = "missing"
Expand All @@ -80,11 +89,17 @@ def log_info_about_onnx(onnx_runtime: bool):
f"Running ONNX vectorizer with quantized model for {onnx_quantization_info}"
)

model_name, use_sentence_transformer_vectorizer = get_model_directory()
model_name, use_sentence_transformer_vectorizer = get_model_name()
onnx_runtime = get_onnx_runtime()
trust_remote_code = get_trust_remote_code()
log_info_about_onnx(onnx_runtime)

meta_config = Meta(model_dir, model_name, use_sentence_transformer_vectorizer)
meta_config = Meta(
model_dir,
model_name,
use_sentence_transformer_vectorizer,
trust_remote_code,
)
vec = Vectorizer(
model_dir,
cuda_support,
Expand All @@ -96,6 +111,7 @@ def log_info_about_onnx(onnx_runtime: bool):
onnx_runtime,
use_sentence_transformer_vectorizer,
model_name,
trust_remote_code,
)


Expand All @@ -112,7 +128,7 @@ def meta():

@app.post("/vectors")
@app.post("/vectors/")
async def read_item(item: VectorInput, response: Response):
async def vectorize(item: VectorInput, response: Response):
try:
vector = await vec.vectorize(item.text, item.config)
return {"text": item.text, "vector": vector.tolist(), "dim": len(vector)}
Expand Down
4 changes: 4 additions & 0 deletions cicd/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ set -eou pipefail
local_repo=${LOCAL_REPO?Variable LOCAL_REPO is required}
model_name=${MODEL_NAME?Variable MODEL_NAME is required}
onnx_runtime=${ONNX_RUNTIME?Variable ONNX_RUNTIME is required}
trust_remote_code=${TRUST_REMOTE_CODE:-false}
use_sentence_transformers_vectorizer=${USE_SENTENCE_TRANSFORMERS_VECTORIZER:-false}

docker build \
--build-arg "MODEL_NAME=$model_name" \
--build-arg "ONNX_RUNTIME=$onnx_runtime" \
--build-arg "TRUST_REMOTE_CODE=$trust_remote_code" \
--build-arg "USE_SENTENCE_TRANSFORMERS_VECTORIZER=$use_sentence_transformers_vectorizer" \
-t "$local_repo" .
5 changes: 5 additions & 0 deletions cicd/docker_push.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ model_name=${MODEL_NAME?Variable MODEL_NAME is required}
docker_username=${DOCKER_USERNAME?Variable DOCKER_USERNAME is required}
docker_password=${DOCKER_PASSWORD?Variable DOCKER_PASSWORD is required}
onnx_runtime=${ONNX_RUNTIME?Variable ONNX_RUNTIME is required}
trust_remote_code=${TRUST_REMOTE_CODE:-false}
use_sentence_transformers_vectorizer=${USE_SENTENCE_TRANSFORMERS_VECTORIZER:-false}
original_model_name=$model_name
git_tag=$GITHUB_REF_NAME

Expand All @@ -16,6 +18,7 @@ function main() {
echo "git ref name is $GITHUB_REF_NAME"
echo "git tag is $git_tag"
echo "onnx_runtime is $onnx_runtime"
echo "trust_remote_code is $trust_remote_code"
push_tag
}

Expand Down Expand Up @@ -46,6 +49,8 @@ function push_tag() {
docker buildx build --platform=linux/arm64,linux/amd64 \
--build-arg "MODEL_NAME=$original_model_name" \
--build-arg "ONNX_RUNTIME=$onnx_runtime" \
--build-arg "TRUST_REMOTE_CODE=$trust_remote_code" \
--build-arg "USE_SENTENCE_TRANSFORMERS_VECTORIZER=$use_sentence_transformers_vectorizer" \
--push \
--tag "$tag_git" \
--tag "$tag_latest" \
Expand Down
23 changes: 20 additions & 3 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
import nltk
import json
from transformers import (
AutoModel,
AutoTokenizer,
Expand Down Expand Up @@ -98,6 +99,18 @@ def quantization_config(onnx_cpu_arch: str):


def download_model(model_name: str, model_dir: str, trust_remote_code: bool = False):
def save_model_name(model_name: str):
with open(f"{model_dir}/model_name", "w") as f:
f.write(model_name)

def save_trust_remote_code(trust_remote_code: bool):
with open(f"{model_dir}/trust_remote_code", "w") as f:
f.write(f"{trust_remote_code}")

def save_model_config(model_config):
with open(f"{model_dir}/model_config", "w") as f:
f.write(json.dumps(model_config))

print(
f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})"
)
Expand All @@ -107,9 +120,11 @@ def download_model(model_name: str, model_dir: str, trust_remote_code: bool = Fa
if (
model_type is not None and model_type == "t5"
) or use_sentence_transformers_vectorizer.lower() == "true":
SentenceTransformer(model_name, cache_folder=model_dir)
with open(f"{model_dir}/model_name", "w") as f:
f.write(model_name)
SentenceTransformer(
model_name, cache_folder=model_dir, trust_remote_code=trust_remote_code
)
save_model_name(model_name)
save_model_config(config.to_dict())
else:
if config.architectures and not force_automodel:
print(f"Using class {config.architectures[0]} to load model weights")
Expand All @@ -136,6 +151,8 @@ def download_model(model_name: str, model_dir: str, trust_remote_code: bool = Fa
model.save_pretrained(model_dir)
tokenizer.save_pretrained(model_dir)

save_trust_remote_code(trust_remote_code)

nltk.download("punkt", download_dir=nltk_dir)
nltk.download("punkt_tab", download_dir=nltk_dir)

Expand Down
15 changes: 9 additions & 6 deletions meta.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import json
import os
from transformers import AutoConfig

from config import TRUST_REMOTE_CODE


class Meta:
config: AutoConfig

def __init__(
self,
model_path: str,
model_name: str,
use_sentence_transformer_vectorizer: bool,
trust_remote_code: bool,
):
if use_sentence_transformer_vectorizer:
self.config = {"model_name": model_name, "model_type": None}
if os.path.exists(f"{model_path}/model_config"):
with open(f"{model_path}/model_config", "r") as f:
self.config = json.loads(f.read())
else:
self.config = {"model_name": model_name, "model_type": None}
else:
self.config = AutoConfig.from_pretrained(
model_path, trust_remote_code=TRUST_REMOTE_CODE
model_path, trust_remote_code=trust_remote_code
).to_dict()

def get(self):
Expand Down
10 changes: 8 additions & 2 deletions smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ def test_meta(self):
self.assertIsInstance(res.json(), dict)

def test_vectorizing(self):
def try_to_vectorize(url):
print(f"url: {url}")
def get_req_body(task_type: str = ""):
req_body = {"text": "The London Eye is a ferris wheel at the River Thames."}
if task_type != "":
req_body["config"] = {"task_type": task_type}
return req_body

def try_to_vectorize(url, task_type: str = ""):
print(f"url: {url}")
req_body = get_req_body(task_type)

res = requests.post(url, json=req_body)
resBody = res.json()
Expand Down
Loading

0 comments on commit 1427253

Please sign in to comment.