Remove GetOutputElement op (#1604)
This commit is contained in:
@@ -120,59 +120,3 @@ TEST(replace_node, replace_nodes)
|
||||
ASSERT_EQ(z_replacement->get_input_node_shared_ptr(0), x_replacement);
|
||||
ASSERT_EQ(z_replacement->get_input_node_shared_ptr(1), mul);
|
||||
}
|
||||
|
||||
TEST(replace_node, replace_nodes_output_order)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f16, Shape{4, 3});
|
||||
auto topk_v0 = make_shared<op::v0::TopK>(data, 0, element::i32, 2, true);
|
||||
|
||||
auto topk_v1 = make_shared<op::v1::TopK>(data,
|
||||
op::Constant::create(element::i32, Shape{}, {2}),
|
||||
0,
|
||||
op::v1::TopK::Mode::MAX,
|
||||
op::v1::TopK::SortType::SORT_VALUES,
|
||||
element::i32);
|
||||
|
||||
auto values = make_shared<op::GetOutputElement>(topk_v1, 0);
|
||||
auto indices = make_shared<op::GetOutputElement>(topk_v1, 1);
|
||||
|
||||
ASSERT_EQ(values->get_input_element_type(0), element::f16);
|
||||
ASSERT_EQ(indices->get_input_element_type(0), element::i32);
|
||||
|
||||
std::vector<int64_t> output_order{1, 0};
|
||||
replace_node(topk_v1, topk_v0, output_order);
|
||||
|
||||
ASSERT_EQ(values->get_input_element_type(0), element::f16);
|
||||
ASSERT_EQ(indices->get_input_element_type(0), element::i32);
|
||||
}
|
||||
|
||||
TEST(replace_node, replace_nodes_output_order_incorrect_size)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f16, Shape{4, 3});
|
||||
auto topk_v0 = make_shared<op::v0::TopK>(data, 0, element::i32, 2, true);
|
||||
|
||||
auto topk_v1 = make_shared<op::v1::TopK>(data,
|
||||
op::Constant::create(element::i32, Shape{}, {2}),
|
||||
0,
|
||||
op::v1::TopK::Mode::MAX,
|
||||
op::v1::TopK::SortType::SORT_VALUES,
|
||||
element::i32);
|
||||
|
||||
auto values = make_shared<op::GetOutputElement>(topk_v1, 0);
|
||||
auto indices = make_shared<op::GetOutputElement>(topk_v1, 1);
|
||||
|
||||
std::vector<int64_t> output_order{2, 1, 0};
|
||||
try
|
||||
{
|
||||
replace_node(topk_v1, topk_v0, output_order);
|
||||
FAIL() << "Incorrect output order size exception not detected";
|
||||
}
|
||||
catch (const ngraph_error& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Target output size: "));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Incorrect output order size exception not thrown for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user