Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is split offloading supported for 8bit mode? #193

Closed
RandomInternetPreson opened this issue Mar 8, 2023 · 9 comments
Closed

Is split offloading supported for 8bit mode? #193

RandomInternetPreson opened this issue Mar 8, 2023 · 9 comments

Comments

@RandomInternetPreson
Copy link
Contributor

It would be really great to run the LLaMA 30B model in 8bit mode, but right now I can't get the memory to split between gpu and CPU using 8bit mode.

I feel like if this were possible it would be the revolutionary!

@sgsdxzy
Copy link
Contributor

sgsdxzy commented Mar 8, 2023

It is possible. I get LLaMA 13B to work on my 3080Ti with both RAM offload and 8bit, and it's 3 times faster than staying in fp16, as more modules can be loaded in VRAM. But this requires some hack:

  1. instead of pass load_in_8bit=True directly, construct a BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) and pass as quantization_config to from_pretained. This is relative easy.
  2. PretrainedModel.from_pretrained does not work with load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True and device_map='auto' huggingface/transformers#22018 of support offloading in auto device mapping is rejected in transformers. So we have to manually build our device map first. I'd rather monkey-patch from_pretained in transformers to just accept auto mapping. Just move the part where device_map='auto' is populated before the load int8 part.

@oobabooga
Copy link
Owner

That's new to me. @sgsdxzy, can you copy and paste the exact modifications that you had to make to get this working for clarity?

@oobabooga
Copy link
Owner

Related to #190

@Digitous
Copy link

Digitous commented Mar 8, 2023

That's new to me. @sgsdxzy, can you copy and paste the exact modifications that you had to make to get this working for clarity?

I did a lot of hacky work on the modeling_utils.py
located under [the following or wherever one's textgen venv transformers is located]
C:\ProgramData\Anaconda3\envs\textgen\lib\site-packages\transformers

I initially tried inserting
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
near the top,
then below quantization_config = kwargs.pop("quantization_config", None)
I had added
AutoModelForCausalLM.from_pretrained(path, device_map = 'auto', quantization_config=quantization_config)
as initially suggested but kept getting errors about "path" not being defined.

--So what worked for me--
Directly under quantization_config kwargs, I surgically inserted a manual device map, hard-coded for a 3090 with 24GB VRAM, followed by the quantization_config suggestions that sgsdxzy suggested above + pamparamm's suggested llm_int8_threshold=0 addition which seems to speed up int8 inference, although I had to make it 0.0 as it required a float value.
After the changes below, I was able to load llama 30b across GPU/CPU with faster inference using
python server.py --load-in-8bit --auto-devices --model llama-30b --notebook --listen
Also to note, am windows user, using the cudaall.dll bitsandbytes hack as well lol. The ideal would be getting the transformers auto map to work as this map is custom written just for 30b (and still fidgeting with what amount of layers I can push to GPU before OOM).

    offload_state_dict = kwargs.pop("offload_state_dict", False)
    load_in_8bit = kwargs.pop("load_in_8bit", False)
    quantization_config = kwargs.pop("quantization_config", None)
#new stuff below	
    device_map = {"model.decoder.embed_tokens": 0, "model.decoder.embed_positions": 0, "model.decoder.norm.weight": 0, "model.decoder.final_layer_norm": 0, "lm_head": 0, "model.decoder.layers.0": 0, "model.decoder.layers.1": 0, "model.decoder.layers.2": 0, "model.decoder.layers.3": 0, "model.decoder.layers.4": 0, "model.decoder.layers.5": 0, "model.decoder.layers.6": 0, "model.decoder.layers.7": 0, "model.decoder.layers.8": 0, "model.decoder.layers.9": 0, "model.decoder.layers.10": 0, "model.decoder.layers.11": 0, "model.decoder.layers.12": 0, "model.decoder.layers.13": 0, "model.decoder.layers.14": 0, "model.decoder.layers.15": 0, "model.decoder.layers.16": 0, "model.decoder.layers.17": 0, "model.decoder.layers.18": 0, "model.decoder.layers.19": 0, "model.decoder.layers.20": 0, "model.decoder.layers.21": 0, "model.decoder.layers.22": 0, "model.decoder.layers.23": 0, "model.decoder.layers.24": 0, "model.decoder.layers.25": 0, "model.decoder.layers.26": 0, "model.decoder.layers.27": 0, "model.decoder.layers.28": 0, "model.decoder.layers.29": 0, "model.decoder.layers.30": 0, "model.decoder.layers.31": 0, "model.decoder.layers.32": 0, "model.decoder.layers.33": 0, "model.decoder.layers.34": 0, "model.decoder.layers.35": 0, "model.decoder.layers.36": 0, "model.decoder.layers.37": 0, "model.decoder.layers.38": "cpu", "model.decoder.layers.39": "cpu", "model.decoder.layers.40": "cpu", "model.decoder.layers.41": "cpu", "model.decoder.layers.42": "cpu", "model.decoder.layers.43": "cpu", "model.decoder.layers.44": "cpu", "model.decoder.layers.45": "cpu", "model.decoder.layers.46": "cpu", "model.decoder.layers.47": "cpu", "model.decoder.layers.48": "cpu", "model.decoder.layers.49": "cpu", "model.decoder.layers.50": "cpu", "model.decoder.layers.51": "cpu", "model.decoder.layers.52": "cpu", "model.decoder.layers.53": "cpu", "model.decoder.layers.54": "cpu", "model.decoder.layers.55": "cpu", "model.decoder.layers.56": "cpu", "model.decoder.layers.57": "cpu", "model.decoder.layers.58": "cpu", "model.decoder.layers.59": "cpu", "model.decoder.layers.60": "cpu", "model.decoder.layers.61": "cpu"}
    quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_threshold=0.0, llm_int8_enable_fp32_cpu_offload=True)

