Skip to content

Commit 330dda4

Browse files
committed
fix ut
1 parent 81aa949 commit 330dda4

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

tao_compiler/mlir/disc/tests/mlir_feature_test.cc

-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ void addBoolFlags(EnvSettings& envSettings, const std::string& key) {
238238
} else {
239239
size_t original_size = envSettings.size();
240240
for (int i = 0; i < original_size; ++i) {
241-
envSettings[i][key].first = "false";
242241
envSettings.push_back(envSettings[i]);
243242
envSettings[i][key].first = "true";
244243
}

tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc

+10-14
Original file line numberDiff line numberDiff line change
@@ -1303,13 +1303,6 @@ LogicalResult lowerWithScheduleParallelReduction(
13031303
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
13041304
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
13051305
Value two = b.create<arith::ConstantIndexOp>(loc, 2);
1306-
auto elemFloatType = getLhloOpsElementType(root_op).cast<FloatType>();
1307-
Value zero_f = b.create<arith::ConstantOp>(
1308-
loc, b.getFloatAttr(getLhloOpsElementType(root_op), 0));
1309-
Value one_f = b.create<arith::ConstantOp>(
1310-
loc, b.getFloatAttr(getLhloOpsElementType(root_op), 1));
1311-
1312-
// Start to emit.
13131306
Value num_blocks = b.create<arith::ConstantIndexOp>(loc, 1024);
13141307
Value block_size = b.create<arith::ConstantIndexOp>(loc, 256);
13151308

@@ -1333,8 +1326,7 @@ LogicalResult lowerWithScheduleParallelReduction(
13331326
b.create<gpu::ThreadIdOp>(loc, b.getIndexType(), gpu::Dimension::x);
13341327
Value grid_dim =
13351328
b.create<gpu::GridDimOp>(loc, b.getIndexType(), gpu::Dimension::x);
1336-
// tid = b.create<arith::RemSIOp>(loc, tid, block_dim);
1337-
// i = blockIdx.x * block_size * 2 + tid;
1329+
// i = blockIdx.x * block_size * 2 + tid;
13381330
Value i = b.create<arith::AddIOp>(
13391331
loc,
13401332
b.create<arith::MulIOp>(
@@ -1393,6 +1385,8 @@ LogicalResult lowerWithScheduleParallelReduction(
13931385
Value data = createLoadOrUseCachedValue(
13941386
loc, &b, root_op, *lhs, load_index, b.saveInsertionPoint());
13951387
Value index2 = b.create<arith::AddIOp>(loc, var_j, block_dim);
1388+
Value iter_value =
1389+
*(for_op_k.getRegionIterArgs().begin() + scalar_red_root_op_idx);
13961390
// if (i + grid_size < n)
13971391
scf::IfOp if_tid_valid_op = b.create<scf::IfOp>(
13981392
loc, /*resultTypes*/ init_values_types,
@@ -1404,17 +1398,19 @@ LogicalResult lowerWithScheduleParallelReduction(
14041398
SmallVector<Value, 2> load_index2({index2, zero});
14051399
Value data1 = createLoadOrUseCachedValue(
14061400
loc, &b, root_op, *lhs, load_index2, b.saveInsertionPoint());
1401+
data1 = (accum_factory[scalar_red_root_op_idx])(iter_value, data1);
14071402
b.setInsertionPointToEnd(&if_tid_valid_op.getThenRegion().front());
14081403
b.create<scf::YieldOp>(loc, data1);
14091404
b.setInsertionPointToStart(&if_tid_valid_op.getElseRegion().front());
1410-
b.create<scf::YieldOp>(loc, zero_f);
1405+
b.create<scf::YieldOp>(loc, iter_value);
1406+
// loc, cast<lmhlo::ReduceOp>(root_op).getInitValues().front());
14111407
b.setInsertionPointAfter(if_tid_valid_op);
1412-
Value sum = (accum_factory[scalar_red_root_op_idx])(
1408+
Value acc = (accum_factory[scalar_red_root_op_idx])(
14131409
data, if_tid_valid_op.getResults().front());
14141410

1415-
auto acc = (accum_factory[scalar_red_root_op_idx])(
1416-
*(for_op_k.getRegionIterArgs().begin() + scalar_red_root_op_idx),
1417-
sum);
1411+
// acc = (accum_factory[scalar_red_root_op_idx])(
1412+
// *(for_op_k.getRegionIterArgs().begin() + scalar_red_root_op_idx),
1413+
// acc);
14181414
yield_values_for_if.push_back(acc);
14191415
scalar_red_root_op_idx++;
14201416
} else if (isa<lmhlo::ReduceOp>(root_op)) {

0 commit comments

Comments
 (0)