Skip to content

Commit

Permalink
Merge pull request #89 from weaviate/fix/ntlk-vulnerability
Browse files Browse the repository at this point in the history
Bump ntlk dep to 3.9.1
  • Loading branch information
antas-marcin authored Oct 4, 2024
2 parents 78bcdd8 + fd102cd commit 244168e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 18 deletions.
48 changes: 32 additions & 16 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,43 @@
from pathlib import Path


model_dir = './models/model'
nltk_dir = './nltk_data'
model_name = os.getenv('MODEL_NAME', None)
force_automodel = os.getenv('FORCE_AUTOMODEL', False)
model_dir = "./models/model"
nltk_dir = "./nltk_data"
model_name = os.getenv("MODEL_NAME", None)
force_automodel = os.getenv("FORCE_AUTOMODEL", False)
if not model_name:
print("Fatal: MODEL_NAME is required")
print("Please set environment variable MODEL_NAME to a HuggingFace model name, see https://huggingface.co/models")
print(
"Please set environment variable MODEL_NAME to a HuggingFace model name, see https://huggingface.co/models"
)
sys.exit(1)

if force_automodel:
print(f"Using AutoModel for {model_name} to instantiate model")

onnx_runtime = os.getenv('ONNX_RUNTIME')
onnx_runtime = os.getenv("ONNX_RUNTIME")
if not onnx_runtime:
onnx_runtime = "false"

onnx_cpu_arch = os.getenv('ONNX_CPU')
onnx_cpu_arch = os.getenv("ONNX_CPU")
if not onnx_cpu_arch:
onnx_cpu_arch = "arm64"

use_sentence_transformers_vectorizer = os.getenv('USE_SENTENCE_TRANSFORMERS_VECTORIZER')
use_sentence_transformers_vectorizer = os.getenv("USE_SENTENCE_TRANSFORMERS_VECTORIZER")
if not use_sentence_transformers_vectorizer:
use_sentence_transformers_vectorizer = "false"

print(f"Downloading MODEL_NAME={model_name} with FORCE_AUTOMODEL={force_automodel} ONNX_RUNTIME={onnx_runtime} ONNX_CPU={onnx_cpu_arch}")
print(
f"Downloading MODEL_NAME={model_name} with FORCE_AUTOMODEL={force_automodel} ONNX_RUNTIME={onnx_runtime} ONNX_CPU={onnx_cpu_arch}"
)


def download_onnx_model(model_name: str, model_dir: str):
# Download model and tokenizer
onnx_path = Path(model_dir)
ort_model = ORTModelForFeatureExtraction.from_pretrained(model_name, from_transformers=True)
ort_model = ORTModelForFeatureExtraction.from_pretrained(
model_name, from_transformers=True
)
# Save model
ort_model.save_pretrained(onnx_path)

Expand All @@ -59,7 +66,9 @@ def quantization_config(onnx_cpu_arch: str):
if onnx_cpu_arch.lower() == "avx512_vnni":
print("Quantize Model for x86_64 (amd64) (avx512_vnni)")
save_quantization_info("AVX-512")
return AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
return AutoQuantizationConfig.avx512_vnni(
is_static=False, per_channel=False
)
if onnx_cpu_arch.lower() == "arm64":
print(f"Quantize Model for ARM64")
save_quantization_info("ARM64")
Expand All @@ -82,24 +91,29 @@ def quantization_config(onnx_cpu_arch: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(onnx_path)


def download_model(model_name: str, model_dir: str):
print(f"Downloading model {model_name} from huggingface model hub")
config = AutoConfig.from_pretrained(model_name)
model_type = config.to_dict()['model_type']
model_type = config.to_dict()["model_type"]

if (model_type is not None and model_type == "t5") or use_sentence_transformers_vectorizer.lower() == "true":
if (
model_type is not None and model_type == "t5"
) or use_sentence_transformers_vectorizer.lower() == "true":
SentenceTransformer(model_name, cache_folder=model_dir)
with open(f"{model_dir}/model_name", "w") as f:
f.write(model_name)
else:
if config.architectures and not force_automodel:
print(f"Using class {config.architectures[0]} to load model weights")
mod = __import__('transformers', fromlist=[config.architectures[0]])
mod = __import__("transformers", fromlist=[config.architectures[0]])
try:
klass_architecture = getattr(mod, config.architectures[0])
model = klass_architecture.from_pretrained(model_name)
except AttributeError:
print(f"{config.architectures[0]} not found in transformers, fallback to AutoModel")
print(
f"{config.architectures[0]} not found in transformers, fallback to AutoModel"
)
model = AutoModel.from_pretrained(model_name)
else:
model = AutoModel.from_pretrained(model_name)
Expand All @@ -109,7 +123,9 @@ def download_model(model_name: str, model_dir: str):
model.save_pretrained(model_dir)
tokenizer.save_pretrained(model_dir)

nltk.download('punkt', download_dir=nltk_dir)
nltk.download("punkt", download_dir=nltk_dir)
nltk.download("punkt_tab", download_dir=nltk_dir)


if onnx_runtime == "true":
download_onnx_model(model_name, model_dir)
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ requests==2.32.3
transformers==4.42.4
fastapi==0.112.0
uvicorn==0.30.5
nltk==3.8.1
nltk==3.9.1
torch==2.4.0
sentencepiece==0.2.0
sentence-transformers==3.0.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
transformers==4.42.4
fastapi==0.112.0
uvicorn==0.30.5
nltk==3.8.1
nltk==3.9.1
torch==2.4.0
sentencepiece==0.2.0
sentence-transformers==3.0.1
Expand Down

0 comments on commit 244168e

Please sign in to comment.