From a825fbb6176e3689c2aabd349b7a5fec6b7c5693 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 11 Mar 2024 12:14:23 -0700 Subject: [PATCH] Use uppercase letters for M, N and K. --- .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 32 +++++++++---------- .../js/quantization/matmul_nbits.h | 4 +-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 8b42e77900471..6d943ee676e83 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -7,12 +7,12 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; // TODO support quantization bits not equal to 4 export interface MatMulNBitsAttributes extends AttributeWithCacheKey { - k: number; - n: number; + K: number; + N: number; accuracyLevel: number; bits: number; blockSize: number; @@ -24,25 +24,25 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt } const a = inputs[0]; const aRank = a.dims.length; - if (a.dims[aRank - 1] !== attributes.k) { + if (a.dims[aRank - 1] !== attributes.K) { throw new Error('The last dim of input shape does not match the k value'); } - const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const nBlocksPerCol = Math.floor((attributes.K + attributes.blockSize - 1) / attributes.blockSize); const blobSize = attributes.blockSize / 8 * attributes.bits; const b = inputs[1]; - if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { + if (!ShapeUtil.areEqual(b.dims, [attributes.N, nBlocksPerCol, blobSize])) { throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); } const scales = inputs[2]; const scalesShape = scales.dims; - if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) { + if (ShapeUtil.size(scalesShape) !== attributes.N * nBlocksPerCol) { throw new Error('scales input size error.'); } if (inputs.length === 4) { const zeroPoints = inputs[3]; const zeroPointsShape = zeroPoints.dims; const expectedZeroPointsSize = - attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + attributes.bits > 4 ? (attributes.N * nBlocksPerCol) : attributes.N * Math.floor((nBlocksPerCol + 1) / 2); if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { throw new Error('zeroPoints input size error.'); } @@ -53,19 +53,19 @@ export const createMatMulNBitsProgramInfo = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const aRank = inputShape.length; - const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n); - const m = inputShape[aRank - 2]; + const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.N); + const M = inputShape[aRank - 2]; const blobSize = attributes.blockSize / 8 * attributes.bits; const blobSizeInWords = blobSize / 4; - const outputNumber = getMaxComponents(m); + const outputNumber = getMaxComponents(M); const components = 1; // getMaxComponents(attributes.n); - const aComponents = getMaxComponents(attributes.k); + const aComponents = getMaxComponents(attributes.K); const bComponents = getMaxComponents(blobSizeInWords); - const zComponents = 1; // getMaxComponents(attributes.n / 8); + const zComponents = 1; // getMaxComponents(attributes.N / 8); const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, - {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.K}, + {type: DataType.uint32, data: attributes.N}, {type: DataType.uint32, data: attributes.accuracyLevel}, {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -88,7 +88,7 @@ export const createMatMulNBitsProgramInfo = {name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} ]; - const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const nBlocksPerCol = Math.floor((attributes.K + attributes.blockSize - 1) / attributes.blockSize); const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const dequantizeArrayReturnType = (() => { switch (aComponents) { diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h index cca2c4757765b..356cdfd495e87 100644 --- a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -22,8 +22,8 @@ class MatMulNBits final : public JsKernel { ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), "Block size must be a power of 2 and greater than or equal to 16."); JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ - "k" : $1, - "n" : $2, + "K" : $1, + "N" : $2, "accuracyLevel" : $3, "bits" : $4, "blockSize" : $5