|
| 1 | +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-propagate-dispatch-config)))" --verify-diagnostics %s | FileCheck %s |
| 2 | + |
| 3 | +// Basic: single dispatch_config + export. |
| 4 | +hal.executable private @basic_exe { |
| 5 | + hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) { |
| 6 | + hal.executable.export @matmul ordinal(0) |
| 7 | + layout(#hal.pipeline.layout<bindings = [ |
| 8 | + #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, |
| 9 | + #hal.pipeline.binding<storage_buffer, Indirect> |
| 10 | + ]>) |
| 11 | + count(%device: !hal.device, %w0: index, %w1: index) |
| 12 | + -> (index, index, index) { |
| 13 | + %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0, %w1) |
| 14 | + hal.return %x, %y, %z : index, index, index |
| 15 | + } |
| 16 | + builtin.module { |
| 17 | + func.func @matmul() { |
| 18 | + return |
| 19 | + } |
| 20 | + iree_codegen.dispatch_config @matmul |
| 21 | + workgroup_size = [64, 16, 1] subgroup_size = 64 { |
| 22 | + ^bb0(%w0: index, %w1: index): |
| 23 | + %c1 = arith.constant 1 : index |
| 24 | + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%w0] |
| 25 | + iree_codegen.yield %0, %w1, %c1 : index, index, index |
| 26 | + } |
| 27 | + } |
| 28 | + } |
| 29 | +} |
| 30 | +// CHECK-LABEL: hal.executable private @basic_exe |
| 31 | +// CHECK: hal.executable.export public @matmul |
| 32 | +// CHECK: count(%[[DEV:.+]]: !hal.device, %[[W0:.+]]: index, %[[W1:.+]]: index) |
| 33 | +// CHECK: %[[C1:.+]] = arith.constant 1 : index |
| 34 | +// CHECK: %[[X:.+]] = affine.apply |
| 35 | +// CHECK: hal.return %[[X]], %[[W1]], %[[C1]] |
| 36 | +// CHECK: } attributes {subgroup_size = 64 : index, workgroup_size = [64 : index, 16 : index, 1 : index]} |
| 37 | +// CHECK: builtin.module |
| 38 | +// CHECK: func.func @matmul() |
| 39 | +// CHECK-NOT: iree_codegen.dispatch_config |
| 40 | + |
| 41 | +// ----- |
| 42 | + |
| 43 | +// Multiple dispatch_configs (specialization). |
| 44 | +hal.executable private @specialized_exe { |
| 45 | + hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) { |
| 46 | + hal.executable.export @matmul ordinal(0) |
| 47 | + layout(#hal.pipeline.layout<bindings = [ |
| 48 | + #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, |
| 49 | + #hal.pipeline.binding<storage_buffer, Indirect> |
| 50 | + ]>) |
| 51 | + count(%device: !hal.device, %w0: index) |
| 52 | + -> (index, index, index) { |
| 53 | + %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0) |
| 54 | + hal.return %x, %y, %z : index, index, index |
| 55 | + } |
| 56 | + hal.executable.export @matmul_0 ordinal(1) |
| 57 | + layout(#hal.pipeline.layout<bindings = [ |
| 58 | + #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, |
| 59 | + #hal.pipeline.binding<storage_buffer, Indirect> |
| 60 | + ]>) |
| 61 | + count(%device: !hal.device, %w0: index) |
| 62 | + -> (index, index, index) { |
| 63 | + %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0) |
| 64 | + hal.return %x, %y, %z : index, index, index |
| 65 | + } |
| 66 | + builtin.module { |
| 67 | + func.func @matmul() { |
| 68 | + return |
| 69 | + } |
| 70 | + func.func @matmul_0() { |
| 71 | + return |
| 72 | + } |
| 73 | + iree_codegen.dispatch_config @matmul |
| 74 | + workgroup_size = [64, 16, 1] subgroup_size = 64 { |
| 75 | + ^bb0(%w0: index): |
| 76 | + %c1 = arith.constant 1 : index |
| 77 | + iree_codegen.yield %w0, %c1, %c1 : index, index, index |
| 78 | + } |
| 79 | + iree_codegen.dispatch_config @matmul_0 |
| 80 | + workgroup_size = [256, 1, 1] subgroup_size = 64 { |
| 81 | + ^bb0(%w0: index): |
| 82 | + %c1 = arith.constant 1 : index |
| 83 | + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%w0] |
| 84 | + iree_codegen.yield %0, %c1, %c1 : index, index, index |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | +} |
| 89 | +// CHECK-LABEL: hal.executable private @specialized_exe |
| 90 | +// CHECK: hal.executable.export public @matmul |
| 91 | +// CHECK: count(%{{.+}}: !hal.device, %[[W0A:.+]]: index) |
| 92 | +// CHECK: %[[C1A:.+]] = arith.constant 1 : index |
| 93 | +// CHECK: hal.return %[[W0A]], %[[C1A]], %[[C1A]] |
| 94 | +// CHECK: } attributes {subgroup_size = 64 : index, workgroup_size = [64 : index, 16 : index, 1 : index]} |
| 95 | +// CHECK: hal.executable.export public @matmul_0 |
| 96 | +// CHECK: count(%{{.+}}: !hal.device, %[[W0B:.+]]: index) |
| 97 | +// CHECK: %[[C1B:.+]] = arith.constant 1 : index |
| 98 | +// CHECK: %[[X:.+]] = affine.apply |
| 99 | +// CHECK: hal.return %[[X]], %[[C1B]], %[[C1B]] |
| 100 | +// CHECK: } attributes {subgroup_size = 64 : index, workgroup_size = [256 : index, 1 : index, 1 : index]} |
| 101 | + |
| 102 | +// ----- |
| 103 | + |
| 104 | +// No subgroup_size attribute. |
| 105 | +hal.executable private @no_subgroup_exe { |
| 106 | + hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) { |
| 107 | + hal.executable.export @entry ordinal(0) |
| 108 | + layout(#hal.pipeline.layout<bindings = [ |
| 109 | + #hal.pipeline.binding<storage_buffer, Indirect> |
| 110 | + ]>) |
| 111 | + count(%device: !hal.device, %w0: index) |
| 112 | + -> (index, index, index) { |
| 113 | + %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0) |
| 114 | + hal.return %x, %y, %z : index, index, index |
| 115 | + } |
| 116 | + builtin.module { |
| 117 | + func.func @entry() { |
| 118 | + return |
| 119 | + } |
| 120 | + iree_codegen.dispatch_config @entry |
| 121 | + workgroup_size = [1024, 1, 1] { |
| 122 | + ^bb0(%w0: index): |
| 123 | + %c1 = arith.constant 1 : index |
| 124 | + iree_codegen.yield %w0, %c1, %c1 : index, index, index |
| 125 | + } |
| 126 | + } |
| 127 | + } |
| 128 | +} |
| 129 | +// CHECK-LABEL: hal.executable private @no_subgroup_exe |
| 130 | +// CHECK: hal.executable.export public @entry |
| 131 | +// CHECK: hal.return |
| 132 | +// CHECK: } attributes {workgroup_size = [1024 : index, 1 : index, 1 : index]} |
| 133 | +// CHECK-NOT: subgroup_size |
| 134 | + |
| 135 | +// ----- |
| 136 | + |
| 137 | +// No dispatch_config ops — pass is a no-op. |
| 138 | +hal.executable private @noop_exe { |
| 139 | + hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) { |
| 140 | + hal.executable.export @entry ordinal(0) |
| 141 | + layout(#hal.pipeline.layout<bindings = [ |
| 142 | + #hal.pipeline.binding<storage_buffer, Indirect> |
| 143 | + ]>) |
| 144 | + count(%device: !hal.device, %w0: index) |
| 145 | + -> (index, index, index) { |
| 146 | + %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0) |
| 147 | + hal.return %x, %y, %z : index, index, index |
| 148 | + } |
| 149 | + builtin.module { |
| 150 | + func.func @entry() { |
| 151 | + return |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | +} |
| 156 | +// CHECK-LABEL: hal.executable private @noop_exe |
| 157 | +// CHECK: iree_tensor_ext.dispatch.workgroup_count_from_slice |
| 158 | + |
| 159 | +// ----- |
| 160 | + |
| 161 | +// Error: arity mismatch. |
| 162 | +hal.executable private @arity_mismatch_exe { |
| 163 | + hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) { |
| 164 | + hal.executable.export @entry ordinal(0) |
| 165 | + layout(#hal.pipeline.layout<bindings = [ |
| 166 | + #hal.pipeline.binding<storage_buffer, Indirect> |
| 167 | + ]>) |
| 168 | + count(%device: !hal.device, %w0: index) |
| 169 | + -> (index, index, index) { |
| 170 | + %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0) |
| 171 | + hal.return %x, %y, %z : index, index, index |
| 172 | + } |
| 173 | + builtin.module { |
| 174 | + func.func @entry() { |
| 175 | + return |
| 176 | + } |
| 177 | + // expected-error @+1 {{workload arity mismatch}} |
| 178 | + iree_codegen.dispatch_config @entry |
| 179 | + workgroup_size = [64, 1, 1] { |
| 180 | + ^bb0(%w0: index, %w1: index, %w2: index): |
| 181 | + %c1 = arith.constant 1 : index |
| 182 | + iree_codegen.yield %w0, %c1, %c1 : index, index, index |
| 183 | + } |
| 184 | + } |
| 185 | + } |
| 186 | +} |
0 commit comments