Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr139 dev oct24 #29

Merged
merged 12 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 50 additions & 90 deletions exo/inference/torch/model/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from exo.inference.shard import Shard
from exo.helpers import DEBUG
from exo.inference.torch.utils import extract_layers
from exo.inference.torch.model.hf_safe_tensor_shard import HFSafeTensorShard

from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -52,124 +53,86 @@ def __init__(

# class vars
self.shard = shard
self.hidden_states = None
self.input_ids = None
self.inputs_embeds = None
self.attention_mask = None
self.position_embeddings = None
self.past_key_values = None
self.cache_position = None
self.position_ids = None
self.causal_mask = None
self.local_model_path = local_model_path
self.is_sharded_model = False

self.weight_map = weight_map
self.device = device
self.dtype = dtype
self.device_map = device_map
self.offload_buffers = offload_buffers
self.model_safetensors_path = self.local_model_path/"model.safetensors.index.json"
self.safetensor_sharder = HFSafeTensorShard(
self.local_model_path,
self.shard
)
# setup logit processors
self.logits_processor = LogitsProcessorList([
TopKLogitsWarper(top_k),
TemperatureLogitsWarper(temp),
TopPLogitsWarper(top_p)
])

self.device = device
self.dtype = dtype
self.device_map = device_map

self.offload_buffers = offload_buffers

self.model_safetensors_path = self.local_model_path/"model.safetensors.index.json"

# setup pytorch and transformer llm
# setup sharded llm
try:
if weight_map:
print("loading shard model")
self.llm_model = self.load_sharded_model(
shard,
weight_map,
offload_buffers=self.offload_buffers
)

self.is_sharded_model = True

# clear out edited safetensor json
# this is needed because shard downloader just
# appends and not redownloads the file
os.remove(self.model_safetensors_path)

self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device)
self.model = self.llm_model.model.to(self.device)
else:
print("loading full model")
self.llm_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.local_model_path,
torch_dtype=self.dtype,
device_map=self.device_map,
offload_buffers=offload_buffers
).to(self.device)

self.llm_model = self.load_sharded_model()
self.model = self.llm_model.model.to(self.device)

# restore originals for next run, if one
self.safetensor_sharder.restore_backups()
except Exception as err:
print(f"error loading and splitting model: {err}")
print(f"error loading and sharding model: {err}")
raise

def load_sharded_model(
self,
shard: Shard,
weight_map: dict,
offload_buffers: bool
) -> AutoModelForCausalLM:
# forward variables
self.hidden_states = None
self.input_ids = None
self.inputs_embeds = None
self.attention_mask = None
self.position_embeddings = None
self.past_key_values = None
self.cache_position = None
self.position_ids = None
self.causal_mask = None

def load_sharded_model(self) -> AutoModelForCausalLM:
"""
Loads sharded version of model where only needed
weights are loaded for necessary layers

Args:

Returns:
llm_model (AutoModelForCausalLM) - sharded llm model with only needed layers loaded
"""
if DEBUG >= 4:
print("load_sharded_model called")
print(f"shard: {shard}")

# break out layers per shard range
layer_weight_map = extract_layers(
weight_map,
shard
)

# rewrite model.safetensors.index.json for only needed layers
try:
mst_json = {}
with open(self.model_safetensors_path, "r") as mst_file:
mst_json = json.load(mst_file)
mst_json["weight_map"] = layer_weight_map

if DEBUG >= 4:
print(f"rewritten safetensor index \n{json.dumps(mst_json, indent=4)}")

os.remove(self.model_safetensors_path)

with open(self.model_safetensors_path, "w") as mst_file:
json.dump(mst_json, mst_file, indent=4)
except Exception as err:
print(f"err: {err}")
raise
# modify safetensor
self.safetensor_sharder.modify_safetensor()
self.safetensor_sharder.create_safetensor_index()
self.safetensor_sharder.shard_safetensor_index(self.weight_map)

# load model
try:
shard_num_hidden_layers = shard.end_layer - shard.start_layer
shard_num_hidden_layers = (self.shard.end_layer - self.shard.start_layer) + 1
if DEBUG >= 4:
print(f"config with {shard_num_hidden_layers} layers")

llm_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.local_model_path,
device_map=self.device_map,
torch_dtype=self.dtype,
offload_buffers=offload_buffers,
offload_buffers=self.offload_buffers,
local_files_only=True,
num_hidden_layers=shard_num_hidden_layers
num_hidden_layers=shard_num_hidden_layers,
use_safetensors=True,
low_cpu_mem_usage=True
)

return llm_model.to(self.device)
# restore backup for next run
self.safetensor_sharder.restore_backups()

if self.device_map == "auto":
return llm_model
else:
return llm_model.to(self.device)

except Exception as err:
print(f"err: {err}")
Expand Down Expand Up @@ -265,14 +228,11 @@ def forward(
self.cache_position = model_inputs["cache_position"]
self.past_key_values = model_inputs["past_key_values"]

if DEBUG >= 4:
print(f"model_inputs: {model_inputs}")
if DEBUG >= 4:
print(f"model_inputs: {model_inputs}")

# run through decoder layers
if self.is_sharded_model:
layer_amt = range(self.shard.end_layer - self.shard.start_layer)
else:
layer_amt = range(self.shard.start_layer, self.shard.end_layer)
layer_amt = range(self.shard.end_layer - self.shard.start_layer)

if DEBUG >= 4:
print(f"hidden_states: {self.hidden_states}")
Expand Down Expand Up @@ -317,7 +277,7 @@ def forward(
# shard is last layer says true at the start and not detecting last layer correctly
if self.shard.is_last_layer():
self.hidden_states = self.model.norm(self.hidden_states)
if use_legacy_cache and self.next_decoder_cache is not None:
if use_legacy_cache:
self.past_key_values = self.next_decoder_cache.to_legacy_cache()
else:
self.past_key_values = self.next_decoder_cache
Expand Down
Loading