GNA add RemoveExtraReshapes tranformation tests (#6461)
* add RemoveExtraReshapes transformation tests * use clone function instead of creating reference function code duplicate
This commit is contained in:
parent
b231b2b576
commit
30b4b4881e
@ -0,0 +1,89 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "transformations/remove_extra_reshapes.hpp"
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
namespace testing {
|
||||
|
||||
TEST(TransformationTests, RemoveExtraReshapesTestReshapeNotEqualInputOutput) {
|
||||
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
|
||||
const ngraph::Shape data_shape{1, 3, 64, 64};
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, data_shape);
|
||||
auto new_shape = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 3, 64 * 64});
|
||||
auto reshape_operation = std::make_shared<ngraph::opset7::Reshape>(input_params, new_shape, true);
|
||||
auto max_pool_operation = std::make_shared<ngraph::opset7::MaxPool>(reshape_operation,
|
||||
ngraph::Strides{1},
|
||||
ngraph::Shape{0},
|
||||
ngraph::Shape{0},
|
||||
ngraph::Shape{3});
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(max_pool_operation);
|
||||
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
|
||||
reference_func = ngraph::clone_function(*func);
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<GNAPluginNS::RemoveExtraReshapes>();
|
||||
m.run_passes(func);
|
||||
ASSERT_NO_THROW(check_rt_info(func));
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, RemoveExtraReshapesTestReshapeEqualInputOutput) {
|
||||
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
|
||||
const ngraph::Shape data_shape{1, 3, 64, 64};
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, data_shape);
|
||||
auto new_shape = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 3, 64, 64});
|
||||
auto reshape_operation = std::make_shared<ngraph::opset7::Reshape>(input_params, new_shape, true);
|
||||
auto max_pool_operation = std::make_shared<ngraph::opset7::MaxPool>(reshape_operation,
|
||||
ngraph::Strides{1, 1},
|
||||
ngraph::Shape{0, 0},
|
||||
ngraph::Shape{0, 0},
|
||||
ngraph::Shape{3, 3});
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(max_pool_operation);
|
||||
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<GNAPluginNS::RemoveExtraReshapes>();
|
||||
m.run_passes(func);
|
||||
ASSERT_NO_THROW(check_rt_info(func));
|
||||
}
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, data_shape);
|
||||
auto max_pool_operation = std::make_shared<ngraph::opset7::MaxPool>(input_params,
|
||||
ngraph::Strides{1, 1},
|
||||
ngraph::Shape{0, 0},
|
||||
ngraph::Shape{1, 1},
|
||||
ngraph::Shape{4, 4});
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(max_pool_operation);
|
||||
reference_func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
|
||||
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
} // namespace testing
|
Loading…
Reference in New Issue
Block a user