You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
670 lines
25 KiB
670 lines
25 KiB
// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
|
|
// extending the set of such functions from fixedpoint.h.
|
|
|
|
#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
|
|
#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
|
|
|
|
#include <algorithm>
|
|
#include <type_traits>
|
|
#include "../fixedpoint/fixedpoint.h"
|
|
|
|
namespace gemmlowp {
|
|
|
|
template <typename ScalarType, int ScalarCount>
|
|
struct RegisterType {
|
|
using Type = ScalarType;
|
|
};
|
|
|
|
inline std::int32_t Min(std::int32_t a, std::int32_t b) {
|
|
return std::min(a, b);
|
|
}
|
|
|
|
inline std::int32_t Max(std::int32_t a, std::int32_t b) {
|
|
return std::max(a, b);
|
|
}
|
|
|
|
inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
|
|
*acc += lhs * rhs;
|
|
}
|
|
|
|
template <typename tScalarType, int tScalarCount>
|
|
struct RegisterBuffer {
|
|
using ScalarType = tScalarType;
|
|
static constexpr int kScalarCount = tScalarCount;
|
|
using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
|
|
static_assert((kScalarCount & (kScalarCount - 1)) == 0,
|
|
"kScalarCount must be a power of two");
|
|
static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
|
|
static constexpr int kRegisterLanes =
|
|
sizeof(RegisterType) / sizeof(ScalarType);
|
|
static constexpr int kRegisterCount =
|
|
(kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
|
|
sizeof(RegisterType);
|
|
|
|
RegisterType reg[kRegisterCount];
|
|
};
|
|
|
|
template <typename tScalarType, int tRows, int tCols>
|
|
struct RegisterBlock {
|
|
using ScalarType = tScalarType;
|
|
static constexpr int kRows = tRows;
|
|
static constexpr int kCols = tCols;
|
|
static constexpr int kScalarCount = kRows * kCols;
|
|
using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
|
|
using RegisterType = typename BufferType::RegisterType;
|
|
static constexpr int kRegisterCount = BufferType::kRegisterCount;
|
|
static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
|
|
|
|
BufferType buf;
|
|
};
|
|
|
|
template <typename RegisterBlockType>
|
|
struct RegisterBlockAddImpl {
|
|
static RegisterBlockType Run(const RegisterBlockType& lhs,
|
|
const RegisterBlockType& rhs) {
|
|
RegisterBlockType result;
|
|
for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
|
|
result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename RegisterBlockType>
|
|
RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
|
|
const RegisterBlockType& rhs) {
|
|
return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
|
|
}
|
|
|
|
template <typename LhsType, typename RhsType>
|
|
struct ShouldFlipLhsRhs {
|
|
static constexpr bool kValue =
|
|
(LhsType::kScalarCount < RhsType::kScalarCount) ||
|
|
(LhsType::kScalarCount == RhsType::kScalarCount &&
|
|
(LhsType::kRows < RhsType::kRows));
|
|
};
|
|
|
|
template <typename LhsType, typename RhsType,
|
|
bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
|
|
struct FlipLhsRhs {
|
|
using FlippedLhsType = LhsType;
|
|
using FlippedRhsType = RhsType;
|
|
static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
|
|
const RhsType& rhs) {
|
|
(void)rhs;
|
|
return lhs;
|
|
}
|
|
static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
|
|
const RhsType& rhs) {
|
|
(void)lhs;
|
|
return rhs;
|
|
}
|
|
};
|
|
|
|
template <typename LhsType, typename RhsType>
|
|
struct FlipLhsRhs<LhsType, RhsType, true> {
|
|
using FlippedLhsType = RhsType;
|
|
using FlippedRhsType = LhsType;
|
|
static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
|
|
const RhsType& rhs) {
|
|
(void)lhs;
|
|
return rhs;
|
|
}
|
|
static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
|
|
const RhsType& rhs) {
|
|
(void)rhs;
|
|
return lhs;
|
|
}
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
struct BroadcastBinaryOpShape {
|
|
static constexpr int kRows =
|
|
Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
|
|
static constexpr int kCols =
|
|
Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
struct BroadcastBinaryOpRegisterBlock {
|
|
using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
|
|
using ScalarType = typename Lhs::ScalarType;
|
|
using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
struct BroadcastAddImpl {
|
|
using ResultBlockType =
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
|
|
static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
|
|
ResultBlockType result;
|
|
static constexpr int Rows = ResultBlockType::kRows;
|
|
static constexpr int Cols = ResultBlockType::kCols;
|
|
static constexpr int LhsRows = Lhs::kRows;
|
|
static constexpr int LhsCols = Lhs::kCols;
|
|
static constexpr int RhsRows = Rhs::kRows;
|
|
static constexpr int RhsCols = Rhs::kCols;
|
|
|
|
static_assert(LhsRows == Rows || LhsRows == 1, "");
|
|
static_assert(RhsRows == Rows || RhsRows == 1, "");
|
|
static_assert(LhsCols == Cols || LhsCols == 1, "");
|
|
static_assert(RhsCols == Cols || RhsCols == 1, "");
|
|
static_assert(ResultBlockType::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Lhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Rhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
|
|
for (int c = 0; c < Cols; c++) {
|
|
const int lhs_c = LhsCols == Cols ? c : 0;
|
|
const int rhs_c = RhsCols == Cols ? c : 0;
|
|
for (int r = 0; r < Rows; r++) {
|
|
const int lhs_r = LhsRows == Rows ? r : 0;
|
|
const int rhs_r = RhsRows == Rows ? r : 0;
|
|
result.buf.reg[r + c * Rows] =
|
|
Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
|
|
rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
|
|
const Lhs& lhs, const Rhs& rhs) {
|
|
using Flip = FlipLhsRhs<Lhs, Rhs>;
|
|
return BroadcastAddImpl<
|
|
typename Flip::FlippedLhsType,
|
|
typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
|
|
Flip::FlippedRhs(lhs, rhs));
|
|
}
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
struct BroadcastShiftLeftImpl {
|
|
using ResultBlockType =
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
|
|
static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
|
|
ResultBlockType result;
|
|
static constexpr int Rows = ResultBlockType::kRows;
|
|
static constexpr int Cols = ResultBlockType::kCols;
|
|
static constexpr int LhsRows = Lhs::kRows;
|
|
static constexpr int LhsCols = Lhs::kCols;
|
|
static constexpr int RhsRows = Rhs::kRows;
|
|
static constexpr int RhsCols = Rhs::kCols;
|
|
|
|
static_assert(LhsRows == Rows || LhsRows == 1, "");
|
|
static_assert(RhsRows == Rows || RhsRows == 1, "");
|
|
static_assert(LhsCols == Cols || LhsCols == 1, "");
|
|
static_assert(RhsCols == Cols || RhsCols == 1, "");
|
|
static_assert(ResultBlockType::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Lhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Rhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
|
|
for (int c = 0; c < Cols; c++) {
|
|
const int lhs_c = LhsCols == Cols ? c : 0;
|
|
const int rhs_c = RhsCols == Cols ? c : 0;
|
|
for (int r = 0; r < Rows; r++) {
|
|
const int lhs_r = LhsRows == Rows ? r : 0;
|
|
const int rhs_r = RhsRows == Rows ? r : 0;
|
|
result.buf.reg[r + c * Rows] =
|
|
ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
|
|
rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft(
|
|
const Lhs& lhs, const Rhs& rhs) {
|
|
using Flip = FlipLhsRhs<Lhs, Rhs>;
|
|
return BroadcastShiftLeftImpl<
|
|
typename Flip::FlippedLhsType,
|
|
typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
|
|
Flip::FlippedRhs(lhs, rhs));
|
|
}
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
struct BroadcastSaturatingRoundingDoublingHighMulImpl {
|
|
using ResultBlockType =
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
|
|
static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
|
|
ResultBlockType result;
|
|
static constexpr int Rows = ResultBlockType::kRows;
|
|
static constexpr int Cols = ResultBlockType::kCols;
|
|
static constexpr int LhsRows = Lhs::kRows;
|
|
static constexpr int LhsCols = Lhs::kCols;
|
|
static constexpr int RhsRows = Rhs::kRows;
|
|
static constexpr int RhsCols = Rhs::kCols;
|
|
|
|
static_assert(LhsRows == Rows || LhsRows == 1, "");
|
|
static_assert(RhsRows == Rows || RhsRows == 1, "");
|
|
static_assert(LhsCols == Cols || LhsCols == 1, "");
|
|
static_assert(RhsCols == Cols || RhsCols == 1, "");
|
|
static_assert(ResultBlockType::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Lhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Rhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
|
|
for (int c = 0; c < Cols; c++) {
|
|
const int lhs_c = LhsCols == Cols ? c : 0;
|
|
const int rhs_c = RhsCols == Cols ? c : 0;
|
|
for (int r = 0; r < Rows; r++) {
|
|
const int lhs_r = LhsRows == Rows ? r : 0;
|
|
const int rhs_r = RhsRows == Rows ? r : 0;
|
|
result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul(
|
|
lhs.buf.reg[lhs_r + lhs_c * LhsRows],
|
|
rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
|
|
BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) {
|
|
using Flip = FlipLhsRhs<Lhs, Rhs>;
|
|
return BroadcastSaturatingRoundingDoublingHighMulImpl<
|
|
typename Flip::FlippedLhsType,
|
|
typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
|
|
Flip::FlippedRhs(lhs, rhs));
|
|
}
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
struct BroadcastRoundingDivideByPOTImpl {
|
|
using ResultBlockType =
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
|
|
static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
|
|
ResultBlockType result;
|
|
static constexpr int Rows = ResultBlockType::kRows;
|
|
static constexpr int Cols = ResultBlockType::kCols;
|
|
static constexpr int LhsRows = Lhs::kRows;
|
|
static constexpr int LhsCols = Lhs::kCols;
|
|
static constexpr int RhsRows = Rhs::kRows;
|
|
static constexpr int RhsCols = Rhs::kCols;
|
|
|
|
static_assert(LhsRows == Rows || LhsRows == 1, "");
|
|
static_assert(RhsRows == Rows || RhsRows == 1, "");
|
|
static_assert(LhsCols == Cols || LhsCols == 1, "");
|
|
static_assert(RhsCols == Cols || RhsCols == 1, "");
|
|
static_assert(ResultBlockType::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Lhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Rhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
|
|
for (int c = 0; c < Cols; c++) {
|
|
const int lhs_c = LhsCols == Cols ? c : 0;
|
|
const int rhs_c = RhsCols == Cols ? c : 0;
|
|
for (int r = 0; r < Rows; r++) {
|
|
const int lhs_r = LhsRows == Rows ? r : 0;
|
|
const int rhs_r = RhsRows == Rows ? r : 0;
|
|
result.buf.reg[r + c * Rows] =
|
|
RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
|
|
rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
|
|
BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) {
|
|
using Flip = FlipLhsRhs<Lhs, Rhs>;
|
|
return BroadcastRoundingDivideByPOTImpl<
|
|
typename Flip::FlippedLhsType,
|
|
typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
|
|
Flip::FlippedRhs(lhs, rhs));
|
|
}
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
struct BroadcastMulImpl {
|
|
using ResultBlockType =
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
|
|
static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
|
|
ResultBlockType result;
|
|
static constexpr int Rows = ResultBlockType::kRows;
|
|
static constexpr int Cols = ResultBlockType::kCols;
|
|
static constexpr int LhsRows = Lhs::kRows;
|
|
static constexpr int LhsCols = Lhs::kCols;
|
|
static constexpr int RhsRows = Rhs::kRows;
|
|
static constexpr int RhsCols = Rhs::kCols;
|
|
static_assert(ResultBlockType::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Lhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Rhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
|
|
static_assert(LhsRows == Rows || LhsRows == 1, "");
|
|
static_assert(RhsRows == Rows || RhsRows == 1, "");
|
|
static_assert(LhsCols == Cols || LhsCols == 1, "");
|
|
static_assert(RhsCols == Cols || RhsCols == 1, "");
|
|
for (int c = 0; c < Cols; c++) {
|
|
const int lhs_c = LhsCols == Cols ? c : 0;
|
|
const int rhs_c = RhsCols == Cols ? c : 0;
|
|
for (int r = 0; r < Rows; r++) {
|
|
const int lhs_r = LhsRows == Rows ? r : 0;
|
|
const int rhs_r = RhsRows == Rows ? r : 0;
|
|
result.buf.reg[r + c * Rows] =
|
|
Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
|
|
rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs>
|
|
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
|
|
const Lhs& lhs, const Rhs& rhs) {
|
|
using Flip = FlipLhsRhs<Lhs, Rhs>;
|
|
return BroadcastMulImpl<
|
|
typename Flip::FlippedLhsType,
|
|
typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
|
|
Flip::FlippedRhs(lhs, rhs));
|
|
}
|
|
|
|
template <typename Lhs, typename Rhs, typename Acc>
|
|
struct BroadcastMulAddImpl {
|
|
static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
|
|
static constexpr int Rows = Acc::kRows;
|
|
static constexpr int Cols = Acc::kCols;
|
|
static constexpr int LhsRows = Lhs::kRows;
|
|
static constexpr int LhsCols = Lhs::kCols;
|
|
static constexpr int RhsRows = Rhs::kRows;
|
|
static constexpr int RhsCols = Rhs::kCols;
|
|
static_assert(Acc::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Lhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static_assert(Rhs::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
|
|
static_assert(LhsRows == Rows || LhsRows == 1, "");
|
|
static_assert(RhsRows == Rows || RhsRows == 1, "");
|
|
static_assert(LhsCols == Cols || LhsCols == 1, "");
|
|
static_assert(RhsCols == Cols || RhsCols == 1, "");
|
|
for (int c = 0; c < Cols; c++) {
|
|
const int lhs_c = LhsCols == Cols ? c : 0;
|
|
const int rhs_c = RhsCols == Cols ? c : 0;
|
|
for (int r = 0; r < Rows; r++) {
|
|
const int lhs_r = LhsRows == Rows ? r : 0;
|
|
const int rhs_r = RhsRows == Rows ? r : 0;
|
|
MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
|
|
rhs.buf.reg[rhs_r + rhs_c * RhsRows],
|
|
&acc->buf.reg[r + c * Rows]);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename Lhs, typename Rhs, typename Acc>
|
|
void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
|
|
using Flip = FlipLhsRhs<Lhs, Rhs>;
|
|
BroadcastMulAddImpl<typename Flip::FlippedLhsType,
|
|
typename Flip::FlippedRhsType,
|
|
Acc>::Run(Flip::FlippedLhs(lhs, rhs),
|
|
Flip::FlippedRhs(lhs, rhs), acc);
|
|
}
|
|
|
|
template <typename RegisterBlockType, typename SrcObjectType>
|
|
struct LoadImpl {
|
|
static_assert(std::is_same<SrcObjectType, void>::value,
|
|
"This generic impl should never be hit");
|
|
};
|
|
|
|
template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
|
|
struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
|
|
MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
|
|
using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
|
|
using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
|
|
static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
|
|
RegisterBlockType result;
|
|
int i = 0;
|
|
for (int c = 0; c < Cols; c++) {
|
|
const ScalarType* src_ptr = src.data(row, col + c);
|
|
for (int r = 0; r < Rows; r++) {
|
|
result.buf.reg[i++] = *src_ptr++;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
|
|
VectorShape Shape>
|
|
struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
|
|
VectorMap<SrcScalarType, Shape>> {
|
|
using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
|
|
using SrcObjectType = VectorMap<SrcScalarType, Shape>;
|
|
static RegisterBlockType Run(const SrcObjectType& src, int pos) {
|
|
static_assert(Shape == VectorShape::Col || Rows == 1, "");
|
|
static_assert(Shape == VectorShape::Row || Cols == 1, "");
|
|
RegisterBlockType result;
|
|
for (int i = 0; i < Rows * Cols; i++) {
|
|
result.buf.reg[i] = src(pos + i);
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
|
|
VectorShape Shape>
|
|
struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
|
|
VectorDup<SrcScalarType, Shape>> {
|
|
using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
|
|
using SrcObjectType = VectorDup<SrcScalarType, Shape>;
|
|
static RegisterBlockType Run(const SrcObjectType& src, int) {
|
|
static_assert(Shape == VectorShape::Col || Rows == 1, "");
|
|
static_assert(Shape == VectorShape::Row || Cols == 1, "");
|
|
RegisterBlockType result;
|
|
for (int i = 0; i < Rows * Cols; i++) {
|
|
result.buf.reg[i] = src(0);
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename RegisterBlockType, typename SrcObjectType>
|
|
RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
|
|
return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
|
|
}
|
|
|
|
template <typename RegisterBlockType, typename SrcObjectType>
|
|
RegisterBlockType Load(const SrcObjectType& src, int pos) {
|
|
return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
|
|
}
|
|
|
|
template <typename RegisterBlockType>
|
|
struct LoadContiguousImpl {
|
|
using ScalarType = typename RegisterBlockType::ScalarType;
|
|
static_assert(RegisterBlockType::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static RegisterBlockType Run(const ScalarType* src) {
|
|
RegisterBlockType result;
|
|
for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
|
|
result.buf.reg[i] = src[i];
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename RegisterBlockType>
|
|
RegisterBlockType LoadContiguous(
|
|
const typename RegisterBlockType::ScalarType* src) {
|
|
return LoadContiguousImpl<RegisterBlockType>::Run(src);
|
|
}
|
|
|
|
template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
|
|
struct LoadForBroadcastingShape {};
|
|
|
|
template <int BroadcastRows, int BroadcastCols, typename ScalarType,
|
|
VectorShape Shape>
|
|
struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
|
|
VectorMap<ScalarType, Shape>> {
|
|
static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
|
|
static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
|
|
};
|
|
|
|
template <int BroadcastRows, int BroadcastCols, typename ScalarType,
|
|
VectorShape Shape>
|
|
struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
|
|
VectorDup<ScalarType, Shape>> {
|
|
static constexpr int kRows = 1;
|
|
static constexpr int kCols = 1;
|
|
};
|
|
|
|
template <typename RegisterBlockType, typename SrcObjectType>
|
|
struct LoadForBroadcastingRegisterBlock {
|
|
using Shape =
|
|
LoadForBroadcastingShape<RegisterBlockType::kRows,
|
|
RegisterBlockType::kCols, SrcObjectType>;
|
|
using ScalarType = typename RegisterBlockType::ScalarType;
|
|
using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
|
|
};
|
|
|
|
template <typename RegisterBlockType, typename SrcObjectType>
|
|
struct LoadForBroadcastingImpl {
|
|
static_assert(std::is_same<SrcObjectType, void>::value,
|
|
"This generic impl should never be hit");
|
|
};
|
|
|
|
template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
|
|
VectorShape Shape>
|
|
struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
|
|
VectorMap<SrcScalarType, Shape>> {
|
|
using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
|
|
using SrcObjectType = VectorMap<SrcScalarType, Shape>;
|
|
using ResultBlockType =
|
|
typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
|
|
SrcObjectType>::Type;
|
|
static_assert(ResultBlockType::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static ResultBlockType Run(const SrcObjectType& src, int pos) {
|
|
ResultBlockType result;
|
|
for (int c = 0; c < ResultBlockType::kCols; c++) {
|
|
for (int r = 0; r < ResultBlockType::kRows; r++) {
|
|
const int i = Shape == VectorShape::Col ? r : c;
|
|
result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
|
|
VectorShape Shape>
|
|
struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
|
|
VectorDup<SrcScalarType, Shape>> {
|
|
using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
|
|
using SrcObjectType = VectorDup<SrcScalarType, Shape>;
|
|
using ResultBlockType =
|
|
typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
|
|
SrcObjectType>::Type;
|
|
static_assert(ResultBlockType::kRegisterLanes == 1,
|
|
"This path is only for scalar values");
|
|
static ResultBlockType Run(const SrcObjectType& src, int) {
|
|
ResultBlockType result;
|
|
for (int c = 0; c < ResultBlockType::kCols; c++) {
|
|
for (int r = 0; r < ResultBlockType::kRows; r++) {
|
|
result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename RegisterBlockType, typename SrcObjectType>
|
|
typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
|
|
SrcObjectType>::Type
|
|
LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
|
|
return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
|
|
src, row, col);
|
|
}
|
|
|
|
template <typename RegisterBlockType, typename SrcObjectType>
|
|
typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
|
|
SrcObjectType>::Type
|
|
LoadForBroadcasting(const SrcObjectType& src, int pos) {
|
|
return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
|
|
pos);
|
|
}
|
|
|
|
template <int ConstantValue, typename RegisterBlockType>
|
|
struct AddConstantImpl {
|
|
static void Run(RegisterBlockType* block) {
|
|
using RegisterType = typename RegisterBlockType::RegisterType;
|
|
const RegisterType dup = Dup<RegisterType>(ConstantValue);
|
|
for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
|
|
block->buf.reg[i] = Add(block->buf.reg[i], dup);
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename RegisterBlockType>
|
|
struct AddConstantImpl<0, RegisterBlockType> {
|
|
static void Run(RegisterBlockType*) {
|
|
// This is a no-op.
|
|
}
|
|
};
|
|
|
|
template <int ConstantValue, typename RegisterBlockType>
|
|
void AddConstant(RegisterBlockType* block) {
|
|
AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
|
|
}
|
|
|
|
template <int N>
|
|
using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
|
|
template <int N>
|
|
using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
|
|
template <int N>
|
|
using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
|
|
template <int N>
|
|
using RegBufferInt8 = RegisterBuffer<std::int8_t, N>;
|
|
template <int R, int C>
|
|
using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
|
|
template <int R, int C>
|
|
using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
|
|
template <int R, int C>
|
|
using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
|
|
template <int R, int C>
|
|
using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>;
|
|
|
|
} // end namespace gemmlowp
|
|
|
|
#if defined GEMMLOWP_NEON
|
|
#include "simd_wrappers_neon.h"
|
|
#elif defined GEMMLOWP_SSE4
|
|
#include "simd_wrappers_sse.h"
|
|
#elif defined GEMMLOWP_MSA
|
|
#include "simd_wrappers_msa.h"
|
|
#endif
|
|
|
|
#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
|