@@ -1303,13 +1303,6 @@ LogicalResult lowerWithScheduleParallelReduction(
1303
1303
Value zero = b.create <arith::ConstantIndexOp>(loc, 0 );
1304
1304
Value one = b.create <arith::ConstantIndexOp>(loc, 1 );
1305
1305
Value two = b.create <arith::ConstantIndexOp>(loc, 2 );
1306
- auto elemFloatType = getLhloOpsElementType (root_op).cast <FloatType>();
1307
- Value zero_f = b.create <arith::ConstantOp>(
1308
- loc, b.getFloatAttr (getLhloOpsElementType (root_op), 0 ));
1309
- Value one_f = b.create <arith::ConstantOp>(
1310
- loc, b.getFloatAttr (getLhloOpsElementType (root_op), 1 ));
1311
-
1312
- // Start to emit.
1313
1306
Value num_blocks = b.create <arith::ConstantIndexOp>(loc, 1024 );
1314
1307
Value block_size = b.create <arith::ConstantIndexOp>(loc, 256 );
1315
1308
@@ -1333,8 +1326,7 @@ LogicalResult lowerWithScheduleParallelReduction(
1333
1326
b.create <gpu::ThreadIdOp>(loc, b.getIndexType (), gpu::Dimension::x);
1334
1327
Value grid_dim =
1335
1328
b.create <gpu::GridDimOp>(loc, b.getIndexType (), gpu::Dimension::x);
1336
- // tid = b.create<arith::RemSIOp>(loc, tid, block_dim);
1337
- // i = blockIdx.x * block_size * 2 + tid;
1329
+ // i = blockIdx.x * block_size * 2 + tid;
1338
1330
Value i = b.create <arith::AddIOp>(
1339
1331
loc,
1340
1332
b.create <arith::MulIOp>(
@@ -1393,6 +1385,8 @@ LogicalResult lowerWithScheduleParallelReduction(
1393
1385
Value data = createLoadOrUseCachedValue (
1394
1386
loc, &b, root_op, *lhs, load_index, b.saveInsertionPoint ());
1395
1387
Value index2 = b.create <arith::AddIOp>(loc, var_j, block_dim);
1388
+ Value iter_value =
1389
+ *(for_op_k.getRegionIterArgs ().begin () + scalar_red_root_op_idx);
1396
1390
// if (i + grid_size < n)
1397
1391
scf::IfOp if_tid_valid_op = b.create <scf::IfOp>(
1398
1392
loc, /* resultTypes*/ init_values_types,
@@ -1404,17 +1398,19 @@ LogicalResult lowerWithScheduleParallelReduction(
1404
1398
SmallVector<Value, 2 > load_index2 ({index2, zero});
1405
1399
Value data1 = createLoadOrUseCachedValue (
1406
1400
loc, &b, root_op, *lhs, load_index2, b.saveInsertionPoint ());
1401
+ data1 = (accum_factory[scalar_red_root_op_idx])(iter_value, data1);
1407
1402
b.setInsertionPointToEnd (&if_tid_valid_op.getThenRegion ().front ());
1408
1403
b.create <scf::YieldOp>(loc, data1);
1409
1404
b.setInsertionPointToStart (&if_tid_valid_op.getElseRegion ().front ());
1410
- b.create <scf::YieldOp>(loc, zero_f);
1405
+ b.create <scf::YieldOp>(loc, iter_value);
1406
+ // loc, cast<lmhlo::ReduceOp>(root_op).getInitValues().front());
1411
1407
b.setInsertionPointAfter (if_tid_valid_op);
1412
- Value sum = (accum_factory[scalar_red_root_op_idx])(
1408
+ Value acc = (accum_factory[scalar_red_root_op_idx])(
1413
1409
data, if_tid_valid_op.getResults ().front ());
1414
1410
1415
- auto acc = (accum_factory[scalar_red_root_op_idx])(
1416
- *(for_op_k.getRegionIterArgs ().begin () + scalar_red_root_op_idx),
1417
- sum );
1411
+ // acc = (accum_factory[scalar_red_root_op_idx])(
1412
+ // *(for_op_k.getRegionIterArgs().begin() + scalar_red_root_op_idx),
1413
+ // acc );
1418
1414
yield_values_for_if.push_back (acc);
1419
1415
scalar_red_root_op_idx++;
1420
1416
} else if (isa<lmhlo::ReduceOp>(root_op)) {
0 commit comments