@@ -863,3 +863,51 @@ util.func public @multi_reduction(%arg0 : tensor<32x16x16384xf32>, %arg1 : tenso
863863// CHECK: %[[GEN2:.+]] = linalg.generic
864864// CHECK-SAME: ins(%[[GEN1]] : tensor<32xf32>)
865865// CHECK: flow.return %[[GEN2]]
866+
867+ // -----
868+
869+ util.func public @collapse_single_fill (%arg0: tensor <11 x470 x725 x224 xf32 >) -> tensor <11 x470 x725 x224 xf32 > {
870+ %0 = flow.dispatch.region -> (tensor <11 x470 x725 x224 xf32 >) {
871+ %cst = arith.constant 0.000000e+00 : f32
872+ %1 = linalg.fill ins (%cst : f32 ) outs (%arg0 : tensor <11 x470 x725 x224 xf32 >) -> tensor <11 x470 x725 x224 xf32 >
873+ flow.return %1 : tensor <11 x470 x725 x224 xf32 >
874+ }
875+ util.return %0 : tensor <11 x470 x725 x224 xf32 >
876+ }
877+ // CHECK-LABEL: util.func public @collapse_single_fill
878+ // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
879+ // CHECK-DAG: %[[COLLAPSE0:.+]] = tensor.collapse_shape %[[ARG0]]
880+ // CHECK: flow.dispatch.region
881+ // CHECK: %[[FILL:.+]] = linalg.fill
882+ // CHECK-SAME: outs(%[[COLLAPSE0]] : tensor<839608000xf32>)
883+ // CHECK: flow.return %[[FILL]]
884+
885+ // -----
886+
887+ util.func public @collapse_fill_of_arg (%arg0: tensor <224 x32 xf32 >, %arg1: tensor <11 x470 x725 x224 xf32 >, %arg2: tensor <11 x470 x725 x32 xf32 >) -> tensor <11 x470 x725 x224 xf32 > {
888+ %0 = flow.dispatch.region -> (tensor <11 x470 x725 x224 xf32 >) {
889+ %cst = arith.constant 0.000000e+00 : f32
890+ %1 = linalg.fill ins (%cst : f32 ) outs (%arg1 : tensor <11 x470 x725 x224 xf32 >) -> tensor <11 x470 x725 x224 xf32 >
891+ %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d3 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " reduction" ]} ins (%arg2 , %arg0 : tensor <11 x470 x725 x32 xf32 >, tensor <224 x32 xf32 >) outs (%1 : tensor <11 x470 x725 x224 xf32 >) {
892+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
893+ %3 = arith.mulf %in , %in_0 : f32
894+ %4 = arith.addf %out , %3 : f32
895+ linalg.yield %4 : f32
896+ } -> tensor <11 x470 x725 x224 xf32 >
897+ flow.return %2 : tensor <11 x470 x725 x224 xf32 >
898+ }
899+ util.return %0 : tensor <11 x470 x725 x224 xf32 >
900+ }
901+ // CHECK-LABEL: util.func public @collapse_fill_of_arg
902+ // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
903+ // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]
904+ // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]
905+ // CHECK-DAG: %[[COLLAPSE1:.+]] = tensor.collapse_shape %[[ARG1]]
906+ // CHECK-DAG: %[[COLLAPSE2:.+]] = tensor.collapse_shape %[[ARG2]]
907+ // CHECK: flow.dispatch.region
908+ // CHECK: %[[FILL:.+]] = linalg.fill
909+ // CHECK-SAME: outs(%[[COLLAPSE1]] : tensor<3748250x224xf32>)
910+ // CHECK: %[[GEN0:.+]] = linalg.generic
911+ // CHECK-SAME: ins(%[[COLLAPSE2]], %[[ARG0]] : tensor<3748250x32xf32>, tensor<224x32xf32>)
912+ // CHECK-SAME: outs(%[[FILL]] : tensor<3748250x224xf32>)
913+ // CHECK: flow.return %[[GEN0]]
0 commit comments