diff --git a/.gitignore b/.gitignore index cb8919a..e6ffbae 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .cache .vscode llvm -upmem \ No newline at end of file +upmem +upmem-src diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/CMakeLists.txt b/cinnamon/include/cinm-mlir/Dialect/Cnm/CMakeLists.txt index c369955..1d10775 100644 --- a/cinnamon/include/cinm-mlir/Dialect/Cnm/CMakeLists.txt +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/CMakeLists.txt @@ -9,3 +9,4 @@ add_custom_target(CnmIncGen) # Attributes, Dialect, Operations and Types. add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/IR/CnmOps.td b/cinnamon/include/cinm-mlir/Dialect/Cnm/IR/CnmOps.td index 3e093c9..a72b7c3 100644 --- a/cinnamon/include/cinm-mlir/Dialect/Cnm/IR/CnmOps.td +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/IR/CnmOps.td @@ -31,6 +31,11 @@ def WorkgroupOp : Cnm_Op<"workgroup", []> { let arguments = (ins); let results = (outs WorkgroupType:$result); + let builders = [ + OpBuilder<(ins "ArrayRef":$shape), + "build($_builder, $_state, WorkgroupType::get($_builder.getContext(), shape));"> + ]; + let assemblyFormat = "attr-dict `:` qualified(type($result))"; } @@ -56,6 +61,11 @@ def AllocOp : Cnm_Op<"alloc", [InferWorkgroupTypeFromBuffer<"result", "wg">]> { let arguments = (ins WorkgroupType:$wg); let results = (outs BufferType:$result); + let builders = [ + OpBuilder<(ins "ArrayRef":$bufShape, "Type":$elementTy, "Value":$wg, CArg<"int64_t", "0">:$level), + "build($_builder, $_state, BufferType::get(bufShape, elementTy, wg.getType().cast().getShape(), level), wg);"> + ]; + let extraClassDeclaration = [{ @@ -124,7 +134,7 @@ def LaunchOp: Cnm_Op<"launch", [AttrSizedOperandSegments, IsolatedFromAbove, Sin let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; - let assemblyFormat = "$wg `in` `(` $inBuffers `:` type($inBuffers) `)` `out` `(` $outBuffers `:` type($outBuffers) `)` attr-dict `on` qualified(type($wg)) $body"; + let assemblyFormat = "$wg `ins` `(` $inBuffers `:` type($inBuffers) `)` `outs` `(` $outBuffers `:` type($outBuffers) `)` attr-dict `on` qualified(type($wg)) $body"; let extraClassDeclaration = [{ SmallVector getParams() { @@ -135,6 +145,107 @@ def LaunchOp: Cnm_Op<"launch", [AttrSizedOperandSegments, IsolatedFromAbove, Sin }]; } + +def ComputeOp: Cnm_Op<"compute", [IsolatedFromAbove, + SingleBlockImplicitTerminator<"TerminatorOp">, + AttrSizedOperandSegments]> { + let summary = "Transformable operation"; + let description = [{}]; + + let arguments = (ins DenseI64ArrayAttr:$workgroupShape, + I64Attr:$numInputs, + Variadic:$symbol_bindings, + Variadic:$buffers, + AffineMapArrayAttr:$affineMaps + ); + + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$body); + + let builders = [ + OpBuilder<(ins + CArg<"ArrayRef">:$workgroupShape, + CArg<"ValueRange">:$allBuffers, + CArg<"uint64_t">:$numInputs, + CArg<"ArrayRef">:$affineMaps, + CArg<"ValueRange", "{}">:$symbol_bindings + )>, + OpBuilder<(ins + CArg<"ArrayRef">:$workgroupShape, + CArg<"ValueRange">:$inputs, + CArg<"ValueRange">:$inits, + CArg<"ArrayRef">:$affineMaps, + CArg<"ValueRange", "{}">:$symbol_bindings + )> + ]; + + let skipDefaultBuilders = 1; + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + SmallVector getInBuffers() { + if (getNumInputs() == 0) + return {}; + SmallVector result; + result.reserve(getNumInputs()); + auto begin = getBuffers().begin(); + result.append(begin, begin + getNumInputs()); + return result; + } + + SmallVector getOutBuffers() { + if (getNumOutputs() == 0) + return {}; + SmallVector result; + result.reserve(getNumOutputs()); + auto begin = getBuffers().begin() + getNumInputs(); + result.append(getBuffers().begin() + getNumInputs(), getBuffers().end()); + return result; + } + + uint64_t getNumOutputs() { + return getBuffers().size() - getNumInputs(); + } + + SmallVector getInMaps() { + if (getNumInputs() == 0) + return {}; + SmallVector result; + result.reserve(getNumInputs()); + auto begin = getAffineMaps().getAsValueRange().begin(); + result.append(begin, begin + getNumInputs()); + return result; + } + + SmallVector getOutMaps() { + if (getNumOutputs() == 0) + return {}; + SmallVector result; + result.reserve(getNumOutputs()); + auto it = getAffineMaps().getAsValueRange(); + result.append(it.begin() + getNumInputs(), it.end()); + return result; + } + + template::value> + SmallVector getAffineMapsVec() { + SmallVector result(getAffineMaps().getAsValueRange()); + return result; + } + + MutableArrayRef getKernelArgs() { + return getBody().getArguments(); + } + + uint64_t getNumSymbols() { + return getSymbolBindings().size(); + } + + }]; +} + + def TerminatorOp: Cnm_Op<"terminator", [Terminator]> { let summary = "Terminates an `upmem.launch` operator region."; let description = [{}]; @@ -144,4 +255,5 @@ def TerminatorOp: Cnm_Op<"terminator", [Terminator]> { let assemblyFormat = "attr-dict"; } + #endif diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CMakeLists.txt b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..97aa183 --- /dev/null +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CMakeLists.txt @@ -0,0 +1,13 @@ + +set(LLVM_TARGET_DEFINITIONS CnmTransformOps.td) +mlir_tablegen(CnmTransformOps.h.inc -gen-op-decls) +mlir_tablegen(CnmTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRCnmTransformOpsIncGen) +add_dependencies(CnmIncGen MLIRCnmTransformOpsIncGen) + + +set(LLVM_TARGET_DEFINITIONS TransformPass.td) +mlir_tablegen(TransformPass.h.inc -gen-pass-decls -name CnmTransformInterpreter) +add_public_tablegen_target(CnmTransformInterpreterPassIncGen) +add_dependencies(CnmIncGen CnmTransformInterpreterPassIncGen) +add_mlir_doc(Passes CnmTransformInterpreter ./ -gen-pass-doc) diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.h b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.h new file mode 100644 index 0000000..627bf19 --- /dev/null +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.h @@ -0,0 +1,24 @@ + +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include +#include + +#define GET_OP_CLASSES +#include + + +namespace mlir::cnm { + +void registerTransformDialectExtension(::mlir::DialectRegistry ®istry); + + +} // namespace mlir::cnm \ No newline at end of file diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.td b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.td new file mode 100644 index 0000000..9acbacd --- /dev/null +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.td @@ -0,0 +1,88 @@ + + +include "mlir/Dialect/Transform/IR/MatchInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + + +def CnmExpandDimOp : Op, + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ +::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::cnm::ComputeOp op, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; + + let arguments = (ins Transform_ConcreteOpType<"cnm.compute">:$target, + I64Attr:$dim, + I64Attr:$factor); + let results = (outs); + let assemblyFormat = + "$target `dim` $dim `by` `factor` $factor attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::cnm::ComputeOp op, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def CnmSwapDimsOp : Op, + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + + }]; + + let arguments = (ins Transform_ConcreteOpType<"cnm.compute">:$target, + I64Attr:$dim0, + I64Attr:$dim1); + let results = (outs); + let assemblyFormat = + "$target `,` $dim0 `,` $dim1 attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::cnm::ComputeOp op, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def CnmPeelRightOp : Op, + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + + }]; + + let arguments = (ins Transform_ConcreteOpType<"cnm.compute">:$target); + let results = (outs); + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::cnm::ComputeOp op, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} \ No newline at end of file diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformPass.h b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformPass.h new file mode 100644 index 0000000..b374175 --- /dev/null +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformPass.h @@ -0,0 +1,13 @@ + + +#include +#include +#include + +namespace mlir::cnm { + +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "cinm-mlir/Dialect/Cnm/TransformOps/TransformPass.h.inc" + +} // namespace mlir::cnm \ No newline at end of file diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/TransformPass.td b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/TransformPass.td new file mode 100644 index 0000000..601d0e2 --- /dev/null +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/TransformPass.td @@ -0,0 +1,45 @@ +//===- Passes.td - Cnm dialect passes ---------------------*- tablegen -*-===// +// +// This is the definitions file for the Cnm dialect transform passes. +// +//===----------------------------------------------------------------------===// + +#ifndef CNM_TRANSFORM_PASSES +#define CNM_TRANSFORM_PASSES + +include "mlir/Pass/PassBase.td" + +def CnmApplyTransformScriptPass : Pass<"cnm-apply-transform", "ModuleOp"> { + let summary = "Apply a transform dialect script to cnm compute ops"; + let description = [{}]; +// let constructor = "cnm::createApplyTransformScriptPass()"; + let dependentDialects = ["transform::TransformDialect"]; + let options = [ + Option<"debugPayloadRootTag", "debug-payload-root-tag", "std::string", + /*default=*/[{""}], + "Select the operation with 'transform.target_tag' attribute having " + "the given value as payload IR root. If empty select the pass " + "anchor operation as the payload IR root.">, + Option<"disableExpensiveChecks", "disable-expensive-checks", "bool", + "false", + "Disable expensive checks in the interpreter for a faster run.">, + Option<"transformFileName", "transform-file-name", "std::string", "", + "File name of the transform script">, + Option<"entryPoint", "entry-point", "std::string", + /*default=*/[{transform::TransformDialect::kTransformEntryPointSymbolName.str()}], + "Entry point of the pass pipeline.">, + Option<"debugTransformRootTag", "debug-transform-root-tag", "std::string", + /*default=*/[{""}], + "Select the operation with 'transform.target_tag' attribute having " + "the given value as container IR for top-level transform ops. This " + "allows user control on what transformation to apply. If empty, " + "select the container of the top-level transform op." + >, + ListOption<"transformLibraryPaths", "transform-library-paths", "std::string", + "Optional paths to files with modules that should be " + "merged into the transform module to provide the " + "definitions of external named sequences."> + ]; +} + +#endif // CNM_TRANSFORM_PASSES diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/CnmComputeTransforms.h b/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/CnmComputeTransforms.h new file mode 100644 index 0000000..cab3f41 --- /dev/null +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/CnmComputeTransforms.h @@ -0,0 +1,185 @@ + +#include +#include +#include +#include +#include +namespace mlir::cnm { + +// Note: in-place transformations don't use a listener, +// others do. + +/// Reshape the workgroup by turning a dimension D at index `dim` +/// into two dimensions, of size `factor` and `D/factor`. Fails if +/// D is not divisible by factor. +/// This is an in-place transformation. +/// +/// ExpandWorkgroupDim(dim=0, factor=2) turns +/// ``` +/// cnm.compute +/// ins(%a[(i)->(i)]: memref<1024xi32>) +/// outs(%o[(i)->(i)]: memref<1024xi32>) +/// on hierarchy<1024> +/// do (%a1: memref, %o1: memref) { +/// %x = memref.load %a1[] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[] +/// } +/// ``` +/// into +/// ``` +/// cnm.compute +/// ins(%a[(i,j)->(i*512+j)]: memref<1024xi32>) +/// outs(%o[(i,j)->(i*512+j)]: memref<1024xi32>) +/// on hierarchy<2x512> +/// do (%a1: memref, %o1: memref) { +/// %x = memref.load %a1[] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[] +/// } +/// ``` +LogicalResult expandWorkshoupDim(cnm::ComputeOp compute, uint64_t dim, + int64_t factor); + +LogicalResult swapWorkgroupDims(cnm::ComputeOp compute, uint64_t dim0, uint64_t dim1); + +/// Turn the leftmost dimension of the workgroup into an outer parallel loop. +/// This transformation might delete the op if the workgroup has only a single +/// dimension. +/// ``` +/// cnm.compute +/// ins(%a[(i,j)->(i*512+j)]: memref<1024xi32>) +/// outs(%o[(i,j)->(i*512+j)]: memref<1024xi32>) +/// on hierarchy<2x512> +/// do (%a1: memref, %o1: memref) { +/// %x = memref.load %a1[] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[] +/// } +/// ``` +/// into +/// ``` +/// affine.parallel (%i) = (0) to (2) { +/// %as = memref.subview %a[%i][512][1] +/// %os = memref.subview %o[%i][512][1] +/// cnm.compute +/// ins(%as[(j)->(j)]: memref<512xi32>) +/// outs(%os[(j)->(j)]: memref<512xi32>) +/// on hierarchy<512> +/// do (%a1: memref, %o1: memref) { +/// %x = memref.load %a1[] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[] +/// } +/// } +/// ``` +FailureOr> +peelLeft(cnm::ComputeOp compute, OpBuilder::Listener *listener = nullptr); + +/// Turn the rightmost dimension of the workgroup into a +/// parallel loop within the kernel. +/// This transformation fails if the workgroup +/// has only a single dimension. +/// This is an in-place transformation. +/// ``` +/// cnm.compute +/// ins(%as[(i,j)->(i,j)]: memref<2x512xi32>) +/// outs(%os[(i,j)->(i,j)]: memref<2x512xi32>) +/// on hierarchy<2x512> +/// do (%a1: memref, %o1: memref) { +/// %x = memref.load %a1[] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[] +/// } +/// ``` +/// into +/// ``` +/// cnm.compute +/// ins(%as[(i)->(i)]: memref<2x512xi32>) +/// outs(%os[(i)->(i)]: memref<2x512xi32>) +/// on hierarchy<2> +/// do (%a1: memref<512xi32>, +/// %o1: memref<512xi32>) { +/// affine.parallel (%i) = (0) to (512) { +/// %x = memref.load %a1[%i] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[%i] +/// } +/// } +/// ``` +/// Broadcast: +/// To support broadcast semantics, we ignore those buffers that do not +/// use the last dimension of the workgroup in their scatter maps. +/// ``` +/// cnm.compute +/// ins(%arg0[(i, j) -> ()]: memref<1024xi32>) +/// outs(%arg1[(i, j) -> (i, j)]: memref<2x512xi32>) +/// on hierarchy<2x512> +/// do (%a1: memref<1024xi32>, %o1: memref) { +/// affine.for %i = 0 to 1024 { +/// %0 = memref.load %a1[%i] : memref<1024xi32> +/// %1 = memref.load %o1[] : memref +/// %2 = arith.addi %0, %1 : i32 +/// memref.store %2, %o1[] : memref +/// } +/// cnm.terminator +/// } +/// ``` +/// into +/// ``` +/// memref<2x512xi32> cnm.compute +/// ins(%arg0[(i) -> ()]: memref<1024xi32>) +/// outs(%r[(i) -> (i)]: memref<2x512xi32>) +/// on hierarchy<2> +/// do (%a1: memref<1024xi32>, %o1: memref<512xi32>) { +/// affine.for %j = 0 to 512 { +/// affine.for %i = 0 to 1024 { +/// %0 = memref.load %a1[%i] : memref<1024xi32> +/// %1 = memref.load %o1[%j] : memref<512xi32> +/// %2 = arith.addi %0, %1 : i32 +/// memref.store %2, %o1[%j] : memref<512xi32> +/// } +/// } +/// cnm.terminator +/// } +/// ``` +LogicalResult peelRight(cnm::ComputeOp compute); + +/// Reshape the inputs to so that they match the workgroup shape. +/// Currently we support that only if the structuring of the affine +/// maps into the new shape produces an identity map. +/// +/// This is not an in-place transformation. +/// +/// ``` +/// cnm.compute +/// ins(%a[(i,j)->(i*512+j)]: memref<1024xi32>) +/// outs(%o[(i,j)->(i*512+j)]: memref<1024xi32>) +/// on hierarchy<2x512> +/// do (%a1: memref, %o1: memref) { +/// %x = memref.load %a1[] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[] +/// } +/// ``` +/// into +/// ``` +/// %as = memref.reshape %a: memref<1024xi32> to memref<2x512xi32> +/// %os = memref.reshape %o: memref<1024xi32> to memref<2x512xi32> +/// cnm.compute +/// ins(%as[(i,j)->(i,j)]: memref<2x512xi32>) +/// outs(%os[(i,j)->(i,j)]: memref<2x512xi32>) +/// on hierarchy<2x512> +/// do (%a1: memref, %o1: memref) { +/// %x = memref.load %a1[] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[] +/// } +/// ``` +LogicalResult normalizeInputs(cnm::ComputeOp compute, + OpBuilder::Listener *listener = nullptr); + +/// Lower a cnm.compute to lower level cnm ops +void lowerComputeToLaunch(cnm::ComputeOp op, + OpBuilder::Listener *listener = nullptr); +} // namespace mlir::cnm \ No newline at end of file diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/Passes.h b/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/Passes.h index 75d19b7..b9be255 100644 --- a/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/Passes.h +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/Passes.h @@ -5,7 +5,6 @@ #pragma once #include "mlir/Pass/Pass.h" - namespace mlir { namespace cnm { diff --git a/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/Passes.td b/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/Passes.td index c5792b2..2c3d2d0 100644 --- a/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/Passes.td +++ b/cinnamon/include/cinm-mlir/Dialect/Cnm/Transforms/Passes.td @@ -74,4 +74,12 @@ def CnmHoistWorkgroupsPass : Pass<"cnm-hoist-workgroups", "func::FuncOp"> { } +def CnmLowerComputePass : Pass<"cnm-lower-compute", "func::FuncOp"> { + let summary = "Lower cnm.compute into other cnm ops."; + let description = [{ + + }]; + let dependentDialects = []; +} + #endif // CNM_TRANSFORM_PASSES diff --git a/cinnamon/include/cinm-mlir/Utils/CinmUtils.h b/cinnamon/include/cinm-mlir/Utils/CinmUtils.h index 5bc67d6..269fbcb 100644 --- a/cinnamon/include/cinm-mlir/Utils/CinmUtils.h +++ b/cinnamon/include/cinm-mlir/Utils/CinmUtils.h @@ -1,5 +1,6 @@ #include +#include #include namespace mlir { @@ -13,8 +14,8 @@ bool scatteredMemrefIsContiguous(TypedValue value, /// Simplify an affine map given static upper bounds on the inputs. /// This is used to simplify even more the affine maps on the CNM and UPMEM -/// levels, given knowledge of the workgroup shape. That makes the generated code -/// simpler, and gives more opportunities for broadcasting. +/// levels, given knowledge of the workgroup shape. That makes the generated +/// code simpler, and gives more opportunities for broadcasting. AffineMap simplifyAffineMapWithBounds(AffineMap map, llvm::ArrayRef dimSizes); } // namespace mlir diff --git a/cinnamon/lib/Dialect/Cinm/TransformOps/CnmTransformExtension.cpp b/cinnamon/lib/Dialect/Cinm/TransformOps/CnmTransformExtension.cpp new file mode 100644 index 0000000..e69de29 diff --git a/cinnamon/lib/Dialect/Cnm/CMakeLists.txt b/cinnamon/lib/Dialect/Cnm/CMakeLists.txt index 9f57627..59e31cb 100644 --- a/cinnamon/lib/Dialect/Cnm/CMakeLists.txt +++ b/cinnamon/lib/Dialect/Cnm/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) \ No newline at end of file diff --git a/cinnamon/lib/Dialect/Cnm/IR/CnmOps.cpp b/cinnamon/lib/Dialect/Cnm/IR/CnmOps.cpp index 19577a9..b78d2d6 100644 --- a/cinnamon/lib/Dialect/Cnm/IR/CnmOps.cpp +++ b/cinnamon/lib/Dialect/Cnm/IR/CnmOps.cpp @@ -6,7 +6,14 @@ #include #include +#include +#include +#include +#include +#include +#include #include +#include #include #include @@ -17,8 +24,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -61,6 +70,334 @@ ::mlir::LogicalResult GatherOp::inferReturnTypeComponents( return failure(); } +static ParseResult parseAffineMapInlineOrNot(OpAsmParser &parser, + Attribute &affineMapAttr) { + + if (failed(parser.parseCustomAttributeWithFallback( + affineMapAttr, Type(), [&](Attribute &result, Type) { + AffineMap inlineMap; + if (parser.parseAffineMap(inlineMap)) + return failure(); + result = AffineMapAttr::get(inlineMap); + return success(); + }))) + return failure(); + if (!affineMapAttr.isa()) + return parser.emitError( + parser.getCurrentLocation(), + "invalid kind of attribute specified, expected affine map"); + return success(); +} +static ParseResult parseComputeOperand(OpAsmParser &parser, + Attribute &affineMapAttr, + OperationState &result, + bool canBeResult) { + OpAsmParser::UnresolvedOperand operand; + Type type; + + if (parser.parseOperand(operand) || parser.parseLSquare() || + parseAffineMapInlineOrNot(parser, affineMapAttr) || + parser.parseRSquare() || parser.parseColonType(type) || + parser.resolveOperand(operand, type, result.operands)) { + return failure(); + } + // a tensor result + if (canBeResult && type.isa()) { + result.addTypes(type); + } + + return success(); +} + +static ParseResult +parseComputeOperandList(OpAsmParser &parser, StringRef kw, + llvm::SmallVectorImpl &affineMaps, + OperationState &result, bool canBeResult = false) { + + if (parser.parseKeyword(kw) || parser.parseLParen() || + parser.parseCommaSeparatedList([&]() -> ParseResult { + return parseComputeOperand(parser, affineMaps.emplace_back(), result, + canBeResult); + }) || + parser.parseRParen()) + return failure(); + return success(); +} + +ParseResult ComputeOp::parse(OpAsmParser &parser, OperationState &result) { + /* + cnm.launch + (symbols [%O1, %O2])? + ins(%as[(i)->(i)]: memref<2x512xi32>) + outs(%os[(i)->(i)]: memref<2x512xi32>) + on hierarchy<2> + do (%a1: memref<512xi32>, + %o1: memref<512xi32>) { + affine.parallel (%i) = (0) to (512) { + %x = memref.load %a1[%i] + %t2 = arith.muli %x, 2 + memref.store %t2, %o1[%i] + } + } + */ + int numSymbols = 0; + if (succeeded(parser.parseOptionalKeyword("symbols"))) { + SmallVector symbolBindings; + if (parser.parseOperandList(symbolBindings, + OpAsmParser::Delimiter::Square) || + parser.resolveOperands(symbolBindings, + parser.getBuilder().getIndexType(), + result.operands)) + return failure(); + + numSymbols = symbolBindings.size(); + } + + SmallVector affineMaps; + if (parseComputeOperandList(parser, "ins", affineMaps, result)) + return failure(); + const int64_t numInputs = affineMaps.size(); + if (parseComputeOperandList(parser, "outs", affineMaps, result, + /*canBeResult=*/true)) + return failure(); + const int numBuffers = affineMaps.size(); + result.addAttribute( + getNumInputsAttrName(result.name), + IntegerAttr::get(IntegerType::get(result.getContext(), 64), numInputs)); + result.addAttribute(getAffineMapsAttrName(result.name), + ArrayAttr::get(result.getContext(), affineMaps)); + + result.addAttribute( + getOperandSegmentSizesAttrName(result.name), + parser.getBuilder().getDenseI32ArrayAttr({numSymbols, numBuffers})); + + SmallVector workgroupDimensions; + if (parser.parseKeyword("on") || parser.parseKeyword("hierarchy") || + parser.parseLess() || + parser.parseDimensionList(workgroupDimensions, false, false) || + parser.parseGreater()) + return failure(); + + result.addAttribute( + getWorkgroupShapeAttrName(result.name), + DenseI64ArrayAttr::get(result.getContext(), workgroupDimensions)); + + SmallVector args; + if (parser.parseKeyword("do") || + parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, + /*allowType=*/true)) { + return failure(); + } + auto ®ion = *result.addRegion(); + if (parser.parseRegion(region, args)) { + return failure(); + } + ComputeOp::ensureTerminator(region, parser.getBuilder(), result.location); + + // todo results, bufferization + + return success(); +} + +void ComputeOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, + ArrayRef workgroupShape, ValueRange allBuffers, + uint64_t numInputs, ArrayRef affineMaps, + ValueRange symbolBindings) { + if (numInputs > allBuffers.size()) { + mlir::emitError(state.location, "Invalid number of inputs ") + << numInputs << " > " << allBuffers.size(); + return; + } + + build(builder, state, workgroupShape, allBuffers.slice(0, numInputs), + allBuffers.slice(numInputs, allBuffers.size() - numInputs), affineMaps, + symbolBindings); +} + +void ComputeOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, + ArrayRef workgroupShape, ValueRange inputs, + ValueRange inits, ArrayRef affineMaps, + ValueRange symbolBindings) { + + state.addOperands(symbolBindings); + state.addOperands(inputs); + state.addOperands(inits); + + state.addAttribute(getOperandSegmentSizesAttrName(state.name), + builder.getDenseI32ArrayAttr( + {static_cast(symbolBindings.size()), + static_cast(inputs.size() + inits.size())})); + + state.addAttribute(getNumInputsAttrName(state.name), + builder.getI64IntegerAttr(inputs.size())); + state.addAttribute(getWorkgroupShapeAttrName(state.name), + builder.getDenseI64ArrayAttr(workgroupShape)); + state.addAttribute(getAffineMapsAttrName(state.name), + builder.getAffineMapArrayAttr(affineMaps)); + + auto &entry = state.addRegion()->emplaceBlock(); + for (auto [i, buf] : llvm::enumerate( + llvm::drop_begin(state.operands, symbolBindings.size()))) { + + if (buf.getType().isa()) { + auto bufTy = buf.getType().cast(); + auto bufShape = bufTy.getShape(); + if (i < affineMaps.size()) { + auto map = affineMaps[i]; + auto argRank = bufShape.size() - map.getNumResults(); + if (argRank >= 0) { + auto argShape = bufShape.slice(map.getNumResults()); + auto argTy = MemRefType::get(argShape, bufTy.getElementType()); + entry.addArgument(argTy, state.location); + } else { + mlir::emitError(state.location, "Buffer of type ") + << buf.getType() << " cannot be addressed by map " + << AffineMapAttr::get(map); + } + } + // tensor result + if (i >= inputs.size() && buf.getType().isa()) + state.addTypes(buf.getType()); + } + } + ComputeOp::ensureTerminator(*state.regions[0], builder, state.location); +} + +void ComputeOp::print(OpAsmPrinter &out) { + bool first = true; + out.increaseIndent(); + out.printNewline(); + auto syms = getSymbolBindings(); + if (!syms.empty()) { + out << "symbols ["; + out.printOperands(syms); + out << "]"; + out.printNewline(); + } + out << "ins("; + for (auto [buf, map, i] : + llvm::zip(getBuffers(), getAffineMaps(), llvm::seq(0UL, 100000UL))) { + if (i == getNumInputs()) { + out << ")"; + out.printNewline(); + out << "outs("; + first = true; + } + if (!first) { + out << ", "; + } + first = false; + + out.printOperand(buf); + out << "["; + map.cast().getValue().print(out.getStream()); + // out.printAttributeWithoutType(map); + out << "] : "; + out.printType(buf.getType()); + } + out << ") "; + out.printNewline(); + out << "on hierarchy<"; + out.printDimensionList(getWorkgroupShape()); + out << ">"; + out.printNewline(); + out << "do ("; + llvm::interleaveComma(getBody().getArguments(), out, + [&](auto arg) { out.printRegionArgument(arg); }); + out << ") "; + out.printRegion(getBody(), false, false); + out.decreaseIndent(); +} + +InFlightDiagnostic emitNiceError(Operation *op, Location loc, + const Twine &message) { + InFlightDiagnostic diag = mlir::emitError(loc, message); + if (op->getContext()->shouldPrintOpOnDiagnostic()) { + diag.attachNote(op->getLoc()) + .append("see current operation: ") + .appendOp(*op, OpPrintingFlags().printGenericOpForm()); + } + return diag; +} + +LogicalResult ComputeOp::verify() { + if (getWorkgroupShape().empty()) { + return emitOpError("has empty workgroup shape"); + } + if (getAffineMaps().size() != getBuffers().size()) { + return emitOpError("affine map count does not match in/out buffer count (") + << getAffineMaps().size() << " != " << getBuffers().size() << ")"; + } + auto args = getBody().getArguments(); + if (args.size() != getBuffers().size()) { + return emitOpError( + "kernel argument count does not match in/out buffer count (") + << args.size() << " != " << getBuffers().size() << ")"; + } + + // compute op may be partially bufferized + SmallVector tensorArgs; + for (auto [i, buf] : llvm::enumerate(getOutBuffers())) { + if (buf.getType().isa()) { + tensorArgs.push_back(buf.getType()); + } else if (!buf.getType().isa()) { + return emitOpError("out argument #") + << i << " should be a tensor or memref"; + } + } + if (tensorArgs != getResultTypes()) { + return emitOpError("tensor results do not match tensor arguments"); + } + + const auto symbolCount = getSymbolBindings().size(); + + for (auto [arg, buf, map, i] : llvm::zip( + args, getBuffers(), getAffineMaps().getAsValueRange(), + llvm::seq(0UL, 10000000UL))) { + if (!arg.getType().isa()) + return emitNiceError(*this, arg.getLoc(), "kernel argument #") + << i << " should be a memref"; + + if (map.getNumDims() != getWorkgroupShape().size()) + return emitOpError("map for argument #") + << i << " should have " << getWorkgroupShape().size() + << " input dimensions, got " << map.getNumDims(); + + if (map.getNumSymbols() != symbolCount) + return emitOpError("map for argument #") + << i << " should have " << symbolCount << " input symbols, got " + << map.getNumSymbols(); + + auto argTy = arg.getType().cast(); + auto inputTy = buf.getType().cast(); + if (argTy.getElementType() != inputTy.getElementType()) + return emitNiceError(*this, arg.getLoc(), "Kernel argument #") + << i << " should have element type " << inputTy.getElementType(); + + auto argShape = argTy.getShape(); + auto bufShape = inputTy.getShape(); + + if (argShape.size() > bufShape.size()) + return emitNiceError(*this, arg.getLoc(), "Kernel argument #") + << i << " should have fewer than " << bufShape.size() + << " dimensions, got " << argShape.size(); + + if (map.getNumResults() + argShape.size() != bufShape.size()) + return emitNiceError(*this, arg.getLoc(), + "Buffer, map, and kernel argument #") + << i << " are incompatible"; + + if (bufShape.slice(map.getNumResults()) != argShape) { + return emitNiceError(*this, arg.getLoc(), "Kernel argument #") + << i << "shape should be suffix of corresponding buffer shape, (" + << bufShape.slice(map.getNumResults()) << " != " << argShape + << ")"; + } + } + + return success(); +} + LogicalResult LaunchOp::verify() { auto bodyArgs = getBody().getArguments(); auto operands = getParams(); diff --git a/cinnamon/lib/Dialect/Cnm/TransformOps/CMakeLists.txt b/cinnamon/lib/Dialect/Cnm/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..c955508 --- /dev/null +++ b/cinnamon/lib/Dialect/Cnm/TransformOps/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(CnmTransformDialectExtension + CnmTransformOps.cpp + CnmTransformExtension.cpp + CnmApplyTransformPass.cpp + +DEPENDS + CnmIncGen + CnmTransformPassesIncGen + +LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRAffineDialect + MLIRLinalgDialect + MLIRTransformUtils + MLIRTransformDialect + MLIRTransformDialectUtils +) diff --git a/cinnamon/lib/Dialect/Cnm/TransformOps/CnmApplyTransformPass.cpp b/cinnamon/lib/Dialect/Cnm/TransformOps/CnmApplyTransformPass.cpp new file mode 100644 index 0000000..8fc1f41 --- /dev/null +++ b/cinnamon/lib/Dialect/Cnm/TransformOps/CnmApplyTransformPass.cpp @@ -0,0 +1,44 @@ + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include +#include +#include +#include +#include + +namespace mlir { +namespace cnm { +#define GEN_PASS_DEF_CNMAPPLYTRANSFORMSCRIPTPASS +#include "cinm-mlir/Dialect/Cnm/TransformOps/TransformPass.h.inc" +} // namespace cnm +} // namespace mlir + +using namespace mlir; + +struct CnmTransformDialectInterpreterPass + : public transform::TransformInterpreterPassBase< + CnmTransformDialectInterpreterPass, + cnm::impl::CnmApplyTransformScriptPassBase> { + + CnmTransformDialectInterpreterPass() = default; + CnmTransformDialectInterpreterPass( + const CnmTransformDialectInterpreterPass &pass) + : TransformInterpreterPassBase(pass) { + + debugTransformRootTag = pass.debugTransformRootTag; + debugPayloadRootTag = pass.debugPayloadRootTag; + disableExpensiveChecks = pass.disableExpensiveChecks; + transformFileName = pass.transformFileName; + entryPoint = pass.entryPoint; + transformLibraryPaths = pass.transformLibraryPaths; + } + CnmTransformDialectInterpreterPass( + const cnm::CnmApplyTransformScriptPassOptions &options) { + debugTransformRootTag = options.debugTransformRootTag; + debugPayloadRootTag = options.debugPayloadRootTag; + disableExpensiveChecks = options.disableExpensiveChecks; + transformFileName = options.transformFileName; + entryPoint = options.entryPoint; + transformLibraryPaths = options.transformLibraryPaths; + } +}; diff --git a/cinnamon/lib/Dialect/Cnm/TransformOps/CnmTransformExtension.cpp b/cinnamon/lib/Dialect/Cnm/TransformOps/CnmTransformExtension.cpp new file mode 100644 index 0000000..ace7b3f --- /dev/null +++ b/cinnamon/lib/Dialect/Cnm/TransformOps/CnmTransformExtension.cpp @@ -0,0 +1,57 @@ + +// In CnmTransformExtension.cpp. +#include "cinm-mlir/Dialect/Cnm/IR/CnmBase.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include + +// Define a new Transform dialect extension. This uses the CRTP idiom to +// identify extensions. +class CnmTransformExtension + : public ::mlir::transform::TransformDialectExtension< + CnmTransformExtension> { +public: + // The extension must derive the base constructor. + using Base::Base; + + // This function initializes the extension, similarly to `initialize` in + // dialect definitions. List individual operations and dependent dialects + // here. + void init(); +}; + +void CnmTransformExtension::init() { + // Similarly to dialects, an extension can declare a dependent dialect. This + // dialect will be loaded along with the extension and, therefore, along with + // the Transform dialect. Only declare as dependent the dialects that contain + // the attributes or types used by transform operations. Do NOT declare as + // dependent the dialects produced during the transformation. + // + // declareDependentDialect(); + + // When transformations are applied, they may produce new operations from + // previously unloaded dialects. Typically, a pass would need to declare + // itself dependent on the dialects containing such new operations. To avoid + // confusion with the dialects the extension itself depends on, the Transform + // dialects differentiates between: + // - dependent dialects, which are used by the transform operations, and + // - generated dialects, which contain the entities (attributes, operations, + // types) that may be produced by applying the transformation even when + // not present in the original payload IR. + // In the following chapter, we will be add operations that generate function + // calls and structured control flow operations, so let's declare the + // corresponding dialects as generated. + declareGeneratedDialect<::mlir::cnm::CnmDialect>(); + declareGeneratedDialect<::mlir::memref::MemRefDialect>(); + + // Finally, we register the additional transform operations with the dialect. + registerTransformOps< +#define GET_OP_LIST +#include "cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.cpp.inc" + >(); +} + +void mlir::cnm::registerTransformDialectExtension( + ::mlir::DialectRegistry ®istry) { + registry.addExtensions(); +} \ No newline at end of file diff --git a/cinnamon/lib/Dialect/Cnm/TransformOps/CnmTransformOps.cpp b/cinnamon/lib/Dialect/Cnm/TransformOps/CnmTransformOps.cpp new file mode 100644 index 0000000..e787738 --- /dev/null +++ b/cinnamon/lib/Dialect/Cnm/TransformOps/CnmTransformOps.cpp @@ -0,0 +1,90 @@ + +#include +#include +#include +#include +#include +#include +#include +#include + +#define GET_OP_CLASSES +#include + +using namespace mlir; +using namespace mlir::transform; + +// Implementation of our Transform dialect operation. +// This operation returns a tri-state result that can be one of: +// - success when the transformation succeeded; +// - definite failure when the transformation failed in such a way that +// following transformations are impossible or undesirable, typically it could +// have left payload IR in an invalid state; it is expected that a diagnostic +// is emitted immediately before returning the definite error; +// - silenceable failure when the transformation failed but following +// transformations are still applicable, typically this means a precondition +// for the transformation is not satisfied and the payload IR has not been +// modified. The silenceable failure additionally carries a Diagnostic that +// can be emitted to the user. +DiagnosedSilenceableFailure CnmExpandDimOp::applyToOne(TransformRewriter &, + cnm::ComputeOp compute, + ApplyToEachResultList &, + TransformState &) { + + if (failed(cnm::expandWorkshoupDim(compute, getDim(), getFactor()))) { + DiagnosedSilenceableFailure diag = emitDefaultSilenceableFailure(compute); + diag.attachNote() << "Transform failed"; + } + + return DiagnosedSilenceableFailure::success(); +} + +void CnmExpandDimOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + // Indicate that the `call` handle is only read by this operation because the + // associated operation is not erased but rather modified in-place, so the + // reference to it remains valid. + onlyReadsHandle(getTarget(), effects); + + // Indicate that the payload is modified by this operation. + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure CnmPeelRightOp::applyToOne(TransformRewriter &, + cnm::ComputeOp compute, + ApplyToEachResultList &, + TransformState &) { + + auto res = cnm::peelRight(compute); + if (failed(res)) { + DiagnosedSilenceableFailure diag = emitDefaultSilenceableFailure(compute); + diag.attachNote() << "Transform failed"; + } + + return DiagnosedSilenceableFailure::success(); +} + +void CnmPeelRightOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTarget(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure CnmSwapDimsOp::applyToOne(TransformRewriter &, + cnm::ComputeOp compute, + ApplyToEachResultList &, + TransformState &) { + + auto res = cnm::swapWorkgroupDims(compute, getDim0(), getDim1()); + if (failed(res)) { + DiagnosedSilenceableFailure diag = emitDefaultSilenceableFailure(compute); + diag.attachNote() << "Transform failed"; + } + return DiagnosedSilenceableFailure::success(); +} + +void CnmSwapDimsOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTarget(), effects); + modifiesPayload(effects); +} diff --git a/cinnamon/lib/Dialect/Cnm/Transforms/Bufferize.cpp b/cinnamon/lib/Dialect/Cnm/Transforms/Bufferize.cpp index 7350692..18cd4db 100644 --- a/cinnamon/lib/Dialect/Cnm/Transforms/Bufferize.cpp +++ b/cinnamon/lib/Dialect/Cnm/Transforms/Bufferize.cpp @@ -23,6 +23,9 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" +#include +#include +#include #include #include @@ -102,6 +105,67 @@ struct GatherOpInterface } }; +struct ComputeOpInterface + : public BufferizableOpInterface::ExternalModel { + // == is an input + bool bufferizesToMemoryRead(Operation *op0, OpOperand &operand, + const AnalysisState &) const { + auto op = cast(op0); + auto inStart = op.getNumSymbols(); + return operand.getOperandNumber() >= inStart && + operand.getOperandNumber() < inStart + op.getNumInputs(); + } + + // == is an output + bool bufferizesToMemoryWrite(Operation *op0, OpOperand &operand, + const AnalysisState &) const { + auto op = cast(op0); + auto outStart = op.getNumSymbols() + op.getNumInputs(); + return operand.getOperandNumber() >= outStart; + } + + AliasingValueList getAliasingValues(Operation *op0, OpOperand &operand, + const AnalysisState &state) const { + if (bufferizesToMemoryWrite(op0, operand, state)) { + auto op = cast(op0); + auto outBufs = op.getOutBuffers(); + auto i = 0; + for (auto buf : outBufs) { + if (buf == operand.get()) + return {{op->getOpResult(i), BufferRelation::Equivalent}}; + + if (buf.getType().isa()) { + i++; + } + } + } + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto compute = cast(op); + llvm::SmallVector newBuffers; + for (auto buf : compute.getBuffers()) { + if (buf.getType().isa()) { + FailureOr v = getBuffer(rewriter, buf, options); + if (failed(v)) { + newBuffers.push_back(buf); + } else { + newBuffers.push_back(*v); + } + } + } + + replaceOpWithNewBufferizedOp( + rewriter, op, compute.getWorkgroupShape(), + newBuffers, compute.getNumInputs(), + compute.getAffineMapsVec(), compute.getSymbolBindings()); + return success(); + } +}; + struct CnmBufferizePass : public cnm::impl::CnmBufferizePassBase { void runOnOperation() override { @@ -123,5 +187,6 @@ void cnm::registerCnmBufferizationExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, cnm::CnmDialect *) { cnm::ScatterOp::attachInterface(*ctx); cnm::GatherOp::attachInterface(*ctx); + cnm::ComputeOp::attachInterface(*ctx); }); } \ No newline at end of file diff --git a/cinnamon/lib/Dialect/Cnm/Transforms/CMakeLists.txt b/cinnamon/lib/Dialect/Cnm/Transforms/CMakeLists.txt index 91785ab..2ea9c58 100644 --- a/cinnamon/lib/Dialect/Cnm/Transforms/CMakeLists.txt +++ b/cinnamon/lib/Dialect/Cnm/Transforms/CMakeLists.txt @@ -2,6 +2,8 @@ add_mlir_dialect_library(CnmTransforms SPIRVAttachAttributes.cpp Bufferize.cpp HoistWorkgroups.cpp + CnmComputeTransforms.cpp + LowerCnmCompute.cpp DEPENDS CnmIncGen diff --git a/cinnamon/lib/Dialect/Cnm/Transforms/CnmComputeTransforms.cpp b/cinnamon/lib/Dialect/Cnm/Transforms/CnmComputeTransforms.cpp new file mode 100644 index 0000000..af8dff2 --- /dev/null +++ b/cinnamon/lib/Dialect/Cnm/Transforms/CnmComputeTransforms.cpp @@ -0,0 +1,358 @@ + +#include "cinm-mlir/Utils/CinmUtils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; + +namespace mlir::cnm { + +LogicalResult expandWorkshoupDim(cnm::ComputeOp compute, uint64_t dim, + int64_t factor) { + auto wg = compute.getWorkgroupShape(); + if (dim >= wg.size()) { + mlir::emitWarning(compute->getLoc()) + << "Cannot expand dim #" << dim << " by factor " << factor + << " because workgroup only has " << wg.size() << " dimensions"; + return failure(); + } + if (wg[dim] % factor != 0) { + mlir::emitWarning(compute->getLoc()) + << "Cannot expand dim #" << dim << " by factor " << factor + << " because dimension (" << wg[dim] << ") is not divisible by factor"; + return failure(); + } + auto tile = wg[dim] / factor; + SmallVector newShape(wg); + newShape[dim] = factor; + newShape.insert(newShape.begin() + dim + 1, tile); + + auto ctx = compute.getContext(); + auto d0 = getAffineDimExpr(dim, ctx); + auto d1 = getAffineDimExpr(dim + 1, ctx); + auto newIx = d0 * tile + d1; + + // Build an identity map that looks like (a,b,d0,d1,c) -> (a,b, d0 * tile + + // d1, c) + SmallVector exprs; + auto offset = 0; + for (uint64_t i = 0; i < wg.size(); i++) { + if (i == dim) { + exprs.push_back(newIx); + offset = 1; + } else { + exprs.push_back(getAffineDimExpr(offset + i, ctx)); + } + } + AffineMap linearMap = AffineMap::get(newShape.size(), 0, exprs, ctx); + linearMap = simplifyAffineMapWithBounds(linearMap, newShape); + + // apply the linear map first, then the original map + auto affineMaps = compute.getAffineMapsVec(); + for (auto &map : affineMaps) { + map = map.compose(linearMap); + } + + OpBuilder b(ctx); + compute.setAffineMapsAttr(b.getAffineMapArrayAttr(affineMaps)); + compute.setWorkgroupShape(newShape); + + return success(); +} + +LogicalResult swapWorkgroupDims(cnm::ComputeOp compute, uint64_t dim0, + uint64_t dim1) { + auto wg = compute.getWorkgroupShape(); + if (dim0 >= wg.size() || dim1 >= wg.size()) { + mlir::emitWarning(compute->getLoc()) + << "Cannot swap dim #" << dim0 << " and #" << dim1 + << " because workgroup only has " << wg.size() << " dimensions"; + return failure(); + } + auto ctx = compute.getContext(); + // build a map where the two dims are swapped + SmallVector exprs; + for (uint64_t i = 0; i < wg.size(); i++) { + if (i == dim0) { + exprs.push_back(getAffineDimExpr(dim1, ctx)); + } else if (i == dim1) { + exprs.push_back(getAffineDimExpr(dim0, ctx)); + } else { + exprs.push_back(getAffineDimExpr(i, ctx)); + } + } + AffineMap swapMap = + AffineMap::get(compute.getWorkgroupShape().size(), 0, exprs, ctx); + + // apply the linear map first, then the original map + auto affineMaps = compute.getAffineMapsVec(); + for (auto &map : affineMaps) { + map = map.compose(swapMap); + } + + OpBuilder b(ctx); + compute.setAffineMapsAttr(b.getAffineMapArrayAttr(affineMaps)); + + SmallVector newShape(wg); + std::swap(newShape[dim0], newShape[dim1]); + + compute.setWorkgroupShape(newShape); + + return success(); +} + +/// Turn the rightmost dimension of the workgroup into a +/// parallel loop within the kernel. +/// This transformation might delete the op if the workgroup +/// has only a single dimension. +/// ``` +/// cnm.compute +/// ins(%a[(i,j)->(i*512+j)]: memref<1024xi32>) +/// outs(%o[(i,j)->(i*512+j)]: memref<1024xi32>) +/// on hierarchy<2x512> +/// do (%a1: memref, %o1: memref) { +/// %x = memref.load %a1[] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[] +/// } +/// ``` +/// into +/// ``` +/// %as = memref.reshape %a: (mr<1024xi32>) ... +/// %os = memref.reshape %o: (mr<1024xi32>) ... +/// cnm.compute +/// ins(%as[(i)->(i)]: memref<2x512xi32>) +/// outs(%os[(i)->(i)]: memref<2x512xi32>) +/// on hierarchy<2> +/// do (%a1: memref<512xi32>, +/// %o1: memref<512xi32>) { +/// affine.parallel (%i) = (0) to (512) { +/// %x = memref.load %a1[%i] +/// %t2 = arith.muli %x, 2 +/// memref.store %t2, %o1[%i] +/// } +/// } +/// ``` + +static void remapUse(Operation *usage, BlockArgument, Value indexVal, + OpBuilder &) { + if (auto load = dyn_cast_or_null(usage)) { + // logically this load does not verify + load.getIndicesMutable().append(indexVal); + } + if (auto store = dyn_cast_or_null(usage)) { + // logically this load does not verify + store.getIndicesMutable().append(indexVal); + } +} + +// This transfo does peelright on a compute op where the +// buffer dims are already correct (do not need reshape) +LogicalResult peelRight(cnm::ComputeOp compute) { + auto ctx = compute.getContext(); + OpBuilder builder(ctx); + auto wg = compute.getWorkgroupShape(); + assert(!wg.empty()); + if (wg.size() == 1) { + mlir::emitWarning(compute->getLoc()) + << "Cannot peel right because workgroup has only 1 dimension"; + return failure(); + } + + auto dimIx = wg.size() - 1; + auto peelDim = wg[dimIx]; + + SmallVector newAffineMaps(compute.getAffineMapsVec()); + SmallVector newArguments(compute.getBody().getArgumentTypes()); + llvm::SmallBitVector changedArgs(compute.getBody().getNumArguments()); + + for (auto [i, buf, arg, map] : + llvm::enumerate(compute.getBuffers(), compute.getKernelArgs(), + compute.getAffineMapsVec())) { + + auto bufTy = buf.getType().cast(); + auto bufShape = bufTy.getShape(); + auto argTy = arg.getType().cast(); + auto argShape = argTy.getShape(); + + bool isBroadCast; + if (map.getNumResults() > 0) { + // all the first N-1 results have to not use the result. + for (auto expr : map.getResults().slice(0, map.getNumResults() - 1)) { + if (expr.isFunctionOfDim(dimIx)) + return failure(); + } + auto lastRes = map.getResult(dimIx); + if (lastRes.isFunctionOfDim(dimIx)) { + if (lastRes != getAffineDimExpr(dimIx, ctx)) { + // todo we can support non-identity last dims later + return failure(); + } + if (bufShape[dimIx] != peelDim) { + mlir::emitWarning(compute->getLoc()) + << "Cannot peel right because the last dimension of buffer #" << i + << " needs to be " << peelDim << ", got " << bufShape[dimIx]; + return failure(); + } + // drop last result + map = map.dropResult(map.getNumResults() - 1); + } + isBroadCast = !lastRes.isFunctionOfDim(dimIx); + } else { + isBroadCast = true; + } + // drop last dim + MutableAffineMap mut(map); + mut.setNumDims(mut.getNumDims() - 1); + newAffineMaps[i] = mut.getAffineMap(); + + if (isBroadCast) { + // This is a broadcast, argument is untouched + continue; + } + + SmallVector newShape; + newShape.reserve(argShape.size() + 1); + newShape.push_back(peelDim); + newShape.append(argShape.begin(), argShape.end()); + newArguments[i] = MemRefType::get(newShape, argTy.getElementType()); + changedArgs.set(i); + } + + // clone the region into a swap space to be able to perform modifications in + // place + std::unique_ptr swapRegion = std::make_unique(); + swapRegion->takeBody(compute.getBody()); + + // update attributes + compute.setWorkgroupShape(wg.slice(0, wg.size() - 1)); + compute.setAffineMapsAttr(builder.getAffineMapArrayAttr(newAffineMaps)); + + // update types of kernel arguments + auto &kernelRegion = compute.getBody(); + auto &entry = kernelRegion.emplaceBlock(); + IRMapping mapping; + for (auto [newTy, arg] : + llvm::zip(newArguments, swapRegion->getArguments())) { + arg.setType(newTy); + auto newArg = entry.addArgument(newTy, arg.getLoc()); + mapping.map(arg, newArg); + } + + builder.setInsertionPointToStart(&kernelRegion.front()); + + auto loop = builder.create( + compute->getLoc(), TypeRange{}, ArrayRef{}, + ArrayRef{peelDim}); + + // clone the region, note that the output is invalid as the uses of the + // original arguments that changed type must be adapted + auto loopBuilder = loop.getBodyBuilder(); + loopBuilder.setListener(builder.getListener()); + + auto &loopRegion = loop.getBodyRegion(); + swapRegion->cloneInto(&loopRegion, mapping); + // now we need to cleanup after the clone + loopRegion.getBlocks().pop_front(); + auto &loopBody = loopRegion.getBlocks().front(); + loopBody.addArgument(builder.getIndexType(), loop->getLoc()); + loopBody.getTerminator()->erase(); // that's the cnm.terminator + builder.setInsertionPointToEnd(&loopBody); + builder.create(loop.getLoc()); + builder.setInsertionPointToEnd(loop->getBlock()); + builder.create(loop.getLoc()); + + for (auto changed : changedArgs.set_bits()) { + auto arg = kernelRegion.getArgument(changed); + for (auto user : arg.getUsers()) { + remapUse(user, arg, loopBody.getArgument(0), builder); + } + } + + return success(); +} + +LogicalResult normalizeInputs(cnm::ComputeOp op, + OpBuilder::Listener *listener) { + + OpBuilder builder0(op->getContext()); + builder0.setListener(listener); + ImplicitLocOpBuilder builder(op.getLoc(), builder0); + builder.setInsertionPoint(op); + + // ok we need to find all tensors who are not broadcast? +} + +void lowerComputeToLaunch(cnm::ComputeOp op, OpBuilder::Listener *listener) { + OpBuilder builder0(op->getContext()); + builder0.setListener(listener); + ImplicitLocOpBuilder builder(op.getLoc(), builder0); + builder.setInsertionPoint(op); + + auto affineMaps = op.getAffineMapsVec(); + Value wg = builder.create(op.getWorkgroupShape()); + llvm::SmallVector cnmBuffers; + for (auto [buf, arg] : llvm::zip(op.getBuffers(), op.getKernelArgs())) { + auto argTy = arg.getType().cast(); + cnmBuffers.push_back(builder.create( + argTy.getShape(), argTy.getElementType(), wg, + 0 // level + )); + } + + for (auto [buf, map, cnmBuf] : + llvm::zip(op.getBuffers(), affineMaps, cnmBuffers)) { + builder.create(buf, cnmBuf, wg, map); + } + + const ArrayRef cnmBufferRef(cnmBuffers); + + auto launch = builder.create( + wg, ValueRange(cnmBufferRef.slice(0, op.getNumInputs())), + ValueRange(cnmBufferRef.slice(op.getNumInputs(), op.getNumOutputs()))); + + launch.getBody().takeBody(op.getBody()); + + SmallVector results; + for (auto [cnmBuf, map, init] : + llvm::drop_begin(llvm::zip(cnmBuffers, affineMaps, op.getBuffers()), + op.getNumInputs())) { + auto gather = builder.create(cnmBuf, wg, map, init); + if (gather->getNumResults() > 0) { + results.push_back(gather.getOutput()); + } + } + + builder.create(wg); + + op->replaceAllUsesWith(results); + op.erase(); +} +} // namespace mlir::cnm \ No newline at end of file diff --git a/cinnamon/lib/Dialect/Cnm/Transforms/HoistWorkgroups.cpp b/cinnamon/lib/Dialect/Cnm/Transforms/HoistWorkgroups.cpp index 955c7a4..c2985ba 100644 --- a/cinnamon/lib/Dialect/Cnm/Transforms/HoistWorkgroups.cpp +++ b/cinnamon/lib/Dialect/Cnm/Transforms/HoistWorkgroups.cpp @@ -16,7 +16,6 @@ namespace mlir::cnm { } // namespace mlir::cnm using namespace mlir; -namespace {} struct CnmHoistWorkgroupsPass : public cnm::impl::CnmHoistWorkgroupsPassBase { diff --git a/cinnamon/lib/Dialect/Cnm/Transforms/LowerCnmCompute.cpp b/cinnamon/lib/Dialect/Cnm/Transforms/LowerCnmCompute.cpp new file mode 100644 index 0000000..51a946f --- /dev/null +++ b/cinnamon/lib/Dialect/Cnm/Transforms/LowerCnmCompute.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlir::cnm { + +#define GEN_PASS_DEF_CNMLOWERCOMPUTEPASS +#include + +} // namespace mlir::cnm + +using namespace mlir; + +struct CnmLowerComputePass + : public cnm::impl::CnmLowerComputePassBase { + void runOnOperation() override { + auto fun = getOperation(); + + fun.walk([&](cnm::ComputeOp op) { + cnm::lowerComputeToLaunch(op); + }); + } +}; \ No newline at end of file diff --git a/cinnamon/lib/Utils/CinmUtils.cpp b/cinnamon/lib/Utils/CinmUtils.cpp index 2451cde..0dfade2 100644 --- a/cinnamon/lib/Utils/CinmUtils.cpp +++ b/cinnamon/lib/Utils/CinmUtils.cpp @@ -131,7 +131,7 @@ AffineMap simplifyAffineMapWithBounds(AffineMap map, if (dim == ShapedType::kDynamic) upperBounds.push_back(std::nullopt); else - upperBounds.push_back(std::make_optional(dim - 1)); + upperBounds.push_back(std::make_optional(dim)); } llvm::SmallVector> lowerBounds; diff --git a/cinnamon/samples/cnm_conv.mlir b/cinnamon/samples/cnm_conv.mlir index 04ad1c1..ba7883c 100644 --- a/cinnamon/samples/cnm_conv.mlir +++ b/cinnamon/samples/cnm_conv.mlir @@ -56,7 +56,7 @@ func.func @conv(%img : tensor<1x128x128x3xi16>, %flt : tensor<3x3x3x8xi16>, %bia : tensor<128x32xi16> into !cnm.buffer<16x16xi16 on 8x2, level 0> %sc_b_token = cnm.scatter %B_pad into %B_buf[#scatter_map] of %wg : tensor<32x16xi16> into !cnm.buffer<16x16xi16 on 8x2, level 0> - %e_token = cnm.launch %wg in(%A_buf, %B_buf : !cnm.buffer<16x16xi16 on 8x2, level 0>, !cnm.buffer<16x16xi16 on 8x2, level 0>) out(%C_buf : !cnm.buffer<16x16xi16 on 8x2, level 0>) on !cnm.workgroup<8x2> { + %e_token = cnm.launch %wg ins(%A_buf, %B_buf : !cnm.buffer<16x16xi16 on 8x2, level 0>, !cnm.buffer<16x16xi16 on 8x2, level 0>) outs(%C_buf : !cnm.buffer<16x16xi16 on 8x2, level 0>) on !cnm.workgroup<8x2> { ^bb0(%A_space: memref<16x16xi16>, %B_space: memref<16x16xi16>, %C_space: memref<16x16xi16>): affine.for %arg3 = 0 to 16 { affine.for %arg4 = 0 to 16 { diff --git a/cinnamon/samples/gemm_cnm.mlir b/cinnamon/samples/gemm_cnm.mlir index 0c48894..918e53d 100644 --- a/cinnamon/samples/gemm_cnm.mlir +++ b/cinnamon/samples/gemm_cnm.mlir @@ -24,7 +24,7 @@ module { %7 = cnm.scatter %reshape_5 into %6[#map] of %3 : tensor<64x16xi32> into !cnm.buffer<16xi32 on 4x8x2, level 0> %8 = cnm.alloc() for %3 : !cnm.buffer %9 = cnm.scatter %cst_3 into %8[#map1] of %3 : tensor<64xi32> into !cnm.buffer - %10 = cnm.launch %3 in(%4, %6 : !cnm.buffer<16xi32 on 4x8x2, level 0>, !cnm.buffer<16xi32 on 4x8x2, level 0>) out(%8 : !cnm.buffer) on !cnm.workgroup<4x8x2> { + %10 = cnm.launch %3 ins(%4, %6 : !cnm.buffer<16xi32 on 4x8x2, level 0>, !cnm.buffer<16xi32 on 4x8x2, level 0>) outs(%8 : !cnm.buffer) on !cnm.workgroup<4x8x2> { ^bb0(%arg6: memref<16xi32>, %arg7: memref<16xi32>, %arg8: memref): linalg.reduce ins(%arg6, %arg7 : memref<16xi32>, memref<16xi32>) outs(%arg8 : memref) dimensions = [0] (%in: i32, %in_6: i32, %init: i32) { diff --git a/cinnamon/test/Dialect/Cinm/cinm-tiling.mlir b/cinnamon/test/Dialect/Cinm/cinm-tiling.mlir index a1f73df..1b9f46c 100644 --- a/cinnamon/test/Dialect/Cinm/cinm-tiling.mlir +++ b/cinnamon/test/Dialect/Cinm/cinm-tiling.mlir @@ -1,29 +1,21 @@ -// RUN: cinm-opt %s --cinm-tiling=reduction-tile-size=16 -split-input-file | FileCheck %s +// RUN: cinm-opt %s --cinm-tiling -split-input-file | FileCheck %s // CHECK-LABEL: @gemmSquare // CHECK-SAME: (%[[A:.*]]: tensor<1024x1024xi32>, %[[B:.*]]: -// CHECK: affine.for %[[i:.*]] = 0 to 1024 step 64 -// CHECK-NEXT: affine.for %[[j:.*]] = 0 to 1024 step 64 -// CHECK-NEXT: %[[blockA:.*]] = tensor.extract_slice %[[A]][%[[i]], 0] [64, 1024] [1, 1] -// CHECK-NEXT: %[[blockB:.*]] = tensor.extract_slice %[[B]][0, %[[j]]] [1024, 64] [1, 1] -// CHECK-NEXT: tensor.generate -// CHECK-NEXT: ^{{.*}}(%[[ti:.*]]: index, %[[tj:.*]]: index): -// CHECK-NEXT: %[[row:.*]] = tensor.extract_slice %[[blockA]][%[[ti]], 0] [1, 1024] [1, 1] -// CHECK-NEXT: %[[col:.*]] = tensor.extract_slice %[[blockB]][0, %[[tj]]] [1024, 1] [1, 1] - -// CHECK: %[[batchedRow:.*]] = tensor.reshape %[[row]](%{{.*}}) : (tensor<1024xi32>, tensor<2xi64>) -> tensor<64x16xi32> -// CHECK-NEXT: %[[batchedCol:.*]] = tensor.reshape %[[col]](%{{.*}}) : (tensor<1024xi32>, tensor<2xi64>) -> tensor<64x16xi32> -// CHECK-NEXT: %[[stage1:.*]] = linalg.reduce ins(%[[batchedRow]], %[[batchedCol]] : {{.*}}) outs(%{{.*}} : tensor<64xi32>) dimensions = [1] -// CHECK-NEXT: (%[[ei:.*]]: i32, %[[ej:.*]]: i32, %[[init:.*]]: i32) -// CHECK-NEXT: %[[mul:.*]] = arith.muli %[[ei]], %[[ej]] -// CHECK-NEXT: %[[add:.*]] = arith.addi %[[mul]], %[[init]] -// CHECK-NEXT: linalg.yield %[[add]] - -// CHECK: linalg.reduce ins(%[[stage1]] : tensor<64xi32>) +// CHECK: affine.for %[[i:.*]] = 0 to 1024 step 2 +// CHECK-NEXT: affine.for %[[j:.*]] = 0 to 1024 step 64 iter_args(%[[arg0:.*]] = +// CHECK-NEXT: %[[cst0:.*]] = arith.constant dense<0> +// CHECK-NEXT: %[[res:.*]] = affine.for %[[k:.*]] = 0 to 1024 step 16 iter_args(%[[arg:.*]] = %[[cst0]]) -> (tensor<2x64xi32>) { +// CHECK-NEXT: %[[blockA:.*]] = tensor.extract_slice %[[A]][%[[i]], %[[k]]] [2, 16] [1, 1] +// CHECK-NEXT: %[[blockB:.*]] = tensor.extract_slice %[[B]][%[[k]], %[[j]]] [16, 64] [1, 1] +// CHECK-NEXT: %[[tile:.*]] = cinm.op.gemm %[[blockA]], %[[blockB]] plus %[[arg]] {cinm.notile} +// CHECK-NEXT: affine.yield %[[tile]] +// CHECK-NEXT: } +// CHECK-NEXT: tensor.insert_slice %[[res]] into %[[arg0]][%[[i]], %[[j]]] [2, 64] [1, 1] func.func @gemmSquare(%a: tensor<1024x1024xi32>, %b: tensor<1024x1024xi32>) -> tensor<1024x1024xi32>{ - %res = cinm.compute -> tensor<1024x1024xi32> { + %res = cinm.compute attributes { workgroupShape=array, bufferSizesInBytes=array} -> tensor<1024x1024xi32> { %d = cinm.op.gemm %a, %b : (tensor<1024x1024xi32>, tensor<1024x1024xi32>) -> tensor<1024x1024xi32> cinm.yield %d: tensor<1024x1024xi32> } @@ -36,7 +28,7 @@ func.func @gemmSquare(%a: tensor<1024x1024xi32>, %b: tensor<1024x1024xi32>) -> t // CHECK-LABEL: @gemv func.func @gemv(%a: tensor<1024x1024xi32>, %b: tensor<1024xi32>) -> tensor<1024xi32>{ - %res = cinm.compute -> tensor<1024xi32> { + %res = cinm.compute attributes { workgroupShape=array, bufferSizesInBytes=array} -> tensor<1024xi32> { %d = cinm.op.gemv %a, %b : (tensor<1024x1024xi32>, tensor<1024xi32>) -> tensor<1024xi32> cinm.yield %d: tensor<1024xi32> } @@ -48,7 +40,7 @@ func.func @gemv(%a: tensor<1024x1024xi32>, %b: tensor<1024xi32>) -> tensor<1024x // CHECK-LABEL: @max func.func @max(%a: tensor<1024xi32>) -> i32 { - %res = cinm.compute -> i32 { + %res = cinm.compute attributes { workgroupShape=array, bufferSizesInBytes=array} -> i32 { %d = cinm.op.reduce max (%a): tensor<1024xi32> cinm.yield %d : i32 } diff --git a/cinnamon/test/Dialect/Cnm/cnm-ops.mlir b/cinnamon/test/Dialect/Cnm/cnm-ops.mlir index 463dd74..89a5543 100644 --- a/cinnamon/test/Dialect/Cnm/cnm-ops.mlir +++ b/cinnamon/test/Dialect/Cnm/cnm-ops.mlir @@ -2,61 +2,115 @@ // RUN: cinm-opt %s --mlir-print-op-generic | cinm-opt | FileCheck %s -// CHECK-LABEL: matmul - -#scatter_map = affine_map<(i) -> (i floordiv 64 mod 4, i floordiv 64, i mod 64)> -#gather_map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @matmul(%A: tensor<1024x1024xi32>, %B: tensor<1024x1024xi32>) -> tensor<1024x1024xi32> { - - %c0_i32 = arith.constant 0 : i32 - - %generated = tensor.generate { - ^bb0(%i: index, %j: index): - %row = tensor.extract_slice %A[%i, 0] [1, 1024] [1, 1] : tensor<1024x1024xi32> to tensor<1024xi32> - %col = tensor.extract_slice %B[0, %j] [1024, 1] [1, 1] : tensor<1024x1024xi32> to tensor<1024xi32> - %3 = arith.muli %row, %col : tensor<1024xi32> - - // === Lower reduction loops === - // Reduction has already been split into two stages: reduce 1024 elements into 64 sums of batch=16 elements - // We pick a workgroup size that adds up to 64: 4x16 - %wg = cnm.workgroup { cnm.physical_dims = ["dpu", "tasklet"] } : !cnm.workgroup<4x16> - - // We alloc the buffer for the batch (the 16 here is batch size) - %A_buf = cnm.alloc() for %wg { cnm.physical_space = "global" } : !cnm.buffer<16xi32 on 4x16, level 0> - cnm.scatter %3 into %A_buf[#scatter_map] of %wg : tensor<1024xi32> into !cnm.buffer<16xi32 on 4x16, level 0> - - // We alloc a buffer for the reduction result (scalar) - %outbuf = cnm.alloc() for %wg { cnm.physical_space = "global" } : !cnm.buffer - // Then we launch the group - %token2 = cnm.launch %wg in(%A_buf: !cnm.buffer<16xi32 on 4x16, level 0>) out(%outbuf : !cnm.buffer) on !cnm.workgroup<4x16> { - ^bb0(%arg0: memref<16xi32>, %arg1: memref): - // Here we have an affine reduction loop - %total = affine.for %x = 0 to 16 iter_args(%sum = %c0_i32) -> i32 { - %elt = affine.load %arg0[%x]: memref<16xi32> - %tmp = arith.addi %sum, %elt: i32 - affine.yield %tmp: i32 - } - // finally store result - memref.store %total, %arg1[] : memref + +#map1 = affine_map<(d0, d1, d2) -> (d0 * 1024 + d1 * 16 + d2)> +#map = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: va_8 + func.func @va_8(%arg0: tensor<8x2097152xi32>, %arg1: tensor<8x2097152xi32>) { + %cst = arith.constant dense<0> : tensor<16384x1024xi32> + %cst_0 = arith.constant dense<[16384, 1024]> : tensor<2xi64> + %cst_1 = arith.constant dense<16777216> : tensor<1xi64> + %reshape = tensor.reshape %arg0(%cst_1) : (tensor<8x2097152xi32>, tensor<1xi64>) -> tensor<16777216xi32> + %reshape_2 = tensor.reshape %arg1(%cst_1) : (tensor<8x2097152xi32>, tensor<1xi64>) -> tensor<16777216xi32> + %0 = cnm.workgroup : !cnm.workgroup<16x64x16> + %reshape_3 = tensor.reshape %reshape(%cst_0) : (tensor<16777216xi32>, tensor<2xi64>) -> tensor<16384x1024xi32> + %1 = cnm.alloc() for %0 : !cnm.buffer<1024xi32 on 16x64x16, level 0> + cnm.scatter %reshape_3 into %1[#map1] of %0 : tensor<16384x1024xi32> into !cnm.buffer<1024xi32 on 16x64x16, level 0> + %reshape_4 = tensor.reshape %reshape_2(%cst_0) : (tensor<16777216xi32>, tensor<2xi64>) -> tensor<16384x1024xi32> + %2 = cnm.alloc() for %0 : !cnm.buffer<1024xi32 on 16x64x16, level 0> + cnm.scatter %reshape_4 into %2[#map1] of %0 : tensor<16384x1024xi32> into !cnm.buffer<1024xi32 on 16x64x16, level 0> + %3 = cnm.alloc() for %0 : !cnm.buffer<1024xi32 on 16x64x16, level 0> + cnm.scatter %cst into %3[#map1] of %0 : tensor<16384x1024xi32> into !cnm.buffer<1024xi32 on 16x64x16, level 0> + cnm.launch %0 ins(%1, %2 : !cnm.buffer<1024xi32 on 16x64x16, level 0>, !cnm.buffer<1024xi32 on 16x64x16, level 0>) outs(%3 : !cnm.buffer<1024xi32 on 16x64x16, level 0>) on !cnm.workgroup<16x64x16> { + ^bb0(%arg2: memref<1024xi32>, %arg3: memref<1024xi32>, %arg4: memref<1024xi32>): + linalg.add ins(%arg2, %arg3 : memref<1024xi32>, memref<1024xi32>) outs(%arg4 : memref<1024xi32>) + } + %4 = tensor.empty() : tensor<16384x1024xi32> + %5 = cnm.gather %3[#map1] of %0 into %4 : !cnm.buffer<1024xi32 on 16x64x16, level 0> into tensor<16384x1024xi32> + cnm.free_workgroup %0 : !cnm.workgroup<16x64x16> + return +} + + + func.func @transform(%arg0: memref<1024xi32>, %arg1: memref<1024xi32>) { + cnm.compute + ins(%arg0[#map] : memref<1024xi32>) + outs(%arg1[#map] : memref<1024xi32>) + on hierarchy<1024> + do (%arg2: memref, %arg3: memref) { + %0 = affine.load %arg2[] : memref + %c2_i32 = arith.constant 2 : i32 + %1 = arith.muli %0, %c2_i32 : i32 + affine.store %1, %arg3[] : memref + cnm.terminator + } + + cnm.compute + ins(%arg0[(i, j) -> ()]: memref<1024xi32>) + outs(%arg1[(i, j) -> (i * 512 + j)]: memref<1024xi32>) + on hierarchy<2x512> + do (%a1: memref<1024xi32>, %o1: memref) { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[] : memref + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[] : memref } + cnm.terminator + } + %r = memref.expand_shape %arg1[[0,1]] : memref<1024xi32> into memref<2x512xi32> - // Finally gather results into a buffer with same shape as the workgroup - %ReductionStage1 = cnm.gather %outbuf[#gather_map] of %wg : !cnm.buffer into tensor<4x16xi32> + cnm.compute + ins(%arg0[(i, j) -> ()]: memref<1024xi32>) + outs(%r[(i, j) -> (i, j)]: memref<2x512xi32>) + on hierarchy<2x512> + do (%a1: memref<1024xi32>, %o1: memref) { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[] : memref + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[] : memref + } + cnm.terminator + } - // === Second reduction loop === - // At this point there is a second linalg.reduce - // I think we can always assume we do this reduction on the host. - // Lower it to affine with --linalg-bufferize --convert-linalg-to-affine-loops - %from_elements = tensor.from_elements %c0_i32 : tensor - %reduced = linalg.reduce ins(%ReductionStage1 : tensor<4x16xi32>) outs(%from_elements : tensor) dimensions = [0, 1] - (%in: i32, %init: i32) { - %4 = arith.addi %in, %init : i32 - linalg.yield %4 : i32 + cnm.compute + ins(%arg0[(i) -> ()]: memref<1024xi32>) + outs(%r[(i) -> (i)]: memref<2x512xi32>) + on hierarchy<2> + do (%a1: memref<1024xi32>, %o1: memref<512xi32>) { + affine.for %j = 0 to 512 { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[%j] : memref<512xi32> + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[%j] : memref<512xi32> + } } - %extracted = tensor.extract %reduced[] : tensor - tensor.yield %extracted : i32 - } : tensor<1024x1024xi32> - return %generated : tensor<1024x1024xi32> + cnm.terminator + } + return + +} + + func.func @compute_tensor(%arg0: tensor<1024xi32>, + %arg1: tensor<1024xi32>, + %arg0m: memref<20xi32>, %arg1m: memref<1024xi32>) { + + %out = cnm.compute + ins(%arg0[#map] : tensor<1024xi32>, %arg0m[(i) -> ()]: memref<20xi32>) + outs(%arg1[#map] : tensor<1024xi32>) + on hierarchy<1024> + do (%arg2: memref, %argx: memref<20xi32>, %arg3: memref) { + %0 = affine.load %arg2[] : memref + %c2_i32 = arith.constant 2 : i32 + %1 = arith.muli %0, %c2_i32 : i32 + affine.store %1, %arg3[] : memref + cnm.terminator + } + + return + } diff --git a/cinnamon/test/Dialect/Cnm/cnm-transform.mlir b/cinnamon/test/Dialect/Cnm/cnm-transform.mlir new file mode 100644 index 0000000..17519da --- /dev/null +++ b/cinnamon/test/Dialect/Cnm/cnm-transform.mlir @@ -0,0 +1,33 @@ +#map = affine_map<(d0) -> (d0)> +module { + func.func @transform(%arg0: memref<1024xi32>, %arg1: memref<1024xi32>) { + cnm.compute + ins(%arg0[#map] : memref<1024xi32>) + outs(%arg1[#map] : memref<1024xi32>) + on hierarchy<1024> + do (%arg2: memref, %arg3: memref) { + %0 = affine.load %arg2[] : memref + %c2_i32 = arith.constant 2 : i32 + %1 = arith.muli %0, %c2_i32 : i32 + affine.store %1, %arg3[] : memref + } + + cnm.compute + ins(%arg0[(i, j) -> ()]: memref<1024xi32>) + outs(%arg1[(i, j) -> (i * 512 + j)]: memref<1024xi32>) + on hierarchy<2x512> + do (%a1: memref<1024xi32>, %o1: memref) { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[] : memref + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[] : memref + } + } + return + } + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"cnm.compute">): + transform.cnm.expand_dim %arg0 dim 0 by factor 2 : (!transform.op<"cnm.compute">) -> () + } +} \ No newline at end of file diff --git a/cinnamon/test/Dialect/Cnm/linalg-tile.mlir b/cinnamon/test/Dialect/Cnm/linalg-tile.mlir new file mode 100644 index 0000000..2b0c2f8 --- /dev/null +++ b/cinnamon/test/Dialect/Cnm/linalg-tile.mlir @@ -0,0 +1,98 @@ +// RUN: cinm-opt %s --cnm-apply-transform | FileCheck %s +#map = affine_map<(d0) -> (d0)> +#matmul_trait = { + doc = "C() += A(k) * B(k)", + indexing_maps = [ + affine_map<(k) -> (k)>, + affine_map<(k) -> (k)>, + affine_map<(k) -> ()> + ], + iterator_types = ["reduction"] +} + +#linalg_trait = { + +} +module { + + + func.func @matmul(%arg0: memref<1024x64xi32>, %arg1: memref<64x1024xi32>) { + %arg1t = memref.alloc() : memref<1024x64xi32> + linalg.transpose ins(%arg1: memref<64x1024xi32>) outs(%arg1t: memref<1024x64xi32>) permutation = [1,0] + %res = memref.alloc() : memref<1024x1024xi32> + + // peel left 2 + // This computes 16x16 tiles of the output on some dimms. + // - Can we lift the reduction loop outside of the nest? That's hard + // - Can we coarsen the output size (not use memref but memref<16xi32>) + // - That's fine. + // + affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + cnm.compute + symbols [%D0, %D1] + ins(%arg0[(d2, d3)[D0, D1] -> (D0 * 16 + d2)] : memref<1024x64xi32>, + %arg1t[(d2, d3)[D0, D1] -> (D1 * 16 + d3)] : memref<1024x64xi32>) + outs(%res[(d2, d3)[D0, D1] -> (D0 * 16 + d2, D1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<16x16> + do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + %t0 = bufferization.to_tensor %arg2: memref<64xi32> + %t1 = bufferization.to_tensor %arg3: memref<64xi32> + %t3 = bufferization.to_tensor %arg4: memref + %o = linalg.generic #matmul_trait + ins(%t0, %t1: tensor<64xi32>, tensor<64xi32>) + outs(%t3: tensor) { + ^bb0(%a: i32, %b: i32, %c: i32): + %0 = arith.muli %a, %b: i32 + %1 = arith.addi %0, %c: i32 + linalg.yield %1: i32 + } -> tensor + } + } + + + // // Ok stepwise: + // affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + // // first tile the inner kernel + // // (todo do that automatically with transform. Is that possible?) + // cnm.compute + // symbols [%D0, %D1] + // ins(%arg0[(d2, d3)[D0, D1] -> (D0 * 16 + d2)] : memref<1024x64xi32>, + // %arg1t[(d2, d3)[D0, D1] -> (D1 * 16 + d3)] : memref<1024x64xi32>) + // outs(%res[(d2, d3)[D0, D1] -> (D0 * 16 + d2, D1 * 16 + d3)] : memref<1024x1024xi32>) + // on hierarchy<16x16> + // do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + // %arg20 = memref.expand_shape %arg2 [[0, 1]]: memref<64xi32> into memref<2x32xi32> + // %arg30 = memref.expand_shape %arg3 [[0, 1]]: memref<64xi32> into memref<2x32xi32> + // linalg.generic { + // indexing_maps = [ + // affine_map<(ko, kt) -> (ko, kt)>, + // affine_map<(ko, kt) -> (ko, kt)>, + // affine_map<(ko, kt) -> ()> + // ], + // iterator_types = ["reduction", "reduction"] + // } + // ins(%arg20, %arg30: memref<2x32xi32>, memref<2x32xi32>) + // outs(%arg4: memref) { + // ^bb0(%a: i32, %b: i32, %c: i32): + // %0 = arith.muli %a, %b: i32 + // %1 = arith.addi %0, %c: i32 + // linalg.yield %1: i32 + // } + // } + // } + + return + } + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"linalg.generic">): + %r:4 = transform.structured.tile_reduction_using_for %arg0 + by tile_sizes = [32]: (!transform.op<"linalg.generic">) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + } + + // transform.sequence failures(propagate) { + // ^bb0(%arg0: !transform.op<"cnm.compute">): + // transform.cnm.expand_dim %arg0 dim 1 by factor 64: (!transform.op<"cnm.compute">) -> () + // transform.cnm.expand_dim %arg0 dim 0 by factor 64: (!transform.op<"cnm.compute">) -> () + // transform.cnm.swap_dims %arg0, 1, 2: (!transform.op<"cnm.compute">) -> () + // } +} \ No newline at end of file diff --git a/cinnamon/test/Dialect/Cnm/transform-expand-dim.mlir b/cinnamon/test/Dialect/Cnm/transform-expand-dim.mlir new file mode 100644 index 0000000..b9c6aac --- /dev/null +++ b/cinnamon/test/Dialect/Cnm/transform-expand-dim.mlir @@ -0,0 +1,65 @@ +// RUN: cinm-opt %s --cnm-apply-transform | FileCheck %s +#map = affine_map<(d0) -> (d0)> +module { + + // CHECK-LABEL: @simple + // CHECK: cnm.compute + // CHECK-NEXT: ins(%arg0[(d0, d1) -> (d0 * 512 + d1)] : memref<1024xi32>) + // CHECK-NEXT: outs(%arg1[(d0, d1) -> (d0 * 512 + d1)] : memref<1024xi32>) + // CHECK-NEXT: on hierarchy<2x512> + func.func @simple(%arg0: memref<1024xi32>, %arg1: memref<1024xi32>) { + cnm.compute + ins(%arg0[(i) -> (i)] : memref<1024xi32>) + outs(%arg1[(i) -> (i)] : memref<1024xi32>) + on hierarchy<1024> + do (%arg2: memref, %arg3: memref) { + %0 = affine.load %arg2[] : memref + %c2_i32 = arith.constant 2 : i32 + %1 = arith.muli %0, %c2_i32 : i32 + affine.store %1, %arg3[] : memref + } + + // cnm.compute + // ins(%arg0[(i, j) -> ()]: memref<1024xi32>) + // outs(%arg1[(i, j) -> (i * 512 + j)]: memref<1024xi32>) + // on hierarchy<2x512> + // do (%a1: memref<1024xi32>, %o1: memref) { + // affine.for %i = 0 to 1024 { + // %0 = affine.load %a1[%i] : memref<1024xi32> + // %1 = affine.load %o1[] : memref + // %2 = arith.addi %0, %1 : i32 + // affine.store %2, %o1[] : memref + // } + // cnm.terminator + // } + + return + } + + // CHECK-LABEL: @broadcast + // CHECK: cnm.compute + // CHECK-NEXT: ins(%arg0[(d0, d1, d2) -> ()] : memref<1024xi32>) + // CHECK-NEXT: outs(%arg1[(d0, d1, d2) -> (d0 * 512 + d2)] : memref<1024xi32>) + // CHECK-NEXT: on hierarchy<2x1x512> + func.func @broadcast(%arg0: memref<1024xi32>, %arg1: memref<1024xi32>) { + cnm.compute + ins(%arg0[(i, j) -> ()]: memref<1024xi32>) + outs(%arg1[(i, j) -> (i * 512 + j)]: memref<1024xi32>) + on hierarchy<2x512> + do (%a1: memref<1024xi32>, %o1: memref) { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[] : memref + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[] : memref + } + } + + return + } + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"cnm.compute">): + transform.cnm.expand_dim %arg0 dim 0 by factor 2: (!transform.op<"cnm.compute">) -> () + } +} \ No newline at end of file diff --git a/cinnamon/test/Dialect/Cnm/transform-matmul-schedule-linalg.mlir b/cinnamon/test/Dialect/Cnm/transform-matmul-schedule-linalg.mlir new file mode 100644 index 0000000..cae1659 --- /dev/null +++ b/cinnamon/test/Dialect/Cnm/transform-matmul-schedule-linalg.mlir @@ -0,0 +1,235 @@ +// RUN: cinm-opt %s --cnm-apply-transform | FileCheck %s +#map = affine_map<(d0) -> (d0)> +#matmul_trait = { + doc = "C() += A(k) * B(k)", + indexing_maps = [ + affine_map<(k) -> (k)>, + affine_map<(k) -> (k)>, + affine_map<(k) -> ()> + ], + iterator_types = ["reduction"] +} + +#linalg_trait = { + +} +module { + + + func.func @matmul(%arg0: memref<1024x64xi32>, %arg1: memref<64x1024xi32>) { + %arg1t = memref.alloc() : memref<1024x64xi32> + linalg.transpose ins(%arg1: memref<64x1024xi32>) outs(%arg1t: memref<1024x64xi32>) permutation = [1,0] + %res = memref.alloc() : memref<1024x1024xi32> + + // this is naive matmul + cnm.compute + ins(%arg0[(i, j) -> (i)]: memref<1024x64xi32>, + %arg1t[(i, j) -> (j)]: memref<1024x64xi32>) + outs(%res[(i, j) -> (i, j)]: memref<1024x1024xi32>) + on hierarchy<1024x1024> + do (%a1: memref<64xi32>, %b1: memref<64xi32>, %o: memref) { + // reduction (could be raised to linalg.reduce) + linalg.generic #matmul_trait + ins(%a1, %b1: memref<64xi32>, memref<64xi32>) + outs(%o: memref) { + ^bb0(%a: i32, %b: i32, %c: i32): + %0 = arith.muli %a, %b: i32 + %1 = arith.addi %0, %c: i32 + linalg.yield %1: i32 + } + } + + // expand dim 1 factor 64 -> <1024x64x16> + // expand dim 0 factor 64 -> <64x16x64x16> + // swap dim 1 and 2 -> <64x64x16x16> + cnm.compute + ins(%arg0[(d0, d1, d2, d3) -> (d0 * 16 + d2)] : memref<1024x64xi32>, + %arg1t[(d0, d1, d2, d3) -> (d1 * 16 + d3)] : memref<1024x64xi32>) + outs(%res[(d0, d1, d2, d3) -> (d0 * 16 + d2, d1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<64x64x16x16> + do (%a1: memref<64xi32>, %b1: memref<64xi32>, %o: memref) { + linalg.generic #matmul_trait + ins(%a1, %b1: memref<64xi32>, memref<64xi32>) + outs(%o: memref) { + ^bb0(%a: i32, %b: i32, %c: i32): + %0 = arith.muli %a, %b: i32 + %1 = arith.addi %0, %c: i32 + linalg.yield %1: i32 + } + } + + // peel left 2 + // This computes 16x16 tiles of the output on some dimms. + // - Can we lift the reduction loop outside of the nest? That's hard + // - Can we coarsen the output size (not use memref but memref<16xi32>) + // - That's fine. + // + affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + cnm.compute + symbols [%D0, %D1] + ins(%arg0[(d2, d3)[D0, D1] -> (D0 * 16 + d2)] : memref<1024x64xi32>, + %arg1t[(d2, d3)[D0, D1] -> (D1 * 16 + d3)] : memref<1024x64xi32>) + outs(%res[(d2, d3)[D0, D1] -> (D0 * 16 + d2, D1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<16x16> + do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + linalg.generic #matmul_trait + ins(%arg2, %arg3: memref<64xi32>, memref<64xi32>) + outs(%arg4: memref) { + ^bb0(%a: i32, %b: i32, %c: i32): + %0 = arith.muli %a, %b: i32 + %1 = arith.addi %0, %c: i32 + linalg.yield %1: i32 + } + } + } + + // Then, let's say we want to lift the reduction out. + + // Let's say that we want to tile the reduction loop with factor 2. + // We need + // - tiling factor + // - reduction iterator index (just do one reduction at a time) + // We want something like + // Notice: + // - Kernel argument types change but not the linalg generic (in particular, not affine maps) + // - The reduction between the two iterations is done implicitly because + // the output of the first iteration is used to initialize the second. + // - We need to reshape inputs. Which ones? The ones that + // - Are inputs of the linalg generic, and + // - Use the index of the targeted reduction iterator (look at affine maps), and + // - Also we need the index space of the reduction iterator to be divisible by the factor + // Simultaneously reshape 64 into 2x32, and introduce a surrounding loop, and introduce a symbol in the affine maps + // to index the new input dimension. + // + // A more stepwise lowering would be + // - Tile the linalg generic, making a new reduction iterator appear as the outer iterator + // - Make a transform that extracts an outermost red iterator + // - Q, does it need to be outermost? + // + affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + affine.for %k0 = 0 to 64 step 32 { + %arg00 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1024x64xi32> into memref<1024x2x32xi32> + %arg10 = memref.expand_shape %arg1t [[0], [1, 2]]: memref<1024x64xi32> into memref<1024x2x32xi32> + + // this does only half of the 64 reduction, returns a partial result + cnm.compute + symbols [%D0, %D1, %k0] + ins(%arg00[(d2, d3)[D0, D1, K] -> (D0 * 16 + d2, K floordiv 32)] : memref<1024x2x32xi32>, + %arg10[(d2, d3)[D0, D1, K] -> (D1 * 16 + d3, K floordiv 32)] : memref<1024x2x32xi32>) + outs(%res[(d2, d3)[D0, D1, K] -> (D0 * 16 + d2, D1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<16x16> + do (%arg2: memref<32xi32>, %arg3: memref<32xi32>, %arg4: memref) { + linalg.generic #matmul_trait + ins(%arg2, %arg3: memref<32xi32>, memref<32xi32>) + outs(%arg4: memref) { + ^bb0(%a: i32, %b: i32, %c: i32): + %0 = arith.muli %a, %b: i32 + %1 = arith.addi %0, %c: i32 + linalg.yield %1: i32 + } + } + } + } + // Ok stepwise: + affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + // first tile the inner kernel + // (todo do that automatically with transform. Is that possible?) + cnm.compute + symbols [%D0, %D1] + ins(%arg0[(d2, d3)[D0, D1] -> (D0 * 16 + d2)] : memref<1024x64xi32>, + %arg1t[(d2, d3)[D0, D1] -> (D1 * 16 + d3)] : memref<1024x64xi32>) + outs(%res[(d2, d3)[D0, D1] -> (D0 * 16 + d2, D1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<16x16> + do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + %arg20 = memref.expand_shape %arg2 [[0, 1]]: memref<64xi32> into memref<2x32xi32> + %arg30 = memref.expand_shape %arg3 [[0, 1]]: memref<64xi32> into memref<2x32xi32> + linalg.generic { + indexing_maps = [ + affine_map<(ko, kt) -> (ko, kt)>, + affine_map<(ko, kt) -> (ko, kt)>, + affine_map<(ko, kt) -> ()> + ], + iterator_types = ["reduction", "reduction"] + } + ins(%arg20, %arg30: memref<2x32xi32>, memref<2x32xi32>) + outs(%arg4: memref) { + ^bb0(%a: i32, %b: i32, %c: i32): + %0 = arith.muli %a, %b: i32 + %1 = arith.addi %0, %c: i32 + linalg.yield %1: i32 + } + } + } + + // fork + cnm.compute + ins(%arg0[(d0, d1, d2, d3) -> (d0 * 16 + d2)] : memref<1024x64xi32>, + %arg1t[(d0, d1, d2, d3) -> (d1 * 16 + d3)] : memref<1024x64xi32>) + outs(%res[(d0, d1, d2, d3) -> (d0 * 16 + d2, d1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<64x64x16x16> + do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + affine.for %arg5 = 0 to 64 { + %0 = affine.load %arg2[%arg5] : memref<64xi32> + %1 = affine.load %arg3[%arg5] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %arg4[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg4[] : memref + } + } + + // peel right (need a reshape first), reshape turns output map into (i,j,k,l) -> (i*16+k,j,l) + %exp0 = memref.expand_shape %res[[0], [1, 2]] : memref<1024x1024xi32> into memref<1024x64x16xi32> + %expb = memref.expand_shape %arg1t[[0, 1], [2]] : memref<1024x64xi32> into memref<64x16x64xi32> + + cnm.compute + ins(%arg0[(d0, d1, d2) -> (d0 * 16 + d2)] : memref<1024x64xi32>, + %expb[(d0, d1, d2) -> (d1)] : memref<64x16x64xi32>) + outs(%exp0[(d0, d1, d2) -> (d0 * 16 + d2, d1)] : memref<1024x64x16xi32>) + on hierarchy<64x64x16> + do (%arg2: memref<64xi32>, %arg3: memref<16x64xi32>, %arg4: memref<16xi32>) { + affine.parallel (%x) = (0) to (16) { + affine.for %arg5 = 0 to 64 { + %0 = affine.load %arg2[%arg5] : memref<64xi32> + %1 = affine.load %arg3[%x, %arg5] : memref<16x64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %arg4[%x] : memref<16xi32> + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg4[%x] : memref<16xi32> + } + } + } + + + + // affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + // cnm.compute + // symbols [%D0, %D1] + // ins(%arg0[(d2, d3)[D0, D1] -> (d3 * 64 + D1)] : memref<1024x64xi32>, + // %arg1t[(d2, d3)[D0, D1] -> (d2 * 64 + D0)] : memref<1024x64xi32>) + // outs(%res[(d2, d3)[D0, D1] -> (d3 * 64 + D1, d2 * 64 + D0)] : memref<1024x1024xi32>) + // on hierarchy<16x16> + // do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + // affine.for %arg5 = 0 to 64 { + // %0 = affine.load %arg2[%arg5] : memref<64xi32> + // %1 = affine.load %arg2[%arg5] : memref<64xi32> + // %2 = arith.muli %0, %1 : i32 + // %3 = affine.load %arg4[] : memref + // %4 = arith.addi %2, %3 : i32 + // affine.store %4, %arg4[] : memref + // } + // } + // } + + // into + + return + } + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"cnm.compute">): + transform.cnm.expand_dim %arg0 dim 1 by factor 64: (!transform.op<"cnm.compute">) -> () + transform.cnm.expand_dim %arg0 dim 0 by factor 64: (!transform.op<"cnm.compute">) -> () + transform.cnm.swap_dims %arg0, 1, 2: (!transform.op<"cnm.compute">) -> () + } +} \ No newline at end of file diff --git a/cinnamon/test/Dialect/Cnm/transform-matmul-schedule.mlir b/cinnamon/test/Dialect/Cnm/transform-matmul-schedule.mlir new file mode 100644 index 0000000..3ab1c9f --- /dev/null +++ b/cinnamon/test/Dialect/Cnm/transform-matmul-schedule.mlir @@ -0,0 +1,144 @@ +// RUN: cinm-opt %s --cnm-apply-transform | FileCheck %s +#map = affine_map<(d0) -> (d0)> +module { + + + func.func @matmul(%arg0: memref<1024x64xi32>, %arg1: memref<64x1024xi32>) { + %arg1t = memref.alloc() : memref<1024x64xi32> + linalg.transpose ins(%arg1: memref<64x1024xi32>) outs(%arg1t: memref<1024x64xi32>) permutation = [1,0] + %res = memref.alloc() : memref<1024x1024xi32> + + // this is naive matmul + cnm.compute + ins(%arg0[(i, j) -> (i)]: memref<1024x64xi32>, + %arg1t[(i, j) -> (j)]: memref<1024x64xi32>) + outs(%res[(i, j) -> (i, j)]: memref<1024x1024xi32>) + on hierarchy<1024x1024> + do (%a1: memref<64xi32>, %b1: memref<64xi32>, %o: memref) { + // reduction look + affine.for %k = 0 to 64 { + %0 = affine.load %a1[%k] : memref<64xi32> + %1 = affine.load %b1[%k] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %o[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %o[] : memref + } + } + + // expand dim 1 factor 64 -> <1024x64x16> + // expand dim 0 factor 64 -> <64x16x64x16> + // swap dim 1 and 2 -> <64x64x16x16> + cnm.compute + ins(%arg0[(d0, d1, d2, d3) -> (d0 * 16 + d2)] : memref<1024x64xi32>, + %arg1t[(d0, d1, d2, d3) -> (d1 * 16 + d3)] : memref<1024x64xi32>) + outs(%res[(d0, d1, d2, d3) -> (d0 * 16 + d2, d1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<64x64x16x16> + do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + affine.for %arg5 = 0 to 64 { + %0 = affine.load %arg2[%arg5] : memref<64xi32> + %1 = affine.load %arg3[%arg5] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %arg4[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg4[] : memref + } + } + + // peel left 2 + // This computes 16x16 tiles of the output on some dimms. + // - Can we lift the reduction loop outside of the nest? That's hard + // - Can we coarsen the output size (not use memref but memref<16xi32>) + // - That's fine. + // + affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + cnm.compute + symbols [%D0, %D1] + ins(%arg0[(d2, d3)[D0, D1] -> (D0 * 16 + d2)] : memref<1024x64xi32>, + %arg1t[(d2, d3)[D0, D1] -> (D1 * 16 + d3)] : memref<1024x64xi32>) + outs(%res[(d2, d3)[D0, D1] -> (D0 * 16 + d2, D1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<16x16> + do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + affine.for %arg5 = 0 to 64 { + %0 = affine.load %arg2[%arg5] : memref<64xi32> + %1 = affine.load %arg3[%arg5] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %arg4[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg4[] : memref + } + } + } + + // fork + cnm.compute + ins(%arg0[(d0, d1, d2, d3) -> (d0 * 16 + d2)] : memref<1024x64xi32>, + %arg1t[(d0, d1, d2, d3) -> (d1 * 16 + d3)] : memref<1024x64xi32>) + outs(%res[(d0, d1, d2, d3) -> (d0 * 16 + d2, d1 * 16 + d3)] : memref<1024x1024xi32>) + on hierarchy<64x64x16x16> + do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + affine.for %arg5 = 0 to 64 { + %0 = affine.load %arg2[%arg5] : memref<64xi32> + %1 = affine.load %arg3[%arg5] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %arg4[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg4[] : memref + } + } + + // peel right (need a reshape first), reshape turns output map into (i,j,k,l) -> (i*16+k,j,l) + %exp0 = memref.expand_shape %res[[0], [1, 2]] : memref<1024x1024xi32> into memref<1024x64x16xi32> + %expb = memref.expand_shape %arg1t[[0, 1], [2]] : memref<1024x64xi32> into memref<64x16x64xi32> + + cnm.compute + ins(%arg0[(d0, d1, d2) -> (d0 * 16 + d2)] : memref<1024x64xi32>, + %expb[(d0, d1, d2) -> (d1)] : memref<64x16x64xi32>) + outs(%exp0[(d0, d1, d2) -> (d0 * 16 + d2, d1)] : memref<1024x64x16xi32>) + on hierarchy<64x64x16> + do (%arg2: memref<64xi32>, %arg3: memref<16x64xi32>, %arg4: memref<16xi32>) { + affine.parallel (%x) = (0) to (16) { + affine.for %arg5 = 0 to 64 { + %0 = affine.load %arg2[%arg5] : memref<64xi32> + %1 = affine.load %arg3[%x, %arg5] : memref<16x64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %arg4[%x] : memref<16xi32> + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg4[%x] : memref<16xi32> + } + } + } + + + + // affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + // cnm.compute + // symbols [%D0, %D1] + // ins(%arg0[(d2, d3)[D0, D1] -> (d3 * 64 + D1)] : memref<1024x64xi32>, + // %arg1t[(d2, d3)[D0, D1] -> (d2 * 64 + D0)] : memref<1024x64xi32>) + // outs(%res[(d2, d3)[D0, D1] -> (d3 * 64 + D1, d2 * 64 + D0)] : memref<1024x1024xi32>) + // on hierarchy<16x16> + // do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + // affine.for %arg5 = 0 to 64 { + // %0 = affine.load %arg2[%arg5] : memref<64xi32> + // %1 = affine.load %arg2[%arg5] : memref<64xi32> + // %2 = arith.muli %0, %1 : i32 + // %3 = affine.load %arg4[] : memref + // %4 = arith.addi %2, %3 : i32 + // affine.store %4, %arg4[] : memref + // } + // } + // } + + // into + + return + } + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"cnm.compute">): + transform.cnm.expand_dim %arg0 dim 1 by factor 64: (!transform.op<"cnm.compute">) -> () + transform.cnm.expand_dim %arg0 dim 0 by factor 64: (!transform.op<"cnm.compute">) -> () + transform.cnm.swap_dims %arg0, 1, 2: (!transform.op<"cnm.compute">) -> () + } +} \ No newline at end of file diff --git a/cinnamon/test/Dialect/Cnm/transform-normalize-buffers.mlir b/cinnamon/test/Dialect/Cnm/transform-normalize-buffers.mlir new file mode 100644 index 0000000..7eca558 --- /dev/null +++ b/cinnamon/test/Dialect/Cnm/transform-normalize-buffers.mlir @@ -0,0 +1,123 @@ +// RUN: cinm-opt %s --cnm-apply-transform | FileCheck %s +#map = affine_map<(d0) -> (d0)> +module { + + // CHECK-LABEL: @simple + // CHECK: cnm.compute + // CHECK-NEXT: ins(%arg0[(d0, d1) -> (d0 * 512 + d1)] : memref<1024xi32>) + // CHECK-NEXT: outs(%arg1[(d0, d1) -> (d0 * 512 + d1)] : memref<1024xi32>) + // CHECK-NEXT: on hierarchy<2x512> + func.func @simple(%arg0: memref<1024xi32>, %arg1: memref<1024xi32>) { + cnm.compute + ins(%arg0[(i, j) -> ()]: memref<1024xi32>) + outs(%arg1[(i, j) -> (i * 512 + j)]: memref<1024xi32>) + on hierarchy<2x512> + do (%a1: memref<1024xi32>, %o1: memref) { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[] : memref + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[] : memref + } + } + + // into + + %r = memref.expand_shape %arg1[[0, 1]] : memref<1024xi32> into memref<2x512xi32> + cnm.compute + ins(%arg0[(i, j) -> ()]: memref<1024xi32>) + outs(%r[(i, j) -> (i, j)]: memref<2x512xi32>) + on hierarchy<2x512> + do (%a1: memref<1024xi32>, %o1: memref) { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[] : memref + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[] : memref + } + } + + return + } + + func.func @partialBroadcast(%arg0: memref<1024x64xi32>, %arg1: memref<64x1024xi32>) { + %arg1t = memref.alloc() : memref<1024x64xi32> + linalg.transpose ins(%arg1: memref<64x1024xi32>) outs(%arg1t: memref<1024x64xi32>) permutation = [1,0] + %res = memref.alloc() : memref<1024x1024xi32> + + // this is naive matmul + cnm.compute + ins(%arg0[(i, j) -> (i)]: memref<1024x64xi32>, + %arg1t[(i, j) -> (j)]: memref<1024x64xi32>) + outs(%res[(i, j) -> (i, j)]: memref<1024x1024xi32>) + on hierarchy<1024x1024> + do (%a1: memref<64xi32>, %b1: memref<64xi32>, %o: memref) { + affine.for %i = 0 to 64 { + %0 = affine.load %a1[%i] : memref<64xi32> + %1 = affine.load %a1[%i] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %o[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %o[] : memref + } + } + + // expand dim 1 factor 64 -> <1024x64x16> + cnm.compute + ins(%arg0[(i, j, k) -> (i)]: memref<1024x64xi32>, + %arg1t[(i, j, k) -> (j * 16 + k)]: memref<1024x64xi32>) + outs(%res[(i, j, k) -> (i, j * 16 + k)]: memref<1024x1024xi32>) + on hierarchy<1024x64x16> + do (%a1: memref<64xi32>, %b1: memref<64xi32>, %o: memref) { + affine.for %i = 0 to 64 { + %0 = affine.load %a1[%i] : memref<64xi32> + %1 = affine.load %a1[%i] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %o[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %o[] : memref + } + } + // expand dim 0 factor 64 -> <64x16x64x16> + cnm.compute + ins(%arg0[(i, j, k) -> (i)]: memref<1024x64xi32>, + %arg1t[(i, j, k) -> (j * 16 + k)]: memref<1024x64xi32>) + outs(%res[(i, j, k) -> (i, j * 16 + k)]: memref<1024x1024xi32>) + on hierarchy<1024x64x16> + do (%a1: memref<64xi32>, %b1: memref<64xi32>, %o: memref) { + affine.for %i = 0 to 64 { + %0 = affine.load %a1[%i] : memref<64xi32> + %1 = affine.load %a1[%i] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %o[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %o[] : memref + } + } + // swap dim 0 and 3 -> <64x64x16x16> + cnm.compute + ins(%arg0[(i, j, k, l) -> (i * 16 + k)]: memref<1024x64xi32>, + %arg1t[(i, j, k, l) -> (j)]: memref<1024x64xi32>) + outs(%res[(i, j, k, l) -> (i, j)]: memref<1024x1024xi32>) + on hierarchy<64x64x16x16> + do (%a1: memref<64xi32>, %b1: memref<64xi32>, %o: memref) { + affine.for %i = 0 to 64 { + %0 = affine.load %a1[%i] : memref<64xi32> + %1 = affine.load %a1[%i] : memref<64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %o[] : memref + %4 = arith.addi %2, %3 : i32 + affine.store %4, %o[] : memref + } + } + + // into + + return + } + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"cnm.compute">): + transform.cnm.expand_dim %arg0 on 0 by factor 2: (!transform.op<"cnm.compute">) -> () + } +} \ No newline at end of file diff --git a/cinnamon/test/Dialect/Cnm/transform-peel-right.mlir b/cinnamon/test/Dialect/Cnm/transform-peel-right.mlir new file mode 100644 index 0000000..b76c97b --- /dev/null +++ b/cinnamon/test/Dialect/Cnm/transform-peel-right.mlir @@ -0,0 +1,57 @@ +// RUN: cinm-opt %s --cnm-apply-transform | FileCheck %s +#map = affine_map<(d0) -> (d0)> +module { + // CHECK-LABEL: @simple + // CHECK: cnm.compute + // CHECK-NEXT: ins(%arg0[(d0) -> ()] : memref<1024xi32>) + // CHECK-NEXT: outs(%{{.*}}[(d0) -> (d0)] : memref<2x512xi32>) + // CHECK-NEXT: on hierarchy<2> + // CHECK-NEXT: do (%[[A:.*]]: memref<1024xi32>, %[[B:.*]]: memref<512xi32>) { + // CHECK-NEXT: affine.parallel (%[[i:.*]]) = (0) to (512) { + // CHECK-NEXT: affine.for %[[k:.*]] = 0 to 1024 { + // CHECK-NEXT: affine.load %[[A]][%[[k]]] + // CHECK-NEXT: affine.load %[[B]][%[[i]]] + // CHECK-NEXT: arith.addi + // CHECK-NEXT: affine.store %{{.*}}, %[[B]][%[[i]]] + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + func.func @simple(%arg0: memref<1024xi32>, %arg1: memref<1024xi32>) { + + %r = memref.expand_shape %arg1[[0, 1]] : memref<1024xi32> into memref<2x512xi32> + cnm.compute + ins(%arg0[(i, j) -> ()]: memref<1024xi32>) + outs(%r[(i, j) -> (i, j)]: memref<2x512xi32>) + on hierarchy<2x512> + do (%a1: memref<1024xi32>, %o1: memref) { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[] : memref + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[] : memref + } + } + + + cnm.compute + ins(%arg0[(i) -> ()]: memref<1024xi32>) + outs(%r[(i) -> (i)]: memref<2x512xi32>) + on hierarchy<2> + do (%a1: memref<1024xi32>, %o1: memref<512xi32>) { + affine.parallel (%j) = (0) to (512) { + affine.for %i = 0 to 1024 { + %0 = affine.load %a1[%i] : memref<1024xi32> + %1 = affine.load %o1[%j] : memref<512xi32> + %2 = arith.addi %0, %1 : i32 + affine.store %2, %o1[%j] : memref<512xi32> + } + } + } + + return + } + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"cnm.compute">): + transform.cnm.peel_right %arg0: (!transform.op<"cnm.compute">) -> () + } +} \ No newline at end of file diff --git a/cinnamon/test/Dialect/Cnm/transform-schedule.mlir b/cinnamon/test/Dialect/Cnm/transform-schedule.mlir new file mode 100644 index 0000000..e620ef7 --- /dev/null +++ b/cinnamon/test/Dialect/Cnm/transform-schedule.mlir @@ -0,0 +1,65 @@ + +// RUN: cinm-opt %s --cnm-apply-transform | FileCheck %s +#map = affine_map<(d0) -> (d0)> +module { + + + func.func @matmul(%arg0: memref<1024x64xi32>, %arg1: memref<64x1024xi32>) { + %arg1t = memref.alloc() : memref<1024x64xi32> + linalg.transpose ins(%arg1: memref<64x1024xi32>) outs(%arg1t: memref<1024x64xi32>) permutation = [1,0] + %res = memref.alloc() : memref<1024x1024xi32> + + // peel right (need a reshape first), reshape turns output map into (i,j,k,l) -> (i*16+k,j,l) + %exp0 = memref.expand_shape %res[[0], [1, 2]] : memref<1024x1024xi32> into memref<1024x64x16xi32> + %expb = memref.expand_shape %arg1t[[0, 1], [2]] : memref<1024x64xi32> into memref<64x16x64xi32> + + cnm.compute + ins(%arg0[(d0, d1, d2) -> (d0 * 16 + d2)] : memref<1024x64xi32>, + %expb[(d0, d1, d2) -> (d1)] : memref<64x16x64xi32>) + outs(%exp0[(d0, d1, d2) -> (d0 * 16 + d2, d1)] : memref<1024x64x16xi32>) + on hierarchy<64x64x16> + do (%arg2: memref<64xi32>, %arg3: memref<16x64xi32>, %arg4: memref<16xi32>) { + affine.parallel (%x) = (0) to (16) { + affine.for %arg5 = 0 to 64 { + %0 = affine.load %arg2[%arg5] : memref<64xi32> + %1 = affine.load %arg3[%x, %arg5] : memref<16x64xi32> + %2 = arith.muli %0, %1 : i32 + %3 = affine.load %arg4[%x] : memref<16xi32> + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg4[%x] : memref<16xi32> + } + } + } + + + + // affine.parallel (%D0, %D1) = (0, 0) to (64, 64) { + // cnm.compute + // symbols [%D0, %D1] + // ins(%arg0[(d2, d3)[D0, D1] -> (d3 * 64 + D1)] : memref<1024x64xi32>, + // %arg1t[(d2, d3)[D0, D1] -> (d2 * 64 + D0)] : memref<1024x64xi32>) + // outs(%res[(d2, d3)[D0, D1] -> (d3 * 64 + D1, d2 * 64 + D0)] : memref<1024x1024xi32>) + // on hierarchy<16x16> + // do (%arg2: memref<64xi32>, %arg3: memref<64xi32>, %arg4: memref) { + // affine.for %arg5 = 0 to 64 { + // %0 = affine.load %arg2[%arg5] : memref<64xi32> + // %1 = affine.load %arg2[%arg5] : memref<64xi32> + // %2 = arith.muli %0, %1 : i32 + // %3 = affine.load %arg4[] : memref + // %4 = arith.addi %2, %3 : i32 + // affine.store %4, %arg4[] : memref + // } + // } + // } + + // into + + return + } + + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"affine.parallel">): + + } +} \ No newline at end of file diff --git a/cinnamon/test/Dialect/UPMEM/upmem-ops.mlir b/cinnamon/test/Dialect/UPMEM/upmem-ops.mlir index 32b8a52..0415486 100644 --- a/cinnamon/test/Dialect/UPMEM/upmem-ops.mlir +++ b/cinnamon/test/Dialect/UPMEM/upmem-ops.mlir @@ -10,8 +10,8 @@ module { %tasklet_count = arith.constant 16 : index %hierarchy = upmem.alloc_dpus : !upmem.hierarchy<2x32x16> %base_offset = upmem.base_dpu_mem_offset : index - %A_offset = upmem.scatter %A[64, #scatter_map] onto %hierarchy at %base_offset : memref<2x32x8192xi32> onto !upmem.hierarchy<2x32x16> - %B_offset = upmem.scatter %B[64, #scatter_map] onto %hierarchy at %A_offset : memref<2x32x8192xi32> onto !upmem.hierarchy<2x32x16> + upmem.scatter %A[0, 64, #scatter_map] onto %hierarchy : memref<2x32x8192xi32> onto !upmem.hierarchy<2x32x16> + upmem.scatter %B[256, 64, #scatter_map] onto %hierarchy : memref<2x32x8192xi32> onto !upmem.hierarchy<2x32x16> upmem.launch %hierarchy ranks(%arg0 upto %rank_count) dpus(%arg1 upto %dpu_count) tasklets(%arg2 upto %tasklet_count) on !upmem.hierarchy<2x32x16> { %cst0 = arith.constant 0 : index %cst1 = arith.constant 1 : index @@ -46,7 +46,7 @@ module { } upmem.terminator } - %C_offset = upmem.gather %C[64, #scatter_map] from %hierarchy at %base_offset : memref<2x32x8192xi32> from !upmem.hierarchy<2x32x16> + upmem.gather %C[512, 64, #scatter_map] from %hierarchy : memref<2x32x8192xi32> from !upmem.hierarchy<2x32x16> return } } diff --git a/cinnamon/testbench/generated/va.cnm.mlir b/cinnamon/testbench/generated/va.cnm.mlir index aca477a..447f508 100644 --- a/cinnamon/testbench/generated/va.cnm.mlir +++ b/cinnamon/testbench/generated/va.cnm.mlir @@ -41,7 +41,7 @@ module { %12 = bufferization.to_tensor %11 : memref<1024x16xi32> %13 = cnm.alloc() for %3 : !cnm.buffer<16xi32 on 8x128x1, level 0> %14 = cnm.scatter %12 into %13[#map] of %3 : tensor<1024x16xi32> into !cnm.buffer<16xi32 on 8x128x1, level 0> - %15 = cnm.launch %3 in(%6, %9 : !cnm.buffer<16xi32 on 8x128x1, level 0>, !cnm.buffer<16xi32 on 8x128x1, level 0>) out(%13 : !cnm.buffer<16xi32 on 8x128x1, level 0>) on !cnm.workgroup<8x128x1> { + %15 = cnm.launch %3 ins(%6, %9 : !cnm.buffer<16xi32 on 8x128x1, level 0>, !cnm.buffer<16xi32 on 8x128x1, level 0>) outs(%13 : !cnm.buffer<16xi32 on 8x128x1, level 0>) on !cnm.workgroup<8x128x1> { ^bb0(%arg4: memref<16xi32>, %arg5: memref<16xi32>, %arg6: memref<16xi32>): linalg.add ins(%arg4, %arg5 : memref<16xi32>, memref<16xi32>) outs(%arg6 : memref<16xi32>) } @@ -87,7 +87,7 @@ module { %12 = bufferization.to_tensor %11 : memref<2048x16xi32> %13 = cnm.alloc() for %3 : !cnm.buffer<16xi32 on 16x128x1, level 0> %14 = cnm.scatter %12 into %13[#map] of %3 : tensor<2048x16xi32> into !cnm.buffer<16xi32 on 16x128x1, level 0> - %15 = cnm.launch %3 in(%6, %9 : !cnm.buffer<16xi32 on 16x128x1, level 0>, !cnm.buffer<16xi32 on 16x128x1, level 0>) out(%13 : !cnm.buffer<16xi32 on 16x128x1, level 0>) on !cnm.workgroup<16x128x1> { + %15 = cnm.launch %3 ins(%6, %9 : !cnm.buffer<16xi32 on 16x128x1, level 0>, !cnm.buffer<16xi32 on 16x128x1, level 0>) outs(%13 : !cnm.buffer<16xi32 on 16x128x1, level 0>) on !cnm.workgroup<16x128x1> { ^bb0(%arg4: memref<16xi32>, %arg5: memref<16xi32>, %arg6: memref<16xi32>): linalg.add ins(%arg4, %arg5 : memref<16xi32>, memref<16xi32>) outs(%arg6 : memref<16xi32>) } diff --git a/cinnamon/tools/cinm-lsp-server/cinm-lsp-server.cpp b/cinnamon/tools/cinm-lsp-server/cinm-lsp-server.cpp index 69209e7..27505a8 100644 --- a/cinnamon/tools/cinm-lsp-server/cinm-lsp-server.cpp +++ b/cinnamon/tools/cinm-lsp-server/cinm-lsp-server.cpp @@ -2,6 +2,7 @@ #include "cinm-mlir/Dialect/Cinm/IR/CinmDialect.h" #include "cinm-mlir/Dialect/Cnm/IR/CnmDialect.h" +#include "cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.h" #include "cinm-mlir/Dialect/UPMEM/IR/UPMEMDialect.h" #include "mlir/IR/Dialect.h" @@ -11,19 +12,18 @@ using namespace mlir; -static int asMainReturnCode(LogicalResult r) -{ - return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE; +static int asMainReturnCode(LogicalResult r) { + return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE; } -int main(int argc, char* argv[]) -{ - DialectRegistry registry; - registerAllDialects(registry); +int main(int argc, char *argv[]) { + DialectRegistry registry; + registerAllDialects(registry); - registry.insert(); - registry.insert(); - registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + cnm::registerTransformDialectExtension(registry); - return asMainReturnCode(MlirLspServerMain(argc, argv, registry)); + return asMainReturnCode(MlirLspServerMain(argc, argv, registry)); } diff --git a/cinnamon/tools/cinm-opt/cinm-opt.cpp b/cinnamon/tools/cinm-opt/cinm-opt.cpp index a03ed56..5053245 100644 --- a/cinnamon/tools/cinm-opt/cinm-opt.cpp +++ b/cinnamon/tools/cinm-opt/cinm-opt.cpp @@ -11,6 +11,8 @@ #include "cinm-mlir/Dialect/Cinm/IR/CinmDialect.h" #include "cinm-mlir/Dialect/Cinm/Transforms/Passes.h" #include "cinm-mlir/Dialect/Cnm/IR/CnmDialect.h" +#include "cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.h" +#include "cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformPass.h" #include "cinm-mlir/Dialect/Cnm/Transforms/Passes.h" #include "cinm-mlir/Dialect/UPMEM/IR/UPMEMDialect.h" #include "cinm-mlir/Dialect/UPMEM/Transforms/Passes.h" @@ -26,7 +28,6 @@ using namespace mlir; - int main(int argc, char *argv[]) { DialectRegistry registry; registerAllDialects(registry); @@ -39,6 +40,8 @@ int main(int argc, char *argv[]) { registerCnmConversionPasses(); cnm::registerCnmBufferizationExternalModels(registry); cnm::registerCnmTransformsPasses(); + cnm::registerTransformDialectExtension(registry); + cnm::registerCnmTransformInterpreterPasses(); cinm::registerCinmTransformsPasses(); upmem::registerConvertUpmemToLLvmInterface(registry);