Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cnm.compute to support transformation #8

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.cache
.vscode
llvm
upmem
upmem
upmem-src
1 change: 1 addition & 0 deletions cinnamon/include/cinm-mlir/Dialect/Cnm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ add_custom_target(CnmIncGen)
# Attributes, Dialect, Operations and Types.
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
114 changes: 113 additions & 1 deletion cinnamon/include/cinm-mlir/Dialect/Cnm/IR/CnmOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def WorkgroupOp : Cnm_Op<"workgroup", []> {
let arguments = (ins);
let results = (outs WorkgroupType:$result);

let builders = [
OpBuilder<(ins "ArrayRef<int64_t>":$shape),
"build($_builder, $_state, WorkgroupType::get($_builder.getContext(), shape));">
];

let assemblyFormat = "attr-dict `:` qualified(type($result))";
}

Expand All @@ -56,6 +61,11 @@ def AllocOp : Cnm_Op<"alloc", [InferWorkgroupTypeFromBuffer<"result", "wg">]> {
let arguments = (ins WorkgroupType:$wg);
let results = (outs BufferType:$result);

let builders = [
OpBuilder<(ins "ArrayRef<int64_t>":$bufShape, "Type":$elementTy, "Value":$wg, CArg<"int64_t", "0">:$level),
"build($_builder, $_state, BufferType::get(bufShape, elementTy, wg.getType().cast<cnm::WorkgroupType>().getShape(), level), wg);">
];


let extraClassDeclaration = [{

Expand Down Expand Up @@ -124,7 +134,7 @@ def LaunchOp: Cnm_Op<"launch", [AttrSizedOperandSegments, IsolatedFromAbove, Sin
let regions = (region SizedRegion<1>:$body);

let hasVerifier = 1;
let assemblyFormat = "$wg `in` `(` $inBuffers `:` type($inBuffers) `)` `out` `(` $outBuffers `:` type($outBuffers) `)` attr-dict `on` qualified(type($wg)) $body";
let assemblyFormat = "$wg `ins` `(` $inBuffers `:` type($inBuffers) `)` `outs` `(` $outBuffers `:` type($outBuffers) `)` attr-dict `on` qualified(type($wg)) $body";

let extraClassDeclaration = [{
SmallVector<Value> getParams() {
Expand All @@ -135,6 +145,107 @@ def LaunchOp: Cnm_Op<"launch", [AttrSizedOperandSegments, IsolatedFromAbove, Sin
}];
}


def ComputeOp: Cnm_Op<"compute", [IsolatedFromAbove,
SingleBlockImplicitTerminator<"TerminatorOp">,
AttrSizedOperandSegments]> {
let summary = "Transformable operation";
let description = [{}];

let arguments = (ins DenseI64ArrayAttr:$workgroupShape,
I64Attr:$numInputs,
Variadic<Index>:$symbol_bindings,
Variadic<AnyShaped>:$buffers,
AffineMapArrayAttr:$affineMaps
);

let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region SizedRegion<1>:$body);

let builders = [
OpBuilder<(ins
CArg<"ArrayRef<int64_t>">:$workgroupShape,
CArg<"ValueRange">:$allBuffers,
CArg<"uint64_t">:$numInputs,
CArg<"ArrayRef<AffineMap>">:$affineMaps,
CArg<"ValueRange", "{}">:$symbol_bindings
)>,
OpBuilder<(ins
CArg<"ArrayRef<int64_t>">:$workgroupShape,
CArg<"ValueRange">:$inputs,
CArg<"ValueRange">:$inits,
CArg<"ArrayRef<AffineMap>">:$affineMaps,
CArg<"ValueRange", "{}">:$symbol_bindings
)>
];

let skipDefaultBuilders = 1;
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
SmallVector<Value> getInBuffers() {
if (getNumInputs() == 0)
return {};
SmallVector<Value> result;
result.reserve(getNumInputs());
auto begin = getBuffers().begin();
result.append(begin, begin + getNumInputs());
return result;
}

SmallVector<Value> getOutBuffers() {
if (getNumOutputs() == 0)
return {};
SmallVector<Value> result;
result.reserve(getNumOutputs());
auto begin = getBuffers().begin() + getNumInputs();
result.append(getBuffers().begin() + getNumInputs(), getBuffers().end());
return result;
}

uint64_t getNumOutputs() {
return getBuffers().size() - getNumInputs();
}

SmallVector<AffineMap> getInMaps() {
if (getNumInputs() == 0)
return {};
SmallVector<AffineMap> result;
result.reserve(getNumInputs());
auto begin = getAffineMaps().getAsValueRange<AffineMapAttr>().begin();
result.append(begin, begin + getNumInputs());
return result;
}

SmallVector<AffineMap> getOutMaps() {
if (getNumOutputs() == 0)
return {};
SmallVector<AffineMap> result;
result.reserve(getNumOutputs());
auto it = getAffineMaps().getAsValueRange<AffineMapAttr>();
result.append(it.begin() + getNumInputs(), it.end());
return result;
}

template<unsigned N = llvm::CalculateSmallVectorDefaultInlinedElements<AffineMap>::value>
SmallVector<AffineMap, N> getAffineMapsVec() {
SmallVector<AffineMap, N> result(getAffineMaps().getAsValueRange<AffineMapAttr>());
return result;
}

MutableArrayRef<BlockArgument> getKernelArgs() {
return getBody().getArguments();
}

uint64_t getNumSymbols() {
return getSymbolBindings().size();
}

}];
}


def TerminatorOp: Cnm_Op<"terminator", [Terminator]> {
let summary = "Terminates an `upmem.launch` operator region.";
let description = [{}];
Expand All @@ -144,4 +255,5 @@ def TerminatorOp: Cnm_Op<"terminator", [Terminator]> {
let assemblyFormat = "attr-dict";
}


#endif
13 changes: 13 additions & 0 deletions cinnamon/include/cinm-mlir/Dialect/Cnm/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

set(LLVM_TARGET_DEFINITIONS CnmTransformOps.td)
mlir_tablegen(CnmTransformOps.h.inc -gen-op-decls)
mlir_tablegen(CnmTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRCnmTransformOpsIncGen)
add_dependencies(CnmIncGen MLIRCnmTransformOpsIncGen)


set(LLVM_TARGET_DEFINITIONS TransformPass.td)
mlir_tablegen(TransformPass.h.inc -gen-pass-decls -name CnmTransformInterpreter)
add_public_tablegen_target(CnmTransformInterpreterPassIncGen)
add_dependencies(CnmIncGen CnmTransformInterpreterPassIncGen)
add_mlir_doc(Passes CnmTransformInterpreter ./ -gen-pass-doc)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

#pragma once

#include <mlir/IR/DialectRegistry.h>
#include <mlir/IR/OpImplementation.h>

#include <cinm-mlir/Dialect/Cnm/IR/CnmOps.h>

#include <mlir/Dialect/Transform/IR/MatchInterfaces.h>
#include <mlir/Dialect/Transform/IR/TransformAttrs.h>
#include <mlir/Dialect/Transform/IR/TransformDialect.h>
#include <mlir/Dialect/Transform/IR/TransformInterfaces.h>
#include <mlir/Dialect/Transform/IR/TransformTypes.h>

#define GET_OP_CLASSES
#include <cinm-mlir/Dialect/Cnm/TransformOps/CnmTransformOps.h.inc>


namespace mlir::cnm {

void registerTransformDialectExtension(::mlir::DialectRegistry &registry);


} // namespace mlir::cnm
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@


include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"


def CnmExpandDimOp : Op<Transform_Dialect, "cnm.expand_dim",
[FunctionalStyleTransformOpTrait,
TransformOpInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::cnm::ComputeOp op,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];

let arguments = (ins Transform_ConcreteOpType<"cnm.compute">:$target,
I64Attr:$dim,
I64Attr:$factor);
let results = (outs);
let assemblyFormat =
"$target `dim` $dim `by` `factor` $factor attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::cnm::ComputeOp op,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

def CnmSwapDimsOp : Op<Transform_Dialect, "cnm.swap_dims",
[FunctionalStyleTransformOpTrait,
TransformOpInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{

}];

let arguments = (ins Transform_ConcreteOpType<"cnm.compute">:$target,
I64Attr:$dim0,
I64Attr:$dim1);
let results = (outs);
let assemblyFormat =
"$target `,` $dim0 `,` $dim1 attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::cnm::ComputeOp op,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

def CnmPeelRightOp : Op<Transform_Dialect, "cnm.peel_right",
[FunctionalStyleTransformOpTrait,
TransformOpInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{

}];

let arguments = (ins Transform_ConcreteOpType<"cnm.compute">:$target);
let results = (outs);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::cnm::ComputeOp op,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@


#include <mlir/IR/DialectRegistry.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/Pass/Pass.h>

namespace mlir::cnm {

#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
#include "cinm-mlir/Dialect/Cnm/TransformOps/TransformPass.h.inc"

} // namespace mlir::cnm
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===- Passes.td - Cnm dialect passes ---------------------*- tablegen -*-===//
//
// This is the definitions file for the Cnm dialect transform passes.
//
//===----------------------------------------------------------------------===//

#ifndef CNM_TRANSFORM_PASSES
#define CNM_TRANSFORM_PASSES

include "mlir/Pass/PassBase.td"

def CnmApplyTransformScriptPass : Pass<"cnm-apply-transform", "ModuleOp"> {
let summary = "Apply a transform dialect script to cnm compute ops";
let description = [{}];
// let constructor = "cnm::createApplyTransformScriptPass()";
let dependentDialects = ["transform::TransformDialect"];
let options = [
Option<"debugPayloadRootTag", "debug-payload-root-tag", "std::string",
/*default=*/[{""}],
"Select the operation with 'transform.target_tag' attribute having "
"the given value as payload IR root. If empty select the pass "
"anchor operation as the payload IR root.">,
Option<"disableExpensiveChecks", "disable-expensive-checks", "bool",
"false",
"Disable expensive checks in the interpreter for a faster run.">,
Option<"transformFileName", "transform-file-name", "std::string", "",
"File name of the transform script">,
Option<"entryPoint", "entry-point", "std::string",
/*default=*/[{transform::TransformDialect::kTransformEntryPointSymbolName.str()}],
"Entry point of the pass pipeline.">,
Option<"debugTransformRootTag", "debug-transform-root-tag", "std::string",
/*default=*/[{""}],
"Select the operation with 'transform.target_tag' attribute having "
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply. If empty, "
"select the container of the top-level transform op."
>,
ListOption<"transformLibraryPaths", "transform-library-paths", "std::string",
"Optional paths to files with modules that should be "
"merged into the transform module to provide the "
"definitions of external named sequences.">
];
}

#endif // CNM_TRANSFORM_PASSES
Loading