You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1226 lines
51 KiB
1226 lines
51 KiB
/*
|
|
* Copyright (C) 2019 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "RandomVariable.h"
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <set>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "RandomGraphGeneratorUtils.h"
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
namespace fuzzing_test {
|
|
|
|
unsigned int RandomVariableBase::globalIndex = 0;
|
|
int RandomVariable::defaultValue = 10;
|
|
|
|
RandomVariableBase::RandomVariableBase(int value)
|
|
: index(globalIndex++),
|
|
type(RandomVariableType::CONST),
|
|
range(value),
|
|
value(value),
|
|
timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
|
|
|
|
RandomVariableBase::RandomVariableBase(int lower, int upper)
|
|
: index(globalIndex++),
|
|
type(RandomVariableType::FREE),
|
|
range(lower, upper),
|
|
timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
|
|
|
|
RandomVariableBase::RandomVariableBase(const std::vector<int>& choices)
|
|
: index(globalIndex++),
|
|
type(RandomVariableType::FREE),
|
|
range(choices),
|
|
timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
|
|
|
|
RandomVariableBase::RandomVariableBase(const RandomVariableNode& lhs, const RandomVariableNode& rhs,
|
|
const std::shared_ptr<const IRandomVariableOp>& op)
|
|
: index(globalIndex++),
|
|
type(RandomVariableType::OP),
|
|
range(op->getInitRange(lhs->range, rhs == nullptr ? RandomVariableRange(0) : rhs->range)),
|
|
op(op),
|
|
parent1(lhs),
|
|
parent2(rhs),
|
|
timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
|
|
|
|
void RandomVariableRange::setRange(int lower, int upper) {
|
|
// kInvalidValue indicates unlimited bound.
|
|
auto head = lower == kInvalidValue ? mChoices.begin()
|
|
: std::lower_bound(mChoices.begin(), mChoices.end(), lower);
|
|
auto tail = upper == kInvalidValue ? mChoices.end()
|
|
: std::upper_bound(mChoices.begin(), mChoices.end(), upper);
|
|
NN_FUZZER_CHECK(head <= tail) << "Invalid range!";
|
|
if (head != mChoices.begin() || tail != mChoices.end()) {
|
|
mChoices = std::vector<int>(head, tail);
|
|
}
|
|
}
|
|
|
|
int RandomVariableRange::toConst() {
|
|
if (mChoices.size() > 1) mChoices = {getRandomChoice(mChoices)};
|
|
return mChoices[0];
|
|
}
|
|
|
|
RandomVariableRange operator&(const RandomVariableRange& lhs, const RandomVariableRange& rhs) {
|
|
std::vector<int> result(lhs.size() + rhs.size());
|
|
auto it = std::set_intersection(lhs.mChoices.begin(), lhs.mChoices.end(), rhs.mChoices.begin(),
|
|
rhs.mChoices.end(), result.begin());
|
|
result.resize(it - result.begin());
|
|
return RandomVariableRange(std::move(result));
|
|
}
|
|
|
|
void RandomVariableBase::freeze() {
|
|
if (type == RandomVariableType::CONST) return;
|
|
value = range.toConst();
|
|
type = RandomVariableType::CONST;
|
|
}
|
|
|
|
int RandomVariableBase::getValue() const {
|
|
switch (type) {
|
|
case RandomVariableType::CONST:
|
|
return value;
|
|
case RandomVariableType::OP:
|
|
return op->eval(parent1->getValue(), parent2 == nullptr ? 0 : parent2->getValue());
|
|
default:
|
|
NN_FUZZER_CHECK(false) << "Invalid type when getting value of var" << index;
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
void RandomVariableBase::updateTimestamp() {
|
|
timestamp = RandomVariableNetwork::get()->getGlobalTime();
|
|
NN_FUZZER_LOG << "Update timestamp of var" << index << " to " << timestamp;
|
|
}
|
|
|
|
RandomVariable::RandomVariable(int value) : mVar(new RandomVariableBase(value)) {
|
|
NN_FUZZER_LOG << "New RandomVariable " << mVar;
|
|
RandomVariableNetwork::get()->add(mVar);
|
|
}
|
|
RandomVariable::RandomVariable(int lower, int upper) : mVar(new RandomVariableBase(lower, upper)) {
|
|
NN_FUZZER_LOG << "New RandomVariable " << mVar;
|
|
RandomVariableNetwork::get()->add(mVar);
|
|
}
|
|
RandomVariable::RandomVariable(const std::vector<int>& choices)
|
|
: mVar(new RandomVariableBase(choices)) {
|
|
NN_FUZZER_LOG << "New RandomVariable " << mVar;
|
|
RandomVariableNetwork::get()->add(mVar);
|
|
}
|
|
RandomVariable::RandomVariable(RandomVariableType type)
|
|
: mVar(new RandomVariableBase(1, defaultValue)) {
|
|
NN_FUZZER_CHECK(type == RandomVariableType::FREE);
|
|
NN_FUZZER_LOG << "New RandomVariable " << mVar;
|
|
RandomVariableNetwork::get()->add(mVar);
|
|
}
|
|
RandomVariable::RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs,
|
|
const std::shared_ptr<const IRandomVariableOp>& op)
|
|
: mVar(new RandomVariableBase(lhs.get(), rhs.get(), op)) {
|
|
// Make a copy if the parent is CONST. This will resolve the fake dependency problem.
|
|
if (mVar->parent1->type == RandomVariableType::CONST) {
|
|
mVar->parent1 = RandomVariable(mVar->parent1->value).get();
|
|
}
|
|
if (mVar->parent2 != nullptr && mVar->parent2->type == RandomVariableType::CONST) {
|
|
mVar->parent2 = RandomVariable(mVar->parent2->value).get();
|
|
}
|
|
mVar->parent1->children.push_back(mVar);
|
|
if (mVar->parent2 != nullptr) mVar->parent2->children.push_back(mVar);
|
|
RandomVariableNetwork::get()->add(mVar);
|
|
NN_FUZZER_LOG << "New RandomVariable " << mVar;
|
|
}
|
|
|
|
void RandomVariable::setRange(int lower, int upper) {
|
|
NN_FUZZER_CHECK(mVar != nullptr) << "setRange() on nullptr";
|
|
NN_FUZZER_LOG << "Set range [" << lower << ", " << upper << "] on var" << mVar->index;
|
|
size_t oldSize = mVar->range.size();
|
|
mVar->range.setRange(lower, upper);
|
|
// Only update the timestamp if the range is *indeed* narrowed down.
|
|
if (mVar->range.size() != oldSize) mVar->updateTimestamp();
|
|
}
|
|
|
|
RandomVariableRange IRandomVariableOp::getInitRange(const RandomVariableRange& lhs,
|
|
const RandomVariableRange& rhs) const {
|
|
std::set<int> st;
|
|
for (auto i : lhs.getChoices()) {
|
|
for (auto j : rhs.getChoices()) {
|
|
int res = this->eval(i, j);
|
|
if (res > kMaxValue || res < -kMaxValue) continue;
|
|
st.insert(res);
|
|
}
|
|
}
|
|
return RandomVariableRange(st);
|
|
}
|
|
|
|
// Check if the range contains exactly all values in [min, max].
|
|
static inline bool isContinuous(const std::set<int>* range) {
|
|
return (*(range->rbegin()) - *(range->begin()) + 1) == static_cast<int>(range->size());
|
|
}
|
|
|
|
// Fill the set with a range of values specified by [lower, upper].
|
|
static inline void fillRange(std::set<int>* range, int lower, int upper) {
|
|
for (int i = lower; i <= upper; i++) range->insert(i);
|
|
}
|
|
|
|
// The slowest algorithm: iterate through every combinations of parents and save the valid pairs.
|
|
void IRandomVariableOp::eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>* childIn, std::set<int>* parent1Out,
|
|
std::set<int>* parent2Out, std::set<int>* childOut) const {
|
|
// Avoid the binary search if the child is a closed range.
|
|
bool isChildInContinuous = isContinuous(childIn);
|
|
std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
|
|
for (auto i : *parent1In) {
|
|
bool valid = false;
|
|
for (auto j : *parent2In) {
|
|
int res = this->eval(i, j);
|
|
// Avoid the binary search if obviously out of range.
|
|
if (res > child.second || res < child.first) continue;
|
|
if (isChildInContinuous || childIn->find(res) != childIn->end()) {
|
|
parent2Out->insert(j);
|
|
childOut->insert(res);
|
|
valid = true;
|
|
}
|
|
}
|
|
if (valid) parent1Out->insert(i);
|
|
}
|
|
}
|
|
|
|
// A helper template to make a class into a Singleton.
|
|
template <class T>
|
|
class Singleton : public T {
|
|
public:
|
|
static const std::shared_ptr<const T>& get() {
|
|
static std::shared_ptr<const T> instance(new T);
|
|
return instance;
|
|
}
|
|
};
|
|
|
|
// A set of operations that only compute on a single input value.
|
|
class IUnaryOp : public IRandomVariableOp {
|
|
public:
|
|
using IRandomVariableOp::eval;
|
|
virtual int eval(int val) const = 0;
|
|
virtual int eval(int lhs, int) const override { return eval(lhs); }
|
|
// The slowest algorithm: iterate through every value of the parent and save the valid one.
|
|
virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>* childIn, std::set<int>* parent1Out,
|
|
std::set<int>* parent2Out, std::set<int>* childOut) const override {
|
|
NN_FUZZER_CHECK(parent2In == nullptr);
|
|
NN_FUZZER_CHECK(parent2Out == nullptr);
|
|
bool isChildInContinuous = isContinuous(childIn);
|
|
std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
|
|
for (auto i : *parent1In) {
|
|
int res = this->eval(i);
|
|
if (res > child.second || res < child.first) continue;
|
|
if (isChildInContinuous || childIn->find(res) != childIn->end()) {
|
|
parent1Out->insert(i);
|
|
childOut->insert(res);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
// A set of operations that only check conditional constraints.
|
|
class IConstraintOp : public IRandomVariableOp {
|
|
public:
|
|
using IRandomVariableOp::eval;
|
|
virtual bool check(int lhs, int rhs) const = 0;
|
|
virtual int eval(int lhs, int rhs) const override {
|
|
return check(lhs, rhs) ? 0 : kInvalidValue;
|
|
}
|
|
// The range for a constraint op is always {0}.
|
|
virtual RandomVariableRange getInitRange(const RandomVariableRange&,
|
|
const RandomVariableRange&) const override {
|
|
return RandomVariableRange(0);
|
|
}
|
|
// The slowest algorithm:
|
|
// iterate through every combinations of parents and save the valid pairs.
|
|
virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>*, std::set<int>* parent1Out, std::set<int>* parent2Out,
|
|
std::set<int>* childOut) const override {
|
|
for (auto i : *parent1In) {
|
|
bool valid = false;
|
|
for (auto j : *parent2In) {
|
|
if (this->check(i, j)) {
|
|
parent2Out->insert(j);
|
|
valid = true;
|
|
}
|
|
}
|
|
if (valid) parent1Out->insert(i);
|
|
}
|
|
if (!parent1Out->empty()) childOut->insert(0);
|
|
}
|
|
};
|
|
|
|
class Addition : public IRandomVariableOp {
|
|
public:
|
|
virtual int eval(int lhs, int rhs) const override { return lhs + rhs; }
|
|
virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
|
|
const RandomVariableRange& rhs) const override {
|
|
return RandomVariableRange(lhs.min() + rhs.min(), lhs.max() + rhs.max());
|
|
}
|
|
virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>* childIn, std::set<int>* parent1Out,
|
|
std::set<int>* parent2Out, std::set<int>* childOut) const override {
|
|
if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) {
|
|
IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
|
|
childOut);
|
|
} else {
|
|
// For parents and child with close range, the out range can be computed directly
|
|
// without iterations.
|
|
std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()};
|
|
std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()};
|
|
std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
|
|
|
|
// From ranges for parent, evaluate range for child.
|
|
// [a, b] + [c, d] -> [a + c, b + d]
|
|
fillRange(childOut, std::max(child.first, parent1.first + parent2.first),
|
|
std::min(child.second, parent1.second + parent2.second));
|
|
|
|
// From ranges for child and one parent, evaluate range for another parent.
|
|
// [a, b] - [c, d] -> [a - d, b - c]
|
|
fillRange(parent1Out, std::max(parent1.first, child.first - parent2.second),
|
|
std::min(parent1.second, child.second - parent2.first));
|
|
fillRange(parent2Out, std::max(parent2.first, child.first - parent1.second),
|
|
std::min(parent2.second, child.second - parent1.first));
|
|
}
|
|
}
|
|
virtual const char* getName() const override { return "ADD"; }
|
|
};
|
|
|
|
class Subtraction : public IRandomVariableOp {
|
|
public:
|
|
virtual int eval(int lhs, int rhs) const override { return lhs - rhs; }
|
|
virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
|
|
const RandomVariableRange& rhs) const override {
|
|
return RandomVariableRange(lhs.min() - rhs.max(), lhs.max() - rhs.min());
|
|
}
|
|
virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>* childIn, std::set<int>* parent1Out,
|
|
std::set<int>* parent2Out, std::set<int>* childOut) const override {
|
|
if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) {
|
|
IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
|
|
childOut);
|
|
} else {
|
|
// Similar algorithm as Addition.
|
|
std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()};
|
|
std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()};
|
|
std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
|
|
fillRange(childOut, std::max(child.first, parent1.first - parent2.second),
|
|
std::min(child.second, parent1.second - parent2.first));
|
|
fillRange(parent1Out, std::max(parent1.first, child.first + parent2.first),
|
|
std::min(parent1.second, child.second + parent2.second));
|
|
fillRange(parent2Out, std::max(parent2.first, parent1.first - child.second),
|
|
std::min(parent2.second, parent1.second - child.first));
|
|
}
|
|
}
|
|
virtual const char* getName() const override { return "SUB"; }
|
|
};
|
|
|
|
class Multiplication : public IRandomVariableOp {
|
|
public:
|
|
virtual int eval(int lhs, int rhs) const override { return lhs * rhs; }
|
|
virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
|
|
const RandomVariableRange& rhs) const override {
|
|
if (lhs.min() < 0 || rhs.min() < 0) {
|
|
return IRandomVariableOp::getInitRange(lhs, rhs);
|
|
} else {
|
|
int lower = std::min(lhs.min() * rhs.min(), kMaxValue);
|
|
int upper = std::min(lhs.max() * rhs.max(), kMaxValue);
|
|
return RandomVariableRange(lower, upper);
|
|
}
|
|
}
|
|
virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>* childIn, std::set<int>* parent1Out,
|
|
std::set<int>* parent2Out, std::set<int>* childOut) const override {
|
|
if (*parent1In->begin() < 0 || *parent2In->begin() < 0 || *childIn->begin() < 0) {
|
|
IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
|
|
childOut);
|
|
} else {
|
|
bool isChildInContinuous = isContinuous(childIn);
|
|
std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
|
|
for (auto i : *parent1In) {
|
|
bool valid = false;
|
|
for (auto j : *parent2In) {
|
|
int res = this->eval(i, j);
|
|
// Since MUL increases monotonically with one value, break the loop if the
|
|
// result is larger than the limit.
|
|
if (res > child.second) break;
|
|
if (res < child.first) continue;
|
|
if (isChildInContinuous || childIn->find(res) != childIn->end()) {
|
|
valid = true;
|
|
parent2Out->insert(j);
|
|
childOut->insert(res);
|
|
}
|
|
}
|
|
if (valid) parent1Out->insert(i);
|
|
}
|
|
}
|
|
}
|
|
virtual const char* getName() const override { return "MUL"; }
|
|
};
|
|
|
|
class Division : public IRandomVariableOp {
|
|
public:
|
|
virtual int eval(int lhs, int rhs) const override {
|
|
return rhs == 0 ? kInvalidValue : lhs / rhs;
|
|
}
|
|
virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
|
|
const RandomVariableRange& rhs) const override {
|
|
if (lhs.min() < 0 || rhs.min() <= 0) {
|
|
return IRandomVariableOp::getInitRange(lhs, rhs);
|
|
} else {
|
|
return RandomVariableRange(lhs.min() / rhs.max(), lhs.max() / rhs.min());
|
|
}
|
|
}
|
|
virtual const char* getName() const override { return "DIV"; }
|
|
};
|
|
|
|
class ExactDivision : public Division {
|
|
public:
|
|
virtual int eval(int lhs, int rhs) const override {
|
|
return (rhs == 0 || lhs % rhs != 0) ? kInvalidValue : lhs / rhs;
|
|
}
|
|
virtual const char* getName() const override { return "EXACT_DIV"; }
|
|
};
|
|
|
|
class Modulo : public IRandomVariableOp {
|
|
public:
|
|
virtual int eval(int lhs, int rhs) const override {
|
|
return rhs == 0 ? kInvalidValue : lhs % rhs;
|
|
}
|
|
virtual RandomVariableRange getInitRange(const RandomVariableRange&,
|
|
const RandomVariableRange& rhs) const override {
|
|
return RandomVariableRange(0, rhs.max());
|
|
}
|
|
virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>* childIn, std::set<int>* parent1Out,
|
|
std::set<int>* parent2Out, std::set<int>* childOut) const override {
|
|
if (*childIn->begin() != 0 || childIn->size() != 1u) {
|
|
IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
|
|
childOut);
|
|
} else {
|
|
// For the special case that child is a const 0, it would be faster if the range for
|
|
// parents are evaluated separately.
|
|
|
|
// Evaluate parent1 directly.
|
|
for (auto i : *parent1In) {
|
|
for (auto j : *parent2In) {
|
|
if (i % j == 0) {
|
|
parent1Out->insert(i);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
// Evaluate parent2, see if a multiple of parent2 value can be found in parent1.
|
|
int parent1Max = *parent1In->rbegin();
|
|
for (auto i : *parent2In) {
|
|
int jMax = parent1Max / i;
|
|
for (int j = 1; j <= jMax; j++) {
|
|
if (parent1In->find(i * j) != parent1In->end()) {
|
|
parent2Out->insert(i);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (!parent1Out->empty()) childOut->insert(0);
|
|
}
|
|
}
|
|
virtual const char* getName() const override { return "MOD"; }
|
|
};
|
|
|
|
class Maximum : public IRandomVariableOp {
|
|
public:
|
|
virtual int eval(int lhs, int rhs) const override { return std::max(lhs, rhs); }
|
|
virtual const char* getName() const override { return "MAX"; }
|
|
};
|
|
|
|
class Minimum : public IRandomVariableOp {
|
|
public:
|
|
virtual int eval(int lhs, int rhs) const override { return std::min(lhs, rhs); }
|
|
virtual const char* getName() const override { return "MIN"; }
|
|
};
|
|
|
|
class Square : public IUnaryOp {
|
|
public:
|
|
virtual int eval(int val) const override { return val * val; }
|
|
virtual const char* getName() const override { return "SQUARE"; }
|
|
};
|
|
|
|
class UnaryEqual : public IUnaryOp {
|
|
public:
|
|
virtual int eval(int val) const override { return val; }
|
|
virtual const char* getName() const override { return "UNARY_EQUAL"; }
|
|
};
|
|
|
|
class Equal : public IConstraintOp {
|
|
public:
|
|
virtual bool check(int lhs, int rhs) const override { return lhs == rhs; }
|
|
virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>* childIn, std::set<int>* parent1Out,
|
|
std::set<int>* parent2Out, std::set<int>* childOut) const override {
|
|
NN_FUZZER_CHECK(childIn->size() == 1u && *childIn->begin() == 0);
|
|
// The intersection of two sets can be found in O(n).
|
|
std::set_intersection(parent1In->begin(), parent1In->end(), parent2In->begin(),
|
|
parent2In->end(), std::inserter(*parent1Out, parent1Out->begin()));
|
|
*parent2Out = *parent1Out;
|
|
childOut->insert(0);
|
|
}
|
|
virtual const char* getName() const override { return "EQUAL"; }
|
|
};
|
|
|
|
class GreaterThan : public IConstraintOp {
|
|
public:
|
|
virtual bool check(int lhs, int rhs) const override { return lhs > rhs; }
|
|
virtual const char* getName() const override { return "GREATER_THAN"; }
|
|
};
|
|
|
|
class GreaterEqual : public IConstraintOp {
|
|
public:
|
|
virtual bool check(int lhs, int rhs) const override { return lhs >= rhs; }
|
|
virtual const char* getName() const override { return "GREATER_EQUAL"; }
|
|
};
|
|
|
|
class FloatMultiplication : public IUnaryOp {
|
|
public:
|
|
FloatMultiplication(float multiplicand) : mMultiplicand(multiplicand) {}
|
|
virtual int eval(int val) const override {
|
|
return static_cast<int>(std::floor(static_cast<float>(val) * mMultiplicand));
|
|
}
|
|
virtual const char* getName() const override { return "MUL_FLOAT"; }
|
|
|
|
private:
|
|
float mMultiplicand;
|
|
};
|
|
|
|
// Arithmetic operators and methods on RandomVariables will create OP RandomVariableNodes.
|
|
// Since there must be at most one edge between two RandomVariableNodes, we have to do something
|
|
// special when both sides are refering to the same node.
|
|
|
|
RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs) {
|
|
return lhs.get() == rhs.get() ? RandomVariable(lhs, 2, Singleton<Multiplication>::get())
|
|
: RandomVariable(lhs, rhs, Singleton<Addition>::get());
|
|
}
|
|
RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs) {
|
|
return lhs.get() == rhs.get() ? RandomVariable(0)
|
|
: RandomVariable(lhs, rhs, Singleton<Subtraction>::get());
|
|
}
|
|
RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs) {
|
|
return lhs.get() == rhs.get() ? RandomVariable(lhs, RandomVariable(), Singleton<Square>::get())
|
|
: RandomVariable(lhs, rhs, Singleton<Multiplication>::get());
|
|
}
|
|
RandomVariable operator*(const RandomVariable& lhs, const float& rhs) {
|
|
return RandomVariable(lhs, RandomVariable(), std::make_shared<FloatMultiplication>(rhs));
|
|
}
|
|
RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs) {
|
|
return lhs.get() == rhs.get() ? RandomVariable(1)
|
|
: RandomVariable(lhs, rhs, Singleton<Division>::get());
|
|
}
|
|
RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs) {
|
|
return lhs.get() == rhs.get() ? RandomVariable(0)
|
|
: RandomVariable(lhs, rhs, Singleton<Modulo>::get());
|
|
}
|
|
RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs) {
|
|
return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Maximum>::get());
|
|
}
|
|
RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs) {
|
|
return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Minimum>::get());
|
|
}
|
|
|
|
RandomVariable RandomVariable::exactDiv(const RandomVariable& other) {
|
|
return mVar == other.get() ? RandomVariable(1)
|
|
: RandomVariable(*this, other, Singleton<ExactDivision>::get());
|
|
}
|
|
|
|
RandomVariable RandomVariable::setEqual(const RandomVariable& other) const {
|
|
RandomVariableNode node1 = mVar, node2 = other.get();
|
|
NN_FUZZER_LOG << "Set equality of var" << node1->index << " and var" << node2->index;
|
|
|
|
// Do not setEqual on the same pair twice.
|
|
if (node1 == node2 || (node1->op == Singleton<UnaryEqual>::get() && node1->parent1 == node2) ||
|
|
(node2->op == Singleton<UnaryEqual>::get() && node2->parent1 == node1)) {
|
|
NN_FUZZER_LOG << "Already equal. Return.";
|
|
return RandomVariable();
|
|
}
|
|
|
|
// If possible, always try UnaryEqual first to reduce the search space.
|
|
// UnaryEqual can be used if node B is FREE and is evaluated later than node A.
|
|
// TODO: Reduce code duplication.
|
|
if (RandomVariableNetwork::get()->isSubordinate(node1, node2)) {
|
|
NN_FUZZER_LOG << " Make var" << node2->index << " a child of var" << node1->index;
|
|
node2->type = RandomVariableType::OP;
|
|
node2->parent1 = node1;
|
|
node2->op = Singleton<UnaryEqual>::get();
|
|
node1->children.push_back(node2);
|
|
RandomVariableNetwork::get()->join(node1, node2);
|
|
node1->updateTimestamp();
|
|
return other;
|
|
}
|
|
if (RandomVariableNetwork::get()->isSubordinate(node2, node1)) {
|
|
NN_FUZZER_LOG << " Make var" << node1->index << " a child of var" << node2->index;
|
|
node1->type = RandomVariableType::OP;
|
|
node1->parent1 = node2;
|
|
node1->op = Singleton<UnaryEqual>::get();
|
|
node2->children.push_back(node1);
|
|
RandomVariableNetwork::get()->join(node2, node1);
|
|
node1->updateTimestamp();
|
|
return *this;
|
|
}
|
|
return RandomVariable(*this, other, Singleton<Equal>::get());
|
|
}
|
|
|
|
RandomVariable RandomVariable::setGreaterThan(const RandomVariable& other) const {
|
|
NN_FUZZER_CHECK(mVar != other.get());
|
|
return RandomVariable(*this, other, Singleton<GreaterThan>::get());
|
|
}
|
|
RandomVariable RandomVariable::setGreaterEqual(const RandomVariable& other) const {
|
|
return mVar == other.get() ? *this
|
|
: RandomVariable(*this, other, Singleton<GreaterEqual>::get());
|
|
}
|
|
|
|
void DisjointNetwork::add(const RandomVariableNode& var) {
|
|
// Find the subnet index of the parents and decide the index for var.
|
|
int ind1 = var->parent1 == nullptr ? -1 : mIndexMap[var->parent1];
|
|
int ind2 = var->parent2 == nullptr ? -1 : mIndexMap[var->parent2];
|
|
int ind = join(ind1, ind2);
|
|
// If no parent, put it into a new subnet component.
|
|
if (ind == -1) ind = mNextIndex++;
|
|
NN_FUZZER_LOG << "Add RandomVariable var" << var->index << " to network #" << ind;
|
|
mIndexMap[var] = ind;
|
|
mEvalOrderMap[ind].push_back(var);
|
|
}
|
|
|
|
int DisjointNetwork::join(int ind1, int ind2) {
|
|
if (ind1 == -1) return ind2;
|
|
if (ind2 == -1) return ind1;
|
|
if (ind1 == ind2) return ind1;
|
|
NN_FUZZER_LOG << "Join network #" << ind1 << " and #" << ind2;
|
|
auto &order1 = mEvalOrderMap[ind1], &order2 = mEvalOrderMap[ind2];
|
|
// Append every node in ind2 to the end of ind1
|
|
for (const auto& var : order2) {
|
|
order1.push_back(var);
|
|
mIndexMap[var] = ind1;
|
|
}
|
|
// Remove ind2 from mEvalOrderMap.
|
|
mEvalOrderMap.erase(mEvalOrderMap.find(ind2));
|
|
return ind1;
|
|
}
|
|
|
|
RandomVariableNetwork* RandomVariableNetwork::get() {
|
|
static RandomVariableNetwork instance;
|
|
return &instance;
|
|
}
|
|
|
|
void RandomVariableNetwork::initialize(int defaultValue) {
|
|
RandomVariableBase::globalIndex = 0;
|
|
RandomVariable::defaultValue = defaultValue;
|
|
mIndexMap.clear();
|
|
mEvalOrderMap.clear();
|
|
mDimProd.clear();
|
|
mNextIndex = 0;
|
|
mGlobalTime = 0;
|
|
mTimestamp = -1;
|
|
}
|
|
|
|
bool RandomVariableNetwork::isSubordinate(const RandomVariableNode& node1,
|
|
const RandomVariableNode& node2) {
|
|
if (node2->type != RandomVariableType::FREE) return false;
|
|
int ind1 = mIndexMap[node1];
|
|
// node2 is of a different subnet.
|
|
if (ind1 != mIndexMap[node2]) return true;
|
|
for (const auto& node : mEvalOrderMap[ind1]) {
|
|
if (node == node2) return false;
|
|
// node2 is of the same subnet but evaluated later than node1.
|
|
if (node == node1) return true;
|
|
}
|
|
NN_FUZZER_CHECK(false) << "Code executed in non-reachable region.";
|
|
return false;
|
|
}
|
|
|
|
struct EvalInfo {
|
|
// The RandomVariableNode that this EvalInfo is associated with.
|
|
// var->value is the current value during evaluation.
|
|
RandomVariableNode var;
|
|
|
|
// The RandomVariable value is staged when a valid combination is found.
|
|
std::set<int> staging;
|
|
|
|
// The staging values are committed after a subnet evaluation.
|
|
std::set<int> committed;
|
|
|
|
// Keeps track of the latest timestamp that committed is updated.
|
|
int timestamp;
|
|
|
|
// For evalSubnetWithLocalNetwork.
|
|
RandomVariableType originalType;
|
|
|
|
// Should only invoke eval on OP RandomVariable.
|
|
bool eval() {
|
|
NN_FUZZER_CHECK(var->type == RandomVariableType::OP);
|
|
var->value = var->op->eval(var->parent1->value,
|
|
var->parent2 == nullptr ? 0 : var->parent2->value);
|
|
if (var->value == kInvalidValue) return false;
|
|
return committed.find(var->value) != committed.end();
|
|
}
|
|
void stage() { staging.insert(var->value); }
|
|
void commit() {
|
|
// Only update committed and timestamp if the range is *indeed* changed.
|
|
if (staging.size() != committed.size()) {
|
|
committed = std::move(staging);
|
|
timestamp = RandomVariableNetwork::get()->getGlobalTime();
|
|
}
|
|
staging.clear();
|
|
}
|
|
void updateRange() {
|
|
// Only update range and timestamp if the range is *indeed* changed.
|
|
if (committed.size() != var->range.size()) {
|
|
var->range = RandomVariableRange(committed);
|
|
var->timestamp = timestamp;
|
|
}
|
|
committed.clear();
|
|
}
|
|
|
|
EvalInfo(const RandomVariableNode& var)
|
|
: var(var),
|
|
committed(var->range.getChoices().begin(), var->range.getChoices().end()),
|
|
timestamp(var->timestamp) {}
|
|
};
|
|
using EvalContext = std::unordered_map<RandomVariableNode, EvalInfo>;
|
|
|
|
// For logging only.
|
|
inline std::string toString(const RandomVariableNode& var, EvalContext* context) {
|
|
std::stringstream ss;
|
|
ss << "var" << var->index << " = ";
|
|
const auto& committed = context->at(var).committed;
|
|
switch (var->type) {
|
|
case RandomVariableType::FREE:
|
|
ss << "FREE ["
|
|
<< joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end())) << "]";
|
|
break;
|
|
case RandomVariableType::CONST:
|
|
ss << "CONST " << var->value;
|
|
break;
|
|
case RandomVariableType::OP:
|
|
ss << "var" << var->parent1->index << " " << var->op->getName();
|
|
if (var->parent2 != nullptr) ss << " var" << var->parent2->index;
|
|
ss << ", [" << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end()))
|
|
<< "]";
|
|
break;
|
|
default:
|
|
NN_FUZZER_CHECK(false);
|
|
}
|
|
ss << ", timestamp = " << context->at(var).timestamp;
|
|
return ss.str();
|
|
}
|
|
|
|
// Check if the subnet needs to be re-evaluated by comparing the timestamps.
|
|
static inline bool needEvaluate(const EvaluationOrder& evalOrder, int subnetTime,
|
|
EvalContext* context = nullptr) {
|
|
for (const auto& var : evalOrder) {
|
|
int timestamp = context == nullptr ? var->timestamp : context->at(var).timestamp;
|
|
// If we find a node that has been modified since last evaluation, the subnet needs to be
|
|
// re-evaluated.
|
|
if (timestamp > subnetTime) return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Helper function to evaluate the subnet recursively.
|
|
// Iterate through all combinations of FREE RandomVariables choices.
|
|
static void evalSubnetHelper(const EvaluationOrder& evalOrder, EvalContext* context, size_t i = 0) {
|
|
if (i == evalOrder.size()) {
|
|
// Reach the end of the evaluation, find a valid combination.
|
|
for (auto& var : evalOrder) context->at(var).stage();
|
|
return;
|
|
}
|
|
const auto& var = evalOrder[i];
|
|
if (var->type == RandomVariableType::FREE) {
|
|
// For FREE RandomVariable, iterate through all valid choices.
|
|
for (int val : context->at(var).committed) {
|
|
var->value = val;
|
|
evalSubnetHelper(evalOrder, context, i + 1);
|
|
}
|
|
return;
|
|
} else if (var->type == RandomVariableType::OP) {
|
|
// For OP RandomVariable, evaluate from parents and terminate if the result is invalid.
|
|
if (!context->at(var).eval()) return;
|
|
}
|
|
evalSubnetHelper(evalOrder, context, i + 1);
|
|
}
|
|
|
|
// Check if the subnet has only one single OP RandomVariable.
|
|
static inline bool isSingleOpSubnet(const EvaluationOrder& evalOrder) {
|
|
int numOp = 0;
|
|
for (const auto& var : evalOrder) {
|
|
if (var->type == RandomVariableType::OP) numOp++;
|
|
if (numOp > 1) return false;
|
|
}
|
|
return numOp != 0;
|
|
}
|
|
|
|
// Evaluate with a potentially faster approach provided by IRandomVariableOp.
|
|
static inline void evalSubnetSingleOpHelper(const EvaluationOrder& evalOrder,
|
|
EvalContext* context) {
|
|
NN_FUZZER_LOG << "Identified as single op subnet";
|
|
const auto& var = evalOrder.back();
|
|
NN_FUZZER_CHECK(var->type == RandomVariableType::OP);
|
|
var->op->eval(&context->at(var->parent1).committed,
|
|
var->parent2 == nullptr ? nullptr : &context->at(var->parent2).committed,
|
|
&context->at(var).committed, &context->at(var->parent1).staging,
|
|
var->parent2 == nullptr ? nullptr : &context->at(var->parent2).staging,
|
|
&context->at(var).staging);
|
|
}
|
|
|
|
// Check if the number of combinations of FREE RandomVariables exceeds the limit.
|
|
static inline uint64_t getNumCombinations(const EvaluationOrder& evalOrder,
|
|
EvalContext* context = nullptr) {
|
|
constexpr uint64_t kLimit = 1e8;
|
|
uint64_t numCombinations = 1;
|
|
for (const auto& var : evalOrder) {
|
|
if (var->type == RandomVariableType::FREE) {
|
|
size_t size =
|
|
context == nullptr ? var->range.size() : context->at(var).committed.size();
|
|
numCombinations *= size;
|
|
// To prevent overflow.
|
|
if (numCombinations > kLimit) return kLimit;
|
|
}
|
|
}
|
|
return numCombinations;
|
|
}
|
|
|
|
// Evaluate the subnet recursively. Will return fail if the number of combinations of FREE
|
|
// RandomVariable exceeds the threshold kMaxNumCombinations.
|
|
static bool evalSubnetWithBruteForce(const EvaluationOrder& evalOrder, EvalContext* context) {
|
|
constexpr uint64_t kMaxNumCombinations = 1e7;
|
|
NN_FUZZER_LOG << "Evaluate with brute force";
|
|
if (isSingleOpSubnet(evalOrder)) {
|
|
// If the network only have one single OP, dispatch to a faster evaluation.
|
|
evalSubnetSingleOpHelper(evalOrder, context);
|
|
} else {
|
|
if (getNumCombinations(evalOrder, context) > kMaxNumCombinations) {
|
|
NN_FUZZER_LOG << "Terminate the evaluation because of large search range";
|
|
std::cout << "[ ] Terminate the evaluation because of large search range"
|
|
<< std::endl;
|
|
return false;
|
|
}
|
|
evalSubnetHelper(evalOrder, context);
|
|
}
|
|
for (auto& var : evalOrder) {
|
|
if (context->at(var).staging.empty()) {
|
|
NN_FUZZER_LOG << "Evaluation failed at " << toString(var, context);
|
|
return false;
|
|
}
|
|
context->at(var).commit();
|
|
}
|
|
return true;
|
|
}
|
|
|
|
struct LocalNetwork {
|
|
EvaluationOrder evalOrder;
|
|
std::vector<RandomVariableNode> bridgeNodes;
|
|
int timestamp = 0;
|
|
|
|
bool eval(EvalContext* context) {
|
|
NN_FUZZER_LOG << "Evaluate local network with timestamp = " << timestamp;
|
|
// Temporarily treat bridge nodes as FREE RandomVariables.
|
|
for (const auto& var : bridgeNodes) {
|
|
context->at(var).originalType = var->type;
|
|
var->type = RandomVariableType::FREE;
|
|
}
|
|
for (const auto& var : evalOrder) {
|
|
context->at(var).staging.clear();
|
|
NN_FUZZER_LOG << " - " << toString(var, context);
|
|
}
|
|
bool success = evalSubnetWithBruteForce(evalOrder, context);
|
|
// Reset the RandomVariable types for bridge nodes.
|
|
for (const auto& var : bridgeNodes) var->type = context->at(var).originalType;
|
|
return success;
|
|
}
|
|
};
|
|
|
|
// Partition the network further into LocalNetworks based on the result from bridge annotation
|
|
// algorithm.
|
|
class GraphPartitioner : public DisjointNetwork {
|
|
public:
|
|
GraphPartitioner() = default;
|
|
|
|
std::vector<LocalNetwork> partition(const EvaluationOrder& evalOrder, int timestamp) {
|
|
annotateBridge(evalOrder);
|
|
for (const auto& var : evalOrder) add(var);
|
|
return get(timestamp);
|
|
}
|
|
|
|
private:
|
|
GraphPartitioner(const GraphPartitioner&) = delete;
|
|
GraphPartitioner& operator=(const GraphPartitioner&) = delete;
|
|
|
|
// Find the parent-child relationship between var1 and var2, and reset the bridge.
|
|
void setBridgeFlag(const RandomVariableNode& var1, const RandomVariableNode& var2) {
|
|
if (var1->parent1 == var2) {
|
|
mBridgeInfo[var1].isParent1Bridge = true;
|
|
} else if (var1->parent2 == var2) {
|
|
mBridgeInfo[var1].isParent2Bridge = true;
|
|
} else {
|
|
setBridgeFlag(var2, var1);
|
|
}
|
|
}
|
|
|
|
// Annoate the bridges with DFS -- an edge [u, v] is a bridge if none of u's ancestor is
|
|
// reachable from a node in the subtree of b. The complexity is O(V + E).
|
|
// discoveryTime: The timestamp a node is visited
|
|
// lowTime: The min discovery time of all reachable nodes from the subtree of the node.
|
|
void annotateBridgeHelper(const RandomVariableNode& var, int* time) {
|
|
mBridgeInfo[var].visited = true;
|
|
mBridgeInfo[var].discoveryTime = mBridgeInfo[var].lowTime = (*time)++;
|
|
|
|
// The algorithm operates on undirected graph. First find all adjacent nodes.
|
|
auto adj = var->children;
|
|
if (var->parent1 != nullptr) adj.push_back(var->parent1);
|
|
if (var->parent2 != nullptr) adj.push_back(var->parent2);
|
|
|
|
for (const auto& weakChild : adj) {
|
|
auto child = weakChild.lock();
|
|
NN_FUZZER_CHECK(child != nullptr);
|
|
if (mBridgeInfo.find(child) == mBridgeInfo.end()) continue;
|
|
if (!mBridgeInfo[child].visited) {
|
|
mBridgeInfo[child].parent = var;
|
|
annotateBridgeHelper(child, time);
|
|
|
|
// If none of nodes in the subtree of child is connected to any ancestors of var,
|
|
// then it is a bridge.
|
|
mBridgeInfo[var].lowTime =
|
|
std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].lowTime);
|
|
if (mBridgeInfo[child].lowTime > mBridgeInfo[var].discoveryTime)
|
|
setBridgeFlag(var, child);
|
|
} else if (mBridgeInfo[var].parent != child) {
|
|
mBridgeInfo[var].lowTime =
|
|
std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].discoveryTime);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Find all bridges in the subnet with DFS.
|
|
void annotateBridge(const EvaluationOrder& evalOrder) {
|
|
for (const auto& var : evalOrder) mBridgeInfo[var];
|
|
int time = 0;
|
|
for (const auto& var : evalOrder) {
|
|
if (!mBridgeInfo[var].visited) annotateBridgeHelper(var, &time);
|
|
}
|
|
}
|
|
|
|
// Re-partition the network by treating bridges as no edge.
|
|
void add(const RandomVariableNode& var) {
|
|
auto parent1 = var->parent1;
|
|
auto parent2 = var->parent2;
|
|
if (mBridgeInfo[var].isParent1Bridge) var->parent1 = nullptr;
|
|
if (mBridgeInfo[var].isParent2Bridge) var->parent2 = nullptr;
|
|
DisjointNetwork::add(var);
|
|
var->parent1 = parent1;
|
|
var->parent2 = parent2;
|
|
}
|
|
|
|
// Add bridge nodes to the local network and remove single node subnet.
|
|
std::vector<LocalNetwork> get(int timestamp) {
|
|
std::vector<LocalNetwork> res;
|
|
for (auto& pair : mEvalOrderMap) {
|
|
// We do not need to evaluate subnet with only a single node.
|
|
if (pair.second.size() == 1 && pair.second[0]->parent1 == nullptr) continue;
|
|
res.emplace_back();
|
|
for (const auto& var : pair.second) {
|
|
if (mBridgeInfo[var].isParent1Bridge) {
|
|
res.back().evalOrder.push_back(var->parent1);
|
|
res.back().bridgeNodes.push_back(var->parent1);
|
|
}
|
|
if (mBridgeInfo[var].isParent2Bridge) {
|
|
res.back().evalOrder.push_back(var->parent2);
|
|
res.back().bridgeNodes.push_back(var->parent2);
|
|
}
|
|
res.back().evalOrder.push_back(var);
|
|
}
|
|
res.back().timestamp = timestamp;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
// For bridge discovery algorithm.
|
|
struct BridgeInfo {
|
|
bool isParent1Bridge = false;
|
|
bool isParent2Bridge = false;
|
|
int discoveryTime = 0;
|
|
int lowTime = 0;
|
|
bool visited = false;
|
|
std::shared_ptr<RandomVariableBase> parent = nullptr;
|
|
};
|
|
std::unordered_map<RandomVariableNode, BridgeInfo> mBridgeInfo;
|
|
};
|
|
|
|
// Evaluate subnets repeatedly until converge.
|
|
// Class T_Subnet must have member evalOrder, timestamp, and member function eval.
|
|
template <class T_Subnet>
|
|
inline bool evalSubnetsRepeatedly(std::vector<T_Subnet>* subnets, EvalContext* context) {
|
|
bool terminate = false;
|
|
while (!terminate) {
|
|
terminate = true;
|
|
for (auto& subnet : *subnets) {
|
|
if (needEvaluate(subnet.evalOrder, subnet.timestamp, context)) {
|
|
if (!subnet.eval(context)) return false;
|
|
subnet.timestamp = RandomVariableNetwork::get()->getGlobalTime();
|
|
terminate = false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Evaluate the subnet by first partitioning it further into LocalNetworks.
|
|
static bool evalSubnetWithLocalNetwork(const EvaluationOrder& evalOrder, int timestamp,
|
|
EvalContext* context) {
|
|
NN_FUZZER_LOG << "Evaluate with local network";
|
|
auto localNetworks = GraphPartitioner().partition(evalOrder, timestamp);
|
|
return evalSubnetsRepeatedly(&localNetworks, context);
|
|
}
|
|
|
|
struct LeafNetwork {
|
|
EvaluationOrder evalOrder;
|
|
int timestamp = 0;
|
|
LeafNetwork(const RandomVariableNode& var, int timestamp) : timestamp(timestamp) {
|
|
std::set<RandomVariableNode> visited;
|
|
constructorHelper(var, &visited);
|
|
}
|
|
// Construct the leaf network by recursively including parent nodes.
|
|
void constructorHelper(const RandomVariableNode& var, std::set<RandomVariableNode>* visited) {
|
|
if (var == nullptr || visited->find(var) != visited->end()) return;
|
|
constructorHelper(var->parent1, visited);
|
|
constructorHelper(var->parent2, visited);
|
|
visited->insert(var);
|
|
evalOrder.push_back(var);
|
|
}
|
|
bool eval(EvalContext* context) {
|
|
return evalSubnetWithLocalNetwork(evalOrder, timestamp, context);
|
|
}
|
|
};
|
|
|
|
// Evaluate the subnet by leaf network.
|
|
// NOTE: This algorithm will only produce correct result for *most* of the time (> 99%).
|
|
// The random graph generator is expected to retry if it fails.
|
|
static bool evalSubnetWithLeafNetwork(const EvaluationOrder& evalOrder, int timestamp,
|
|
EvalContext* context) {
|
|
NN_FUZZER_LOG << "Evaluate with leaf network";
|
|
// Construct leaf networks.
|
|
std::vector<LeafNetwork> leafNetworks;
|
|
for (const auto& var : evalOrder) {
|
|
if (var->children.empty()) {
|
|
NN_FUZZER_LOG << "Found leaf " << toString(var, context);
|
|
leafNetworks.emplace_back(var, timestamp);
|
|
}
|
|
}
|
|
return evalSubnetsRepeatedly(&leafNetworks, context);
|
|
}
|
|
|
|
void RandomVariableNetwork::addDimensionProd(const std::vector<RandomVariable>& dims) {
|
|
if (dims.size() <= 1) return;
|
|
EvaluationOrder order;
|
|
for (const auto& dim : dims) order.push_back(dim.get());
|
|
mDimProd.push_back(order);
|
|
}
|
|
|
|
bool enforceDimProd(const std::vector<EvaluationOrder>& mDimProd,
|
|
const std::unordered_map<RandomVariableNode, int>& indexMap,
|
|
EvalContext* context, std::set<int>* dirtySubnets) {
|
|
for (auto& evalOrder : mDimProd) {
|
|
NN_FUZZER_LOG << " Dimension product network size = " << evalOrder.size();
|
|
// Initialize EvalInfo of each RandomVariable.
|
|
for (auto& var : evalOrder) {
|
|
if (context->find(var) == context->end()) context->emplace(var, var);
|
|
NN_FUZZER_LOG << " - " << toString(var, context);
|
|
}
|
|
|
|
// Enforce the product of the dimension values below kMaxValue:
|
|
// max(dimA) = kMaxValue / (min(dimB) * min(dimC) * ...)
|
|
int prod = 1;
|
|
for (const auto& var : evalOrder) prod *= (*context->at(var).committed.begin());
|
|
for (auto& var : evalOrder) {
|
|
auto& committed = context->at(var).committed;
|
|
int maxValue = kMaxValue / (prod / *committed.begin());
|
|
auto it = committed.upper_bound(maxValue);
|
|
// var has empty range -> no solution.
|
|
if (it == committed.begin()) return false;
|
|
// The range is not modified -> continue.
|
|
if (it == committed.end()) continue;
|
|
// The range is modified -> the subnet of var is dirty, i.e. needs re-evaluation.
|
|
committed.erase(it, committed.end());
|
|
context->at(var).timestamp = RandomVariableNetwork::get()->getGlobalTime();
|
|
dirtySubnets->insert(indexMap.at(var));
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool RandomVariableNetwork::evalRange() {
|
|
constexpr uint64_t kMaxNumCombinationsWithBruteForce = 500;
|
|
constexpr uint64_t kMaxNumCombinationsWithLocalNetwork = 1e5;
|
|
NN_FUZZER_LOG << "Evaluate on " << mEvalOrderMap.size() << " sub-networks";
|
|
EvalContext context;
|
|
std::set<int> dirtySubnets; // Which subnets needs evaluation.
|
|
for (auto& pair : mEvalOrderMap) {
|
|
const auto& evalOrder = pair.second;
|
|
// Decide whether needs evaluation by timestamp -- if no range has changed after the last
|
|
// evaluation, then the subnet does not need re-evaluation.
|
|
if (evalOrder.size() == 1 || !needEvaluate(evalOrder, mTimestamp)) continue;
|
|
dirtySubnets.insert(pair.first);
|
|
}
|
|
if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false;
|
|
|
|
// Repeat until the ranges converge.
|
|
while (!dirtySubnets.empty()) {
|
|
for (int ind : dirtySubnets) {
|
|
const auto& evalOrder = mEvalOrderMap[ind];
|
|
NN_FUZZER_LOG << " Sub-network #" << ind << " size = " << evalOrder.size();
|
|
|
|
// Initialize EvalInfo of each RandomVariable.
|
|
for (auto& var : evalOrder) {
|
|
if (context.find(var) == context.end()) context.emplace(var, var);
|
|
NN_FUZZER_LOG << " - " << toString(var, &context);
|
|
}
|
|
|
|
// Dispatch to different algorithm according to search range.
|
|
bool success;
|
|
uint64_t numCombinations = getNumCombinations(evalOrder);
|
|
if (numCombinations <= kMaxNumCombinationsWithBruteForce) {
|
|
success = evalSubnetWithBruteForce(evalOrder, &context);
|
|
} else if (numCombinations <= kMaxNumCombinationsWithLocalNetwork) {
|
|
success = evalSubnetWithLocalNetwork(evalOrder, mTimestamp, &context);
|
|
} else {
|
|
success = evalSubnetWithLeafNetwork(evalOrder, mTimestamp, &context);
|
|
}
|
|
if (!success) return false;
|
|
}
|
|
dirtySubnets.clear();
|
|
if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false;
|
|
}
|
|
// A successful evaluation, update RandomVariables from EvalContext.
|
|
for (auto& pair : context) pair.second.updateRange();
|
|
mTimestamp = getGlobalTime();
|
|
NN_FUZZER_LOG << "Finish range evaluation";
|
|
return true;
|
|
}
|
|
|
|
static void unsetEqual(const RandomVariableNode& node) {
|
|
if (node == nullptr) return;
|
|
NN_FUZZER_LOG << "Unset equality of var" << node->index;
|
|
auto weakPtrEqual = [&node](const std::weak_ptr<RandomVariableBase>& ptr) {
|
|
return ptr.lock() == node;
|
|
};
|
|
RandomVariableNode parent1 = node->parent1, parent2 = node->parent2;
|
|
parent1->children.erase(
|
|
std::find_if(parent1->children.begin(), parent1->children.end(), weakPtrEqual));
|
|
node->parent1 = nullptr;
|
|
if (parent2 != nullptr) {
|
|
// For Equal.
|
|
parent2->children.erase(
|
|
std::find_if(parent2->children.begin(), parent2->children.end(), weakPtrEqual));
|
|
node->parent2 = nullptr;
|
|
} else {
|
|
// For UnaryEqual.
|
|
node->type = RandomVariableType::FREE;
|
|
node->op = nullptr;
|
|
}
|
|
}
|
|
|
|
// A class to revert all the changes made to RandomVariableNetwork since the Reverter object is
|
|
// constructed. Only used when setEqualIfCompatible results in incompatible.
|
|
class RandomVariableNetwork::Reverter {
|
|
public:
|
|
// Take a snapshot of RandomVariableNetwork when Reverter is constructed.
|
|
Reverter() : mSnapshot(*RandomVariableNetwork::get()) {}
|
|
// Add constraint (Equal) nodes to the reverter.
|
|
void addNode(const RandomVariableNode& node) { mEqualNodes.push_back(node); }
|
|
void revert() {
|
|
NN_FUZZER_LOG << "Revert RandomVariableNetwork";
|
|
// Release the constraints.
|
|
for (const auto& node : mEqualNodes) unsetEqual(node);
|
|
// Reset all member variables.
|
|
*RandomVariableNetwork::get() = std::move(mSnapshot);
|
|
}
|
|
|
|
private:
|
|
Reverter(const Reverter&) = delete;
|
|
Reverter& operator=(const Reverter&) = delete;
|
|
RandomVariableNetwork mSnapshot;
|
|
std::vector<RandomVariableNode> mEqualNodes;
|
|
};
|
|
|
|
bool RandomVariableNetwork::setEqualIfCompatible(const std::vector<RandomVariable>& lhs,
|
|
const std::vector<RandomVariable>& rhs) {
|
|
NN_FUZZER_LOG << "Check compatibility of {" << joinStr(", ", lhs) << "} and {"
|
|
<< joinStr(", ", rhs) << "}";
|
|
if (lhs.size() != rhs.size()) return false;
|
|
Reverter reverter;
|
|
bool result = true;
|
|
for (size_t i = 0; i < lhs.size(); i++) {
|
|
auto node = lhs[i].setEqual(rhs[i]).get();
|
|
reverter.addNode(node);
|
|
// Early terminate if there is no common choice between two ranges.
|
|
if (node != nullptr && node->range.empty()) result = false;
|
|
}
|
|
result = result && evalRange();
|
|
if (!result) reverter.revert();
|
|
NN_FUZZER_LOG << "setEqualIfCompatible: " << (result ? "[COMPATIBLE]" : "[INCOMPATIBLE]");
|
|
return result;
|
|
}
|
|
|
|
bool RandomVariableNetwork::freeze() {
|
|
NN_FUZZER_LOG << "Freeze the random network";
|
|
if (!evalRange()) return false;
|
|
|
|
std::vector<RandomVariableNode> nodes;
|
|
for (const auto& pair : mEvalOrderMap) {
|
|
// Find all FREE RandomVariables in the subnet.
|
|
for (const auto& var : pair.second) {
|
|
if (var->type == RandomVariableType::FREE) nodes.push_back(var);
|
|
}
|
|
}
|
|
|
|
// Randomly shuffle the order, this is for a more uniform randomness.
|
|
randomShuffle(&nodes);
|
|
|
|
// An inefficient algorithm that does freeze -> re-evaluate for every FREE RandomVariable.
|
|
// TODO: Might be able to optimize this.
|
|
for (const auto& var : nodes) {
|
|
if (var->type != RandomVariableType::FREE) continue;
|
|
size_t size = var->range.size();
|
|
NN_FUZZER_LOG << "Freeze " << var;
|
|
var->freeze();
|
|
NN_FUZZER_LOG << " " << var;
|
|
// There is no need to re-evaluate if the FREE RandomVariable have only one choice.
|
|
if (size > 1) {
|
|
var->updateTimestamp();
|
|
if (!evalRange()) {
|
|
NN_FUZZER_LOG << "Freeze failed at " << var;
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
NN_FUZZER_LOG << "Finish freezing the random network";
|
|
return true;
|
|
}
|
|
|
|
} // namespace fuzzing_test
|
|
} // namespace nn
|
|
} // namespace android
|