Skip to content

Commit

Permalink
Merge pull request #91 from johnflavin/trust-remote-download
Browse files Browse the repository at this point in the history
Use TRUST_REMOTE_CODE env from config when downloading custom models
  • Loading branch information
antas-marcin authored Oct 4, 2024
2 parents 78bcdd8 + dc457c3 commit 3bcbb95
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
import nltk
from config import TRUST_REMOTE_CODE
from transformers import (
AutoModel,
AutoTokenizer,
Expand Down Expand Up @@ -82,9 +83,9 @@ def quantization_config(onnx_cpu_arch: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(onnx_path)

def download_model(model_name: str, model_dir: str):
print(f"Downloading model {model_name} from huggingface model hub")
config = AutoConfig.from_pretrained(model_name)
def download_model(model_name: str, model_dir: str, trust_remote_code: bool = False):
print(f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})")
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model_type = config.to_dict()['model_type']

if (model_type is not None and model_type == "t5") or use_sentence_transformers_vectorizer.lower() == "true":
Expand All @@ -100,11 +101,11 @@ def download_model(model_name: str, model_dir: str):
model = klass_architecture.from_pretrained(model_name)
except AttributeError:
print(f"{config.architectures[0]} not found in transformers, fallback to AutoModel")
model = AutoModel.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)
else:
model = AutoModel.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)

model.save_pretrained(model_dir)
tokenizer.save_pretrained(model_dir)
Expand All @@ -114,4 +115,4 @@ def download_model(model_name: str, model_dir: str):
if onnx_runtime == "true":
download_onnx_model(model_name, model_dir)
else:
download_model(model_name, model_dir)
download_model(model_name, model_dir, TRUST_REMOTE_CODE)

0 comments on commit 3bcbb95

Please sign in to comment.