-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor Parallelism Support for AffineQuantizedTensor #988
Comments
@jerryzh168 should'nt this Issue be kept open until other points in the description are fixed ? |
Hi @jerryzh168, I had two questions
|
|
Summary: Following pytorch#988 we added TP support for int4_weight_only quantization in torchao that's using TensorCoreTiledLayout Addresses one work item in pytorch#988 Also clarified docs based on pytorch#386 Also restructructured the tests in test/dtypes/test_affine_quantized_tensor_parallel.py to not depend on torchao/utils.py to reduce the jumps people have to do to understand what is tested Test Plan: python test/dtypes/test_affine_quantized_tensor_parallel.py Reviewers: Subscribers: Tasks: Tags:
* Add tensor parallelism support for int4_weight_only quantization Summary: Following #988 we added TP support for int4_weight_only quantization in torchao that's using TensorCoreTiledLayout Addresses one work item in #988 Also clarified docs based on #386 Also restructructured the tests in test/dtypes/test_affine_quantized_tensor_parallel.py to not depend on torchao/utils.py to reduce the jumps people have to do to understand what is tested Test Plan: python test/dtypes/test_affine_quantized_tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: * typo
Recently we landed #939 to support tensor parallelism for int8 weight only quantization, another example: #785
now we can support tensor parallelism for other types of quantization as well.
Steps
1. Create test
Since we don't have many tests today, we can optimize for readability for now, so we can copy paste the test cases to a https://github.com/pytorch/ao/blob/main/test/dtypes/test_affine_quantized_tensor_parallel.py instead of inheriting from these test cases
For new tests you can follow
ao/test/dtypes/test_affine_quantized_tensor_parallel.py
Lines 133 to 153 in c87cc9b
2. Run the test
python test/dtypes/test_affine_quantized_tensor_parallel.py
3. Add support for missing ops until test passes
We'd expect people to add some slicing ops etc. to the corresponding TensorImpl tensor subclass
The text was updated successfully, but these errors were encountered: