You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
5.8 KiB
120 lines
5.8 KiB
// RUN: mlir-opt %s -convert-linalg-to-std | FileCheck %s
|
|
|
|
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
|
|
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
|
|
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d2 * s2 + d1)>
|
|
// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
|
|
// CHECK-DAG: #[[$map6:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
|
|
// CHECK-DAG: #[[$map7:.*]] = affine_map<()[s0] -> (s0)>
|
|
// CHECK-DAG: #[[$map8:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
|
|
|
|
func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>,
|
|
%arg1: memref<?xf32, offset: ?, strides: [1]>,
|
|
%arg2: memref<f32>) {
|
|
linalg.dot ins(%arg0, %arg1: memref<?xf32, offset: ?, strides: [1]>,
|
|
memref<?xf32, offset: ?, strides: [1]>)
|
|
outs(%arg2: memref<f32>)
|
|
return
|
|
}
|
|
// CHECK-LABEL: func @dot(
|
|
// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, #[[$map0]]>,
|
|
// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, #[[$map0]]>,
|
|
// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
|
|
// CHECK: %[[o0:.*]] = memref_cast %[[arg0]] :
|
|
// CHECK-SAME: memref<?xf32, #[[$map0]]> to memref<?xf32, #[[$map6]]>
|
|
// CHECK: %[[o1:.*]] = memref_cast %[[arg1]] :
|
|
// CHECK-SAME: memref<?xf32, #[[$map0]]> to memref<?xf32, #[[$map6]]>
|
|
// CHECK: %[[o2:.*]] = memref_cast %[[arg2]] :
|
|
// CHECK-SAME: memref<f32> to memref<f32, #[[$map7]]>
|
|
// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32(
|
|
// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) :
|
|
// CHECK-SAME: memref<?xf32, #[[$map6]]>, memref<?xf32, #[[$map6]]>, memref<f32, #[[$map7]]>
|
|
|
|
func @copy(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
|
linalg.copy(%arg0, %arg1) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
|
return
|
|
}
|
|
// CHECK-LABEL: func @copy(
|
|
// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[$map1]]>,
|
|
// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[$map1]]>) {
|
|
// CHECK: %[[o0:.*]] = memref_cast %[[arg0]] :
|
|
// CHECK-SAME: memref<?x?x?xf32, #[[$map1]]> to memref<?x?x?xf32, #[[$map8]]>
|
|
// CHECK: %[[o1:.*]] = memref_cast %[[arg1]] :
|
|
// CHECK-SAME: memref<?x?x?xf32, #[[$map1]]> to memref<?x?x?xf32, #[[$map8]]>
|
|
// CHECK: call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%[[o0]], %[[o1]]) :
|
|
// CHECK-SAME: memref<?x?x?xf32, #[[$map8]]>, memref<?x?x?xf32, #[[$map8]]>
|
|
|
|
func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
|
linalg.copy(%arg0, %arg1) {inputPermutation = affine_map<(i, j, k) -> (i, k, j)>,
|
|
outputPermutation = affine_map<(i, j, k) -> (k, j, i)>}
|
|
: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
|
return
|
|
}
|
|
// CHECK-LABEL: func @copy_transpose(
|
|
// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[$map1]]>,
|
|
// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[$map1]]>) {
|
|
// CHECK: %[[t0:.*]] = transpose %[[arg0]]
|
|
// CHECK-SAME: (d0, d1, d2) -> (d0, d2, d1) : memref<?x?x?xf32, #[[$map1]]>
|
|
// CHECK: %[[t1:.*]] = transpose %[[arg1]]
|
|
// CHECK-SAME: (d0, d1, d2) -> (d2, d1, d0) : memref<?x?x?xf32, #[[$map1]]>
|
|
// CHECK: %[[o0:.*]] = memref_cast %[[t0]] :
|
|
// CHECK-SAME: memref<?x?x?xf32, #[[$map2]]> to memref<?x?x?xf32, #[[$map8]]>
|
|
// CHECK: %[[o1:.*]] = memref_cast %[[t1]] :
|
|
// CHECK-SAME: memref<?x?x?xf32, #[[$map4]]> to memref<?x?x?xf32, #[[$map8]]>
|
|
// CHECK: call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%[[o0]], %[[o1]]) :
|
|
// CHECK-SAME: memref<?x?x?xf32, #[[$map8]]>, memref<?x?x?xf32, #[[$map8]]>
|
|
|
|
#matmul_accesses = [
|
|
affine_map<(m, n, k) -> (m, k)>,
|
|
affine_map<(m, n, k) -> (k, n)>,
|
|
affine_map<(m, n, k) -> (m, n)>
|
|
]
|
|
#matmul_trait = {
|
|
iterator_types = ["parallel", "parallel", "reduction"],
|
|
indexing_maps = #matmul_accesses,
|
|
library_call = "external_outerproduct_matmul"
|
|
}
|
|
|
|
!vector_type_A = type vector<4xf32>
|
|
!vector_type_B = type vector<4xf32>
|
|
!vector_type_C = type vector<4x4xf32>
|
|
|
|
!matrix_type_A = type memref<?x?x!vector_type_A>
|
|
!matrix_type_B = type memref<?x?x!vector_type_B>
|
|
!matrix_type_C = type memref<?x?x!vector_type_C>
|
|
|
|
func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
|
|
linalg.generic #matmul_trait
|
|
ins(%A, %B : !matrix_type_A, !matrix_type_B)
|
|
outs(%C : !matrix_type_C) {
|
|
^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
|
|
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
|
|
linalg.yield %d: !vector_type_C
|
|
}
|
|
return
|
|
}
|
|
// CHECK-LABEL: func @matmul_vec_impl(
|
|
// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
|
|
|
|
#indexed_matmul_trait = {
|
|
iterator_types = ["parallel", "parallel", "reduction"],
|
|
indexing_maps = #matmul_accesses,
|
|
library_call = "external_indexed_outerproduct_matmul"
|
|
}
|
|
func @matmul_vec_indexed(%A: !matrix_type_A,
|
|
%B: !matrix_type_B,
|
|
%C: !matrix_type_C) {
|
|
linalg.indexed_generic #indexed_matmul_trait
|
|
ins(%A, %B : !matrix_type_A, !matrix_type_B)
|
|
outs(%C : !matrix_type_C) {
|
|
^bb0(%i: index, %j: index, %k: index,
|
|
%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
|
|
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
|
|
linalg.yield %d: !vector_type_C
|
|
}
|
|
return
|
|
}
|
|
// CHECK-LABEL: func @matmul_vec_indexed(
|
|
// CHECK: %[[ZERO:.*]] = constant 0 : index
|
|
// CHECK: call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}})
|