Test Assign-ReadValue shape propagation while function changes input shape (#4515)

This commit is contained in:
Evgenya Stepyreva
2021-03-01 07:18:30 +03:00
committed by GitHub
parent cadcd7c926
commit a3458a2e0c

View File

@@ -16,7 +16,9 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/variable.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "ngraph/opsets/opset6.hpp"
#include "util/type_prop.hpp"
using namespace std;
@@ -51,3 +53,26 @@ TEST(type_prop, assign_deduce)
ASSERT_EQ(assign->get_element_type(), element::f32);
ASSERT_EQ(assign->get_shape(), (Shape{1, 2, 64, 64}));
}
TEST(type_prop, assign_read_value_new_shape)
{
auto input = make_shared<op::Parameter>(element::f16, Shape{4, 3, 2, 1});
auto variable =
std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "ID"});
auto read_value = make_shared<opset6::ReadValue>(input, variable);
auto assign = make_shared<opset6::Assign>(read_value, variable);
ASSERT_EQ(assign->get_element_type(), element::f16);
ASSERT_EQ(assign->get_shape(), (Shape{4, 3, 2, 1}));
auto f = std::make_shared<Function>(ResultVector{}, SinkVector{assign}, ParameterVector{input});
input->set_partial_shape({3, {4, 5}, 8});
f->validate_nodes_and_infer_types();
ASSERT_EQ(assign->get_element_type(), element::f16);
ASSERT_EQ(assign->get_output_partial_shape(0), (PartialShape{3, {4, 5}, 8}));
ASSERT_EQ(variable->get_info().data_type, element::f16);
ASSERT_EQ(variable->get_info().data_shape, (PartialShape{3, {4, 5}, 8}));
}