Skip to content

Commit 98e0c11

Browse files
committed
feat(mlir): finish "toy.or" lower to StandardOp
OrOp only support int type, to lower toy.or to standardOrOp we need: \t* insert a FpToInt op before OrOp; \t* insert a IntToFp after OrOp.
1 parent 6eaffdd commit 98e0c11

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

mlir/mycode/Ch6/mlir/LowerToAffineLoops.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
8181
// and the loop induction variables. This function will return the value
8282
// to store at the current index.
8383
Value valueToStore = processIteration(nestedBuilder, operands, ivs);
84+
// Patch to support "toy.or" Op
85+
if(nestedBuilder.getI64Type() == valueToStore.getType()) {
86+
valueToStore = nestedBuilder.create<UIToFPOp>(loc, nestedBuilder.getF64Type(), valueToStore);
87+
}
8488
nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
8589
});
8690

@@ -104,7 +108,7 @@ struct BinaryOpLowering : public ConversionPattern {
104108
auto loc = op->getLoc();
105109
lowerOpToLoops(
106110
op, operands, rewriter,
107-
[loc](OpBuilder &builder, ValueRange memRefOperands,
111+
[loc, op](OpBuilder &builder, ValueRange memRefOperands,
108112
ValueRange loopIvs) {
109113
// Generate an adaptor for the remapped operands of the BinaryOp. This
110114
// allows for using the nice named accessors that are generated by the
@@ -118,6 +122,15 @@ struct BinaryOpLowering : public ConversionPattern {
118122
auto loadedRhs =
119123
builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
120124

125+
// patch to support "toy.or" operation.
126+
auto opname = op->getName();
127+
if (opname.getStringRef().str() == "toy.or") {
128+
auto castLhs = builder.create<FPToUIOp>(loc, builder.getI64Type(), loadedLhs);
129+
auto castRhs = builder.create<FPToUIOp>(loc, builder.getI64Type(), loadedRhs);
130+
131+
return builder.create<LoweredBinaryOp>(loc, castLhs, castRhs);
132+
}
133+
121134
// Create the binary operation performed on the loaded values.
122135
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
123136
});
@@ -126,6 +139,7 @@ struct BinaryOpLowering : public ConversionPattern {
126139
};
127140
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
128141
using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
142+
using OrOpLowering = BinaryOpLowering<toy::OrOp, OrOp>;
129143

130144
//===----------------------------------------------------------------------===//
131145
// ToyToAffine RewritePatterns: Constant operations
@@ -297,7 +311,7 @@ void ToyToAffineLoweringPass::runOnFunction() {
297311
// Now that the conversion target has been defined, we just need to provide
298312
// the set of patterns that will lower the Toy operations.
299313
RewritePatternSet patterns(&getContext());
300-
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
314+
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering, OrOpLowering,
301315
ReturnOpLowering, TransposeOpLowering>(&getContext());
302316

303317
// With the target and rewrite patterns defined, we can now attempt the

mlir/mycode/Ch6/toyc.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
8787
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
8888
llvm::MemoryBuffer::getFileOrSTDIN(filename);
8989
if (std::error_code ec = fileOrErr.getError()) {
90-
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
90+
llvm::errs() << "Could not open input file: " << filename << ec.message() << "\n";
9191
return nullptr;
9292
}
9393
auto buffer = fileOrErr.get()->getBuffer();
@@ -111,7 +111,7 @@ int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
111111
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
112112
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
113113
if (std::error_code EC = fileOrErr.getError()) {
114-
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
114+
llvm::errs() << "Could not open input file: " << inputFilename << ":" << EC.message() << "\n";
115115
return -1;
116116
}
117117

0 commit comments

Comments
 (0)