Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Allow multi-result ops in reshape fusion #108576

Merged
merged 1 commit into from
Sep 16, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Sep 13, 2024

Fusion of reshapes by collapsing patterns were restricted to single result operations, but the implementation supports multi result ops. This PR removes the restriction, since it is not necessary.

@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (Max191)

Changes

Fusion of reshapes by collapsing patterns were restricted to single result operations, but the implementation supports multi result ops. This PR removes the restriction, since it is not necessary.


Full diff: https://github.com/llvm/llvm-project/pull/108576.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+1-1)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+24-18)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index c818675993c2c3..a934e47794051c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1254,7 +1254,7 @@ static SmallVector<ReassociationIndices>
 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
                                  ArrayRef<ReassociationIndices> reassociation) {
   // Some basic checks for this fusion to be valid.
-  if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
+  if (!genericOp.hasPureTensorSemantics())
     return {};
 
   if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 600f0dea31f4a8..f17881d59a266e 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -7,49 +7,55 @@
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
 #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)>
 func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
-    %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
+    %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
   %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
-  %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
-  %generic = linalg.generic {
-    indexing_maps = [#map0, #map1, #map2, #map3],
+  %init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+  %init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
+  %generic:2 = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3, #map4],
     iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
     ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
-    outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
-      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+    outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32):
         %t0 = arith.addi %b0, %b1 : i32
         %t1 = arith.addi %t0, %b2 : i32
-        linalg.yield %t1 : i32
-    } -> tensor<2x3x4x5x6x7x8x9xi32>
-  return %generic : tensor<2x3x4x5x6x7x8x9xi32>
+        linalg.yield %t1, %t1 : i32, i32
+    } -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>)
+  return %generic#0, %generic#1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
 //      CHECK: func @fuse_by_collapsing(
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
 // CHECK-SAME:   %[[ARG1:.+]]: tensor<2x3x4xi32>
 // CHECK-SAME:   %[[ARG2:.+]]: tensor<5x6x7x8xi32>
-//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+//  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
 //  CHECK-DAG:   %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
 //  CHECK-DAG:   %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
-//  CHECK-DAG:   %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
-//      CHECK:   %[[COLLAPSED_OP:.+]] = linalg.generic
-// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
+//  CHECK-DAG:   %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+//  CHECK-DAG:   %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
+//      CHECK:   %[[COLLAPSED_OP:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
-// CHECK-SAME:       outs(%[[INIT_RESHAPE]] :
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
-//      CHECK:   return %[[RESULT_RESHAPE]]
+// CHECK-SAME:       outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
+//      CHECK:   %[[RESULT0_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#0 {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
+//      CHECK:   %[[RESULT1_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#1 {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}} output_shape [3, 4, 2, 9, 5, 6, 7, 8]
+//      CHECK:   return %[[RESULT0_RESHAPE]], %[[RESULT1_RESHAPE]]
 
 //      CONTROL: func @fuse_by_collapsing(
 // CONTROL-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
 // CONTROL-SAME:   %[[ARG1:.+]]: tensor<2x3x4xi32>
 // CONTROL-SAME:   %[[ARG2:.+]]: tensor<5x6x7x8xi32>
 //      CONTROL:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
-//      CONTROL:   %[[GENERIC:.+]] = linalg.generic
+//      CONTROL:   %[[GENERIC:.+]]:2 = linalg.generic
 // CONTROL-SAME:       ins(%[[EXPAND]],
-//      CONTROL:   return %[[GENERIC]]
+//      CONTROL:   return %[[GENERIC]]#0, %[[GENERIC]]#1
 
 // -----
 

@Max191 Max191 force-pushed the fuse-multi-result-reshapes branch from 848a59d to fd39676 Compare September 16, 2024 15:32
@Max191 Max191 merged commit 08efa23 into llvm:main Sep 16, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants