Skip to content

Commit

Permalink
Add torchchat quantizer (pytorch#897)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#897

This diff adds a quantizer for the new torchao kernels that is similar to the Int8DynActInt4WeightQuantizer quantizer in torchchat (imported from from torchao.quantization.quant_api).  See the draft torchchat PR (pytorch/torchchat#1070) for how this can integrate with torchchat's quantization API.

I confirmed that models quantized with this are compatible with eager, compile, AOTI, and export to ExecuTorch in torchchat.  They do not run on ExecuTorch because we still have not written an ExecuTorch kernel wrapper.

jerryzh168 this does not use the new subclass API, and this is something I'd like to discuss further with you.  I'll set up a sync with you this week, but I wanted to have some API on the table to ground the discussion.

We do not currently have the required C++ methods implemented to support the new subclass API (e.g., we cannot unpack the packed weights from python; they are instead unpacked inline in the kernel).  From a torchchat user's perspective, I do not think this is important, but I'd like to discuss further.

Reviewed By: digantdesai

Differential Revision: D62394341
  • Loading branch information
metascroy authored and facebook-github-bot committed Sep 24, 2024
1 parent 728d629 commit 6d5db05
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 316 deletions.
8 changes: 4 additions & 4 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

add_library(
kernel_aarch64
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release)
add_compile_options("-Wall" "-Werror")

include(CMakePrintHelpers)
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
include_directories(${TORCHAO_LIBRARIES})
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
include_directories(${TORCHAO_INCLUDE_DIRS})

add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)

include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake)

set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH")
string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
# LICENSE file in the root directory of this source tree.

SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../..
export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../..

export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
export CMAKE_OUT=/tmp/cmake-out/torchao
cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DPLATFORM="ATEN" \
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
-S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
-B ${CMAKE_OUT}
cmake --build ${CMAKE_OUT}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
# LICENSE file in the root directory of this source tree.

import copy
import glob

import sys

import torch
from torch_custom_op import (
linear_a8sz_w_lowbit_reference_impl,
replace_linear_with_quantized_linear,
)

sys.path.insert(0, "../../../../..")
from quant_api import Int8DynActIntxWeightQuantizer

libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
torch.ops.load_library(libs[0])

group_size = 256
m = 1
Expand All @@ -27,15 +33,15 @@

print("Quantizing random model")
quantized_model = copy.deepcopy(model)
quantized_model = quantized_model.eval()
replace_linear_with_quantized_linear(
quantized_model,
kwargs={
"group_size": group_size,
"nbit": nbit,
"has_weight_zeros": has_weight_zeros,
},
quantizer = Int8DynActIntxWeightQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=nbit,
groupsize=group_size,
has_weight_zeros=has_weight_zeros,
)
quantized_model = quantizer.quantize(quantized_model)
quantized_model = quantized_model.eval()

print("Creating random activations")
activations = torch.randn(m, k, dtype=torch.float32)
Expand All @@ -58,44 +64,3 @@
print("Running AOTI")
fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu")
fn(activations)


print("\nChecking correctness on layer 0")
linear = model[0]
quantized_linear = quantized_model[0]

with torch.no_grad():
result = quantized_linear(activations)
expected_result = linear_a8sz_w_lowbit_reference_impl(
linear.weight, activations, group_size, nbit, has_weight_zeros
)
non_quantized_result = linear(activations)


# Check that entries in result match entries in expected_result
num_mismatch_at_low_tol = 0
num_total = result.reshape(-1).shape[0]
for i in range(num_total):
actual_val = result.reshape(-1)[i]
expected_val = expected_result.reshape(-1)[i]
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

# If results are not close at a relaxed tolerance, exit with failure
if not torch.allclose(actual_val, expected_val, atol=1e-6):
assert False, "Correctness check failed"

# Assert at most 5% of entries are not close at a low tolerance
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
print(
"Correctness check passed. All results are close, and ",
(num_total - num_mismatch_at_low_tol),
"/",
num_total,
" entries are close at a low tolerance.",
)
print("Quantization errors:")
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
print("\tquantized_result[0:5]: ", result[0][0:5])
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,27 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy

import glob

import sys
import unittest

import torch
from torch_custom_op import (
linear_a8sz_w_lowbit_reference_impl,
replace_linear_with_quantized_linear,

sys.path.insert(0, "../../../../..")
from quant_api import (
_Int8DynActIntxWeightQuantizedLinearFallback,
Int8DynActIntxWeightQuantizer,
)
import copy

class TestTorchCustomOp(unittest.TestCase):
libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
torch.ops.load_library(libs[0])


class TestInt8DynActIntxWeightQuantizer(unittest.TestCase):
def test_accuracy(self):
group_size = 128
m = 1
Expand All @@ -22,24 +33,27 @@ def test_accuracy(self):
activations = torch.randn(m, k, dtype=torch.float32)
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])

for nbit in [2, 3, 4, 5]:
for has_weight_zeros in [False, True]:
for nbit in [1, 2, 3, 4, 5, 6, 7]:
for has_weight_zeros in [True, False]:
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
quantized_model = copy.deepcopy(model)
replace_linear_with_quantized_linear(
quantized_model,
kwargs={
"group_size": group_size,
"nbit": nbit,
"has_weight_zeros": has_weight_zeros,
},
quantizer = Int8DynActIntxWeightQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=nbit,
groupsize=group_size,
has_weight_zeros=has_weight_zeros,
)
quantized_model = quantizer.quantize(quantized_model)

with torch.no_grad():
result = quantized_model(activations)
expected_result = linear_a8sz_w_lowbit_reference_impl(
model[0].weight, activations, group_size, nbit, has_weight_zeros
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback()
reference_impl.quantize_and_pack_weights(
model[0].weight, nbit, group_size, has_weight_zeros
)

expected_result = reference_impl(activations)

num_mismatch_at_low_tol = 0
num_total = result.reshape(-1).shape[0]
for i in range(num_total):
Expand All @@ -50,7 +64,8 @@ def test_accuracy(self):
num_mismatch_at_low_tol += 1

# Assert at most 5% of entries are not close at a low tolerance
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05)

if __name__ == '__main__':
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 6d5db05

Please sign in to comment.