// RUN: mlir-opt %s -inline | FileCheck %s // These tests verify that regions with operations from TOSA dialect // can be inlined. // CHECK-LABEL: func @inlined_if_fn // Check that both the calls and the functions are eliminated after inlining: // CHECK-NOT: @add // CHECK-NOT: @sub func @inlined_if_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors %1 = call @add(%arg3, %arg4) : (tensor, tensor) -> tensor "tosa.yield"(%1) : (tensor) -> () }, { ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors %1 = call @sub(%arg3, %arg4) : (tensor, tensor) -> tensor "tosa.yield"(%1) : (tensor) -> () }) : (tensor, tensor, tensor) -> tensor return %0 : tensor } func @add(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } func @sub(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { %0 = "tosa.sub"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } // ----- // CHECK-LABEL: func @inlined_while_fn func @inlined_while_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor<10xi32> { // Check that calls are inlined and functions eliminated: // CHECK-NOT: @while %1:4 = "tosa.while_loop"(%arg0, %arg1, %arg2, %arg3) ( { ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<10xi32>): // no predecessors %2 = call @while_cond_40(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> tensor "tosa.yield"(%2) : (tensor) -> () }, { ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<10xi32>): // no predecessors %2:4 = call @while_body_50(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) "tosa.yield"(%2#0, %2#1, %2#2, %2#3) : (tensor, tensor, tensor, tensor<10xi32>) -> () }) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) return %1#3 : tensor<10xi32> } func @while_body_50(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) attributes {sym_visibility = "private"} { %1 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor) -> tensor<10xi32> return %1, %arg1, %arg2, %2: tensor, tensor, tensor, tensor<10xi32> } func @while_cond_40(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor attributes {sym_visibility = "private"} { %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor, tensor) -> tensor %1 = "tosa.logical_not"(%0) : (tensor) -> tensor return %1 : tensor }