Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallelism Support for AffineQuantizedTensor #988

Open
3 of 6 tasks
jerryzh168 opened this issue Oct 1, 2024 · 3 comments · Fixed by #1003
Open
3 of 6 tasks

Tensor Parallelism Support for AffineQuantizedTensor #988

jerryzh168 opened this issue Oct 1, 2024 · 3 comments · Fixed by #1003
Assignees
Labels
good first issue Good for newcomers

Comments

@jerryzh168
Copy link
Contributor

jerryzh168 commented Oct 1, 2024

Recently we landed #939 to support tensor parallelism for int8 weight only quantization, another example: #785

now we can support tensor parallelism for other types of quantization as well.

Steps

1. Create test

Since we don't have many tests today, we can optimize for readability for now, so we can copy paste the test cases to a https://github.com/pytorch/ao/blob/main/test/dtypes/test_affine_quantized_tensor_parallel.py instead of inheriting from these test cases

For new tests you can follow

class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tp(self, dtype):
return self._test_tp(dtype)
class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
COMMON_DTYPES = [torch.bfloat16]
@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tp(self, dtype):
return self._test_tp(dtype)
to create your own test case

2. Run the test

python test/dtypes/test_affine_quantized_tensor_parallel.py

3. Add support for missing ops until test passes

We'd expect people to add some slicing ops etc. to the corresponding TensorImpl tensor subclass

@jerryzh168 jerryzh168 added the good first issue Good for newcomers label Oct 1, 2024
@jainapurva jainapurva linked a pull request Oct 4, 2024 that will close this issue
@melvinebenezer
Copy link
Contributor

@jerryzh168 should'nt this Issue be kept open until other points in the description are fixed ?
I have a WIP #1026

@jainapurva jainapurva reopened this Oct 7, 2024
@p4arth
Copy link
Contributor

p4arth commented Oct 17, 2024

Hi @jerryzh168, I had two questions

  1. What does fpx mean here exactly?
  2. For int8 dynamic act + int8 weight only writing tests is left?

@jerryzh168
Copy link
Contributor Author

@p4arth

jerryzh168 added a commit to jerryzh168/ao that referenced this issue Oct 18, 2024
Summary:
Following pytorch#988 we added TP support for int4_weight_only quantization in torchao
that's using TensorCoreTiledLayout

Addresses one work item in pytorch#988

Also clarified docs based on pytorch#386

Also restructructured the tests in test/dtypes/test_affine_quantized_tensor_parallel.py to not depend on
torchao/utils.py to reduce the jumps people have to do to understand what is tested

Test Plan:
python test/dtypes/test_affine_quantized_tensor_parallel.py

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit that referenced this issue Oct 19, 2024
* Add tensor parallelism support for int4_weight_only quantization

Summary:
Following #988 we added TP support for int4_weight_only quantization in torchao
that's using TensorCoreTiledLayout

Addresses one work item in #988

Also clarified docs based on #386

Also restructructured the tests in test/dtypes/test_affine_quantized_tensor_parallel.py to not depend on
torchao/utils.py to reduce the jumps people have to do to understand what is tested

Test Plan:
python test/dtypes/test_affine_quantized_tensor_parallel.py

Reviewers:

Subscribers:

Tasks:

Tags:

* typo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants