//===- OpReducer.cpp - Operation Reducer ------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the OpReducer class. It defines a variant generator method // with the purpose of producing different variants by eliminating a // parameterizable type of operations from the parent module. // //===----------------------------------------------------------------------===// #include "mlir/Reducer/Passes/OpReducer.h" using namespace mlir; OpReducerImpl::OpReducerImpl( llvm::function_ref(ModuleOp)> getSpecificOps) : getSpecificOps(getSpecificOps) {} /// Return the name of this reducer class. StringRef OpReducerImpl::getName() { return StringRef("High Level Operation Reduction"); } /// Return the initial transformSpace containing the transformable indices. std::vector OpReducerImpl::initTransformSpace(ModuleOp module) { auto ops = getSpecificOps(module); int numOps = std::distance(ops.begin(), ops.end()); return ReductionTreeUtils::createTransformSpace(module, numOps); } /// Generate variants by removing opType operations from the module in the /// parent and link the variants as childs in the Reduction Tree Pass. void OpReducerImpl::generateVariants( ReductionNode *parent, const Tester &test, int numVariants, llvm::function_ref transform) { ReductionTreeUtils::createVariants(parent, test, numVariants, transform, true); }