@@ -1305,14 +1305,14 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,
13051305
13061306 b->setInsertionPointToEnd (&if_inbound_ops[i].getElseRegion ().front ());
13071307 if (i == num_input_operands - 1 ) {
1308- input_index[axis] = b-> create <arith::SubIOp>(loc, out_idx, low_bound);
1309- auto operand_memref = op. getOperand (i );
1308+ // we expect this branch never be executed
1309+ input_index[axis] = b-> create <arith::ConstantIndexOp>(loc, 0 );
13101310 auto ret_value =
13111311 check_cache ? createLoadOrUseCachedValue (
1312- loc, b, op.getOperation (), operand_memref ,
1312+ loc, b, op.getOperation (), op. getOperand (i) ,
13131313 input_index, b->saveInsertionPoint (), lower_config)
13141314 : createMaySpecificLoad (*b, loc, op.getOperation (),
1315- operand_memref , input_index,
1315+ op. getOperand (i) , input_index,
13161316 lower_config);
13171317 b->create <scf::YieldOp>(loc, ret_value);
13181318 } else {
@@ -1360,7 +1360,24 @@ Value elementalLower<lmhlo_disc::ConcatenateOp>(OpBuilder* b, Location loc,
13601360
13611361 auto int_ptr =
13621362 b->create <memref::LoadOp>(loc, ptr_array, ValueRange{operand_index});
1363- Type ptr_type = LLVM::LLVMPointerType::get (FloatType::getF32 (ctx));
1363+ auto elem_ty = out.getType ().cast <MemRefType>().getElementType ();
1364+ // if elem_ty is bf16
1365+ Type ptr_type;
1366+ if (elem_ty.isBF16 ()) {
1367+ ptr_type = LLVM::LLVMPointerType::get (FloatType::getBF16 (ctx));
1368+ } else if (elem_ty.isF16 ()) {
1369+ ptr_type = LLVM::LLVMPointerType::get (FloatType::getF16 (ctx));
1370+ } else if (elem_ty.isF32 ()) {
1371+ ptr_type = LLVM::LLVMPointerType::get (FloatType::getF32 (ctx));
1372+ } else if (elem_ty.isInteger (32 ) || elem_ty.isInteger (64 ) ||
1373+ elem_ty.isInteger (8 )) {
1374+ ptr_type = LLVM::LLVMPointerType::get (
1375+ IntegerType::get (ctx, elem_ty.getIntOrFloatBitWidth ()));
1376+ } else {
1377+ op.emitError (" unsupported element type for ConcatenateOp" );
1378+ return Value (nullptr );
1379+ }
1380+
13641381 auto llvm_ptr = b->create <LLVM::IntToPtrOp>(loc, ptr_type, int_ptr);
13651382
13661383 SmallVector<Value, 4 > input_index;
0 commit comments