Fix clone affinity (#4577)
* Added test
* Fixed code style
* Removed propagate_rt_info function
* Revert "Removed propagate_rt_info function"
This reverts commit 514a21b93c
.
* Don't skip conflicted RT info
* Fixed code style
This commit is contained in:
parent
2983962449
commit
fff0a85910
@ -1248,7 +1248,19 @@ void propagate_rt_info(Node* node, const Output<Node>& final_port)
|
||||
if (stop_nodes.count(in.get_node()))
|
||||
continue;
|
||||
auto consumer = in.get_node()->shared_from_this();
|
||||
// FIXME: Here we have a WA in order to save some original fields
|
||||
// if we have conflicts because Variant merge doesn't work.
|
||||
// We can restore original fields because we don't change the operation
|
||||
auto orig_rt_info = consumer->get_rt_info();
|
||||
|
||||
copy_runtime_info({curr_node, consumer}, consumer);
|
||||
|
||||
auto& rt_info = consumer->get_rt_info();
|
||||
for (const auto& it : orig_rt_info)
|
||||
{
|
||||
if (rt_info.find(it.first) == rt_info.end())
|
||||
rt_info[it.first] = it.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/op_annotations.hpp"
|
||||
#include "ngraph/opsets/opset6.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "util/all_close.hpp"
|
||||
@ -261,6 +262,61 @@ TEST(graph_util, clone_multiple_results)
|
||||
auto copy = clone_function(*f);
|
||||
}
|
||||
|
||||
TEST(graph_util, clone_rt_info)
|
||||
{
|
||||
const std::string testAffinity = "CPU";
|
||||
std::shared_ptr<ngraph::Function> original_f;
|
||||
{
|
||||
ngraph::PartialShape shape({1, 84});
|
||||
ngraph::element::Type type(ngraph::element::Type_t::f32);
|
||||
auto param = std::make_shared<ngraph::opset6::Parameter>(type, shape);
|
||||
auto matMulWeights =
|
||||
ngraph::opset6::Constant::create(ngraph::element::Type_t::f32, {10, 84}, {1});
|
||||
auto shapeOf = std::make_shared<ngraph::opset6::ShapeOf>(matMulWeights);
|
||||
auto gConst1 = ngraph::opset6::Constant::create(ngraph::element::Type_t::i32, {1}, {1});
|
||||
auto gConst2 = ngraph::opset6::Constant::create(ngraph::element::Type_t::i64, {}, {0});
|
||||
auto gather = std::make_shared<ngraph::opset6::Gather>(shapeOf, gConst1, gConst2);
|
||||
auto concatConst = ngraph::opset6::Constant::create(ngraph::element::Type_t::i64, {1}, {1});
|
||||
auto concat =
|
||||
std::make_shared<ngraph::opset6::Concat>(ngraph::NodeVector{concatConst, gather}, 0);
|
||||
auto relu = std::make_shared<ngraph::opset6::Relu>(param);
|
||||
auto reshape = std::make_shared<ngraph::opset6::Reshape>(relu, concat, false);
|
||||
auto matMul = std::make_shared<ngraph::opset6::MatMul>(reshape, matMulWeights, false, true);
|
||||
auto matMulBias =
|
||||
ngraph::opset6::Constant::create(ngraph::element::Type_t::f32, {1, 10}, {1});
|
||||
auto addBias = std::make_shared<ngraph::opset6::Add>(matMul, matMulBias);
|
||||
auto result = std::make_shared<ngraph::opset6::Result>(addBias);
|
||||
|
||||
ngraph::ParameterVector params = {param};
|
||||
ngraph::ResultVector results = {result};
|
||||
|
||||
original_f = std::make_shared<ngraph::Function>(results, params);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> affinity;
|
||||
|
||||
for (auto&& node : original_f->get_ordered_ops())
|
||||
{
|
||||
auto& nodeInfo = node->get_rt_info();
|
||||
|
||||
nodeInfo["affinity"] = std::make_shared<ngraph::VariantWrapper<std::string>>(testAffinity);
|
||||
affinity[node->get_friendly_name()] = testAffinity;
|
||||
}
|
||||
|
||||
auto clonedFunction = ngraph::clone_function(*original_f);
|
||||
|
||||
for (auto&& node : clonedFunction->get_ordered_ops())
|
||||
{
|
||||
auto& nodeInfo = node->get_rt_info();
|
||||
auto itInfo = nodeInfo.find("affinity");
|
||||
ASSERT_TRUE(itInfo != nodeInfo.end());
|
||||
auto value =
|
||||
ngraph::as_type_ptr<ngraph::VariantWrapper<std::string>>(itInfo->second)->get();
|
||||
ASSERT_TRUE(affinity.find(node->get_friendly_name()) != affinity.end());
|
||||
ASSERT_TRUE(affinity[node->get_friendly_name()] == value);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(util, round_up)
|
||||
{
|
||||
EXPECT_EQ(0, round_up(0, 4));
|
||||
|
Loading…
Reference in New Issue
Block a user