nGraph version of the MO transformation SplitConcatPairToInte… (#7850)

* Written nGraph version of the MO transformation SplitConcatPairToInterpolate.

* Small fix.

* Started to write tests for the transformation.

* Small fixes.

* Written more tests.

* Deleted commented code.

* Deleted debug prints.

* Added the transformation SplitConcatPairToInterpolateFusion into common_fusions.

* Small fix.

* Relaced std::set by std::unordered_set.

* Now the function grouped_vector is not template function.

* Small simplification.

* Deleted commented code.

* Now std::pair is used instead of SplitAndScale.

* Enabled the transformation SplitConcatPairToInterpolateFusion and added it into MOCTransformations pipeline.

* Removed the transformation from common_optimization.cpp.

* Small fixes.

* Added comment to the function grouped_vector.

* Deleted the local variable result from the function get_split_before_concat().

* Small change.

* Added comments for the conditions on input rank.

* Used std::tie instead of .first and .second.

* Skipped the ngraph namespace specification in the transformation callback.

* Got rid of std::unordered_set of std::shared_ptr<ngraph::opset8::Split>.

* size_t was replaced with uint64_t.

* Added descrption of the transformation.

* Small fix.

* Added condition that scale_factor != 1.

* Added more tests. Also the transformation has boolean parameter use_shape_for_elimination.

* Small fix.

* Written CPU layer tests for the nGraph transformation SplitConcatPairToInterpolateFusion.

* Some fixes.

* Added tests for the case of dynamic input shapes.
This commit is contained in:
Vladimir Gavrilov 2021-10-25 13:59:42 +03:00 committed by GitHub
parent 8741b1504a
commit 86208efd56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1048 additions and 0 deletions

View File

@ -0,0 +1,33 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SplitConcatPairToInterpolateFusion;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SplitConcatPairToInterpolateFusion transformation replaces group of
* operations: Split -> Concat to Interpolate op.
*/
class ngraph::pass::SplitConcatPairToInterpolateFusion : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SplitConcatPairToInterpolateFusion(bool use_shape_for_elimination = true);
};

View File

@ -48,6 +48,7 @@
#include <transformations/common_optimizations/transpose_to_reshape.hpp>
#include <transformations/common_optimizations/batch_to_space_fusion.hpp>
#include <transformations/common_optimizations/mul_conv_fusion.hpp>
#include "transformations/common_optimizations/split_concat_pair_to_interpolate_fusion.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
@ -120,6 +121,7 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
common_fusions->add_matcher<ngraph::pass::LeakyReluFusion>();
common_fusions->add_matcher<ngraph::pass::RandomUniformFusion>();
common_fusions->add_matcher<ngraph::pass::SplitConcatPairToInterpolateFusion>(m_use_shapes);
common_fusions->set_name("ngraph::pass::CommonFusions");
manager.register_pass<ngraph::pass::BinarizeWeights>();

View File

