Fix clone_function method in case of Assign/ReadValue v3 (#7406)

This commit is contained in:
Ivan Tikhonov 2021-09-08 20:35:24 +03:00 committed by GitHub
parent f89b3d770b
commit 7bc6a8ea13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 3 deletions

View File

@ -348,8 +348,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
VariableVector cloned_vars;
std::map<std::string, std::shared_ptr<Variable>> var_map;
for (const auto& var : variables) {
auto cloned_var = std::make_shared<Variable>(
VariableInfo{PartialShape::dynamic(), element::dynamic, var->get_info().variable_id});
auto cloned_var = std::make_shared<Variable>(var->get_info());
cloned_vars.push_back(cloned_var);
var_map[cloned_var->get_info().variable_id] = cloned_var;
}

View File

@ -13,6 +13,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/opsets/opset6.hpp"
#include "ngraph/opsets/opset8.hpp"
#include "ngraph/pass/manager.hpp"
@ -236,7 +237,7 @@ TEST(graph_util, clone_multiple_results) {
auto copy = clone_function(*f);
}
TEST(graph_util, clone_function_variables) {
TEST(graph_util, clone_function_variables_dynamic) {
auto c_fp16 = make_shared<opset8::Constant>(element::f16, Shape{3}, std::vector<float>{0});
auto variable = make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "var_1"});
auto read_value = make_shared<opset8::ReadValue>(c_fp16, variable);
@ -253,6 +254,19 @@ TEST(graph_util, clone_function_variables) {
copy = clone_function(*f);
}
TEST(graph_util, clone_function_variables_validate_partially) {
auto c_fp16 = make_shared<opset8::Constant>(element::f16, Shape{3}, std::vector<float>{0});
auto read_value = make_shared<opset3::ReadValue>(c_fp16, "var_1");
auto assign = make_shared<opset3::Assign>(read_value, "var_1");
auto res = make_shared<opset3::Result>(read_value);
auto f = make_shared<Function>(ResultVector{res}, SinkVector{assign}, ParameterVector{});
f->validate_nodes_and_infer_types();
NodeMap nm;
auto copy = clone_function(*f, nm);
nm[assign.get()]->validate_and_infer_types();
}
TEST(graph_util, clone_rt_info) {
const std::string testAffinity = "CPU";
std::shared_ptr<ngraph::Function> original_f;