@@ -34,7 +34,7 @@ func.func @dynamic_batch_matvec() {
3434}
3535
3636// CDNA3-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 1], [0, 0, 0, 32]{{\]}}>
37- // CDNA3-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1]>
37+ // CDNA3-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1] subgroup_size = 32 >
3838// CDNA3-LABEL: func.func @dynamic_batch_matvec()
3939// CDNA3-SAME: translation_info = #[[$TRANSLATION]]
4040// CDNA3: linalg.batch_matmul
@@ -367,15 +367,57 @@ func.func @dynamic_parallel_dims(%dynsize : index, %input : tensor<4x?x4096xf16>
367367 return %2 : tensor <4 x?xf32 >
368368}
369369// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 64]{{\]}}
370- // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1]>
370+ // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64 >
371371// CHECK: func @dynamic_parallel_dims
372372// CHECK-SAME: translation_info = #[[TRANSLATION]]
373373// CHECK: linalg.generic
374374// CHECK-SAME: lowering_config = #[[CONFIG]]
375375
376376// CDNA3-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 32]{{\]}}
377- // CDNA3-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1]>
377+ // CDNA3-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1] subgroup_size = 32 >
378378// CDNA3: func @dynamic_parallel_dims
379379// CDNA3-SAME: translation_info = #[[TRANSLATION]]
380380// CDNA3: linalg.generic
381381// CDNA3-SAME: lowering_config = #[[CONFIG]]
382+
383+ // -----
384+
385+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>
386+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>
387+ #map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 )>
388+ #map3 = affine_map <(d0 , d1 ) -> (d0 , d1 )>
389+ #map4 = affine_map <(d0 , d1 ) -> ()>
390+ func.func @test_dyn_reduction (%arg0: tensor <128 x?x32 xf8 E4 M3 FNUZ>, %arg1: tensor <128 x?x32 x128 xf8 E4 M3 FNUZ>, %arg2: tensor <f32 >) -> tensor <128 x128 xf8 E4 M3 FNUZ> {
391+ %cst = arith.constant 0.000000e+00 : f32
392+ %cst_0 = arith.constant -2.400000e+02 : f8E4M3FNUZ
393+ %cst_1 = arith.constant 2.400000e+02 : f8E4M3FNUZ
394+ %0 = tensor.empty () : tensor <128 x128 xf8 E4 M3 FNUZ>
395+ %1 = tensor.empty () : tensor <128 x128 xf32 >
396+ %2 = linalg.fill ins (%cst : f32 ) outs (%1 : tensor <128 x128 xf32 >) -> tensor <128 x128 xf32 >
397+ %3 = linalg.generic {index ing_maps = [#map , #map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " reduction" , " reduction" ]} ins (%arg0 , %arg1 : tensor <128 x?x32 xf8 E4 M3 FNUZ>, tensor <128 x?x32 x128 xf8 E4 M3 FNUZ>) outs (%2 : tensor <128 x128 xf32 >) {
398+ ^bb0 (%in: f8E4M3FNUZ , %in_2: f8E4M3FNUZ , %out: f32 ):
399+ %5 = arith.extf %in : f8E4M3FNUZ to f32
400+ %6 = arith.extf %in_2 : f8E4M3FNUZ to f32
401+ %7 = arith.mulf %5 , %6 : f32
402+ %8 = arith.addf %out , %7 : f32
403+ linalg.yield %8 : f32
404+ } -> tensor <128 x128 xf32 >
405+ %4 = linalg.generic {index ing_maps = [#map3 , #map4 , #map3 ], iterator_types = [" parallel" , " parallel" ]} ins (%3 , %arg2 : tensor <128 x128 xf32 >, tensor <f32 >) outs (%0 : tensor <128 x128 xf8 E4 M3 FNUZ>) {
406+ ^bb0 (%in: f32 , %in_2: f32 , %out: f8E4M3FNUZ ):
407+ %5 = arith.truncf %in : f32 to f8E4M3FNUZ
408+ %6 = arith.truncf %in_2 : f32 to f8E4M3FNUZ
409+ %7 = arith.divf %5 , %6 : f8E4M3FNUZ
410+ %8 = arith.cmpf ult , %7 , %cst_0 : f8E4M3FNUZ
411+ %9 = arith.select %8 , %cst_0 , %7 : f8E4M3FNUZ
412+ %10 = arith.cmpf ugt , %9 , %cst_1 : f8E4M3FNUZ
413+ %11 = arith.select %10 , %cst_1 , %9 : f8E4M3FNUZ
414+ linalg.yield %11 : f8E4M3FNUZ
415+ } -> tensor <128 x128 xf8 E4 M3 FNUZ>
416+ return %4 : tensor <128 x128 xf8 E4 M3 FNUZ>
417+ }
418+ // CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 1, 64]{{\]}}>
419+ // CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
420+ // CHECK: func.func @test_dyn_reduction
421+ // CHECK-SAME: translation_info = #[[$TRANSLATION]]
422+ // CHECK: linalg.generic
423+ // CHECK-SAME: lowering_config = #[[$CONFIG]]
0 commit comments