@ -0,0 +1,205 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "itt.hpp"
#include "transformations/common_optimizations/split_concat_pair_to_interpolate_fusion.hpp"
#include <algorithm>
#include <memory>
#include <numeric>
#include <tuple>
#include <unordered_set>
#include <utility>
#include <vector>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/rt_info/disable_constant_folding.hpp>
namespace {
// This function creates a partition of its argument into groups consisting of adjacent identical elements.
// Argument: std::vector<uint64_t> v
// Returns: partition of the argument
std::vector<std::vector<uint64_t>> grouped_vector(const std::vector<uint64_t>& v) {
std::vector<std::vector<uint64_t>> result;
if (v.empty()) return result;
uint64_t prev = v[0];
std::vector<uint64_t> group;
for (const auto& x : v) {
if (prev != x) {
result.emplace_back(group);
group.clear();
prev = x;
}
group.emplace_back(x);
}
result.emplace_back(group);
return result;
}
std::pair<std::shared_ptr<ngraph::opset8::Split>, uint64_t> get_split_before_concat(const std::shared_ptr<ngraph::opset8::Concat>& concat) {
// This function gets producers of the 'concat' node, checks that the following conditions are fulfilled:
// 1) all producers for 'concat' are Split nodes;
// 2) 'concat' has only one unique producer ('split');
// 3) 'split' node has only one consumer;
// 4) for any output port of 'split', number of corresponding input ports of the consumer is the same;
// 5) for any output port 'i' of the 'split', corresponding input ports of the consumer are
// [i * m, ..., i * m + (m - 1)], where 'm' is the same for all 'i';
// and, if all these conditions are fulfilled, returns the above mentioned 'Concat' node. Otherwise, if some of these
// conditions is false, this functions returns nullptr.
std::vector<uint64_t> idx;
std::shared_ptr<ngraph::opset8::Split> split;
for (const auto& input : concat->input_values()) {
// If 'concat' has some non-Split producer, then the transformation is not applicable.
auto split_op = std::dynamic_pointer_cast<ngraph::opset8::Split>(input.get_node_shared_ptr());
if (!split)
split = split_op;
if (!split_op || split != split_op) return {};
idx.emplace_back(static_cast<uint64_t>(input.get_index()));
}
// If 'split' node has more than one consumer, then the transformation is not applicable.
for (const auto& output : split->outputs()) {
for (const auto& consumer : output.get_target_inputs()) {
if (consumer.get_node() != concat.get()) return {};
}
}
// If numbers of consumer ports are various for various output ports of 'split', then the transformation is not applicable.
auto grouped_idx = grouped_vector(idx);
std::unordered_set<uint64_t> sizes_of_groups;
for (const auto& group : grouped_idx) {
sizes_of_groups.insert(static_cast<uint64_t>(group.size()));
}
if (sizes_of_groups.size() != 1) return {};
uint64_t size_of_group = *(sizes_of_groups.begin());
// The transformation is applicable if output port 0 of 'split' goes to ports [0, ..., m-1] of next node,
// output port 1 of 'split' goes to ports [m, ..., m + (m-1)] of next node, ..., output port i of 'split'
// goes to ports [i * m, ..., i * m + (m - 1)], and so on.
for (uint64_t i = 0; i < static_cast<uint64_t>(grouped_idx.size()); ++i) {
const auto& current_group = grouped_idx[i];
if (std::any_of(current_group.begin(), current_group.end(), [i](uint64_t j){ return j != i; })) { return {}; }
}
return {split, size_of_group};
}
} // namespace
NGRAPH_RTTI_DEFINITION(ngraph::pass::SplitConcatPairToInterpolateFusion, "SplitConcatPairToInterpolateFusion", 0);
ngraph::pass::SplitConcatPairToInterpolateFusion::SplitConcatPairToInterpolateFusion(bool use_shape_for_elimination) {
MATCHER_SCOPE(SplitConcatPairToInterpolateFusion);
// This transformation looks for Interpolate layer implemented using simple operations, namely Split and Concat,
// and replaces found pattern with a sequence of Shape, StridedSlice, Const, Mul, Interpolate.
// Found pattern:
// Split -> Concat
// Here we assume that
// 1) input data of Split is 4D or 5D tensor;
// 2) split dimensions for 'split' belongs to {1, 2, 3};
// 3) all outputs of 'split' go to only inputs of 'concat';
// 4) 'concat' takes inputs only from 'split';
// 5) split_dim of 'split' is equal to axis of 'concat';
// 6) output port 0 of 'split' goes to ports [0, ..., m-1] of next node, output port 1 of 'split' goes to ports
// [m, ..., m + (m-1)] of next node, ..., output port i of 'split' goes to ports [i * m, ..., i * m + (m - 1)],
// and so on;
// 7) number of outputs of 'split' is equal to the length of the split axis.
// Such subgraph
// Split -> Concat
// can be replaced with the Interpolate layer with the following attributes:
// mode = nearest
// shape_calculation_mode = scales
// nearest_mode = round_prefer_floor
// pads_begin = {0}
// pads_end = {0}
// antialias = false
// coordinate_transformation_mode = half_pixel
// cube_coeff = -0.75
// Next, the scaling factor in Interpolate is equal to a quotient of dividing number of input ports of 'concat'
// by number of output ports of 'split'.
//
// Detect only concat, because we don't know how many inputs will go into concat.
auto concat_pattern = ngraph::pattern::wrap_type<ngraph::opset8::Concat>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto concat = std::dynamic_pointer_cast<opset8::Concat>(m.get_match_root());
if (!concat) return false;
uint64_t scale_factor;
std::shared_ptr<opset8::Split> split;
std::tie(split, scale_factor) = get_split_before_concat(concat);
// If scale_factor == 1, then output data of Interpolate are equal to input data. Hence, we should not replace
// Split->Concat pair with Interpolate.
if (!split || !scale_factor || scale_factor == 1) return false;
if (split->get_input_partial_shape(0).rank().is_dynamic()) return false;
int64_t split_input_rank = split->get_input_partial_shape(0).rank().get_length();
// If this transformation is applied in the case of the the rank is less than 4, we have a performance degradation.
// And, at this time, we have no models with Split->Concat pattern when this transformation is applicable and
// input rank of Split is greater than 5.
if (split_input_rank != 4 && split_input_rank != 5) return false;
auto split_axis_const = std::dynamic_pointer_cast<opset8::Constant>(split->input_value(1).get_node_shared_ptr());
if (!split_axis_const) return false;
int64_t axis = split_axis_const->cast_vector<int64_t>()[0];
if (split->get_input_partial_shape(0)[axis].is_static() &&
split->get_input_partial_shape(0)[axis].get_length() != static_cast<int64_t>(split->outputs().size()))
return false;
opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto scales_node = opset8::Constant::create(element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = opset8::Constant::create(element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<opset8::ShapeOf>(split->input_value(0));
auto sslice_begin = opset8::Constant::create(element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = opset8::Constant::create(element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<opset8::Convert>(strided_slice_node, element::f32);
auto mul_node = std::make_shared<opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<opset8::Convert>(floor_node, element::i64);
std::shared_ptr<Node> sizes_node;
if (use_shape_for_elimination) {
sizes_node = get_constant_from_source(cast_mul_result_to_int);
} else {
disable_constant_folding(shape_node);
}
if (!sizes_node)
sizes_node = cast_mul_result_to_int;
auto interpolate = register_new_node<opset8::Interpolate>(split->input_value(0), sizes_node, scales_node, axis_node, attrs);
interpolate->set_friendly_name(concat->get_friendly_name());
copy_runtime_info({split, concat}, {scales_node, axis_node, shape_node, sslice_begin, sslice_end, strided_slice_node, cast_shape_to_float, mul_node,
floor_node, sizes_node, interpolate});
replace_node(concat, interpolate);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(concat_pattern, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,684 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/common_optimizations/split_concat_pair_to_interpolate_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial2D1) {
ngraph::Shape input_shape { 1, 100, 120, 150 };
int64_t axis = 3;
size_t num_splits = input_shape[axis];
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>(false);
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
auto sslice_begin = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<ngraph::opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(strided_slice_node, ngraph::element::f32);
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<ngraph::opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial2D2) {
ngraph::Shape input_shape { 1, 100, 120, 150 };
int64_t axis = 2;
size_t num_splits = input_shape[axis];
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>(false);
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
auto sslice_begin = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<ngraph::opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(strided_slice_node, ngraph::element::f32);
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<ngraph::opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial3D1) {
ngraph::Shape input_shape { 1, 3, 100, 120, 150 };
int64_t axis = 4;
size_t num_splits = input_shape[axis];
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>(false);
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
auto sslice_begin = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<ngraph::opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(strided_slice_node, ngraph::element::f32);
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<ngraph::opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial3D2) {
ngraph::Shape input_shape { 1, 3, 100, 120, 150 };
int64_t axis = 3;
size_t num_splits = input_shape[axis];
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>(false);
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
auto sslice_begin = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<ngraph::opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(strided_slice_node, ngraph::element::f32);
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<ngraph::opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionTwoSplitsOneConcat) {
size_t num_splits = 2;
int64_t axis = 4;
{
auto input1 = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 13, 13, 3, 2 });
auto input2 = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 13, 13, 3, 2 });
auto split1_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split1 = std::make_shared<ngraph::opset8::Split>(input1, split1_axis, num_splits);
auto split2_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split2 = std::make_shared<ngraph::opset8::Split>(input2, split2_axis, num_splits);
ngraph::OutputVector concat_inputs_vec{split1->output(0), split1->output(1), split2->output(0), split2->output(1)};
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input1, input2 });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>();
}
{
auto input1 = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 13, 13, 3, 2 });
auto input2 = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 13, 13, 3, 2 });
auto split1_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split1 = std::make_shared<ngraph::opset8::Split>(input1, split1_axis, num_splits);
auto split2_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split2 = std::make_shared<ngraph::opset8::Split>(input2, split2_axis, num_splits);
ngraph::OutputVector concat_inputs_vec{split1->output(0), split1->output(1), split2->output(0), split2->output(1)};
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input1, input2 });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial2D1WithConstantFolding) {
ngraph::Shape input_shape { 1, 100, 120, 150 };
int64_t axis = 3;
size_t num_splits = input_shape[axis];
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
int64_t target_size = static_cast<int64_t>(input_shape[axis]) * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>();
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{target_size});
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, sizes_node, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial2D2WithConstantFolding) {
ngraph::Shape input_shape { 1, 100, 120, 150 };
int64_t axis = 2;
size_t num_splits = input_shape[axis];
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
int64_t target_size = static_cast<int64_t>(input_shape[axis]) * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>();
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{target_size});
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, sizes_node, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial3D1WithConstantFolding) {
ngraph::Shape input_shape { 1, 3, 100, 120, 150 };
int64_t axis = 4;
size_t num_splits = input_shape[axis];
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
int64_t target_size = static_cast<int64_t>(input_shape[axis]) * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>();
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{target_size});
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, sizes_node, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial3D2WithConstantFolding) {
ngraph::Shape input_shape { 1, 3, 100, 120, 150 };
int64_t axis = 3;
size_t num_splits = input_shape[axis];
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
int64_t target_size = static_cast<int64_t>(input_shape[axis]) * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>();
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{target_size});
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, sizes_node, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial2D1Dynamic) {
ngraph::PartialShape input_shape { ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic(), 150 };
int64_t axis = 3;
size_t num_splits = input_shape[axis].get_length();
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>(false);
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
auto sslice_begin = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<ngraph::opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(strided_slice_node, ngraph::element::f32);
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<ngraph::opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial2D2Dynamic) {
ngraph::PartialShape input_shape { ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic(), 120, ngraph::Dimension::dynamic() };
int64_t axis = 2;
size_t num_splits = input_shape[axis].get_length();
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>(false);
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
auto sslice_begin = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<ngraph::opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(strided_slice_node, ngraph::element::f32);
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<ngraph::opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial3D1Dynamic) {
ngraph::PartialShape input_shape { ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic(),
ngraph::Dimension::dynamic(), 150 };
int64_t axis = 4;
size_t num_splits = input_shape[axis].get_length();
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>(false);
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
auto sslice_begin = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<ngraph::opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(strided_slice_node, ngraph::element::f32);
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<ngraph::opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SplitConcatPairToInterpolateFusionSpatial3D2Dynamic) {
ngraph::PartialShape input_shape { ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic(),
120, ngraph::Dimension::dynamic() };
int64_t axis = 3;
size_t num_splits = input_shape[axis].get_length();
size_t scale_factor = 2;
size_t num_of_concat_inputs = num_splits * scale_factor;
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto split_axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, { axis });
auto split = std::make_shared<ngraph::opset8::Split>(input, split_axis, num_splits);
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
auto concat = std::make_shared<ngraph::opset8::Concat>(concat_inputs_vec, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SplitConcatPairToInterpolateFusion>(false);
}
{
ngraph::opset8::Interpolate::InterpolateAttrs attrs;
attrs.mode = ngraph::opset8::Interpolate::InterpolateMode::NEAREST;
attrs.shape_calculation_mode = ngraph::opset8::Interpolate::ShapeCalcMode::SCALES;
attrs.nearest_mode = ngraph::opset8::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
attrs.pads_begin = std::vector<size_t>{0};
attrs.pads_end = std::vector<size_t>{0};
attrs.antialias = false;
attrs.coordinate_transformation_mode = ngraph::opset8::Interpolate::CoordinateTransformMode::HALF_PIXEL;
attrs.cube_coeff = -0.75f;
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, {1}, std::vector<float>{static_cast<float>(scale_factor)});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
auto sslice_begin = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis});
auto sslice_end = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{axis + 1});
std::vector<int64_t> begin_mask = {0};
std::vector<int64_t> end_mask = {0};
auto strided_slice_node = std::make_shared<ngraph::opset8::StridedSlice>(shape_node, sslice_begin, sslice_end, begin_mask, end_mask);
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(strided_slice_node, ngraph::element::f32);
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
auto floor_node = std::make_shared<ngraph::opset8::Floor>(mul_node);
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axis_node, attrs);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
}
}

