57 lines
2.0 KiB
57 lines
2.0 KiB
/* Copyright 2020 Google LLC. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef RUY_RUY_REFERENCE_MUL_H_
|
|
#define RUY_RUY_REFERENCE_MUL_H_
|
|
|
|
#include <algorithm>
|
|
|
|
#include "ruy/apply_multiplier.h"
|
|
#include "ruy/matrix.h"
|
|
#include "ruy/mul_params.h"
|
|
|
|
namespace ruy {
|
|
|
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
|
typename DstScalar>
|
|
void ReferenceMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
|
|
const MulParams<AccumScalar, DstScalar>& mul_params,
|
|
Matrix<DstScalar>* dst) {
|
|
for (int i = 0; i < lhs.layout().rows(); i++) {
|
|
for (int j = 0; j < rhs.layout().cols(); j++) {
|
|
AccumScalar accum = 0;
|
|
for (int k = 0; k < lhs.layout().cols(); k++) {
|
|
AccumScalar lhs_val = Element(lhs, i, k);
|
|
AccumScalar rhs_val = Element(rhs, k, j);
|
|
accum += (lhs_val - lhs.zero_point()) * (rhs_val - rhs.zero_point());
|
|
}
|
|
int channel =
|
|
mul_params.channel_dimension() == ChannelDimension::kRow ? i : j;
|
|
if (mul_params.bias()) {
|
|
accum += mul_params.bias()[channel];
|
|
}
|
|
ApplyMultiplier(mul_params, channel, &accum);
|
|
accum += dst->zero_point();
|
|
accum = std::min<AccumScalar>(accum, mul_params.clamp_max());
|
|
accum = std::max<AccumScalar>(accum, mul_params.clamp_min());
|
|
*ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace ruy
|
|
|
|
#endif // RUY_RUY_REFERENCE_MUL_H_
|