Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid deterministic algorithm warning in inference #285

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ release most of the models ourselves.
| meta-llama/Meta-Llama-3.1-8B-Instruct | [model-kaitchup-autogptq-int4*](https://huggingface.co/kaitchup/Meta-Llama-3.1-8B-Instruct-autoround-gptq-4bit-asym), [model-kaitchup-autogptq-sym-int4*](https://huggingface.co/kaitchup/Meta-Llama-3.1-8B-Instruct-autoround-gptq-4bit-sym), [recipe](https://huggingface.co/Intel/Meta-Llama-3.1-8B-Instruct-int4-inc) |
| meta-llama/Meta-Llama-3.1-8B | [model-kaitchup-autogptq-sym-int4*](https://huggingface.co/kaitchup/Meta-Llama-3.1-8B-autoround-gptq-4bit-sym) |
| Qwen/Qwen-VL | [accuracy](./examples/multimodal-modeling/Qwen-VL/README.md), [recipe](./examples/multimodal-modeling/Qwen-VL/run_autoround.sh)
| Qwen/Qwen2-7B | [model-autoround-int4](https://huggingface.co/Intel/Qwen2-7B-int4-inc) |
| Qwen/Qwen2-57B-A14B-Instruct | [model-autoround-int4](https://huggingface.co/Intel/Qwen2-57B-A14B-Instruct-int4-inc) |
| Qwen/Qwen2-7B | [model-autoround-sym-int4](https://huggingface.co/Intel/Qwen2-7B-int4-inc), [model-autogptq-sym-int4](https://huggingface.co/Intel/Qwen2-7B-int4-inc) |
| Qwen/Qwen2-57B-A14B-Instruct | [model-autoround-sym-int4](https://huggingface.co/Intel/Qwen2-57B-A14B-Instruct-int4-inc),[model-autogptq-sym-int4](https://huggingface.co/Intel/Qwen2-57B-A14B-Instruct-int4-inc) |
| 01-ai/Yi-1.5-9B | [model-LnL-AI-autogptq-int4*](https://huggingface.co/LnL-AI/Yi-1.5-9B-4bit-gptq-autoround) |
| 01-ai/Yi-1.5-9B-Chat | [model-LnL-AI-autogptq-int4*](https://huggingface.co/LnL-AI/Yi-1.5-9B-Chat-4bit-gptq-autoround) |
| Intel/neural-chat-7b-v3-3 | [model-autogptq-int4](https://huggingface.co/Intel/neural-chat-7b-v3-3-int4-inc) |
Expand All @@ -283,7 +283,7 @@ release most of the models ourselves.
| google/gemma-2b | [model-autogptq-int4](https://huggingface.co/Intel/gemma-2b-int4-inc) |
| tiiuae/falcon-7b | [model-autogptq-int4-G64](https://huggingface.co/Intel/falcon-7b-int4-inc) |
| sapienzanlp/modello-italia-9b | [model-fbaldassarri-autogptq-int4*](https://huggingface.co/fbaldassarri/modello-italia-9b-autoround-w4g128-cpu) |
| microsoft/phi-2 | [model-autogptq-sym-int4](https://huggingface.co/Intel/phi-2-int4-inc) |
| microsoft/phi-2 | [model-autoround-sym-int4](https://huggingface.co/Intel/phi-2-int4-inc) [model-autogptq-sym-int4](https://huggingface.co/Intel/phi-2-int4-inc) |
| microsoft/Phi-3.5-mini-instruct | [model-kaitchup-autogptq-sym-int4*](https://huggingface.co/kaitchup/Phi-3.5-Mini-instruct-AutoRound-4bit) |
| microsoft/Phi-3-vision-128k-instruct | [recipe](./examples/multimodal-modeling/Phi-3-vision/run_autoround.sh)
| mistralai/Mistral-7B-Instruct-v0.2 | [accuracy](./docs/Mistral-7B-Instruct-v0.2-acc.md), [recipe](./examples/language-modeling/scripts/Mistral-7B-Instruct-v0.2.sh), [example](./examples/language-modeling/) |
Expand Down
7 changes: 3 additions & 4 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
import os
import torch
import transformers

torch.use_deterministic_algorithms(True, warn_only=True)

import copy
import time
from typing import Optional, Union

from transformers import set_seed
from torch import autocast
from tqdm import tqdm
from .calib_dataset import get_dataloader

from .quantizer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer, \
WrapperTransformerConv1d
from .special_model_handler import (check_hidden_state_dim,
Expand Down Expand Up @@ -488,8 +485,10 @@ def calib(self, nsamples, bs):
nsamples (int): The number of samples to use for calibration.
bs (int): The number of samples to use for calibration
"""
from .calib_dataset import get_dataloader
if isinstance(self.dataset, str):
dataset = self.dataset.replace(" ", "") ##remove all whitespaces

# slow here
self.dataloader = get_dataloader(
self.tokenizer,
Expand Down
38 changes: 20 additions & 18 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import random

import torch

torch.use_deterministic_algorithms(True, warn_only=True)
from torch.utils.data import DataLoader

from .utils import is_local_path, logger
Expand Down Expand Up @@ -58,15 +60,15 @@ def default_tokenizer_function(examples, apply_template=apply_template):
if not apply_template:
example = tokenizer(examples["text"], truncation=True, max_length=seqlen)
else:
from jinja2 import Template # pylint: disable=E0401
from jinja2 import Template # pylint: disable=E0401
chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \
else tokenizer.default_chat_template
template = Template(chat_template)
rendered_messages = []
for text in examples["text"]:
message = [{"role": "user", "content": text}]
rendered_message = template.render(messages=message, add_generation_prompt=True, \
bos_token=tokenizer.bos_token)
bos_token=tokenizer.bos_token)
rendered_messages.append(rendered_message)
example = tokenizer(rendered_messages, truncation=True, max_length=seqlen)
return example
Expand Down Expand Up @@ -103,11 +105,11 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split

@register_dataset("madao33/new-title-chinese")
def get_new_chinese_title_dataset(
tokenizer,
seqlen,
dataset_name="madao33/new-title-chinese",
split=None,
seed=42,
tokenizer,
seqlen,
dataset_name="madao33/new-title-chinese",
split=None,
seed=42,
apply_template=False
):
"""Returns a dataloader for the specified dataset and split.
Expand Down Expand Up @@ -148,7 +150,7 @@ def default_tokenizer_function(examples, apply_template=apply_template):
for text in examples["text"]:
message = [{"role": "user", "content": text}]
rendered_message = template.render(messages=message, add_generation_prompt=True, \
bos_token=tokenizer.bos_token)
bos_token=tokenizer.bos_token)
rendered_messages.append(rendered_message)
example = tokenizer(rendered_messages, truncation=True, max_length=seqlen)
return example
Expand Down Expand Up @@ -267,12 +269,12 @@ def load_local_data(data_path):


def get_dataloader(
tokenizer,
seqlen,
dataset_name="NeelNanda/pile-10k",
seed=42,
bs=8,
nsamples=512,
tokenizer,
seqlen,
dataset_name="NeelNanda/pile-10k",
seed=42,
bs=8,
nsamples=512,
):
"""Generate a DataLoader for calibration using specified parameters.

Expand All @@ -293,6 +295,7 @@ def get_dataloader(
"""

dataset_names = dataset_name.split(",")

def filter_func(example):
if isinstance(example["input_ids"], list):
example["input_ids"] = torch.tensor(example["input_ids"])
Expand All @@ -316,7 +319,7 @@ def concat_dataset_element(dataset):
input_id = input_id[1:]
os_cnt, have_bos = os_cnt + 1, True
if input_id[-1] == eos_token_id:
input_id = input_id[:-1]
input_id = input_id[:-1]
os_cnt, have_eos = os_cnt + 1, True

if buffer_input_id.shape[-1] + input_id.shape[-1] + os_cnt > seqlen:
Expand All @@ -326,7 +329,7 @@ def concat_dataset_element(dataset):
input_id_to_append = [torch.tensor([bos_token_id])] + input_id_to_append
if have_eos:
input_id_to_append.append(torch.tensor([eos_token_id]))

concat_input_ids.append(torch.cat(input_id_to_append).to(torch.int64))
attention_mask_list.append(attention_mask)
buffer_input_id = input_id[idx_keep:]
Expand Down Expand Up @@ -405,7 +408,7 @@ def concat_dataset_element(dataset):
name = dataset_names[i].split(':')[0]
if name not in data_lens:
target_cnt = (nsamples - cnt) // (len(datasets) - len(data_lens)) if data_lens \
else (nsamples - cnt) // (len(datasets) - i)
else (nsamples - cnt) // (len(datasets) - i)
target_cnt = min(target_cnt, len(datasets[i]))
cnt += target_cnt
else:
Expand Down Expand Up @@ -447,4 +450,3 @@ def collate_batch(batch):

calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch)
return calib_dataloader