You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
277 lines
12 KiB
277 lines
12 KiB
/*
|
|
* Copyright (C) 2019 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H
|
|
#define ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H
|
|
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <numeric>
|
|
#include <set>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
namespace fuzzing_test {
|
|
|
|
static const int kMaxValue = 10000;
|
|
static const int kInvalidValue = INT_MIN;
|
|
|
|
// Describe the search range for the value of a random variable.
|
|
class RandomVariableRange {
|
|
public:
|
|
RandomVariableRange() = default;
|
|
explicit RandomVariableRange(int value) : mChoices({value}) {}
|
|
RandomVariableRange(int lower, int upper) : mChoices(upper - lower + 1) {
|
|
std::iota(mChoices.begin(), mChoices.end(), lower);
|
|
}
|
|
explicit RandomVariableRange(const std::vector<int>& vec) : mChoices(vec) {}
|
|
explicit RandomVariableRange(const std::set<int>& st) : mChoices(st.begin(), st.end()) {}
|
|
RandomVariableRange(const RandomVariableRange&) = default;
|
|
RandomVariableRange& operator=(const RandomVariableRange&) = default;
|
|
|
|
bool empty() const { return mChoices.empty(); }
|
|
bool has(int value) const {
|
|
return std::binary_search(mChoices.begin(), mChoices.end(), value);
|
|
}
|
|
size_t size() const { return mChoices.size(); }
|
|
int min() const { return *mChoices.begin(); }
|
|
int max() const { return *mChoices.rbegin(); }
|
|
const std::vector<int>& getChoices() const { return mChoices; }
|
|
|
|
// Narrow down the range to fit [lower, upper]. Use kInvalidValue to indicate unlimited bound.
|
|
void setRange(int lower, int upper);
|
|
// Narrow down the range to a random selected choice. Return the chosen value.
|
|
int toConst();
|
|
|
|
// Calculate the intersection of two ranges.
|
|
friend RandomVariableRange operator&(const RandomVariableRange& lhs,
|
|
const RandomVariableRange& rhs);
|
|
|
|
private:
|
|
// Always in ascending order.
|
|
std::vector<int> mChoices;
|
|
};
|
|
|
|
// Defines the interface for an operation applying to RandomVariables.
|
|
class IRandomVariableOp {
|
|
public:
|
|
virtual ~IRandomVariableOp() {}
|
|
// Forward evaluation of two values.
|
|
virtual int eval(int lhs, int rhs) const = 0;
|
|
// Gets the range of the operation outcomes. The returned range must include all possible
|
|
// outcomes of this operation, but may contain invalid results.
|
|
virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
|
|
const RandomVariableRange& rhs) const;
|
|
// Provides faster range evaluation for evalSubnetSingleOpHelper if possible.
|
|
virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
|
|
const std::set<int>* childIn, std::set<int>* parent1Out,
|
|
std::set<int>* parent2Out, std::set<int>* childOut) const;
|
|
// For debugging purpose.
|
|
virtual const char* getName() const = 0;
|
|
};
|
|
|
|
enum class RandomVariableType { FREE = 0, CONST = 1, OP = 2 };
|
|
|
|
struct RandomVariableBase {
|
|
// Each RandomVariableBase is assigned an unique index for debugging purpose.
|
|
static unsigned int globalIndex;
|
|
int index;
|
|
|
|
RandomVariableType type;
|
|
RandomVariableRange range;
|
|
int value = 0;
|
|
std::shared_ptr<const IRandomVariableOp> op = nullptr;
|
|
|
|
// Network structural information.
|
|
std::shared_ptr<RandomVariableBase> parent1 = nullptr;
|
|
std::shared_ptr<RandomVariableBase> parent2 = nullptr;
|
|
std::vector<std::weak_ptr<RandomVariableBase>> children;
|
|
|
|
// The last time that this RandomVariableBase is modified.
|
|
int timestamp;
|
|
|
|
explicit RandomVariableBase(int value);
|
|
RandomVariableBase(int lower, int upper);
|
|
explicit RandomVariableBase(const std::vector<int>& choices);
|
|
RandomVariableBase(const std::shared_ptr<RandomVariableBase>& lhs,
|
|
const std::shared_ptr<RandomVariableBase>& rhs,
|
|
const std::shared_ptr<const IRandomVariableOp>& op);
|
|
RandomVariableBase(const RandomVariableBase&) = delete;
|
|
RandomVariableBase& operator=(const RandomVariableBase&) = delete;
|
|
|
|
// Freeze FREE RandomVariable to one valid choice.
|
|
// Should only invoke on FREE RandomVariable.
|
|
void freeze();
|
|
|
|
// Get CONST value or calculate from parents.
|
|
// Should not invoke on FREE RandomVariable.
|
|
int getValue() const;
|
|
|
|
// Update the timestamp to the latest global time.
|
|
void updateTimestamp();
|
|
};
|
|
|
|
using RandomVariableNode = std::shared_ptr<RandomVariableBase>;
|
|
|
|
// A wrapper class of RandomVariableBase that manages RandomVariableBase with shared_ptr and
|
|
// provides useful methods and operator overloading to build the random variable network.
|
|
class RandomVariable {
|
|
public:
|
|
// Construct a placeholder RandomVariable with nullptr.
|
|
RandomVariable() : mVar(nullptr) {}
|
|
|
|
// Construct a CONST RandomVariable with specified value.
|
|
/* implicit */ RandomVariable(int value);
|
|
|
|
// Construct a FREE RandomVariable with range [lower, upper].
|
|
RandomVariable(int lower, int upper);
|
|
|
|
// Construct a FREE RandomVariable with specified value choices.
|
|
explicit RandomVariable(const std::vector<int>& choices);
|
|
|
|
// This is for RandomVariableType::FREE only.
|
|
// Construct a FREE RandomVariable with default range [1, defaultValue].
|
|
/* implicit */ RandomVariable(RandomVariableType type);
|
|
|
|
// RandomVariables share the same RandomVariableBase if copied or copy-assigned.
|
|
RandomVariable(const RandomVariable& other) = default;
|
|
RandomVariable& operator=(const RandomVariable& other) = default;
|
|
|
|
// Get the value of the RandomVariable, the value must be deterministic.
|
|
int getValue() const { return mVar->getValue(); }
|
|
|
|
// Get the underlying managed RandomVariableNode.
|
|
RandomVariableNode get() const { return mVar; };
|
|
|
|
bool operator==(std::nullptr_t) const { return mVar == nullptr; }
|
|
bool operator!=(std::nullptr_t) const { return mVar != nullptr; }
|
|
|
|
// Arithmetic operators and methods on RandomVariables.
|
|
friend RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs);
|
|
friend RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs);
|
|
friend RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs);
|
|
friend RandomVariable operator*(const RandomVariable& lhs, const float& rhs);
|
|
friend RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs);
|
|
friend RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs);
|
|
friend RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs);
|
|
friend RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs);
|
|
RandomVariable exactDiv(const RandomVariable& other);
|
|
|
|
// Set constraints on the RandomVariable. Use kInvalidValue to indicate unlimited bound.
|
|
void setRange(int lower, int upper);
|
|
RandomVariable setEqual(const RandomVariable& other) const;
|
|
RandomVariable setGreaterThan(const RandomVariable& other) const;
|
|
RandomVariable setGreaterEqual(const RandomVariable& other) const;
|
|
|
|
// A FREE RandomVariable is constructed with default range [1, defaultValue].
|
|
static int defaultValue;
|
|
|
|
private:
|
|
// Construct a RandomVariable as the result of an OP between two other RandomVariables.
|
|
RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs,
|
|
const std::shared_ptr<const IRandomVariableOp>& op);
|
|
RandomVariableNode mVar;
|
|
};
|
|
|
|
using EvaluationOrder = std::vector<RandomVariableNode>;
|
|
|
|
// The base class of a network consisting of disjoint subnets.
|
|
class DisjointNetwork {
|
|
public:
|
|
// Add a node to the network, join the parent subnets if needed.
|
|
void add(const RandomVariableNode& var);
|
|
|
|
// Similar to join(int, int), but accept RandomVariableNodes.
|
|
int join(const RandomVariableNode& var1, const RandomVariableNode& var2) {
|
|
return DisjointNetwork::join(mIndexMap[var1], mIndexMap[var2]);
|
|
}
|
|
|
|
protected:
|
|
DisjointNetwork() = default;
|
|
DisjointNetwork(const DisjointNetwork&) = default;
|
|
DisjointNetwork& operator=(const DisjointNetwork&) = default;
|
|
|
|
// Join two subnets by appending every node in ind2 after ind1, return the resulting subnet
|
|
// index. Use -1 for invalid subnet index.
|
|
int join(int ind1, int ind2);
|
|
|
|
// A map from the network node to the corresponding subnet index.
|
|
std::unordered_map<RandomVariableNode, int> mIndexMap;
|
|
|
|
// A map from the subnet index to the set of nodes within the subnet. The nodes are maintained
|
|
// in a valid evaluation order, that is, a valid topological sort.
|
|
std::map<int, EvaluationOrder> mEvalOrderMap;
|
|
|
|
// The next index for a new disjoint subnet component.
|
|
int mNextIndex = 0;
|
|
};
|
|
|
|
// Manages the active RandomVariable network. Only one instance of this class will exist.
|
|
class RandomVariableNetwork : public DisjointNetwork {
|
|
public:
|
|
// Returns the singleton network instance.
|
|
static RandomVariableNetwork* get();
|
|
|
|
// Re-initialization. Should be called every time a new random graph is being generated.
|
|
void initialize(int defaultValue);
|
|
|
|
// Set the elementwise equality of the two vectors of RandomVariables iff it results in a
|
|
// soluble network.
|
|
bool setEqualIfCompatible(const std::vector<RandomVariable>& lhs,
|
|
const std::vector<RandomVariable>& rhs);
|
|
|
|
// Freeze all FREE RandomVariables in the network to a random valid combination.
|
|
bool freeze();
|
|
|
|
// Check if node2 is FREE and can be evaluated after node1.
|
|
bool isSubordinate(const RandomVariableNode& node1, const RandomVariableNode& node2);
|
|
|
|
// Get and then advance the current global timestamp.
|
|
int getGlobalTime() { return mGlobalTime++; }
|
|
|
|
// Add a special constraint on dimension product.
|
|
void addDimensionProd(const std::vector<RandomVariable>& dims);
|
|
|
|
private:
|
|
RandomVariableNetwork() = default;
|
|
RandomVariableNetwork(const RandomVariableNetwork&) = default;
|
|
RandomVariableNetwork& operator=(const RandomVariableNetwork&) = default;
|
|
|
|
// A class to revert all the changes made to RandomVariableNetwork since the Reverter object is
|
|
// constructed. Only used when setEqualIfCompatible results in incompatible.
|
|
class Reverter;
|
|
|
|
// Find valid choices for all RandomVariables in the network. Update the RandomVariableRange
|
|
// if the network is soluble, otherwise, return false and leave the ranges unchanged.
|
|
bool evalRange();
|
|
|
|
int mGlobalTime = 0;
|
|
int mTimestamp = -1;
|
|
|
|
std::vector<EvaluationOrder> mDimProd;
|
|
};
|
|
|
|
} // namespace fuzzing_test
|
|
} // namespace nn
|
|
} // namespace android
|
|
|
|
#endif // ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H
|