Optimize Function Topological Sort (#8519)
* Initial optimization implementation * Add tests * Updated function and node accessors * Avoid using static mutex in get_ordered_ops * Apply comments * resolve PR comments * Move mutex up * Simplify shared info * Fix info copying in node copy c-tor * Move friend declarations to public * Fix to support node access from multiple functions in different threads
This commit is contained in:
parent
14d1f7c844
commit
c307f206dc
@ -8,6 +8,7 @@
|
||||
#include <initializer_list>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -24,6 +25,7 @@
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
|
||||
namespace ov {
|
||||
class FunctionAccessor;
|
||||
/// A user-defined function.
|
||||
class OPENVINO_API Function : public std::enable_shared_from_this<Function> {
|
||||
public:
|
||||
@ -36,6 +38,7 @@ public:
|
||||
}
|
||||
OPENVINO_DEPRECATED("This member was deprecated. Please use ::get_type_info_static() instead.")
|
||||
static const ov::DiscreteTypeInfo type_info;
|
||||
|
||||
Function(const ov::NodeVector& results, const ov::ParameterVector& parameters, const std::string& name = "");
|
||||
|
||||
Function(const ov::OutputVector& results, const ov::ParameterVector& parameters, const std::string& name = "");
|
||||
@ -293,6 +296,8 @@ public:
|
||||
Function& operator=(Function&&) = delete;
|
||||
|
||||
private:
|
||||
friend class ov::FunctionAccessor;
|
||||
|
||||
/// \brief Depending on the options selected,
|
||||
/// checks all the Parameter/Variables are registered in the list of Function
|
||||
/// parameters/variables or finds all Parameters/Variables in a function and registers them.
|
||||
@ -315,6 +320,17 @@ private:
|
||||
ov::ParameterVector m_parameters;
|
||||
ov::op::util::VariableVector m_variables;
|
||||
RTMap m_rt_info;
|
||||
|
||||
// Cache of topologically sorted nodes which is stored as a vector
|
||||
// of weak_ptr not to increase node ref counter to prevent the situation when
|
||||
// node has no consumers but still exists in a graph.
|
||||
mutable std::vector<std::weak_ptr<Node>> m_cached_ordered_ops;
|
||||
|
||||
// Private runtime info which is shared across nodes and used only
|
||||
// for internal purposes.
|
||||
std::shared_ptr<SharedRTInfo> m_shared_rt_info;
|
||||
|
||||
mutable std::mutex m_topological_sort_mutex;
|
||||
};
|
||||
|
||||
OPENVINO_API
|
||||
|
@ -71,6 +71,10 @@ class Output;
|
||||
|
||||
class Node;
|
||||
|
||||
class Function;
|
||||
|
||||
class SharedRTInfo;
|
||||
|
||||
/// EvaluationContext stores and manages a context (additional parameters, values and
|
||||
/// environment) for evaluating ov::Function.
|
||||
using EvaluationContext = std::map<std::string, std::shared_ptr<Variant>>;
|
||||
@ -100,6 +104,8 @@ std::string node_validation_failure_loc_string(const Node* node);
|
||||
case element::Type_t::a: \
|
||||
rc = evaluate<element::Type_t::a>
|
||||
|
||||
class NodeAccessor;
|
||||
|
||||
/// Nodes are the backbone of the graph of Value dataflow. Every node has
|
||||
/// zero or more nodes as arguments and one value, which is either a tensor
|
||||
/// or a (possibly empty) tuple of values.
|
||||
@ -115,18 +121,20 @@ class OPENVINO_API Node : public std::enable_shared_from_this<Node> {
|
||||
template <typename NodeType>
|
||||
friend class Output;
|
||||
|
||||
friend class Function;
|
||||
|
||||
protected:
|
||||
descriptor::Input& get_input_descriptor(size_t position);
|
||||
descriptor::Output& get_output_descriptor(size_t position);
|
||||
|
||||
/// \brief Construct an unitialized Node
|
||||
/// \brief Construct an uninitialized Node
|
||||
Node() = default;
|
||||
/// \brief Copying a node
|
||||
Node(const Node&);
|
||||
/// \brief Assignment operator
|
||||
Node& operator=(const Node&);
|
||||
|
||||
/// \brief Construct an unitialized Node
|
||||
/// \brief Construct an uninitialized Node
|
||||
/// \param output_size Number of outputs for this node
|
||||
Node(size_t output_size);
|
||||
|
||||
@ -383,9 +391,6 @@ public:
|
||||
OPENVINO_DEPRECATED("The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.")
|
||||
const std::string& get_input_tensor_name(size_t i) const;
|
||||
|
||||
std::unordered_set<descriptor::Tensor*> liveness_new_list;
|
||||
std::unordered_set<descriptor::Tensor*> liveness_free_list;
|
||||
|
||||
Node* get_input_node_ptr(size_t index) const;
|
||||
std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
|
||||
Output<Node> get_input_source_output(size_t i) const;
|
||||
@ -478,9 +483,9 @@ public:
|
||||
virtual bool match_node(ov::pass::pattern::Matcher* matcher, const Output<Node>& graph_value);
|
||||
|
||||
private:
|
||||
friend class ov::NodeAccessor;
|
||||
std::vector<Node*> m_control_dependents;
|
||||
std::vector<std::shared_ptr<Node>> m_control_dependencies;
|
||||
std::string m_node_type;
|
||||
size_t m_instance_id{m_next_instance_id.fetch_add(1)};
|
||||
std::string m_friendly_name;
|
||||
mutable std::string m_unique_name;
|
||||
@ -491,7 +496,20 @@ private:
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
|
||||
RTMap m_rt_info;
|
||||
|
||||
// The vector of SharedRTInfo attributes associated to Functions
|
||||
// where this node belongs to. SharedRTInfo is private field which
|
||||
// is used for internal purposes. For example: tracking changes
|
||||
// during graph transformations.
|
||||
std::set<std::shared_ptr<SharedRTInfo>> m_shared_rt_info;
|
||||
|
||||
// As node can be included into different Functions which
|
||||
// can be executed into multiple threads means that m_shared_rt_info
|
||||
// can be updated simultaneously, so we have to guaranty exclusive
|
||||
// update of this field by having specific method with mutex.
|
||||
void insert_info(std::shared_ptr<SharedRTInfo> info);
|
||||
std::mutex m_insert_mutex;
|
||||
};
|
||||
|
||||
using NodeTypeInfo = Node::type_info_t;
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "openvino/core/descriptor/output.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/core/type/element_type.hpp"
|
||||
#include "shared_node_info.hpp"
|
||||
|
||||
ov::descriptor::Input::Input(ov::Node* node, size_t index, Output& output)
|
||||
: m_node(node),
|
||||
@ -44,6 +45,14 @@ void ov::descriptor::Input::replace_output(Output& new_output) {
|
||||
// if a new input violates one of the type checks in the c-tor.
|
||||
m_node->clone_with_new_inputs(m_node->input_values());
|
||||
}
|
||||
|
||||
// Output replacement may change the topological order of nodes,
|
||||
// so we have to reset cache by setting a flag into shared node info.
|
||||
for_each(m_node->m_shared_rt_info.cbegin(),
|
||||
m_node->m_shared_rt_info.cend(),
|
||||
[](const std::shared_ptr<SharedRTInfo>& info) {
|
||||
info->set_use_topological_cache(false);
|
||||
});
|
||||
}
|
||||
|
||||
void ov::descriptor::Input::replace_output(const std::shared_ptr<ov::Node>& node, size_t i) {
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include "openvino/op/util/variable_context.hpp"
|
||||
#include "openvino/op/util/variable_extension.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "shared_node_info.hpp"
|
||||
#include "transformations/smart_reshape/smart_reshape.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -191,6 +192,8 @@ ov::Function::Function(const OutputVector& results, const string& name)
|
||||
void ov::Function::prerequirements(bool detect_variables, bool detect_parameters) {
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "Function::prerequirements");
|
||||
|
||||
m_shared_rt_info = std::make_shared<SharedRTInfo>();
|
||||
|
||||
const auto& ordered_ops = get_ordered_ops();
|
||||
if (detect_parameters)
|
||||
m_parameters = auto_detect_parameters(ordered_ops);
|
||||
@ -254,10 +257,20 @@ void ov::Function::validate_nodes_and_infer_types() const {
|
||||
|
||||
std::vector<shared_ptr<ov::Node>> ov::Function::get_ordered_ops() const {
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "Function::get_ordered_ops");
|
||||
lock_guard<mutex> lock(m_topological_sort_mutex);
|
||||
|
||||
vector<shared_ptr<Node>> nodes;
|
||||
for (auto& r : get_results()) {
|
||||
nodes.push_back(r);
|
||||
NodeVector nodes;
|
||||
if (m_shared_rt_info->get_use_topological_cache()) {
|
||||
for (const auto& node : m_cached_ordered_ops) {
|
||||
if (auto locked_node = node.lock()) {
|
||||
nodes.emplace_back(locked_node);
|
||||
}
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
for (const auto& r : get_results()) {
|
||||
nodes.emplace_back(r);
|
||||
}
|
||||
for (auto& r : get_sinks()) {
|
||||
nodes.emplace_back(r);
|
||||
@ -266,7 +279,18 @@ std::vector<shared_ptr<ov::Node>> ov::Function::get_ordered_ops() const {
|
||||
nodes.push_back(param);
|
||||
}
|
||||
|
||||
return m_topological_sorter(nodes);
|
||||
auto order = m_topological_sorter(nodes);
|
||||
|
||||
// Update nodes cache and update all nodes to have shared rt info
|
||||
// which belongs to the current Function.
|
||||
m_cached_ordered_ops.clear();
|
||||
for_each(order.cbegin(), order.cend(), [this](const shared_ptr<Node>& node) {
|
||||
m_cached_ordered_ops.push_back(node);
|
||||
node->insert_info(m_shared_rt_info);
|
||||
});
|
||||
m_shared_rt_info->set_use_topological_cache(true);
|
||||
|
||||
return order;
|
||||
}
|
||||
|
||||
void ov::Function::map_unordered_ops(std::function<void(Node*)> f) const {
|
||||
@ -396,6 +420,8 @@ void ov::Function::replace_parameter(size_t parameter_index, const shared_ptr<ng
|
||||
|
||||
void ov::Function::set_topological_sort(topological_sort_t sorter) {
|
||||
m_topological_sorter = sorter;
|
||||
// reset topological nodes order cache as new sorter can have different behaviour
|
||||
m_shared_rt_info->set_use_topological_cache(false);
|
||||
}
|
||||
|
||||
int64_t ov::Function::get_parameter_index(const std::shared_ptr<ngraph::op::Parameter>& parameter) const {
|
||||
@ -558,6 +584,9 @@ void ov::Function::add_sinks(const ngraph::SinkVector& sinks) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// reset topological nodes order cache as new sinks/results/parameters
|
||||
// can be in a separate connectivity component.
|
||||
m_shared_rt_info->set_use_topological_cache(false);
|
||||
}
|
||||
|
||||
void ov::Function::remove_sink(const std::shared_ptr<ngraph::op::Sink>& sink) {
|
||||
@ -567,10 +596,14 @@ void ov::Function::remove_sink(const std::shared_ptr<ngraph::op::Sink>& sink) {
|
||||
return s == sink;
|
||||
}),
|
||||
m_sinks.end());
|
||||
m_shared_rt_info->set_use_topological_cache(false);
|
||||
}
|
||||
|
||||
void ov::Function::add_results(const ResultVector& results) {
|
||||
m_results.insert(m_results.end(), results.begin(), results.end());
|
||||
// reset topological nodes order cache as new sinks/results/parameters
|
||||
// can be in a separate connectivity component.
|
||||
m_shared_rt_info->set_use_topological_cache(false);
|
||||
}
|
||||
|
||||
void ov::Function::remove_result(const std::shared_ptr<ngraph::op::Result>& result) {
|
||||
@ -580,6 +613,7 @@ void ov::Function::remove_result(const std::shared_ptr<ngraph::op::Result>& resu
|
||||
return r == result;
|
||||
}),
|
||||
m_results.end());
|
||||
m_shared_rt_info->set_use_topological_cache(false);
|
||||
}
|
||||
|
||||
void ov::Function::add_parameters(const ngraph::ParameterVector& params) {
|
||||
@ -593,6 +627,9 @@ void ov::Function::add_parameters(const ngraph::ParameterVector& params) {
|
||||
}
|
||||
}
|
||||
m_parameters.insert(m_parameters.end(), params.begin(), params.end());
|
||||
// reset topological nodes order cache as new sinks/results/parameters
|
||||
// can be in a separate connectivity component.
|
||||
m_shared_rt_info->set_use_topological_cache(false);
|
||||
}
|
||||
|
||||
void ov::Function::remove_parameter(const std::shared_ptr<ngraph::op::Parameter>& param) {
|
||||
@ -602,6 +639,7 @@ void ov::Function::remove_parameter(const std::shared_ptr<ngraph::op::Parameter>
|
||||
return r == param;
|
||||
}),
|
||||
m_parameters.end());
|
||||
m_shared_rt_info->set_use_topological_cache(false);
|
||||
}
|
||||
|
||||
void ov::Function::add_variables(const op::util::VariableVector& variables) {
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "ngraph/op/result.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "openvino/core/descriptor/input.hpp"
|
||||
#include "shared_node_info.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@ -25,9 +26,7 @@ atomic<size_t> ov::Node::m_next_instance_id(0);
|
||||
|
||||
ov::Node::Node(const Node& node)
|
||||
: m_control_dependents(node.m_control_dependents),
|
||||
m_control_dependencies(node.m_control_dependencies)
|
||||
// skip m_node_type -- will be generated automatically
|
||||
,
|
||||
m_control_dependencies(node.m_control_dependencies),
|
||||
m_instance_id(m_next_instance_id.fetch_add(1)),
|
||||
m_friendly_name(node.m_friendly_name)
|
||||
// skip m_unique_name -- will be generated automatically
|
||||
@ -60,6 +59,11 @@ ov::Node& ov::Node::operator=(const Node& node) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
void ov::Node::insert_info(std::shared_ptr<SharedRTInfo> info) {
|
||||
std::lock_guard<std::mutex> lock(m_insert_mutex);
|
||||
m_shared_rt_info.insert(std::move(info));
|
||||
}
|
||||
|
||||
ov::Node::Node(size_t output_size) : Node() {
|
||||
set_output_size(output_size);
|
||||
}
|
||||
@ -70,6 +74,11 @@ ov::Node::Node(const OutputVector& arguments, size_t output_size) : Node() {
|
||||
}
|
||||
|
||||
ov::Node::~Node() {
|
||||
// raise a flag to reset nodes cache
|
||||
for_each(m_shared_rt_info.cbegin(), m_shared_rt_info.cend(), [](const std::shared_ptr<SharedRTInfo>& info) {
|
||||
info->set_use_topological_cache(false);
|
||||
});
|
||||
|
||||
for (descriptor::Input& input : m_inputs) {
|
||||
if (input.has_output()) {
|
||||
// This test adds 1 to the actual count, so a count of 2 means this input is the only
|
||||
@ -162,6 +171,11 @@ void ov::Node::set_arguments(const OutputVector& arguments) {
|
||||
auto& output_descriptor = output_node->m_outputs.at(output.get_index());
|
||||
m_inputs.emplace_back(this, i++, output_descriptor);
|
||||
}
|
||||
|
||||
// set_arguments doesn't use replace_output method, so we have to reset cache manually here
|
||||
for_each(this->m_shared_rt_info.cbegin(), this->m_shared_rt_info.cend(), [](std::shared_ptr<SharedRTInfo> info) {
|
||||
info->set_use_topological_cache(false);
|
||||
});
|
||||
}
|
||||
|
||||
ov::descriptor::Input& ov::Node::get_input_descriptor(size_t position) {
|
||||
@ -279,6 +293,12 @@ void ov::Node::add_control_dependency(std::shared_ptr<Node> node) {
|
||||
node->m_control_dependents.push_back(this);
|
||||
}
|
||||
}
|
||||
|
||||
// control dependency may change the topological order so we have to reset cache
|
||||
// by setting a flag into shared node info.
|
||||
for_each(node->m_shared_rt_info.cbegin(), node->m_shared_rt_info.cend(), [](std::shared_ptr<SharedRTInfo> info) {
|
||||
info->set_use_topological_cache(false);
|
||||
});
|
||||
}
|
||||
|
||||
void ov::Node::add_node_control_dependencies(std::shared_ptr<Node> source_node) {
|
||||
|
28
ngraph/core/src/shared_node_info.hpp
Normal file
28
ngraph/core/src/shared_node_info.hpp
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/core/except.hpp>
|
||||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/core/variant.hpp>
|
||||
|
||||
namespace ov {
|
||||
class SharedRTInfo {
|
||||
public:
|
||||
SharedRTInfo() : m_use_topological_cache(false) {}
|
||||
|
||||
void set_use_topological_cache(bool status) {
|
||||
m_use_topological_cache = status;
|
||||
}
|
||||
|
||||
bool get_use_topological_cache() const {
|
||||
return m_use_topological_cache;
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_use_topological_cache;
|
||||
};
|
||||
} // namespace ov
|
@ -6,6 +6,9 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <shared_node_info.hpp>
|
||||
#include <test_common.hpp>
|
||||
|
||||
#include "openvino/core/partial_shape.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
|
||||
@ -966,3 +969,262 @@ TEST(function, add_output_port_to_result) {
|
||||
EXPECT_NO_THROW(f->add_output(result->output(0)));
|
||||
EXPECT_EQ(f->get_results().size(), 1);
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool all_ops_have_same_info(const std::shared_ptr<ov::Function>& f) {
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
for (auto&& op : f->get_ordered_ops()) {
|
||||
if (std::set<std::shared_ptr<ov::SharedRTInfo>>({shared_info}) != ov::NodeAccessor(op).get_shared_info()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(function, topological_sort_caching_basic) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
// Check that after function creation which call get_ordered_ops
|
||||
// cache is set to true value
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
// Check that nodes contains the same shared info after function creation
|
||||
ASSERT_EQ(ov::NodeAccessor(arg0).get_shared_info().size(), 1);
|
||||
ASSERT_TRUE(ov::NodeAccessor(arg0).get_shared_info().count(shared_info));
|
||||
|
||||
ASSERT_EQ(ov::NodeAccessor(relu1).get_shared_info().size(), 1);
|
||||
ASSERT_TRUE(ov::NodeAccessor(relu1).get_shared_info().count(shared_info));
|
||||
|
||||
ASSERT_EQ(ov::NodeAccessor(relu2).get_shared_info().size(), 1);
|
||||
ASSERT_TRUE(ov::NodeAccessor(relu2).get_shared_info().count(shared_info));
|
||||
|
||||
ASSERT_EQ(ov::NodeAccessor(result).get_shared_info().size(), 1);
|
||||
ASSERT_TRUE(ov::NodeAccessor(result).get_shared_info().count(shared_info));
|
||||
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), 4);
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_replace_node) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
auto new_relu = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
ov::replace_node(relu2, new_relu);
|
||||
|
||||
// Function has changed so cache must be updated
|
||||
ASSERT_FALSE(shared_info->get_use_topological_cache());
|
||||
|
||||
// Before get_ordered_ops, new_node shouldn't have shared_info, but after
|
||||
// it will be set to the function shared_info and cache will be used.
|
||||
ASSERT_FALSE(ov::NodeAccessor(new_relu).get_shared_info().count(shared_info));
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), 4);
|
||||
ASSERT_TRUE(ov::NodeAccessor(new_relu).get_shared_info().count(shared_info));
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
ASSERT_TRUE(all_ops_have_same_info(f));
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_replace_source_output) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
relu2->input(0).replace_source_output(relu1);
|
||||
|
||||
// Function has changed so cache must be updated
|
||||
ASSERT_FALSE(shared_info->get_use_topological_cache());
|
||||
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), 4);
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
ASSERT_TRUE(all_ops_have_same_info(f));
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_dangling_node) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
auto new_relu = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
|
||||
// Function has not changed so cache mustn't be updated
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
// Dangling node is not in Function
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), 4);
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_replace_output) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
auto new_relu = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
relu2->output(0).replace(new_relu);
|
||||
|
||||
// Function has changed so cache must be updated
|
||||
ASSERT_FALSE(shared_info->get_use_topological_cache());
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), 4);
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
ASSERT_TRUE(all_ops_have_same_info(f));
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_set_argument) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
relu2->set_argument(0, arg0);
|
||||
|
||||
// Function has changed so cache must be updated
|
||||
ASSERT_FALSE(shared_info->get_use_topological_cache());
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), 3);
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
ASSERT_TRUE(all_ops_have_same_info(f));
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_set_arguments) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
relu2->set_arguments({arg0->output(0)});
|
||||
|
||||
// Function has changed so cache must be updated
|
||||
ASSERT_FALSE(shared_info->get_use_topological_cache());
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), 3);
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
ASSERT_TRUE(all_ops_have_same_info(f));
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_add_cf) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
relu2->add_control_dependency(arg0);
|
||||
|
||||
// Function has changed so cache must be updated
|
||||
ASSERT_FALSE(shared_info->get_use_topological_cache());
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), 4);
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
ASSERT_TRUE(all_ops_have_same_info(f));
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_result_parameter_sink) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
auto result = std::make_shared<ov::opset8::Result>(relu2);
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
|
||||
auto check_caching_status = [=](int64_t expected_number_of_ops) {
|
||||
ASSERT_FALSE(shared_info->get_use_topological_cache());
|
||||
ASSERT_EQ(f->get_ordered_ops().size(), expected_number_of_ops);
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
ASSERT_TRUE(all_ops_have_same_info(f));
|
||||
};
|
||||
|
||||
auto result2 = std::make_shared<ov::opset8::Result>(relu2);
|
||||
f->add_results({result2});
|
||||
check_caching_status(5);
|
||||
|
||||
f->remove_result(result2);
|
||||
check_caching_status(4);
|
||||
|
||||
auto arg1 = std::make_shared<ov::opset8::Parameter>();
|
||||
f->add_parameters({arg1});
|
||||
check_caching_status(5);
|
||||
|
||||
f->remove_parameter(arg1);
|
||||
check_caching_status(4);
|
||||
|
||||
auto assign = std::make_shared<ov::opset8::Assign>();
|
||||
f->add_sinks({assign});
|
||||
check_caching_status(5);
|
||||
|
||||
f->remove_sink(assign);
|
||||
check_caching_status(4);
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_multiple_components) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu0 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto result0 = std::make_shared<ov::opset8::Result>(relu0);
|
||||
|
||||
auto arg1 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg1);
|
||||
auto result1 = std::make_shared<ov::opset8::Result>(relu1);
|
||||
|
||||
auto f = std::make_shared<ov::Function>(ov::ResultVector{result0, result1}, ov::ParameterVector{arg0, arg1});
|
||||
|
||||
auto shared_info = ov::FunctionAccessor(f).get_shared_info();
|
||||
ASSERT_TRUE(shared_info->get_use_topological_cache());
|
||||
ASSERT_TRUE(all_ops_have_same_info(f));
|
||||
}
|
||||
|
||||
TEST(function, topological_sort_caching_shared_nodes) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
auto relu0 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
auto result0 = std::make_shared<ov::opset8::Result>(relu0);
|
||||
|
||||
auto f1 = std::make_shared<ov::Function>(ov::ResultVector{result0}, ov::ParameterVector{arg0});
|
||||
auto f2 = std::make_shared<ov::Function>(ov::ResultVector{result0}, ov::ParameterVector{arg0});
|
||||
|
||||
auto f1_shared_info = ov::FunctionAccessor(f1).get_shared_info();
|
||||
auto f2_shared_info = ov::FunctionAccessor(f2).get_shared_info();
|
||||
|
||||
for (auto&& node : f1->get_ordered_ops()) {
|
||||
auto node_info = ov::NodeAccessor(node).get_shared_info();
|
||||
// As two Functions owns the same node so node will have two shared_info objects
|
||||
ASSERT_EQ(node_info.size(), 2);
|
||||
ASSERT_TRUE(node_info.count(f1_shared_info));
|
||||
ASSERT_TRUE(node_info.count(f2_shared_info));
|
||||
}
|
||||
|
||||
relu0->add_control_dependency(arg0); // simply drop cache
|
||||
ASSERT_FALSE(f1_shared_info->get_use_topological_cache());
|
||||
ASSERT_FALSE(f2_shared_info->get_use_topological_cache());
|
||||
}
|
||||
|
@ -16,5 +16,7 @@ target_include_directories(ngraph_test_util PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
file(GLOB_RECURSE all_util_src "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp")
|
||||
add_clang_format_target(ngraph_test_util_clang FOR_SOURCES ${all_util_src})
|
||||
|
||||
set_source_files_properties(${all_util_src} PROPERTIES INCLUDE_DIRECTORIES ${CMAKE_CURRENT_SOURCE_DIR}/../../core/src/)
|
||||
|
||||
# developer package
|
||||
openvino_developer_export_targets(COMPONENT ngraph TARGETS ngraph_test_util)
|
||||
|
@ -86,4 +86,17 @@ std::string TestsCommon::GetTestName() const {
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
std::shared_ptr<SharedRTInfo> FunctionAccessor::get_shared_info() const {
|
||||
if (auto f = m_function.lock()) {
|
||||
return f->m_shared_rt_info;
|
||||
}
|
||||
throw ngraph::ngraph_error("Original function is not available");
|
||||
}
|
||||
|
||||
std::set<std::shared_ptr<SharedRTInfo>> NodeAccessor::get_shared_info() const {
|
||||
if (auto node = m_node.lock()) {
|
||||
return node->m_shared_rt_info;
|
||||
}
|
||||
throw ngraph::ngraph_error("Original node is not available");
|
||||
}
|
||||
} // namespace ov
|
||||
|
@ -7,7 +7,10 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/core/function.hpp>
|
||||
#include <shared_node_info.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@ -22,4 +25,22 @@ protected:
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
|
||||
class FunctionAccessor {
|
||||
std::weak_ptr<Function> m_function;
|
||||
|
||||
public:
|
||||
FunctionAccessor(std::weak_ptr<Function> f) : m_function(std::move(f)) {}
|
||||
|
||||
std::shared_ptr<SharedRTInfo> get_shared_info() const;
|
||||
};
|
||||
|
||||
class NodeAccessor {
|
||||
std::weak_ptr<Node> m_node;
|
||||
|
||||
public:
|
||||
NodeAccessor(std::weak_ptr<Node> node) : m_node(std::move(node)) {}
|
||||
|
||||
std::set<std::shared_ptr<SharedRTInfo>> get_shared_info() const;
|
||||
};
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user