33// CHECK-LABEL: @gemm_memref
44
55func.func @gemm_memref (%arg0: memref <8 x1024 xi32 >, %arg1: memref <1024 x128 xi32 >) -> memref <8 x128 xi32 > {
6+ // CHECK: cinm.compute (%[[a0:.*]] = %{{.*}}, %[[b0:.*]] = %{{.*}}) ->
67 // CHECK: %[[out:.*]] = memref.alloc()
78 // CHECK: linalg.fill ins({{.*}}) outs(%[[out]] :
89 // CHECK: affine.for %[[i:.*]] = 0 to 8 step 8
910 // CHECK: affine.for %[[j:.*]] = 0 to 128 step 128
1011 // CHECK: %[[sliceOut:.*]] = memref.subview %[[out]][%[[i]], %[[j]]] [8, 128] [1, 1] :
1112 // CHECK: affine.for %[[k:.*]] = 0 to 1024 step 32
12- // CHECK: %[[sliceA:.*]] = memref.subview %arg0 [%[[i]], %[[k]]] [8, 32] [1, 1] :
13- // CHECK: %[[sliceB:.*]] = memref.subview %arg1 [%[[k]], %[[j]]] [32, 128] [1, 1] :
13+ // CHECK: %[[sliceA:.*]] = memref.subview %[[a0]] [%[[i]], %[[k]]] [8, 32] [1, 1] :
14+ // CHECK: %[[sliceB:.*]] = memref.subview %[[b0]] [%[[k]], %[[j]]] [32, 128] [1, 1] :
1415 // CHECK: cinm.op.gemm %[[sliceA]], %[[sliceB]] into %[[sliceOut]] {cinm.notile} :
15- %0 = cinm.compute attributes {workgroupShape = array<i64 : 8 , 128 , 1 >, bufferSizesInBytes =array<i64 : 0 ,0 ,512 >} -> memref < 8 x 128 x i32 > {
16+ %0 = cinm.compute ( %a0 = %arg0 : memref < 8 x 1024 x i32 >, %a1 = %arg1: memref < 1024 x 128 x i32 >) -> memref < 8 x 128 x i32 > attributes {workgroupShape = array<i64 : 8 , 128 , 1 >, bufferSizesInBytes =array<i64 : 0 ,0 ,512 >} {
1617 %alloc = memref.alloc () : memref <8 x128 xi32 >
1718 %c0_i32 = arith.constant 0 : i32
1819 linalg.fill ins (%c0_i32 : i32 ) outs (%alloc : memref <8 x128 xi32 >)
19- cinm.op.gemm %arg0 , %arg1 into %alloc : memref <8 x1024 xi32 >, memref <1024 x128 xi32 > into memref <8 x128 xi32 >
20+ cinm.op.gemm %a0 , %a1 into %alloc : memref <8 x1024 xi32 >, memref <1024 x128 xi32 > into memref <8 x128 xi32 >
2021 cinm.yield %alloc : memref <8 x128 xi32 >
2122 }
2223 return %0 : memref <8 x128 xi32 >
@@ -27,62 +28,64 @@ func.func @gemm_memref(%arg0: memref<8x1024xi32>, %arg1: memref<1024x128xi32>) -
2728// CHECK-SAME: ({{.*}}, %[[bias:.*]]: memref<8x128xi32>)
2829
2930func.func @gemm_memref_bias (%arg0: memref <8 x1024 xi32 >, %arg1: memref <1024 x128 xi32 >, %bias: memref <8 x128 xi32 >) -> memref <8 x128 xi32 > {
31+ // CHECK: cinm.compute (%[[a0:.*]] = %{{.*}}, %[[b0:.*]] = %{{.*}}, %[[c0:.*]] = %{{.*}}) ->
3032 // CHECK: %[[out:.*]] = memref.alloc()
3133 // CHECK: affine.for %[[i:.*]] = 0 to 8 step 8
3234 // CHECK: affine.for %[[j:.*]] = 0 to 128 step 128
33- // CHECK: %[[sliceBias:.*]] = memref.subview %arg2 [%[[i]], %[[j]]] [8, 128] [1, 1] :
35+ // CHECK: %[[sliceBias:.*]] = memref.subview %[[c0]] [%[[i]], %[[j]]] [8, 128] [1, 1] :
3436 // CHECK: %[[sliceOut:.*]] = memref.subview %[[out]][%[[i]], %[[j]]] [8, 128] [1, 1] :
3537 // CHECK: linalg.add ins(%[[sliceBias]], %[[sliceOut]] : {{.*}}) outs(%[[sliceOut]] :
3638 // CHECK: affine.for %[[k:.*]] = 0 to 1024 step 32
37- // CHECK: %[[sliceA:.*]] = memref.subview %arg0 [%[[i]], %[[k]]] [8, 32] [1, 1] :
38- // CHECK: %[[sliceB:.*]] = memref.subview %arg1 [%[[k]], %[[j]]] [32, 128] [1, 1] :
39+ // CHECK: %[[sliceA:.*]] = memref.subview %[[a0]] [%[[i]], %[[k]]] [8, 32] [1, 1] :
40+ // CHECK: %[[sliceB:.*]] = memref.subview %[[b0]] [%[[k]], %[[j]]] [32, 128] [1, 1] :
3941 // CHECK: cinm.op.gemm %[[sliceA]], %[[sliceB]] into %[[sliceOut]] {cinm.notile} :
40- %0 = cinm.compute attributes {workgroupShape = array<i64 : 8 , 128 , 1 >, bufferSizesInBytes =array<i64 : 0 ,0 ,512 >} -> memref < 8 x 128 x i32 > {
42+ %0 = cinm.compute ( %a0 = %arg0: memref < 8 x 1024 x i32 >, %a1 = %arg1: memref < 1024 x 128 x i32 >, %b0 = %bias: memref < 8 x 128 x i32 >) -> memref < 8 x 128 x i32 > attributes {workgroupShape = array<i64 : 8 , 128 , 1 >, bufferSizesInBytes =array<i64 : 0 ,0 ,512 >} {
4143 %alloc = memref.alloc () : memref <8 x128 xi32 >
4244 %c0_i32 = arith.constant 0 : i32
4345 linalg.fill ins (%c0_i32 : i32 ) outs (%alloc : memref <8 x128 xi32 >)
44- cinm.op.gemm %arg0 , %arg1 plus %bias into %alloc : memref <8 x1024 xi32 >, memref <1024 x128 xi32 > plus memref <8 x128 xi32 > into memref <8 x128 xi32 >
46+ cinm.op.gemm %a0 , %a1 plus %b0 into %alloc : memref <8 x1024 xi32 >, memref <1024 x128 xi32 > plus memref <8 x128 xi32 > into memref <8 x128 xi32 >
4547 cinm.yield %alloc : memref <8 x128 xi32 >
4648 }
4749 return %0 : memref <8 x128 xi32 >
4850}
4951
5052// -----
5153// CHECK-LABEL: @gemm_tensor
54+ // CHECK: cinm.compute (%[[a0:.*]] = %{{.*}}, %[[b0:.*]] = %{{.*}}) ->
5255// CHECK: affine.for %[[i:.*]] = 0 to 8 step 8 iter_args(%
5356// CHECK: affine.for %[[j:.*]] = 0 to 128 step 128 iter_args(%[[outer:.*]] =
5457// CHECK: %[[innerinit:.*]] = arith.constant dense<0> :
5558// CHECK: %[[x:.*]] = affine.for %[[k:.*]] = 0 to 1024 step 32 iter_args(%[[inner:.*]] = %[[innerinit]])
5659
57- // CHECK: %[[sliceA:.*]] = tensor.extract_slice %arg0 [%[[i]], %[[k]]] [8, 32] [1, 1] :
58- // CHECK: %[[sliceB:.*]] = tensor.extract_slice %arg1 [%[[k]], %[[j]]] [32, 128] [1, 1] :
60+ // CHECK: %[[sliceA:.*]] = tensor.extract_slice %[[a0]] [%[[i]], %[[k]]] [8, 32] [1, 1] :
61+ // CHECK: %[[sliceB:.*]] = tensor.extract_slice %[[b0]] [%[[k]], %[[j]]] [32, 128] [1, 1] :
5962// CHECK: %[[r:.*]] = cinm.op.gemm %[[sliceA]], %[[sliceB]] plus %[[inner]] {cinm.notile} :
6063// CHECK: affine.yield %[[r]]
6164// CHECK: tensor.insert_slice %[[x]] into %[[outer]][%[[i]], %[[j]]] [8, 128] [1, 1] :
6265func.func @gemm_tensor (%A: tensor <8 x1024 xi32 >, %B: tensor <1024 x128 xi32 >) -> tensor <8 x128 xi32 > {
63- %r0 = cinm.compute attributes { workgroupShape =array<i64 : 8 , 128 , 1 >, bufferSizesInBytes =array<i64 : 0 ,0 ,512 > } -> tensor < 8 x 128 x i32 > {
64- %r = cinm.op.gemm %A , %B : tensor <8 x1024 xi32 >, tensor <1024 x128 xi32 > -> tensor <8 x128 xi32 >
66+ %r0 = cinm.compute ( %a = %A: tensor < 8 x 1024 x i32 >, %b = %B: tensor < 1024 x 128 x i32 >) -> tensor < 8 x 128 x i32 > attributes { workgroupShape =array<i64 : 8 , 128 , 1 >, bufferSizesInBytes =array<i64 : 0 ,0 ,512 > } {
67+ %r = cinm.op.gemm %a , %b : tensor <8 x1024 xi32 >, tensor <1024 x128 xi32 > -> tensor <8 x128 xi32 >
6568 cinm.yield %r : tensor <8 x128 xi32 >
6669 }
6770 func.return %r0 : tensor <8 x128 xi32 >
6871}
6972
7073// -----
7174// CHECK-LABEL: @gemm_tensor_bias
72- // CHECK-SAME: ( {{.*}}, %[[bias :.*]]: tensor<8x128xi32>)
75+ // CHECK: cinm.compute (%[[a0:.*]] = % {{.*}}, %[[b0 :.*]] = %{{.*}}, %[[bias:.*]] = %{{.*}}) ->
7376// CHECK: affine.for %[[i:.*]] = 0 to 8 step 8 iter_args(%
7477// CHECK: affine.for %[[j:.*]] = 0 to 128 step 128 iter_args(%[[outer:.*]] =
7578// CHECK: %[[innerinit:.*]] = tensor.extract_slice %[[bias]][%[[i]], %[[j]]] [8, 128] [1, 1] :
7679// CHECK: %[[x:.*]] = affine.for %[[k:.*]] = 0 to 1024 step 32 iter_args(%[[inner:.*]] = %[[innerinit]])
7780
78- // CHECK: %[[sliceA:.*]] = tensor.extract_slice %arg0 [%[[i]], %[[k]]] [8, 32] [1, 1] :
79- // CHECK: %[[sliceB:.*]] = tensor.extract_slice %arg1 [%[[k]], %[[j]]] [32, 128] [1, 1] :
81+ // CHECK: %[[sliceA:.*]] = tensor.extract_slice %[[a0]] [%[[i]], %[[k]]] [8, 32] [1, 1] :
82+ // CHECK: %[[sliceB:.*]] = tensor.extract_slice %[[b0]] [%[[k]], %[[j]]] [32, 128] [1, 1] :
8083// CHECK: %[[r:.*]] = cinm.op.gemm %[[sliceA]], %[[sliceB]] plus %[[inner]] {cinm.notile} :
8184// CHECK: affine.yield %[[r]]
8285// CHECK: tensor.insert_slice %[[x]] into %[[outer]][%[[i]], %[[j]]] [8, 128] [1, 1] :
8386func.func @gemm_tensor_bias (%A: tensor <8 x1024 xi32 >, %B: tensor <1024 x128 xi32 >, %bias: tensor <8 x128 xi32 >) -> tensor <8 x128 xi32 > {
84- %r0 = cinm.compute attributes { workgroupShape =array<i64 : 8 , 128 , 1 >, bufferSizesInBytes =array<i64 : 0 ,0 ,512 > } -> tensor < 8 x 128 x i32 > {
85- %r = cinm.op.gemm %A , %B plus %bias : tensor <8 x1024 xi32 >, tensor <1024 x128 xi32 > plus tensor <8 x128 xi32 > -> tensor <8 x128 xi32 >
87+ %r0 = cinm.compute ( %a = %A: tensor < 8 x 1024 x i32 >, %b = %B: tensor < 1024 x 128 x i32 >, %c = %bias: tensor < 8 x 128 x i32 >) -> tensor < 8 x 128 x i32 > attributes { workgroupShape =array<i64 : 8 , 128 , 1 >, bufferSizesInBytes =array<i64 : 0 ,0 ,512 > } {
88+ %r = cinm.op.gemm %a , %b plus %c : tensor <8 x1024 xi32 >, tensor <1024 x128 xi32 > plus tensor <8 x128 xi32 > -> tensor <8 x128 xi32 >
8689 cinm.yield %r : tensor <8 x128 xi32 >
8790 }
8891 func.return %r0 : tensor <8 x128 xi32 >
0 commit comments