Removed redundant methods from function and util (#1505)
This commit is contained in:
parent
18836f53cd
commit
5b918810d0
@ -35,39 +35,33 @@ Function::Function(const ResultVector& results,
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name)
|
||||
: Lambda(results, parameters)
|
||||
, m_temporary_pool_size(0)
|
||||
, m_instance_id(m_next_instance_id.fetch_add(1))
|
||||
, m_name(name)
|
||||
, m_unique_name("Function_" + to_string(m_instance_id))
|
||||
, m_unique_name("Function_" + to_string(m_next_instance_id.fetch_add(1)))
|
||||
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
|
||||
{
|
||||
init();
|
||||
validate_nodes_and_infer_types();
|
||||
}
|
||||
|
||||
Function::Function(const OutputVector& results,
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name)
|
||||
: Lambda(results, parameters)
|
||||
, m_temporary_pool_size(0)
|
||||
, m_instance_id(m_next_instance_id.fetch_add(1))
|
||||
, m_name(name)
|
||||
, m_unique_name("Function_" + to_string(m_instance_id))
|
||||
, m_unique_name("Function_" + to_string(m_next_instance_id.fetch_add(1)))
|
||||
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
|
||||
{
|
||||
init();
|
||||
validate_nodes_and_infer_types();
|
||||
}
|
||||
|
||||
Function::Function(const NodeVector& results,
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name)
|
||||
: Lambda(as_output_vector(results), parameters)
|
||||
, m_temporary_pool_size(0)
|
||||
, m_instance_id(m_next_instance_id.fetch_add(1))
|
||||
, m_name(name)
|
||||
, m_unique_name("Function_" + to_string(m_instance_id))
|
||||
, m_unique_name("Function_" + to_string(m_next_instance_id.fetch_add(1)))
|
||||
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
|
||||
{
|
||||
init();
|
||||
validate_nodes_and_infer_types();
|
||||
}
|
||||
|
||||
Function::Function(const std::shared_ptr<Node>& result,
|
||||
@ -95,11 +89,6 @@ void Function::validate_nodes_and_infer_types()
|
||||
}
|
||||
}
|
||||
|
||||
void Function::init()
|
||||
{
|
||||
validate_nodes_and_infer_types();
|
||||
}
|
||||
|
||||
std::vector<shared_ptr<Node>> Function::get_ordered_ops() const
|
||||
{
|
||||
vector<shared_ptr<Node>> nodes;
|
||||
@ -172,16 +161,6 @@ void Function::set_friendly_name(const string& name)
|
||||
}
|
||||
}
|
||||
|
||||
size_t Function::get_temporary_pool_size()
|
||||
{
|
||||
return m_temporary_pool_size;
|
||||
}
|
||||
|
||||
void Function::set_temporary_pool_size(size_t size)
|
||||
{
|
||||
m_temporary_pool_size = size;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Function& f)
|
||||
{
|
||||
out << "Function(" << f.get_name() << ")";
|
||||
|
@ -52,10 +52,7 @@ namespace ngraph
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name = "");
|
||||
|
||||
void init();
|
||||
|
||||
virtual ~Function() {}
|
||||
public:
|
||||
/// Return the number of outputs for this function.
|
||||
size_t get_output_size() const;
|
||||
|
||||
@ -97,9 +94,6 @@ namespace ngraph
|
||||
void map_unordered_ops(std::function<void(Node*)> f) const;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream&, const Function&);
|
||||
size_t get_instance_id() { return m_instance_id; }
|
||||
size_t get_temporary_pool_size();
|
||||
void set_temporary_pool_size(size_t);
|
||||
// updates graph and m_results list
|
||||
void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
|
||||
|
||||
@ -127,16 +121,12 @@ namespace ngraph
|
||||
const std::vector<std::shared_ptr<Node>>& root_nodes)>;
|
||||
void set_topological_sort(topological_sort_t);
|
||||
|
||||
protected:
|
||||
size_t m_temporary_pool_size;
|
||||
|
||||
private:
|
||||
Function(const Function&) = delete;
|
||||
Function(const Function&&) = delete;
|
||||
Function& operator=(const Function&) = delete;
|
||||
|
||||
static std::atomic<size_t> m_next_instance_id;
|
||||
size_t m_instance_id;
|
||||
std::string m_name;
|
||||
const std::string m_unique_name;
|
||||
size_t m_placement{0};
|
||||
|
@ -37,11 +37,6 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
std::string ngraph::to_cplusplus_sourcecode_literal(bool val)
|
||||
{
|
||||
return val ? "true" : "false";
|
||||
}
|
||||
|
||||
void ngraph::dump(ostream& out, const void* _data, size_t _size)
|
||||
{
|
||||
auto flags = out.flags();
|
||||
@ -195,128 +190,6 @@ size_t ngraph::round_up(size_t size, size_t alignment)
|
||||
return size + alignment - remainder;
|
||||
}
|
||||
|
||||
ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
|
||||
std::shared_ptr<ngraph::Function> bprop)
|
||||
{
|
||||
using namespace ngraph;
|
||||
|
||||
// Create a fprop_cache object to store the results of this analysis
|
||||
FpropCache fprop_cache;
|
||||
|
||||
// Traverse bprop to find all of the nodes in the bprop graph
|
||||
std::set<Output<Node>> in_bprop;
|
||||
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
|
||||
for (auto value : node->outputs())
|
||||
{
|
||||
in_bprop.insert(value);
|
||||
}
|
||||
});
|
||||
|
||||
// Traverse fprop to make a map that stores parameters with the same
|
||||
// shape and element type as the nodes in fprop iff they are in bprop
|
||||
// and aren't inputs to bprop
|
||||
vector<Output<Node>> bprop_inputs;
|
||||
for (auto param : bprop->get_parameters())
|
||||
{
|
||||
bprop_inputs.push_back(param);
|
||||
}
|
||||
ngraph::traverse_nodes(
|
||||
fprop, [&fprop_cache, &in_bprop, &bprop_inputs](std::shared_ptr<Node> node) {
|
||||
for (auto value : node->outputs())
|
||||
{
|
||||
if (in_bprop.count(value) != 0 &&
|
||||
std::find(bprop_inputs.begin(), bprop_inputs.end(), value) ==
|
||||
bprop_inputs.end())
|
||||
{
|
||||
fprop_cache.node_param_map[value] = std::make_shared<op::Parameter>(
|
||||
value.get_element_type(), value.get_shape());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// clone the nodes in bprop, replacing fprop-related nodes with the
|
||||
// intermediate parameters from fprop_cache. This breaks connections in the
|
||||
// bprop graph such that only intermediate values from fprop needed by bprop
|
||||
// are still connected to the bprop graph as parameters
|
||||
ngraph::clone_nodes(bprop->get_ops(), fprop_cache.node_param_map);
|
||||
|
||||
// invert the fprop_cache cloned node map for easy back and for acces.
|
||||
std::map<Output<Node>, RawNodeOutput> inverted_node_map;
|
||||
for (auto kv : fprop_cache.node_param_map)
|
||||
{
|
||||
inverted_node_map[kv.second] = kv.first;
|
||||
}
|
||||
|
||||
// get cloned bprop results
|
||||
ResultVector cloned_results;
|
||||
NodeVector result_nodes;
|
||||
for (auto node : bprop->get_results())
|
||||
{
|
||||
auto result = as_type_ptr<op::Result>(
|
||||
fprop_cache.node_param_map.at(Output<Node>(node)).get_node_shared_ptr());
|
||||
if (!result)
|
||||
{
|
||||
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
|
||||
}
|
||||
cloned_results.push_back(result);
|
||||
result_nodes.push_back(result);
|
||||
}
|
||||
|
||||
// Utility for getting bprop parameters with fprop cache.
|
||||
auto get_bprop_params = [&bprop_inputs, &fprop_cache]() {
|
||||
// get cloned bprop parameters
|
||||
ParameterVector bprop_input_params;
|
||||
for (auto param : bprop_inputs)
|
||||
{
|
||||
bprop_input_params.push_back(as_type_ptr<op::Parameter>(
|
||||
fprop_cache.node_param_map.at(Output<Node>(param)).get_node_shared_ptr()));
|
||||
}
|
||||
|
||||
// add the cached fprop nodes as inputs to bprop
|
||||
for (auto x : fprop_cache.fprop_output_nodes)
|
||||
{
|
||||
bprop_input_params.push_back(
|
||||
as_type_ptr<op::Parameter>(fprop_cache.node_param_map.at(x).get_node_shared_ptr()));
|
||||
}
|
||||
return bprop_input_params;
|
||||
};
|
||||
|
||||
// Traverse the graph from the cloned results of bprop. If we find a parameter
|
||||
// that's not an original input of bprop, this is an intermediate value of
|
||||
// fprop that needs to be returned from fprop and send to bprop
|
||||
auto cloned_bprop_inputs = get_bprop_params();
|
||||
ngraph::traverse_nodes(
|
||||
result_nodes,
|
||||
[&cloned_bprop_inputs, &fprop_cache, &inverted_node_map](std::shared_ptr<Node> node) {
|
||||
auto pnode = as_type_ptr<op::Parameter>(node);
|
||||
if (pnode &&
|
||||
std::find(cloned_bprop_inputs.begin(), cloned_bprop_inputs.end(), pnode) ==
|
||||
cloned_bprop_inputs.end())
|
||||
{
|
||||
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(Output<Node>(node)));
|
||||
}
|
||||
});
|
||||
|
||||
// create the new outputs for fprop and the new fprop function
|
||||
ResultVector fprop_outputs = fprop->get_results();
|
||||
|
||||
for (auto fpirn : fprop_cache.fprop_output_nodes)
|
||||
{
|
||||
if (as_type_ptr<op::Result>(fpirn.node->shared_from_this()))
|
||||
{
|
||||
throw ngraph_error("Unexpected op::Result in fprop->get_results()");
|
||||
}
|
||||
fprop_outputs.push_back(std::make_shared<op::Result>(fpirn));
|
||||
}
|
||||
|
||||
fprop_cache.fprop = std::make_shared<Function>(fprop_outputs, fprop->get_parameters());
|
||||
|
||||
// Create the new bprop function with cloned results and cached parameters.
|
||||
fprop_cache.bprop = std::make_shared<Function>(cloned_results, get_bprop_params());
|
||||
|
||||
return fprop_cache;
|
||||
}
|
||||
|
||||
size_t stopwatch::get_call_count() const
|
||||
{
|
||||
return m_total_count;
|
||||
@ -444,50 +317,6 @@ std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv)
|
||||
return os;
|
||||
}
|
||||
|
||||
void ngraph::check_fp_values_isinf(const char* name, const float* array, size_t n)
|
||||
{
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
if (std::isinf(array[i]))
|
||||
{
|
||||
throw std::runtime_error("Discovered Inf in '" + string(name) + "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::check_fp_values_isinf(const char* name, const double* array, size_t n)
|
||||
{
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
if (std::isinf(array[i]))
|
||||
{
|
||||
throw std::runtime_error("Discovered Inf in '" + string(name) + "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::check_fp_values_isnan(const char* name, const float* array, size_t n)
|
||||
{
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
if (std::isnan(array[i]))
|
||||
{
|
||||
throw std::runtime_error("Discovered NaN in '" + string(name) + "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t n)
|
||||
{
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
if (std::isnan(array[i]))
|
||||
{
|
||||
throw std::runtime_error("Discovered NaN in '" + string(name) + "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ngraph::is_valid_permutation(ngraph::AxisVector permutation, ngraph::Rank rank)
|
||||
{
|
||||
std::vector<bool> axis_occurs(permutation.size(), false);
|
||||
@ -591,16 +420,6 @@ AxisVector ngraph::get_default_order(size_t rank)
|
||||
return default_order;
|
||||
}
|
||||
|
||||
AxisVector ngraph::get_permutation_to_default_order(const AxisVector& axis_order)
|
||||
{
|
||||
AxisVector out(axis_order.size());
|
||||
for (size_t i = 0; i < axis_order.size(); i++)
|
||||
{
|
||||
out.at(axis_order[i]) = i;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void ngraph::parse_version_string(
|
||||
std::string version, size_t& major, size_t& minor, size_t& patch, string& extra)
|
||||
{
|
||||
|
@ -50,9 +50,6 @@ namespace ngraph
|
||||
class Tensor;
|
||||
}
|
||||
|
||||
NGRAPH_API
|
||||
std::string to_cplusplus_sourcecode_literal(bool val);
|
||||
|
||||
template <typename T>
|
||||
std::string join(const T& v, const std::string& sep = ", ")
|
||||
{
|
||||
@ -202,11 +199,6 @@ namespace ngraph
|
||||
return y > x ? 0 : x - y;
|
||||
}
|
||||
|
||||
void check_fp_values_isinf(const char* name, const float* array, size_t n);
|
||||
void check_fp_values_isinf(const char* name, const double* array, size_t n);
|
||||
void check_fp_values_isnan(const char* name, const float* array, size_t n);
|
||||
void check_fp_values_isnan(const char* name, const double* array, size_t n);
|
||||
|
||||
NGRAPH_API
|
||||
void* ngraph_malloc(size_t size);
|
||||
NGRAPH_API
|
||||
@ -237,34 +229,6 @@ namespace ngraph
|
||||
NGRAPH_API
|
||||
AxisVector get_default_order(const Shape& shape);
|
||||
|
||||
NGRAPH_API
|
||||
AxisVector get_permutation_to_default_order(const AxisVector& axis_order);
|
||||
|
||||
//
|
||||
// Return type struct for cache_fprop, with the modified fprop and bprop
|
||||
// functions
|
||||
// and a list of the nodes that have been appended to fprop output/bprop
|
||||
// input
|
||||
//
|
||||
struct FpropCache
|
||||
{
|
||||
std::shared_ptr<Function> fprop;
|
||||
std::shared_ptr<Function> bprop;
|
||||
std::vector<RawNodeOutput> fprop_output_nodes;
|
||||
RawNodeOutputMap node_param_map;
|
||||
};
|
||||
|
||||
//
|
||||
// This utility takes forward-propogation and back-propagation functions
|
||||
// and turns them into clone functions where the intermediate values of
|
||||
// the forward prop are added to the output of fprop and the input of the bprop
|
||||
// to avoid repeat calculations.
|
||||
// The last argument is the adjoints coming into the bprop function, the output
|
||||
// bprop function will have these nodes as the first N input parameters
|
||||
//
|
||||
NGRAPH_API
|
||||
FpropCache cache_fprop(std::shared_ptr<Function> fprop, std::shared_ptr<Function> bprop);
|
||||
|
||||
// NodeExecutors are used in compiler optimization passes like ConstantFolding to execute a node
|
||||
// using the supplied input and output memory locations.
|
||||
// A BuildNodeExecutor returns a backend-specific NodeExecutor for a given Node type
|
||||
@ -274,15 +238,6 @@ namespace ngraph
|
||||
|
||||
using BuildNodeExecutorMap = std::unordered_map<std::type_index, BuildNodeExecutor>;
|
||||
|
||||
enum class TensorRole
|
||||
{
|
||||
INPUT,
|
||||
CONSTANT,
|
||||
OUTPUT,
|
||||
INTERMEDIATE,
|
||||
UNKNOWN
|
||||
};
|
||||
|
||||
//
|
||||
// EnumMask is intended to work with a scoped enum type. It's used to store
|
||||
// a combination of enum values and provides easy access and manipulation
|
||||
|
Loading…
Reference in New Issue
Block a user