View File

@ -0,0 +1,124 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
using namespace ngraph;
using FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc;
namespace CPUSubgraphTestsDefinitions {
typedef std::tuple<
Shape, // Input shape
element::Type, // Input precision
int, // Axis
size_t, // num_splits
size_t, // scale_factor
std::string // Device name
> FuseSplitConcatPairToInterpolateTuple;
class FuseSplitConcatPairToInterpolateTest : public testing::WithParamInterface<FuseSplitConcatPairToInterpolateTuple>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<FuseSplitConcatPairToInterpolateTuple> &obj) {
Shape inputShape;
element::Type inputPrecision;
int axis;
size_t num_splits;
size_t scale_factor;
std::string targetName;
std::tie(inputShape, inputPrecision, axis, num_splits, scale_factor, targetName) = obj.param;
std::ostringstream results;
results << "IS=" << inputShape
<< "_InPRC=" << inputPrecision
<< "_Axis=" << axis
<< "_Num_splits=" << num_splits
<< "_Scale_factor=" << scale_factor;
results << "_targetDevice=" << targetName;
return results.str();
}
protected:
void SetUp() override {
Shape inputShape;
element::Type inputPrecision;
int axis;
size_t num_splits;
size_t scale_factor;
std::tie(inputShape, inputPrecision, axis, num_splits, scale_factor, targetDevice) = this->GetParam();
size_t num_of_concat_inputs = num_splits * scale_factor;
const auto param = std::make_shared<opset6::Parameter>(inputPrecision, inputShape);
const auto split = builder::makeSplit(param, inputPrecision, num_splits, static_cast<int64_t>(axis));
ngraph::OutputVector concat_inputs_vec(num_of_concat_inputs);
for (size_t split_output_port = 0; split_output_port < num_splits; ++split_output_port) {
for (size_t j = 0; j < scale_factor; ++j) {
concat_inputs_vec[split_output_port * scale_factor + j] = split->output(split_output_port);
}
}
const auto concat = builder::makeConcat(concat_inputs_vec, axis);
ngraph::ResultVector results{std::make_shared<ngraph::opset6::Result>(concat)};
function = std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{param}, "FuseSplitConcatPairToInterpolate");
}
};
TEST_P(FuseSplitConcatPairToInterpolateTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
Run();
}
namespace {
std::vector<Shape> inputShapes4D {
{1, 2, 6, 6}
};
std::vector<size_t> num_of_outputs_of_split {
2, 3, 6
};
std::vector<size_t> scale_factors {
2, 3, 4
};
std::vector<int> axes4D {
2, 3
};
std::vector<Shape> inputShapes5D {
{1, 3, 10, 6, 6}
};
std::vector<int> axes5D {
3, 4
};
INSTANTIATE_TEST_SUITE_P(smoke_FuseSplitConcatPairToInterpolate4D, FuseSplitConcatPairToInterpolateTest,
::testing::Combine(
::testing::ValuesIn(inputShapes4D),
::testing::Values(element::f32),
::testing::ValuesIn(axes4D),
::testing::ValuesIn(num_of_outputs_of_split),
::testing::ValuesIn(scale_factors),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
FuseSplitConcatPairToInterpolateTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_FuseSplitConcatPairToInterpolate5D, FuseSplitConcatPairToInterpolateTest,
::testing::Combine(
::testing::ValuesIn(inputShapes5D),
::testing::Values(element::f32),
::testing::ValuesIn(axes5D),
::testing::ValuesIn(num_of_outputs_of_split),
::testing::ValuesIn(scale_factors),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
FuseSplitConcatPairToInterpolateTest::getTestCaseName);
} // namespace
} // namespace CPUSubgraphTestsDefinitions