Skip to content

Commit

Permalink
Use uppercase letters for M, N and K.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Mar 11, 2024
1 parent f186004 commit a825fbb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
32 changes: 16 additions & 16 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';

// TODO support quantization bits not equal to 4
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
k: number;
n: number;
K: number;
N: number;
accuracyLevel: number;
bits: number;
blockSize: number;
Expand All @@ -24,25 +24,25 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
}
const a = inputs[0];
const aRank = a.dims.length;
if (a.dims[aRank - 1] !== attributes.k) {
if (a.dims[aRank - 1] !== attributes.K) {
throw new Error('The last dim of input shape does not match the k value');
}
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const nBlocksPerCol = Math.floor((attributes.K + attributes.blockSize - 1) / attributes.blockSize);
const blobSize = attributes.blockSize / 8 * attributes.bits;
const b = inputs[1];
if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) {
if (!ShapeUtil.areEqual(b.dims, [attributes.N, nBlocksPerCol, blobSize])) {
throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize');
}
const scales = inputs[2];
const scalesShape = scales.dims;
if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) {
if (ShapeUtil.size(scalesShape) !== attributes.N * nBlocksPerCol) {
throw new Error('scales input size error.');
}
if (inputs.length === 4) {
const zeroPoints = inputs[3];
const zeroPointsShape = zeroPoints.dims;
const expectedZeroPointsSize =
attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2);
attributes.bits > 4 ? (attributes.N * nBlocksPerCol) : attributes.N * Math.floor((nBlocksPerCol + 1) / 2);
if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) {
throw new Error('zeroPoints input size error.');
}
Expand All @@ -53,19 +53,19 @@ export const createMatMulNBitsProgramInfo =
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const aRank = inputShape.length;
const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n);
const m = inputShape[aRank - 2];
const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.N);
const M = inputShape[aRank - 2];
const blobSize = attributes.blockSize / 8 * attributes.bits;
const blobSizeInWords = blobSize / 4;
const outputNumber = getMaxComponents(m);
const outputNumber = getMaxComponents(M);
const components = 1; // getMaxComponents(attributes.n);
const aComponents = getMaxComponents(attributes.k);
const aComponents = getMaxComponents(attributes.K);
const bComponents = getMaxComponents(blobSizeInWords);
const zComponents = 1; // getMaxComponents(attributes.n / 8);
const zComponents = 1; // getMaxComponents(attributes.N / 8);
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.K},
{type: DataType.uint32, data: attributes.N}, {type: DataType.uint32, data: attributes.accuracyLevel},
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
];
const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand All @@ -88,7 +88,7 @@ export const createMatMulNBitsProgramInfo =
{name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
];
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const nBlocksPerCol = Math.floor((attributes.K + attributes.blockSize - 1) / attributes.blockSize);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const dequantizeArrayReturnType = (() => {
switch (aComponents) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/js/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class MatMulNBits final : public JsKernel {
ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)),
"Block size must be a power of 2 and greater than or equal to 16.");
JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({
"k" : $1,
"n" : $2,
"K" : $1,
"N" : $2,
"accuracyLevel" : $3,
"bits" : $4,
"blockSize" : $5
Expand Down

0 comments on commit a825fbb

Please sign in to comment.