// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // simd_wrappers.h: some inline functions wrapping SIMD intrinsics, // extending the set of such functions from fixedpoint.h. #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ #include #include #include "../fixedpoint/fixedpoint.h" namespace gemmlowp { template struct RegisterType { using Type = ScalarType; }; inline std::int32_t Min(std::int32_t a, std::int32_t b) { return std::min(a, b); } inline std::int32_t Max(std::int32_t a, std::int32_t b) { return std::max(a, b); } inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) { *acc += lhs * rhs; } template struct RegisterBuffer { using ScalarType = tScalarType; static constexpr int kScalarCount = tScalarCount; using RegisterType = typename RegisterType::Type; static_assert((kScalarCount & (kScalarCount - 1)) == 0, "kScalarCount must be a power of two"); static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, ""); static constexpr int kRegisterLanes = sizeof(RegisterType) / sizeof(ScalarType); static constexpr int kRegisterCount = (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) / sizeof(RegisterType); RegisterType reg[kRegisterCount]; }; template struct RegisterBlock { using ScalarType = tScalarType; static constexpr int kRows = tRows; static constexpr int kCols = tCols; static constexpr int kScalarCount = kRows * kCols; using BufferType = RegisterBuffer; using RegisterType = typename BufferType::RegisterType; static constexpr int kRegisterCount = BufferType::kRegisterCount; static constexpr int kRegisterLanes = BufferType::kRegisterLanes; BufferType buf; }; template struct RegisterBlockAddImpl { static RegisterBlockType Run(const RegisterBlockType& lhs, const RegisterBlockType& rhs) { RegisterBlockType result; for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); } return result; } }; template RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs, const RegisterBlockType& rhs) { return RegisterBlockAddImpl::Run(lhs, rhs); } template struct ShouldFlipLhsRhs { static constexpr bool kValue = (LhsType::kScalarCount < RhsType::kScalarCount) || (LhsType::kScalarCount == RhsType::kScalarCount && (LhsType::kRows < RhsType::kRows)); }; template ::kValue> struct FlipLhsRhs { using FlippedLhsType = LhsType; using FlippedRhsType = RhsType; static const FlippedLhsType& FlippedLhs(const LhsType& lhs, const RhsType& rhs) { (void)rhs; return lhs; } static const FlippedRhsType& FlippedRhs(const LhsType& lhs, const RhsType& rhs) { (void)lhs; return rhs; } }; template struct FlipLhsRhs { using FlippedLhsType = RhsType; using FlippedRhsType = LhsType; static const FlippedLhsType& FlippedLhs(const LhsType& lhs, const RhsType& rhs) { (void)lhs; return rhs; } static const FlippedRhsType& FlippedRhs(const LhsType& lhs, const RhsType& rhs) { (void)rhs; return lhs; } }; template struct BroadcastBinaryOpShape { static constexpr int kRows = Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows; static constexpr int kCols = Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols; }; template struct BroadcastBinaryOpRegisterBlock { using Shape = BroadcastBinaryOpShape; using ScalarType = typename Lhs::ScalarType; using Type = RegisterBlock; }; template struct BroadcastAddImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock::Type; static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { ResultBlockType result; static constexpr int Rows = ResultBlockType::kRows; static constexpr int Cols = ResultBlockType::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; result.buf.reg[r + c * Rows] = Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows]); } } return result; } }; template typename BroadcastBinaryOpRegisterBlock::Type BroadcastAdd( const Lhs& lhs, const Rhs& rhs) { using Flip = FlipLhsRhs; return BroadcastAddImpl< typename Flip::FlippedLhsType, typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs)); } template struct BroadcastShiftLeftImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock::Type; static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { ResultBlockType result; static constexpr int Rows = ResultBlockType::kRows; static constexpr int Cols = ResultBlockType::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; result.buf.reg[r + c * Rows] = ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows]); } } return result; } }; template typename BroadcastBinaryOpRegisterBlock::Type BroadcastShiftLeft( const Lhs& lhs, const Rhs& rhs) { using Flip = FlipLhsRhs; return BroadcastShiftLeftImpl< typename Flip::FlippedLhsType, typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs)); } template struct BroadcastSaturatingRoundingDoublingHighMulImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock::Type; static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { ResultBlockType result; static constexpr int Rows = ResultBlockType::kRows; static constexpr int Cols = ResultBlockType::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul( lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows]); } } return result; } }; template typename BroadcastBinaryOpRegisterBlock::Type BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) { using Flip = FlipLhsRhs; return BroadcastSaturatingRoundingDoublingHighMulImpl< typename Flip::FlippedLhsType, typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs)); } template struct BroadcastRoundingDivideByPOTImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock::Type; static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { ResultBlockType result; static constexpr int Rows = ResultBlockType::kRows; static constexpr int Cols = ResultBlockType::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; result.buf.reg[r + c * Rows] = RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows]); } } return result; } }; template typename BroadcastBinaryOpRegisterBlock::Type BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) { using Flip = FlipLhsRhs; return BroadcastRoundingDivideByPOTImpl< typename Flip::FlippedLhsType, typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs)); } template struct BroadcastMulImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock::Type; static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { ResultBlockType result; static constexpr int Rows = ResultBlockType::kRows; static constexpr int Cols = ResultBlockType::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; result.buf.reg[r + c * Rows] = Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows]); } } return result; } }; template typename BroadcastBinaryOpRegisterBlock::Type BroadcastMul( const Lhs& lhs, const Rhs& rhs) { using Flip = FlipLhsRhs; return BroadcastMulImpl< typename Flip::FlippedLhsType, typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs)); } template struct BroadcastMulAddImpl { static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) { static constexpr int Rows = Acc::kRows; static constexpr int Cols = Acc::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(Acc::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows], &acc->buf.reg[r + c * Rows]); } } } }; template void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) { using Flip = FlipLhsRhs; BroadcastMulAddImpl::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs), acc); } template struct LoadImpl { static_assert(std::is_same::value, "This generic impl should never be hit"); }; template struct LoadImpl, MatrixMap> { using RegisterBlockType = RegisterBlock; using SrcObjectType = MatrixMap; static RegisterBlockType Run(const SrcObjectType& src, int row, int col) { RegisterBlockType result; int i = 0; for (int c = 0; c < Cols; c++) { const ScalarType* src_ptr = src.data(row, col + c); for (int r = 0; r < Rows; r++) { result.buf.reg[i++] = *src_ptr++; } } return result; } }; template struct LoadImpl, VectorMap> { using RegisterBlockType = RegisterBlock; using SrcObjectType = VectorMap; static RegisterBlockType Run(const SrcObjectType& src, int pos) { static_assert(Shape == VectorShape::Col || Rows == 1, ""); static_assert(Shape == VectorShape::Row || Cols == 1, ""); RegisterBlockType result; for (int i = 0; i < Rows * Cols; i++) { result.buf.reg[i] = src(pos + i); } return result; } }; template struct LoadImpl, VectorDup> { using RegisterBlockType = RegisterBlock; using SrcObjectType = VectorDup; static RegisterBlockType Run(const SrcObjectType& src, int) { static_assert(Shape == VectorShape::Col || Rows == 1, ""); static_assert(Shape == VectorShape::Row || Cols == 1, ""); RegisterBlockType result; for (int i = 0; i < Rows * Cols; i++) { result.buf.reg[i] = src(0); } return result; } }; template RegisterBlockType Load(const SrcObjectType& src, int row, int col) { return LoadImpl::Run(src, row, col); } template RegisterBlockType Load(const SrcObjectType& src, int pos) { return LoadImpl::Run(src, pos); } template struct LoadContiguousImpl { using ScalarType = typename RegisterBlockType::ScalarType; static_assert(RegisterBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static RegisterBlockType Run(const ScalarType* src) { RegisterBlockType result; for (int i = 0; i < RegisterBlockType::kScalarCount; i++) { result.buf.reg[i] = src[i]; } return result; } }; template RegisterBlockType LoadContiguous( const typename RegisterBlockType::ScalarType* src) { return LoadContiguousImpl::Run(src); } template struct LoadForBroadcastingShape {}; template struct LoadForBroadcastingShape> { static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1; static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1; }; template struct LoadForBroadcastingShape> { static constexpr int kRows = 1; static constexpr int kCols = 1; }; template struct LoadForBroadcastingRegisterBlock { using Shape = LoadForBroadcastingShape; using ScalarType = typename RegisterBlockType::ScalarType; using Type = RegisterBlock; }; template struct LoadForBroadcastingImpl { static_assert(std::is_same::value, "This generic impl should never be hit"); }; template struct LoadForBroadcastingImpl, VectorMap> { using RegisterBlockType = RegisterBlock; using SrcObjectType = VectorMap; using ResultBlockType = typename LoadForBroadcastingRegisterBlock::Type; static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static ResultBlockType Run(const SrcObjectType& src, int pos) { ResultBlockType result; for (int c = 0; c < ResultBlockType::kCols; c++) { for (int r = 0; r < ResultBlockType::kRows; r++) { const int i = Shape == VectorShape::Col ? r : c; result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i); } } return result; } }; template struct LoadForBroadcastingImpl, VectorDup> { using RegisterBlockType = RegisterBlock; using SrcObjectType = VectorDup; using ResultBlockType = typename LoadForBroadcastingRegisterBlock::Type; static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static ResultBlockType Run(const SrcObjectType& src, int) { ResultBlockType result; for (int c = 0; c < ResultBlockType::kCols; c++) { for (int r = 0; r < ResultBlockType::kRows; r++) { result.buf.reg[r + c * ResultBlockType::kRows] = src(0); } } return result; } }; template typename LoadForBroadcastingRegisterBlock::Type LoadForBroadcasting(const SrcObjectType& src, int row, int col) { return LoadForBroadcastingImpl::Run( src, row, col); } template typename LoadForBroadcastingRegisterBlock::Type LoadForBroadcasting(const SrcObjectType& src, int pos) { return LoadForBroadcastingImpl::Run(src, pos); } template struct AddConstantImpl { static void Run(RegisterBlockType* block) { using RegisterType = typename RegisterBlockType::RegisterType; const RegisterType dup = Dup(ConstantValue); for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { block->buf.reg[i] = Add(block->buf.reg[i], dup); } } }; template struct AddConstantImpl<0, RegisterBlockType> { static void Run(RegisterBlockType*) { // This is a no-op. } }; template void AddConstant(RegisterBlockType* block) { AddConstantImpl::Run(block); } template using RegBufferInt32 = RegisterBuffer; template using RegBufferInt16 = RegisterBuffer; template using RegBufferUint8 = RegisterBuffer; template using RegBufferInt8 = RegisterBuffer; template using RegBlockInt32 = RegisterBlock; template using RegBlockInt16 = RegisterBlock; template using RegBlockUint8 = RegisterBlock; template using RegBlockInt8 = RegisterBlock; } // end namespace gemmlowp #if defined GEMMLOWP_NEON #include "simd_wrappers_neon.h" #elif defined GEMMLOWP_SSE4 #include "simd_wrappers_sse.h" #elif defined GEMMLOWP_MSA #include "simd_wrappers_msa.h" #endif #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_