diff --git a/inference-engine/tests/functional/inference_engine/transformations/eliminate_split_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/eliminate_split_test.cpp new file mode 100644 index 00000000000..e6f17ba7774 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/eliminate_split_test.cpp @@ -0,0 +1,81 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST_F(TransformationTestsF, EliminateSplit) { + { + auto input = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + auto mul_constant = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {89.2}); + auto mul = std::make_shared(input, mul_constant); + auto axis_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {2}); + auto split = std::make_shared(mul, axis_const, 1); + auto res = std::make_shared(split); + function = std::make_shared(ngraph::NodeVector{res}, ngraph::ParameterVector{input}); + + manager.register_pass(); + } + { + auto input = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + auto mul_constant = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {89.2}); + auto mul = std::make_shared(input, mul_constant); + auto res = std::make_shared(mul); + function_ref = std::make_shared(ngraph::NodeVector{res}, ngraph::ParameterVector{input}); + } +} + +TEST_F(TransformationTestsF, EliminateSplitNegative) { + { + auto input = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + auto mul_constant = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {89.2}); + auto mul = std::make_shared(input, mul_constant); + auto axis_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {2}); + auto split = std::make_shared(mul, axis_const, 3); + auto res1 = std::make_shared(split->output(0)); + auto res2 = std::make_shared(split->output(1)); + auto res3 = std::make_shared(split->output(2)); + function = std::make_shared(ngraph::NodeVector{res1, res2, res3}, ngraph::ParameterVector{input}); + + manager.register_pass(); + } +} + +TEST_F(TransformationTestsF, EliminateSequenceOfSplits) { + { + auto input = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + auto axis_const1 = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto axis_const2 = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}); + auto axis_const3 = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {2}); + auto split1 = std::make_shared(input, axis_const1, 1); + auto split2 = std::make_shared(split1, axis_const2, 1); + auto split3 = std::make_shared(split2, axis_const3, 1); + auto axis_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {2}); + auto true_split = std::make_shared(split3, axis_const, 3); + auto res1 = std::make_shared(true_split->output(0)); + auto res2 = std::make_shared(true_split->output(1)); + auto res3 = std::make_shared(true_split->output(2)); + function = std::make_shared(ngraph::NodeVector{res1, res2, res3}, ngraph::ParameterVector{input}); + + manager.register_pass(); + } + + { + auto input = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + auto axis_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {2}); + auto split = std::make_shared(input, axis_const, 3); + auto res1 = std::make_shared(split->output(0)); + auto res2 = std::make_shared(split->output(1)); + auto res3 = std::make_shared(split->output(2)); + function_ref = std::make_shared(ngraph::NodeVector{res1, res2, res3}, ngraph::ParameterVector{input}); + } +} \ No newline at end of file