From cf2fda5dde1ef5fa160925d1c982635bd481b457 Mon Sep 17 00:00:00 2001 From: Georg Kunze Date: Mon, 16 Sep 2024 19:27:30 +0200 Subject: [PATCH] update cnm-to-gpu conversion --- .../cinm-mlir/Conversion/CommonPatterns.h | 9 + cinnamon/justfile | 19 +- cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp | 238 +++++++----------- cinnamon/lib/Conversion/CommonPatterns.cpp | 45 ++++ cinnamon/samples/gemm_cnm.mlir | 97 +++---- cinnamon/samples/gemm_gpu.mlir | 99 ++++++++ .../cinm-vulkan-runner/cinm-vulkan-runner.cpp | 2 +- 7 files changed, 310 insertions(+), 199 deletions(-) create mode 100644 cinnamon/samples/gemm_gpu.mlir diff --git a/cinnamon/include/cinm-mlir/Conversion/CommonPatterns.h b/cinnamon/include/cinm-mlir/Conversion/CommonPatterns.h index 31896e4..58cc69b 100644 --- a/cinnamon/include/cinm-mlir/Conversion/CommonPatterns.h +++ b/cinnamon/include/cinm-mlir/Conversion/CommonPatterns.h @@ -6,6 +6,8 @@ #include #include +#include +#include #include #include #include @@ -37,4 +39,11 @@ struct ConvertCnmSetZeroToAffine : public OpConversionPattern { ConversionPatternRewriter &) const override; }; +SmallVector createAffineApply(OpBuilder &builder, Location loc, + AffineMap map, ValueRange values); + +void createMemrefSubviewCopy(OpBuilder &builder, Location loc, Value src, + Value dst, ArrayRef sliceShape, + ValueRange srcOffsets, ValueRange dstOffsets); + } // namespace mlir diff --git a/cinnamon/justfile b/cinnamon/justfile index b77a765..9128b6c 100644 --- a/cinnamon/justfile +++ b/cinnamon/justfile @@ -59,11 +59,28 @@ cinm-opt-help: (cinm-opt "--help") debug-cinm-opt *ARGS: gdb --args {{build_dir}}/bin/cinm-opt {{ARGS}} +cinm-to-cnm FILE *ARGS: ( + cinm-opt FILE + "--cinm-tiling" + "--affine-loop-unroll='unroll-full unroll-full-threshold=1'" + "--convert-cinm-to-cnm" + "--lower-affine" + "--one-shot-bufferize='bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map'" + "--convert-linalg-to-affine-loops" + "--lower-affine" + "--buffer-loop-hoisting" + "--buffer-hoisting" + "--cse" + ARGS +) + cnm-to-gpu FILE *ARGS: (cinm-opt FILE "--convert-cnm-to-gpu" ARGS) +cinm-to-gpu FILE *ARGS: (cinm-to-cnm FILE "--convert-cnm-to-gpu" ARGS) cinm-vulkan-runner FILE *ARGS: {{build_dir}}/bin/cinm-vulkan-runner {{FILE}} \ - --shared-libs=../llvm-project/build/lib/libvulkan-runtime-wrappers.so,../llvm-project/build/lib/libmlir_runner_utils.so.17 \ + --shared-libs={{llvm_prefix}}/lib/libvulkan-runtime-wrappers.so,{{llvm_prefix}}/lib/libmlir_runner_utils.so \ + --entry-point-result=void \ {{ARGS}} genBench NAME: (doNinja "cinm-opt") diff --git a/cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp b/cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp index 0001257..681fc82 100644 --- a/cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp +++ b/cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -41,61 +42,28 @@ namespace mlir::cnm { namespace cnmtogpu { -SmallVector getBufferTypeShape(cnm::BufferType bufferType) { - SmallVector shape{bufferType.getShape()}; - while (shape.size() < bufferType.getWorkgroupShape().size()) { - shape.insert(shape.begin(), 1); - } - return shape; -} - MemRefType convertCnmBufferToMemRefType(cnm::BufferType bufferType) { - ArrayRef workgroupShape = bufferType.getWorkgroupShape(); - SmallVector shape = getBufferTypeShape(bufferType); - for (size_t i = 0; i < workgroupShape.size(); i++) { - shape[i] *= workgroupShape[i]; - } + SmallVector shape{bufferType.getWorkgroupShape()}; + shape.append(bufferType.getShape().begin(), bufferType.getShape().end()); return MemRefType::get(shape, bufferType.getElementType()); } -SmallVector createCalculateScatterIndices(Location loc, - OpBuilder &builder, - const AffineMap &scatterMap, - ValueRange indices, - BufferType bufferType) { - SmallVector bufferIndices; - ArrayRef workgroupShape = bufferType.getWorkgroupShape(); - for (size_t i = 0; i < workgroupShape.size(); i++) { - const AffineExpr indexExpr = - scatterMap.getResult(i) * workgroupShape[i] + - scatterMap.getResult(workgroupShape.size() + i); - bufferIndices.push_back(builder.create( - loc, AffineMap::get(indices.size(), 0, indexExpr), indices)); - } - return bufferIndices; -} - void convertLaunchParameter(ConversionPatternRewriter &rewriter, Location loc, Value buffer, ValueRange threadIds, - ArrayRef workgroupShape, BlockArgument arg) { const BufferType bufferType = buffer.getType().dyn_cast(); - const SmallVector bufferShape = getBufferTypeShape(bufferType); + const MemRefType memrefType = convertCnmBufferToMemRefType(bufferType); const Value source = createOrFoldUnrealizedConversionCast( loc, rewriter, convertCnmBufferToMemRefType(bufferType), rewriter.getRemappedValue(buffer)); - const SmallVector staticOffsets(workgroupShape.size(), - ShapedType::kDynamic); - const SmallVector staticSizes{bufferShape}; - const SmallVector staticStrides(workgroupShape.size(), 1); - - SmallVector dynamicOffsets; - for (size_t i = 0; i < workgroupShape.size(); i++) { - const AffineExpr indexExpr = rewriter.getAffineDimExpr(0) * bufferShape[i]; - dynamicOffsets.push_back(rewriter.create( - loc, AffineMap::get(1, 0, indexExpr), ValueRange{threadIds[i]})); + SmallVector staticOffsets(memrefType.getRank(), 0); + SmallVector staticSizes{memrefType.getShape()}; + const SmallVector staticStrides(memrefType.getRank(), 1); + for (unsigned i = 0; i < threadIds.size(); i++) { + staticSizes[i] = 1; + staticOffsets[i] = ShapedType::kDynamic; } const Type resultType = memref::SubViewOp::inferRankReducedResultType( @@ -104,7 +72,7 @@ void convertLaunchParameter(ConversionPatternRewriter &rewriter, Location loc, const Value subview = rewriter - .create(loc, resultType, source, dynamicOffsets, + .create(loc, resultType, source, threadIds, ValueRange{}, ValueRange{}, staticOffsets, staticSizes, staticStrides) .getResult(); @@ -129,8 +97,15 @@ struct ConvertCnmAllocToGPU : public OpConversionPattern { LogicalResult matchAndRewrite(cnm::AllocOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, convertCnmBufferToMemRefType(op.getType())); + Type asyncToken; + ValueRange asyncDependencies; + ValueRange dynamicSizes; + ValueRange symbolOperands; + UnitAttr hostShared; + + rewriter.replaceOpWithNewOp( + op, convertCnmBufferToMemRefType(op.getType()), asyncToken, + asyncDependencies, dynamicSizes, symbolOperands, hostShared); return success(); } }; @@ -141,44 +116,29 @@ struct ConvertCnmScatterToGPU : public OpConversionPattern { LogicalResult matchAndRewrite(cnm::ScatterOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { + const WorkgroupType workgroupType = op.getWg().getType(); + const ArrayRef workgroupShape = workgroupType.getShape(); const cnm::BufferType bufferType = op.getOperandTypes()[1].dyn_cast(); - const SmallVector bufferShape = getBufferTypeShape(bufferType); - - Value memref = rewriter.getRemappedValue(op.getOperand(1)); - memref = createOrFoldUnrealizedConversionCast( - op.getLoc(), rewriter, convertCnmBufferToMemRefType(bufferType), - memref); - - const Value tensor = op.getOperand(0); - const RankedTensorType tensorType = - tensor.getType().dyn_cast(); - - SmallVector loops; - SmallVector indices; - - for (int64_t size : tensorType.getShape()) { - affine::AffineForOp loop = - rewriter.create(op.getLoc(), 0, size, 1); - loops.push_back(loop); - indices.push_back(loop.getBody()->getArgument(0)); - rewriter.setInsertionPointToStart(loop.getBody()); - } - - // inner most loop body - const AffineMap scatterMap = op.getScatterMap(); - SmallVector bufferIndices = createCalculateScatterIndices( - op.getLoc(), rewriter, scatterMap, indices, bufferType); - const Value element = - rewriter.create(op.getLoc(), tensor, indices); - rewriter.create(op.getLoc(), element, memref, - bufferIndices); - - // replace token with const 0 - rewriter.setInsertionPointAfter(loops[0]); - rewriter.replaceOpWithNewOp(op, 0); + Value src = rewriter.getRemappedValue(op.getOperand(0)); + Value dst = rewriter.getRemappedValue(op.getOperand(1)); + dst = createOrFoldUnrealizedConversionCast( + op.getLoc(), rewriter, convertCnmBufferToMemRefType(bufferType), dst); + + const SmallVector loopSteps(workgroupShape.size(), 1); + createNestedAffineForLoops( + rewriter, op.getLoc(), workgroupShape, loopSteps, ValueRange{}, + [&](OpBuilder &builder, Location loc, ValueRange indices, + ValueRange) -> SmallVector { + const SmallVector mappedIndices = + createAffineApply(builder, loc, op.getScatterMap(), indices); + createMemrefSubviewCopy(builder, loc, src, dst, bufferType.getShape(), + mappedIndices, indices); + return {}; + }); + rewriter.eraseOp(op); return success(); } }; @@ -189,55 +149,29 @@ struct ConvertCnmGatherToGPU : public OpConversionPattern { LogicalResult matchAndRewrite(cnm::GatherOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { + const WorkgroupType workgroupType = op.getWg().getType(); + const ArrayRef workgroupShape = workgroupType.getShape(); const cnm::BufferType bufferType = op.getOperandTypes()[0].dyn_cast(); - const SmallVector bufferShape = getBufferTypeShape(bufferType); - - Value memref = rewriter.getRemappedValue(op.getOperand(0)); - memref = createOrFoldUnrealizedConversionCast( - op.getLoc(), rewriter, convertCnmBufferToMemRefType(bufferType), - memref); - - const RankedTensorType tensorType = - op.getResultTypes()[0].cast(); - const Value tensor = rewriter.create( - op.getLoc(), tensorType.getShape(), tensorType.getElementType()); - - SmallVector loops; - SmallVector indices; - - for (int64_t size : tensorType.getShape()) { - const Value iterArg = - loops.empty() ? tensor : loops.back().getBody()->getArgument(1); - affine::AffineForOp loop = rewriter.create( - op.getLoc(), 0, size, 1, SmallVector{iterArg}); - indices.push_back(loop.getBody()->getArgument(0)); - - if (!loops.empty()) { - rewriter.create(op.getLoc(), loop.getResult(0)); - } - - rewriter.setInsertionPointToStart(loop.getBody()); - loops.push_back(loop); - } - // inner most loop body - const Value iterArg = loops.back().getBody()->getArgument(1); - - const AffineMap gatherMap = op.getGatherMap(); - SmallVector bufferIndices = createCalculateScatterIndices( - op.getLoc(), rewriter, gatherMap, indices, bufferType); - const Value element = - rewriter.create(op.getLoc(), memref, bufferIndices); - const Value result = rewriter.create(op.getLoc(), element, - iterArg, indices); - rewriter.create(op.getLoc(), result); - - // replace token with const 0 - rewriter.setInsertionPointAfter(loops[0]); - const Value token = rewriter.create(op.getLoc(), 0); - rewriter.replaceOp(op, {loops.front().getResult(0), token}); + Value src = rewriter.getRemappedValue(op.getOperand(0)); + src = createOrFoldUnrealizedConversionCast( + op.getLoc(), rewriter, convertCnmBufferToMemRefType(bufferType), src); + Value dst = rewriter.getRemappedValue(op.getOperand(2)); + + const SmallVector loopSteps(workgroupShape.size(), 1); + createNestedAffineForLoops( + rewriter, op.getLoc(), workgroupShape, loopSteps, ValueRange{}, + [&](OpBuilder &builder, Location loc, ValueRange indices, + ValueRange) -> SmallVector { + const SmallVector mappedIndices = + createAffineApply(builder, loc, op.getGatherMap(), indices); + createMemrefSubviewCopy(builder, loc, src, dst, bufferType.getShape(), + indices, mappedIndices); + return {}; + }); + rewriter.eraseOp(op); return success(); } }; @@ -252,12 +186,11 @@ struct ConvertCnmLaunchToGPU : public OpConversionPattern { const ArrayRef workgroupShape = workgroupType.getShape(); const Value one = rewriter.create(op.getLoc(), 1); - const Value gridSizeX = one, gridSizeY = one, gridSizeZ = one; - const Value blockSizeX = - rewriter.create(op.getLoc(), workgroupShape[0]); - const Value blockSizeY = - rewriter.create(op.getLoc(), workgroupShape[1]); - const Value blockSizeZ = one; + SmallVector launchDimensions(6, one); + for (size_t i = 0; i < workgroupShape.size(); i++) { + launchDimensions[i] = rewriter.create( + op.getLoc(), workgroupShape[i]); + } const Value dynamicSharedMemorySize; const Type asyncTokenType; @@ -266,23 +199,26 @@ struct ConvertCnmLaunchToGPU : public OpConversionPattern { const TypeRange privateAttributions; gpu::LaunchOp launchOp = rewriter.create( - op.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, - blockSizeZ, dynamicSharedMemorySize, asyncTokenType, asyncDependencies, - workgroupAttributions, privateAttributions); - - const SmallVector threadIds{ - launchOp.getThreadIds().x, - launchOp.getThreadIds().y, - launchOp.getThreadIds().z, + op.getLoc(), launchDimensions[0], launchDimensions[1], + launchDimensions[2], launchDimensions[3], launchDimensions[4], + launchDimensions[5], dynamicSharedMemorySize, asyncTokenType, + asyncDependencies, workgroupAttributions, privateAttributions); + + const SmallVector allThreadIds{ + launchOp.getBlockIds().x, launchOp.getBlockIds().y, + launchOp.getBlockIds().z, launchOp.getThreadIds().x, + launchOp.getThreadIds().y, launchOp.getThreadIds().z, }; + const ValueRange usedThreadIds = + ValueRange{allThreadIds}.take_front(workgroupShape.size()); rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); // convert cnm.buffer parameters to memref subviews - int64_t i = 0; + size_t i = 0; for (const Value &buffer : op.getParams()) { - convertLaunchParameter(rewriter, op.getLoc(), buffer, threadIds, - workgroupShape, op.getBody().getArgument(i++)); + convertLaunchParameter(rewriter, op.getLoc(), buffer, usedThreadIds, + op.getBody().getArgument(i++)); } launchOp.getBody().front().getOperations().splice( @@ -312,21 +248,25 @@ void populateCnmToGPUFinalTypeConversions(TypeConverter &typeConverter) { [&](cnm::BufferType bufferType) -> std::optional { return cnmtogpu::convertCnmBufferToMemRefType(bufferType); }); + + typeConverter.addConversion([&](cnm::WorkgroupType t) -> std::optional { + return IndexType::get(t.getContext()); + }); } void populateCnmToGPUConversionPatterns(RewritePatternSet &patterns, - MLIRContext *context) { + MLIRContext *ctx) { patterns .add(context); + cnmtogpu::ConvertCnmTerminatorToGPU>(ctx); } struct ConvertCnmToGPUPass : public ::impl::ConvertCnmToGPUPassBase { void runOnOperation() final { - TypeConverter converter{}; + TypeConverter converter; populateCnmToGPUFinalTypeConversions(converter); const auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { @@ -341,15 +281,7 @@ struct ConvertCnmToGPUPass populateReconcileUnrealizedCastsPatterns(patterns); ConversionTarget target(getContext()); - // target.addIllegalDialect(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - + target.addIllegalDialect(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( diff --git a/cinnamon/lib/Conversion/CommonPatterns.cpp b/cinnamon/lib/Conversion/CommonPatterns.cpp index ac7833d..40bf147 100644 --- a/cinnamon/lib/Conversion/CommonPatterns.cpp +++ b/cinnamon/lib/Conversion/CommonPatterns.cpp @@ -112,4 +112,49 @@ LogicalResult ConvertCnmSetZeroToAffine::matchAndRewrite( return success(); } +SmallVector createAffineApply(OpBuilder &builder, Location loc, + AffineMap map, ValueRange values) { + SmallVector result; + for (unsigned i = 0; i < map.getNumResults(); i++) { + result.push_back( + builder.create(loc, map.getSubMap({i}), values)); + } + return result; +} + +void createMemrefSubviewCopy(OpBuilder &builder, Location loc, Value src, + Value dst, ArrayRef sliceShape, + ValueRange srcOffsets, ValueRange dstOffsets) { + MemRefType srcType = src.getType().cast(); + MemRefType dstType = dst.getType().cast(); + + SmallVector srcStaticOffsets(srcType.getRank(), 0); + SmallVector srcStaticSizes{srcType.getShape()}; + SmallVector srcStaticStrides(srcType.getRank(), 1); + for (unsigned i = 0; i < srcOffsets.size(); i++) { + srcStaticSizes[i] = 1; + srcStaticOffsets[i] = ShapedType::kDynamic; + } + + SmallVector dstStaticOffsets(dstType.getRank(), 0); + SmallVector dstStaticSizes{dstType.getShape()}; + SmallVector dstStaticStrides(dstType.getRank(), 1); + for (unsigned i = 0; i < dstOffsets.size(); i++) { + dstStaticSizes[i] = 1; + dstStaticOffsets[i] = ShapedType::kDynamic; + } + + const Type sliceType = memref::SubViewOp::inferRankReducedResultType( + sliceShape, dstType, dstStaticOffsets, dstStaticSizes, dstStaticStrides); + + const Value src_slice = builder.create( + loc, sliceType, src, srcOffsets, ValueRange{}, ValueRange{}, + srcStaticOffsets, srcStaticSizes, srcStaticStrides); + const Value dst_slice = builder.create( + loc, sliceType, dst, dstOffsets, ValueRange{}, ValueRange{}, + dstStaticOffsets, dstStaticSizes, dstStaticStrides); + + builder.create(loc, src_slice, dst_slice); +} + } // namespace mlir diff --git a/cinnamon/samples/gemm_cnm.mlir b/cinnamon/samples/gemm_cnm.mlir index 0c48894..3c7da9a 100644 --- a/cinnamon/samples/gemm_cnm.mlir +++ b/cinnamon/samples/gemm_cnm.mlir @@ -1,53 +1,62 @@ -#map = affine_map<(d0, d1) -> (d0 floordiv 16, (d0 mod 16) floordiv 2, d0 mod 2, d1)> -#map1 = affine_map<(d0) -> (d0 floordiv 16, (d0 mod 16) floordiv 2, d0 mod 2)> -#map2 = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2)> +#map = affine_map<(d0, d1) -> (0)> +#map1 = affine_map<(d0, d1) -> (d1 mod 128)> +#map2 = affine_map<(d0, d1) -> (d1 floordiv 128, d1 mod 128)> module { + memref.global "private" constant @__constant_1x128xi32 : memref<1x128xi32> = dense<0> {alignment = 64 : i64} func.func @main() { - %0 = tensor.empty() : tensor<1024x1024xi32> - %1 = affine.for %arg0 = 0 to 1024 step 64 iter_args(%arg1 = %0) -> (tensor<1024x1024xi32>) { - %2 = affine.for %arg2 = 0 to 1024 step 64 iter_args(%arg3 = %arg1) -> (tensor<1024x1024xi32>) { - %extracted_slice = tensor.extract_slice %0[%arg0, 0] [64, 1024] [1, 1] : tensor<1024x1024xi32> to tensor<64x1024xi32> - %extracted_slice_0 = tensor.extract_slice %0[0, %arg2] [1024, 64] [1, 1] : tensor<1024x1024xi32> to tensor<1024x64xi32> - %generated = tensor.generate { - ^bb0(%arg4: index, %arg5: index): - %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg4, 0] [1, 1024] [1, 1] : tensor<64x1024xi32> to tensor<1024xi32> - %extracted_slice_2 = tensor.extract_slice %extracted_slice_0[0, %arg5] [1024, 1] [1, 1] : tensor<1024x64xi32> to tensor<1024xi32> - %cst = arith.constant dense<0> : tensor - %cst_3 = arith.constant dense<0> : tensor<64xi32> - %cst_4 = arith.constant dense<[64, 16]> : tensor<2xi64> - %reshape = tensor.reshape %extracted_slice_1(%cst_4) : (tensor<1024xi32>, tensor<2xi64>) -> tensor<64x16xi32> - %reshape_5 = tensor.reshape %extracted_slice_2(%cst_4) : (tensor<1024xi32>, tensor<2xi64>) -> tensor<64x16xi32> - %3 = cnm.workgroup : !cnm.workgroup<4x8x2> - %4 = cnm.alloc() for %3 : !cnm.buffer<16xi32 on 4x8x2, level 0> - %5 = cnm.scatter %reshape into %4[#map] of %3 : tensor<64x16xi32> into !cnm.buffer<16xi32 on 4x8x2, level 0> - %6 = cnm.alloc() for %3 : !cnm.buffer<16xi32 on 4x8x2, level 0> - %7 = cnm.scatter %reshape_5 into %6[#map] of %3 : tensor<64x16xi32> into !cnm.buffer<16xi32 on 4x8x2, level 0> - %8 = cnm.alloc() for %3 : !cnm.buffer - %9 = cnm.scatter %cst_3 into %8[#map1] of %3 : tensor<64xi32> into !cnm.buffer - %10 = cnm.launch %3 in(%4, %6 : !cnm.buffer<16xi32 on 4x8x2, level 0>, !cnm.buffer<16xi32 on 4x8x2, level 0>) out(%8 : !cnm.buffer) on !cnm.workgroup<4x8x2> { - ^bb0(%arg6: memref<16xi32>, %arg7: memref<16xi32>, %arg8: memref): - linalg.reduce ins(%arg6, %arg7 : memref<16xi32>, memref<16xi32>) outs(%arg8 : memref) dimensions = [0] - (%in: i32, %in_6: i32, %init: i32) { - %11 = arith.muli %in, %in_6 : i32 - %12 = arith.addi %11, %init : i32 - linalg.yield %12 : i32 - } + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xi32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xi32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xi32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x128xi32> + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<128x32xi32> + %0 = scf.for %arg0 = %c0 to %c1024 step %c1 iter_args(%arg1 = %alloc_1) -> (memref<1024x1024xi32>) { + %1 = scf.for %arg2 = %c0 to %c1024 step %c128 iter_args(%arg3 = %arg1) -> (memref<1024x1024xi32>) { + %2 = memref.get_global @__constant_1x128xi32 : memref<1x128xi32> + memref.copy %2, %alloc_2 : memref<1x128xi32> to memref<1x128xi32> + %3 = scf.for %arg4 = %c0 to %c1024 step %c32 iter_args(%arg5 = %alloc_2) -> (memref<1x128xi32>) { + %subview_4 = memref.subview %alloc[%arg0, %arg4] [1, 32] [1, 1] : memref<1024x1024xi32> to memref<1x32xi32, strided<[1024, 1], offset: ?>> + %subview_5 = memref.subview %alloc_0[%arg4, %arg2] [32, 128] [1, 1] : memref<1024x1024xi32> to memref<32x128xi32, strided<[1024, 1], offset: ?>> + %4 = cnm.workgroup : !cnm.workgroup<1x128> + scf.for %arg6 = %c0 to %c128 step %c1 { + scf.for %arg7 = %c0 to %c32 step %c1 { + %8 = memref.load %subview_5[%arg7, %arg6] : memref<32x128xi32, strided<[1024, 1], offset: ?>> + memref.store %8, %alloc_3[%arg6, %arg7] : memref<128x32xi32> + } } - %output, %token = cnm.gather %8[#map2] of %3 : !cnm.buffer into tensor<64xi32> - %reduced = linalg.reduce ins(%output : tensor<64xi32>) outs(%cst : tensor) dimensions = [0] - (%in: i32, %init: i32) { - %11 = arith.addi %in, %init : i32 - linalg.yield %11 : i32 + %5 = cnm.alloc() for %4 : !cnm.buffer<32xi32 on 1x128, level 0> + %6 = cnm.alloc() for %4 : !cnm.buffer<32xi32 on 1x128, level 0> + %7 = cnm.alloc() for %4 : !cnm.buffer + cnm.scatter %subview_4 into %5[#map] of %4 : memref<1x32xi32, strided<[1024, 1], offset: ?>> into !cnm.buffer<32xi32 on 1x128, level 0> + cnm.scatter %alloc_3 into %6[#map1] of %4 : memref<128x32xi32> into !cnm.buffer<32xi32 on 1x128, level 0> + cnm.scatter %arg5 into %7[#map2] of %4 : memref<1x128xi32> into !cnm.buffer + cnm.launch %4 in(%5, %6 : !cnm.buffer<32xi32 on 1x128, level 0>, !cnm.buffer<32xi32 on 1x128, level 0>) out(%7 : !cnm.buffer) on !cnm.workgroup<1x128> { + ^bb0(%arg6: memref<32xi32>, %arg7: memref<32xi32>, %arg8: memref): + %c0_6 = arith.constant 0 : index + %c32_7 = arith.constant 32 : index + %c1_8 = arith.constant 1 : index + scf.for %arg9 = %c0_6 to %c32_7 step %c1_8 { + %8 = memref.load %arg6[%arg9] : memref<32xi32> + %9 = memref.load %arg7[%arg9] : memref<32xi32> + %10 = memref.load %arg8[] : memref + %11 = arith.muli %8, %9 : i32 + %12 = arith.addi %11, %10 : i32 + memref.store %12, %arg8[] : memref } - %extracted = tensor.extract %reduced[] : tensor - tensor.yield %extracted : i32 - } : tensor<64x64xi32> - %inserted_slice = tensor.insert_slice %generated into %arg3[%arg0, %arg2] [64, 64] [1, 1] : tensor<64x64xi32> into tensor<1024x1024xi32> - affine.yield %inserted_slice : tensor<1024x1024xi32> + } + cnm.gather %7[#map2] of %4 into %arg5 : !cnm.buffer into memref<1x128xi32> + scf.yield %arg5 : memref<1x128xi32> + } + %subview = memref.subview %arg3[%arg0, %arg2] [1, 128] [1, 1] : memref<1024x1024xi32> to memref<1x128xi32, strided<[1024, 1], offset: ?>> + memref.copy %3, %subview : memref<1x128xi32> to memref<1x128xi32, strided<[1024, 1], offset: ?>> + scf.yield %arg3 : memref<1024x1024xi32> } - affine.yield %2 : tensor<1024x1024xi32> + scf.yield %1 : memref<1024x1024xi32> } return } } - diff --git a/cinnamon/samples/gemm_gpu.mlir b/cinnamon/samples/gemm_gpu.mlir new file mode 100644 index 0000000..b8484e9 --- /dev/null +++ b/cinnamon/samples/gemm_gpu.mlir @@ -0,0 +1,99 @@ +#map = affine_map<(d0, d1) -> (0)> +#map1 = affine_map<(d0, d1) -> (d1 mod 128)> +#map2 = affine_map<(d0, d1) -> (d1 floordiv 128)> +module { + memref.global "private" constant @__constant_1x128xi32 : memref<1x128xi32> = dense<0> {alignment = 64 : i64} + func.func @main() { + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xi32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xi32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xi32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x128xi32> + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<128x32xi32> + %0 = scf.for %arg0 = %c0 to %c1024 step %c1 iter_args(%arg1 = %alloc_1) -> (memref<1024x1024xi32>) { + %1 = scf.for %arg2 = %c0 to %c1024 step %c128 iter_args(%arg3 = %arg1) -> (memref<1024x1024xi32>) { + %2 = memref.get_global @__constant_1x128xi32 : memref<1x128xi32> + memref.copy %2, %alloc_2 : memref<1x128xi32> to memref<1x128xi32> + %3 = scf.for %arg4 = %c0 to %c1024 step %c32 iter_args(%arg5 = %alloc_2) -> (memref<1x128xi32>) { + %subview_4 = memref.subview %alloc[%arg0, %arg4] [1, 32] [1, 1] : memref<1024x1024xi32> to memref<1x32xi32, strided<[1024, 1], offset: ?>> + %subview_5 = memref.subview %alloc_0[%arg4, %arg2] [32, 128] [1, 1] : memref<1024x1024xi32> to memref<32x128xi32, strided<[1024, 1], offset: ?>> + %c0_6 = arith.constant 0 : index + scf.for %arg6 = %c0 to %c128 step %c1 { + scf.for %arg7 = %c0 to %c32 step %c1 { + %4 = memref.load %subview_5[%arg7, %arg6] : memref<32x128xi32, strided<[1024, 1], offset: ?>> + memref.store %4, %alloc_3[%arg6, %arg7] : memref<128x32xi32> + } + } + %memref = gpu.alloc () : memref<1x128x32xi32> + %memref_7 = gpu.alloc () : memref<1x128x32xi32> + %memref_8 = gpu.alloc () : memref<1x128xi32> + affine.for %arg6 = 0 to 1 { + affine.for %arg7 = 0 to 128 { + %4 = affine.apply #map(%arg6, %arg7) + %subview_12 = memref.subview %subview_4[%4, 0] [1, 32] [1, 1] : memref<1x32xi32, strided<[1024, 1], offset: ?>> to memref<32xi32, strided<[1], offset: ?>> + %subview_13 = memref.subview %memref[%arg6, %arg7, 0] [1, 1, 32] [1, 1, 1] : memref<1x128x32xi32> to memref<32xi32, strided<[1], offset: ?>> + memref.copy %subview_12, %subview_13 : memref<32xi32, strided<[1], offset: ?>> to memref<32xi32, strided<[1], offset: ?>> + } + } + affine.for %arg6 = 0 to 1 { + affine.for %arg7 = 0 to 128 { + %4 = affine.apply #map1(%arg6, %arg7) + %subview_12 = memref.subview %alloc_3[%4, 0] [1, 32] [1, 1] : memref<128x32xi32> to memref<32xi32, strided<[1], offset: ?>> + %subview_13 = memref.subview %memref_7[%arg6, %arg7, 0] [1, 1, 32] [1, 1, 1] : memref<1x128x32xi32> to memref<32xi32, strided<[1], offset: ?>> + memref.copy %subview_12, %subview_13 : memref<32xi32, strided<[1], offset: ?>> to memref<32xi32, strided<[1], offset: ?>> + } + } + affine.for %arg6 = 0 to 1 { + affine.for %arg7 = 0 to 128 { + %4 = affine.apply #map2(%arg6, %arg7) + %5 = affine.apply #map1(%arg6, %arg7) + %subview_12 = memref.subview %arg5[%4, %5] [1, 1] [1, 1] : memref<1x128xi32> to memref> + %subview_13 = memref.subview %memref_8[%arg6, %arg7] [1, 1] [1, 1] : memref<1x128xi32> to memref> + memref.copy %subview_12, %subview_13 : memref> to memref> + } + } + %c1_9 = arith.constant 1 : index + %c1_10 = arith.constant 1 : index + %c128_11 = arith.constant 128 : index + gpu.launch blocks(%arg6, %arg7, %arg8) in (%arg12 = %c1_10, %arg13 = %c128_11, %arg14 = %c1_9) threads(%arg9, %arg10, %arg11) in (%arg15 = %c1_9, %arg16 = %c1_9, %arg17 = %c1_9) { + %subview_12 = memref.subview %memref[%arg6, %arg7, 0] [1, 1, 32] [1, 1, 1] : memref<1x128x32xi32> to memref<32xi32, strided<[1], offset: ?>> + %subview_13 = memref.subview %memref_7[%arg6, %arg7, 0] [1, 1, 32] [1, 1, 1] : memref<1x128x32xi32> to memref<32xi32, strided<[1], offset: ?>> + %subview_14 = memref.subview %memref_8[%arg6, %arg7] [1, 1] [1, 1] : memref<1x128xi32> to memref> + %c0_15 = arith.constant 0 : index + %c32_16 = arith.constant 32 : index + %c1_17 = arith.constant 1 : index + scf.for %arg18 = %c0_15 to %c32_16 step %c1_17 { + %4 = memref.load %subview_12[%arg18] : memref<32xi32, strided<[1], offset: ?>> + %5 = memref.load %subview_13[%arg18] : memref<32xi32, strided<[1], offset: ?>> + %6 = memref.load %subview_14[] : memref> + %7 = arith.muli %4, %5 : i32 + %8 = arith.addi %7, %6 : i32 + memref.store %8, %subview_14[] : memref> + } + gpu.terminator + } + affine.for %arg6 = 0 to 1 { + affine.for %arg7 = 0 to 128 { + %4 = affine.apply #map2(%arg6, %arg7) + %5 = affine.apply #map1(%arg6, %arg7) + %subview_12 = memref.subview %memref_8[%arg6, %arg7] [1, 1] [1, 1] : memref<1x128xi32> to memref> + %subview_13 = memref.subview %arg5[%4, %5] [1, 1] [1, 1] : memref<1x128xi32> to memref> + memref.copy %subview_12, %subview_13 : memref> to memref> + } + } + scf.yield %arg5 : memref<1x128xi32> + } + %subview = memref.subview %arg3[%arg0, %arg2] [1, 128] [1, 1] : memref<1024x1024xi32> to memref<1x128xi32, strided<[1024, 1], offset: ?>> + memref.copy %3, %subview : memref<1x128xi32> to memref<1x128xi32, strided<[1024, 1], offset: ?>> + scf.yield %arg3 : memref<1024x1024xi32> + } + scf.yield %1 : memref<1024x1024xi32> + } + return + } +} + diff --git a/cinnamon/tools/cinm-vulkan-runner/cinm-vulkan-runner.cpp b/cinnamon/tools/cinm-vulkan-runner/cinm-vulkan-runner.cpp index 05f5b89..ef56a1c 100644 --- a/cinnamon/tools/cinm-vulkan-runner/cinm-vulkan-runner.cpp +++ b/cinnamon/tools/cinm-vulkan-runner/cinm-vulkan-runner.cpp @@ -58,7 +58,7 @@ static LogicalResult runMLIRPasses(Operation *op) { // memref.load passManager.addPass(createLowerAffinePass()); // affine.apply -> arith ops - passManager.addPass(createCnmSPIRVAttachTargetAttributePass( + passManager.addPass(cnm::createCnmSPIRVAttachTargetAttributePass( cnm::CnmSPIRVAttachTargetAttributePassOptions{ .spirvCapabilities = {"Shader"}, .spirvExtensions = {"SPV_KHR_storage_buffer_storage_class"},