Skip to content

Commit

Permalink
Add torchchat quantizer
Browse files Browse the repository at this point in the history
Differential Revision: D62394341

Pull Request resolved: pytorch#897
  • Loading branch information
metascroy authored and weifengpy committed Sep 26, 2024
1 parent a05a40f commit 334891b
Show file tree
Hide file tree
Showing 8 changed files with 432 additions and 351 deletions.
8 changes: 4 additions & 4 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

add_library(
kernel_aarch64
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release)
add_compile_options("-Wall" "-Werror")

include(CMakePrintHelpers)
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
include_directories(${TORCHAO_LIBRARIES})
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
include_directories(${TORCHAO_INCLUDE_DIRS})

add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)

include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake)

set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH")
string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
# LICENSE file in the root directory of this source tree.

SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../..
export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../..

export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
export CMAKE_OUT=/tmp/cmake-out/torchao
cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DPLATFORM="ATEN" \
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
-S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
-B ${CMAKE_OUT}
cmake --build ${CMAKE_OUT}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@
# LICENSE file in the root directory of this source tree.

import copy
import glob
import os

import sys

import torch
from torch_custom_op import (
linear_a8sz_w_lowbit_reference_impl,
replace_linear_with_quantized_linear,

sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
)
from quant_api import Int8DynActIntxWeightQuantizer

libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
torch.ops.load_library(libs[0])

group_size = 256
m = 1
Expand All @@ -27,15 +36,15 @@

print("Quantizing random model")
quantized_model = copy.deepcopy(model)
quantized_model = quantized_model.eval()
replace_linear_with_quantized_linear(
quantized_model,
kwargs={
"group_size": group_size,
"nbit": nbit,
"has_weight_zeros": has_weight_zeros,
},
quantizer = Int8DynActIntxWeightQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=nbit,
groupsize=group_size,
has_weight_zeros=has_weight_zeros,
)
quantized_model = quantizer.quantize(quantized_model)
quantized_model = quantized_model.eval()

print("Creating random activations")
activations = torch.randn(m, k, dtype=torch.float32)
Expand All @@ -58,44 +67,3 @@
print("Running AOTI")
fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu")
fn(activations)


print("\nChecking correctness on layer 0")
linear = model[0]
quantized_linear = quantized_model[0]

with torch.no_grad():
result = quantized_linear(activations)
expected_result = linear_a8sz_w_lowbit_reference_impl(
linear.weight, activations, group_size, nbit, has_weight_zeros
)
non_quantized_result = linear(activations)


# Check that entries in result match entries in expected_result
num_mismatch_at_low_tol = 0
num_total = result.reshape(-1).shape[0]
for i in range(num_total):
actual_val = result.reshape(-1)[i]
expected_val = expected_result.reshape(-1)[i]
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

# If results are not close at a relaxed tolerance, exit with failure
if not torch.allclose(actual_val, expected_val, atol=1e-6):
assert False, "Correctness check failed"

# Assert at most 5% of entries are not close at a low tolerance
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
print(
"Correctness check passed. All results are close, and ",
(num_total - num_mismatch_at_low_tol),
"/",
num_total,
" entries are close at a low tolerance.",
)
print("Quantization errors:")
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
print("\tquantized_result[0:5]: ", result[0][0:5])
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy

import glob
import os

import sys
import unittest

import torch

sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
)
from quant_api import (
_Int8DynActIntxWeightQuantizedLinearFallback,
Int8DynActIntxWeightQuantizer,
)

libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
if len(libs) == 0:
print(
"Could not find library lowbit_op_aten; please run `sh build_custom_op.sh` to build the library. A slow fallback kernel will be used instaed."
)
else:
torch.ops.load_library(libs[0])


class TestInt8DynActIntxWeightQuantizer(unittest.TestCase):
def test_accuracy(self):
group_size = 128
m = 1
n = 1071
k = 4096
activations = torch.randn(m, k, dtype=torch.float32)
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])

for nbit in [1, 2, 3, 4, 5, 6, 7]:
for has_weight_zeros in [True, False]:
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
quantized_model = copy.deepcopy(model)
quantizer = Int8DynActIntxWeightQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=nbit,
groupsize=group_size,
has_weight_zeros=has_weight_zeros,
)
quantized_model = quantizer.quantize(quantized_model)

with torch.no_grad():
result = quantized_model(activations)
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback()
reference_impl.quantize_and_pack_weights(
model[0].weight, nbit, group_size, has_weight_zeros
)
expected_result = reference_impl(activations)

num_mismatch_at_low_tol = 0
num_total = result.reshape(-1).shape[0]
for i in range(num_total):
actual_val = result.reshape(-1)[i]
expected_val = expected_result.reshape(-1)[i]
self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6))
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

# Assert at most 5% of entries are not close at a low tolerance
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 334891b

Please sign in to comment.