Skip to content

Commit

Permalink
Merge pull request #93 from weaviate/add-support-for-trust-remote-code
Browse files Browse the repository at this point in the history
Add support for trust remote code
  • Loading branch information
antas-marcin authored Oct 5, 2024
2 parents 244168e + d7d8312 commit ef1ad13
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
nltk_dir = "./nltk_data"
model_name = os.getenv("MODEL_NAME", None)
force_automodel = os.getenv("FORCE_AUTOMODEL", False)
trust_remote_code = os.getenv("TRUST_REMOTE_CODE", False)
if not model_name:
print("Fatal: MODEL_NAME is required")
print(
Expand Down Expand Up @@ -46,11 +47,13 @@
)


def download_onnx_model(model_name: str, model_dir: str):
def download_onnx_model(
model_name: str, model_dir: str, trust_remote_code: bool = False
):
# Download model and tokenizer
onnx_path = Path(model_dir)
ort_model = ORTModelForFeatureExtraction.from_pretrained(
model_name, from_transformers=True
model_name, from_transformers=True, trust_remote_code=trust_remote_code
)
# Save model
ort_model.save_pretrained(onnx_path)
Expand Down Expand Up @@ -92,9 +95,11 @@ def quantization_config(onnx_cpu_arch: str):
tokenizer.save_pretrained(onnx_path)


def download_model(model_name: str, model_dir: str):
print(f"Downloading model {model_name} from huggingface model hub")
config = AutoConfig.from_pretrained(model_name)
def download_model(model_name: str, model_dir: str, trust_remote_code: bool = False):
print(
f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})"
)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model_type = config.to_dict()["model_type"]

if (
Expand All @@ -114,11 +119,17 @@ def download_model(model_name: str, model_dir: str):
print(
f"{config.architectures[0]} not found in transformers, fallback to AutoModel"
)
model = AutoModel.from_pretrained(model_name)
model = AutoModel.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)
else:
model = AutoModel.from_pretrained(model_name)
model = AutoModel.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)

model.save_pretrained(model_dir)
tokenizer.save_pretrained(model_dir)
Expand All @@ -128,6 +139,6 @@ def download_model(model_name: str, model_dir: str):


if onnx_runtime == "true":
download_onnx_model(model_name, model_dir)
download_onnx_model(model_name, model_dir, trust_remote_code)
else:
download_model(model_name, model_dir)
download_model(model_name, model_dir, trust_remote_code)

0 comments on commit ef1ad13

Please sign in to comment.