You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
7.8 KiB
208 lines
7.8 KiB
// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
|
|
|
|
#ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
|
|
#define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
|
|
|
|
#include "../internal/kernel_default.h"
|
|
#include "../public/map.h"
|
|
#include "../public/output_stages.h"
|
|
#include "multi_thread_gemm.h"
|
|
|
|
namespace gemmlowp {
|
|
|
|
template <typename T>
|
|
struct TransposeImpl {
|
|
typedef T DstType;
|
|
static T Run(const T& t) { return t; }
|
|
};
|
|
|
|
template <typename T>
|
|
using TransposeType = typename TransposeImpl<T>::DstType;
|
|
|
|
template <typename T>
|
|
TransposeType<T> Transpose(const T& t) {
|
|
return TransposeImpl<T>::Run(t);
|
|
}
|
|
|
|
template <MapOrder Order>
|
|
struct TransposeMapOrder {
|
|
static constexpr MapOrder Value =
|
|
Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
|
|
};
|
|
|
|
template <VectorShape Shape>
|
|
struct TransposeVectorShape {
|
|
static constexpr VectorShape Value =
|
|
Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
|
|
};
|
|
|
|
template <typename Scalar, VectorShape Shape>
|
|
struct TransposeImpl<VectorMap<Scalar, Shape>> {
|
|
typedef VectorMap<Scalar, Shape> SrcType;
|
|
static constexpr VectorShape TransposedShape =
|
|
TransposeVectorShape<Shape>::Value;
|
|
typedef VectorMap<Scalar, TransposedShape> DstType;
|
|
static DstType Run(const SrcType& src) {
|
|
return DstType(src.data(), src.size());
|
|
}
|
|
};
|
|
|
|
template <typename Scalar, MapOrder Order>
|
|
struct TransposeImpl<MatrixMap<Scalar, Order>> {
|
|
typedef MatrixMap<Scalar, Order> SrcType;
|
|
static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
|
|
typedef MatrixMap<Scalar, TransposedOrder> DstType;
|
|
static DstType Run(const SrcType& src) {
|
|
return DstType(src.data(), src.cols(), src.rows(), src.stride());
|
|
}
|
|
};
|
|
|
|
template <VectorShape Shape>
|
|
struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
|
|
typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
|
|
static constexpr VectorShape TransposedShape =
|
|
TransposeVectorShape<Shape>::Value;
|
|
typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
|
|
static DstType Run(const SrcType& src) {
|
|
DstType dst;
|
|
dst.result_shift = src.result_shift;
|
|
dst.result_offset = Transpose(src.result_offset);
|
|
dst.result_mult_int = Transpose(src.result_mult_int);
|
|
return dst;
|
|
}
|
|
};
|
|
|
|
template <VectorShape Shape>
|
|
struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
|
|
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
|
|
static constexpr VectorShape TransposedShape =
|
|
TransposeVectorShape<Shape>::Value;
|
|
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
|
|
DstType;
|
|
static DstType Run(const SrcType& src) {
|
|
DstType dst;
|
|
dst.result_fixedpoint_multiplier =
|
|
Transpose(src.result_fixedpoint_multiplier);
|
|
dst.result_exponent = Transpose(src.result_exponent);
|
|
dst.result_offset_after_shift = src.result_offset_after_shift;
|
|
return dst;
|
|
}
|
|
};
|
|
|
|
template <typename VectorMapType>
|
|
struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
|
|
typedef OutputStageBiasAddition<VectorMapType> SrcType;
|
|
typedef TransposeType<VectorMapType> TransposedVectorMapType;
|
|
typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
|
|
static DstType Run(const SrcType& src) {
|
|
DstType dst;
|
|
dst.bias_vector = Transpose(src.bias_vector);
|
|
return dst;
|
|
}
|
|
};
|
|
|
|
// TODO(benoitjacob) - does anyone understand C++ variadic templates?
|
|
// How to use them to implement TransposeTuple? Note: there are lots
|
|
// of answers on StackOverflow but they seem to all involve either
|
|
// C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
|
|
inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
|
|
|
|
template <typename T0>
|
|
std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
|
|
return std::make_tuple(Transpose(std::get<0>(t)));
|
|
}
|
|
|
|
template <typename T0, typename T1>
|
|
std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
|
|
const std::tuple<T0, T1>& t) {
|
|
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
|
|
}
|
|
|
|
template <typename T0, typename T1, typename T2>
|
|
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
|
|
TransposeTuple(const std::tuple<T0, T1, T2>& t) {
|
|
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
|
|
Transpose(std::get<2>(t)));
|
|
}
|
|
|
|
template <typename T0, typename T1, typename T2, typename T3>
|
|
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
|
|
TransposeType<T3>>
|
|
TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
|
|
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
|
|
Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
|
|
}
|
|
|
|
template <typename T0, typename T1, typename T2, typename T3, typename T4>
|
|
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
|
|
TransposeType<T3>, TransposeType<T4>>
|
|
TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
|
|
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
|
|
Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
|
|
Transpose(std::get<4>(t)));
|
|
}
|
|
|
|
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
|
typename T5>
|
|
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
|
|
TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
|
|
TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
|
|
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
|
|
Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
|
|
Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
|
|
}
|
|
|
|
template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
|
|
MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
|
|
typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
|
|
typename GemmContextType>
|
|
void DispatchGemmShape(GemmContextType* context,
|
|
const MatrixMap<const InputScalar, LhsOrder>& lhs,
|
|
const MatrixMap<const InputScalar, RhsOrder>& rhs,
|
|
MatrixMap<OutputScalar, ResultOrder>* result,
|
|
const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
|
|
const OutputPipelineType& output_pipeline) {
|
|
assert(lhs.cols() == rhs.rows());
|
|
|
|
int rows = result->rows();
|
|
int cols = result->cols();
|
|
int depth = lhs.cols();
|
|
|
|
if (rows == 0 || cols == 0 || depth == 0) {
|
|
// Vacuous GEMM, return early to avoid having to deal with
|
|
// zero sizes below.
|
|
return;
|
|
}
|
|
|
|
if (rows < cols) {
|
|
auto transposed_result_map = Transpose(*result);
|
|
return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
|
|
context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
|
|
Transpose(rhs_offset), Transpose(lhs_offset),
|
|
TransposeTuple(output_pipeline));
|
|
}
|
|
|
|
typedef DefaultKernel<BitDepthParams> Kernel;
|
|
MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
|
|
BitDepthParams>(context, Kernel(), lhs, rhs, result,
|
|
lhs_offset, rhs_offset, output_pipeline);
|
|
}
|
|
|
|
} // end namespace gemmlowp
|
|
|
|
#endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
|