@ye7iaserag
Copy link
Contributor

If this solution can be integrated here... the possibilities!

@sgsdxzy
Copy link
Contributor

sgsdxzy commented Mar 9, 2023

@oobabooga I was originally preparing a pr for you... now that the transformers part is rejected thing get a bit difficult.
The fastest way to get this to work:

  1. edit text-generation-webui/modules/models.py line 102 from params.append("load_in_8bit=True"... to params.append("quantization_config=BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)") (or you can just use params.extend(["load_in_8bit=True", "llm_int8_enable_fp32_cpu_offload=True"])
  2. edit python's site-packages\transformers\modeling_utils.py, in from_pretrained, search for " # Extend the modules to not convert to keys that are supposed to be offloaded to cpu or disk", copy this following part:
if isinstance(device_map, str):
        if model._no_split_modules is None:
            raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
        no_split_modules = model._no_split_modules
        if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
            raise ValueError(
                "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
                "'sequential'."
            )
        elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
            raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
        if device_map != "sequential" and get_balanced_memory is not None:
            max_memory = get_balanced_memory(
                model,
                max_memory=max_memory,
                no_split_module_classes=no_split_modules,
                dtype=torch_dtype,
                low_zero=(device_map == "balanced_low_0"),
            )
        # Make sure tied weights are tied before creating the device map.
        model.tie_weights()
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=no_split_modules,
            dtype=torch_dtype if not load_in_8bit else torch.int8,
            max_memory=max_memory,
        )

(the part that convert device map from 'auto' to a dict) and paste before " # Extend the modules to not convert to keys that are supposed to be offloaded to cpu or disk"

And pass --load-in-8bit --auto-devices to server.py.

@RandomInternetPreson
Copy link
Contributor Author

Oh my frick... thank you so much kind internet stranger. I was able to use your code and can bifurcate 8-bit between cpu and gpu oh my god, wow like absolutely incredible thank you so much!

@athu16
Copy link

athu16 commented Mar 12, 2023

Thanks a lot for this!
A note for anyone who gets NameError: name 'BitsAndBytesConfig' is not defined, use the second method, i.e. add params.extend(["load_in_8bit=True", "llm_int8_enable_fp32_cpu_offload=True"] below the pre-existing code (instead of removing it).

@oobabooga
Copy link
Owner

Fixed in #358. Just use --load-in-8bit --gpu-memory 10 or similar.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants