GNA add RemoveExtraReshapes tranformation tests (#6461)
* add RemoveExtraReshapes transformation tests * use clone function instead of creating reference function code duplicate
This commit is contained in:
parent
b231b2b576
commit
30b4b4881e
@ -0,0 +1,89 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "transformations/remove_extra_reshapes.hpp"
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
|
||||||
|
namespace testing {
|
||||||
|
|
||||||
|
TEST(TransformationTests, RemoveExtraReshapesTestReshapeNotEqualInputOutput) {
|
||||||
|
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
|
||||||
|
const ngraph::Shape data_shape{1, 3, 64, 64};
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, data_shape);
|
||||||
|
auto new_shape = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 3, 64 * 64});
|
||||||
|
auto reshape_operation = std::make_shared<ngraph::opset7::Reshape>(input_params, new_shape, true);
|
||||||
|
auto max_pool_operation = std::make_shared<ngraph::opset7::MaxPool>(reshape_operation,
|
||||||
|
ngraph::Strides{1},
|
||||||
|
ngraph::Shape{0},
|
||||||
|
ngraph::Shape{0},
|
||||||
|
ngraph::Shape{3});
|
||||||
|
auto result = std::make_shared<ngraph::opset7::Result>(max_pool_operation);
|
||||||
|
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||||
|
ngraph::ParameterVector{input_params});
|
||||||
|
|
||||||
|
reference_func = ngraph::clone_function(*func);
|
||||||
|
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<GNAPluginNS::RemoveExtraReshapes>();
|
||||||
|
m.run_passes(func);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(func));
|
||||||
|
}
|
||||||
|
|
||||||
|
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||||
|
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||||
|
ASSERT_TRUE(result.valid);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, RemoveExtraReshapesTestReshapeEqualInputOutput) {
|
||||||
|
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
|
||||||
|
const ngraph::Shape data_shape{1, 3, 64, 64};
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, data_shape);
|
||||||
|
auto new_shape = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 3, 64, 64});
|
||||||
|
auto reshape_operation = std::make_shared<ngraph::opset7::Reshape>(input_params, new_shape, true);
|
||||||
|
auto max_pool_operation = std::make_shared<ngraph::opset7::MaxPool>(reshape_operation,
|
||||||
|
ngraph::Strides{1, 1},
|
||||||
|
ngraph::Shape{0, 0},
|
||||||
|
ngraph::Shape{0, 0},
|
||||||
|
ngraph::Shape{3, 3});
|
||||||
|
auto result = std::make_shared<ngraph::opset7::Result>(max_pool_operation);
|
||||||
|
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||||
|
ngraph::ParameterVector{input_params});
|
||||||
|
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<GNAPluginNS::RemoveExtraReshapes>();
|
||||||
|
m.run_passes(func);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(func));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, data_shape);
|
||||||
|
auto max_pool_operation = std::make_shared<ngraph::opset7::MaxPool>(input_params,
|
||||||
|
ngraph::Strides{1, 1},
|
||||||
|
ngraph::Shape{0, 0},
|
||||||
|
ngraph::Shape{1, 1},
|
||||||
|
ngraph::Shape{4, 4});
|
||||||
|
auto result = std::make_shared<ngraph::opset7::Result>(max_pool_operation);
|
||||||
|
reference_func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||||
|
ngraph::ParameterVector{input_params});
|
||||||
|
}
|
||||||
|
|
||||||
|
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
|
||||||
|
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||||
|
ASSERT_TRUE(result.valid);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace testing
|
Loading…
Reference in New Issue
Block a user