Fix clone affinity (#4577)

* Added test

* Fixed code style

* Removed propagate_rt_info function

* Revert "Removed propagate_rt_info function"

This reverts commit 514a21b93c.

* Don't skip conflicted RT info

* Fixed code style
This commit is contained in:
Ilya Churaev 2021-03-04 13:01:44 +03:00 committed by GitHub
parent 2983962449
commit fff0a85910
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 0 deletions

View File

@ -1248,7 +1248,19 @@ void propagate_rt_info(Node* node, const Output<Node>& final_port)
if (stop_nodes.count(in.get_node()))
continue;
auto consumer = in.get_node()->shared_from_this();
// FIXME: Here we have a WA in order to save some original fields
// if we have conflicts because Variant merge doesn't work.
// We can restore original fields because we don't change the operation
auto orig_rt_info = consumer->get_rt_info();
copy_runtime_info({curr_node, consumer}, consumer);
auto& rt_info = consumer->get_rt_info();
for (const auto& it : orig_rt_info)
{
if (rt_info.find(it.first) == rt_info.end())
rt_info[it.first] = it.second;
}
}
}
}

View File

@ -26,6 +26,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/opsets/opset6.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "util/all_close.hpp"
@ -261,6 +262,61 @@ TEST(graph_util, clone_multiple_results)
auto copy = clone_function(*f);
}
TEST(graph_util, clone_rt_info)
{
const std::string testAffinity = "CPU";
std::shared_ptr<ngraph::Function> original_f;
{
ngraph::PartialShape shape({1, 84});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::opset6::Parameter>(type, shape);
auto matMulWeights =
ngraph::opset6::Constant::create(ngraph::element::Type_t::f32, {10, 84}, {1});
auto shapeOf = std::make_shared<ngraph::opset6::ShapeOf>(matMulWeights);
auto gConst1 = ngraph::opset6::Constant::create(ngraph::element::Type_t::i32, {1}, {1});
auto gConst2 = ngraph::opset6::Constant::create(ngraph::element::Type_t::i64, {}, {0});
auto gather = std::make_shared<ngraph::opset6::Gather>(shapeOf, gConst1, gConst2);
auto concatConst = ngraph::opset6::Constant::create(ngraph::element::Type_t::i64, {1}, {1});
auto concat =
std::make_shared<ngraph::opset6::Concat>(ngraph::NodeVector{concatConst, gather}, 0);
auto relu = std::make_shared<ngraph::opset6::Relu>(param);
auto reshape = std::make_shared<ngraph::opset6::Reshape>(relu, concat, false);
auto matMul = std::make_shared<ngraph::opset6::MatMul>(reshape, matMulWeights, false, true);
auto matMulBias =
ngraph::opset6::Constant::create(ngraph::element::Type_t::f32, {1, 10}, {1});
auto addBias = std::make_shared<ngraph::opset6::Add>(matMul, matMulBias);
auto result = std::make_shared<ngraph::opset6::Result>(addBias);
ngraph::ParameterVector params = {param};
ngraph::ResultVector results = {result};
original_f = std::make_shared<ngraph::Function>(results, params);
}
std::unordered_map<std::string, std::string> affinity;
for (auto&& node : original_f->get_ordered_ops())
{
auto& nodeInfo = node->get_rt_info();
nodeInfo["affinity"] = std::make_shared<ngraph::VariantWrapper<std::string>>(testAffinity);
affinity[node->get_friendly_name()] = testAffinity;
}
auto clonedFunction = ngraph::clone_function(*original_f);
for (auto&& node : clonedFunction->get_ordered_ops())
{
auto& nodeInfo = node->get_rt_info();
auto itInfo = nodeInfo.find("affinity");
ASSERT_TRUE(itInfo != nodeInfo.end());
auto value =
ngraph::as_type_ptr<ngraph::VariantWrapper<std::string>>(itInfo->second)->get();
ASSERT_TRUE(affinity.find(node->get_friendly_name()) != affinity.end());
ASSERT_TRUE(affinity[node->get_friendly_name()] == value);
}
}
TEST(util, round_up)
{
EXPECT_EQ(0, round_up(0, 4));