Skip to content

Commit

Permalink
Lower tensor.fork to stream.fork
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Feb 28, 2024
1 parent 4614100 commit d2bfba4
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions lib/Dialect/HLS/Transforms/ReduceTensorToStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,42 @@ struct EliminateIntermediateTensor
};
} // namespace

namespace {
struct ConvertTensorForkToStreamFork
: public OpRewritePattern<hls::TensorForkOp> {
using OpRewritePattern<hls::TensorForkOp>::OpRewritePattern;

LogicalResult matchAndRewrite(hls::TensorForkOp tensorForkOp,
PatternRewriter &rewriter) const override {
auto streamToTensor =
tensorForkOp.getSource().getDefiningOp<hls::StreamToTensorOp>();
if (!streamToTensor)
return failure();

// Construct N forked stream channels.
auto loc = tensorForkOp.getLoc();
SmallVector<Value> destStreams;
for (unsigned i = 0; i < tensorForkOp.getNumResults(); i++)
destStreams.push_back(
rewriter.create<hls::StreamOp>(loc, streamToTensor.getStreamType()));

// Create the stream fork operation.
rewriter.create<hls::StreamForkOp>(loc, streamToTensor.getStream(),
destStreams);

// Replace the tensor fork results with the forked streams.
for (auto [result, destStream] :
llvm::zip(tensorForkOp.getResults(), destStreams)) {
auto destTensor = rewriter.create<hls::StreamToTensorOp>(
loc, result.getType(), destStream);
rewriter.replaceAllUsesWith(result, destTensor.getResult());
}
return success();
// love uuuuuuuu ;)
}
};
} // namespace

namespace {
struct ReduceTensorToStream
: public ReduceTensorToStreamBase<ReduceTensorToStream> {
Expand All @@ -117,6 +153,7 @@ struct ReduceTensorToStream

mlir::RewritePatternSet patterns(context);
patterns.add<EliminateIntermediateTensor>(context);
patterns.add<ConvertTensorForkToStreamFork>(context);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
Expand Down

0 comments on commit d2bfba4

Please sign in to comment.