nGraph version of the MO transformation InterpolateSequenceToInterpolate (#8397)
* Started to write transformation to fuse Interpolate sequence. * Some changes. * Written the transformation to fuse two Interpolate layers. * Deleted commented code. * Small fixes. * Some fixes. * Started to write tests. * Small fix. * Added more tests. Deleted commented code. * Deleted redundant headers. * Small fix. * Fixes in the function can_be_fused(): the last statement was decomposed. * Added operators == and != for op::v4::Interpolate::InterpolateAttrs. * Added more checks for nullptr. * Fixed codestyle. * Added Interpolate registration. * Small change. * Implementation of operator== of InterpolateAttrs was moved into header file.
This commit is contained in:
committed by
GitHub
parent
e1abe32053
commit
edb98aeb8e
@@ -0,0 +1,33 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API InterpolateSequenceFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief InterpolateSequenceFusion transformation replaces a sequence of
|
||||
* operations to Interpolate op.
|
||||
*/
|
||||
class ngraph::pass::InterpolateSequenceFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
InterpolateSequenceFusion();
|
||||
};
|
||||
@@ -49,6 +49,7 @@
|
||||
#include "transformations/common_optimizations/strides_optimization.hpp"
|
||||
#include "transformations/common_optimizations/convert_nms_gather_path_to_unsigned.hpp"
|
||||
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
|
||||
#include "transformations/common_optimizations/interpolate_sequence_fusion.hpp"
|
||||
#include "transformations/common_optimizations/convert_compression_only_to_legacy.hpp"
|
||||
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
||||
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
|
||||
@@ -114,6 +115,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
common_fusions->add_matcher<ngraph::pass::ShuffleChannelsFusion>(false);
|
||||
common_fusions->add_matcher<ngraph::pass::SpaceToBatchFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::BatchToSpaceFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::InterpolateSequenceFusion>();
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/interpolate_sequence_fusion.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
// #include <ngraph/op/interpolate.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
namespace {
|
||||
using namespace ngraph;
|
||||
|
||||
bool compatible_axes(const std::vector<int64_t>& fst_axes_vector, const std::vector<int64_t>& snd_axes_vector) {
|
||||
std::set<int64_t> fst_axes_set(fst_axes_vector.begin(), fst_axes_vector.end());
|
||||
for (const auto& a : snd_axes_vector) {
|
||||
if (fst_axes_set.count(a) != 0) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool shape_calculation_mode_can_use_constant_inputs(const std::shared_ptr<opset8::Interpolate>& interpolate) {
|
||||
const auto& attrs = interpolate->get_attrs();
|
||||
if (attrs.shape_calculation_mode == ngraph::opset8::Interpolate::ShapeCalcMode::SIZES) {
|
||||
return std::dynamic_pointer_cast<opset8::Constant>(interpolate->input_value(1).get_node_shared_ptr()) != nullptr;
|
||||
}
|
||||
return std::dynamic_pointer_cast<opset8::Constant>(interpolate->input_value(2).get_node_shared_ptr()) != nullptr;
|
||||
}
|
||||
|
||||
bool is_candidate_for_fusion(const std::shared_ptr<opset8::Interpolate>& interpolate) {
|
||||
return (interpolate->get_input_partial_shape(0).rank().is_static()) &&
|
||||
(interpolate->inputs().size() != 4 || std::dynamic_pointer_cast<opset8::Constant>(interpolate->input_value(3).get_node_shared_ptr())) &&
|
||||
shape_calculation_mode_can_use_constant_inputs(interpolate);
|
||||
}
|
||||
|
||||
std::vector<int64_t> get_interpolated_axes(const std::shared_ptr<opset8::Interpolate>& interpolate) {
|
||||
if (interpolate->inputs().size() != 4) {
|
||||
const auto input_rank = interpolate->get_input_partial_shape(0).rank().get_length();
|
||||
|
||||
std::vector<int64_t> default_value(input_rank);
|
||||
std::iota(default_value.begin(), default_value.end(), 0);
|
||||
|
||||
return default_value;
|
||||
}
|
||||
return std::dynamic_pointer_cast<opset8::Constant>(interpolate->input_value(3).get_node_shared_ptr())->cast_vector<int64_t>();
|
||||
}
|
||||
|
||||
bool can_be_fused(const std::shared_ptr<opset8::Interpolate>& fst, const std::shared_ptr<opset8::Interpolate>& snd) {
|
||||
// The first Interpolate (fst) must have only one consumer.
|
||||
for (const auto& output : fst->outputs()) {
|
||||
for (const auto& consumer : output.get_target_inputs()) {
|
||||
if (consumer.get_node() != snd.get()) return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (fst->get_attrs() != snd->get_attrs() || !is_candidate_for_fusion(fst) || !is_candidate_for_fusion(snd)) return false;
|
||||
|
||||
const auto fst_axes = get_interpolated_axes(fst);
|
||||
const auto snd_axes = get_interpolated_axes(snd);
|
||||
return compatible_axes(fst_axes, snd_axes);
|
||||
}
|
||||
|
||||
ngraph::NodeVector subgraph_for_sizes_calculation_mode(const std::shared_ptr<opset8::Interpolate>& fst, const std::shared_ptr<opset8::Interpolate>& snd,
|
||||
pass::MatcherPass* matcherPass) {
|
||||
const auto fst_axes = get_interpolated_axes(fst);
|
||||
const auto snd_axes = get_interpolated_axes(snd);
|
||||
const auto fst_sizes_node = std::dynamic_pointer_cast<opset8::Constant>(fst->input_value(1).get_node_shared_ptr());
|
||||
const auto snd_sizes_node = std::dynamic_pointer_cast<opset8::Constant>(snd->input_value(1).get_node_shared_ptr());
|
||||
if (!fst_sizes_node || !snd_sizes_node) return {};
|
||||
|
||||
const auto fst_sizes = fst_sizes_node->cast_vector<int64_t>();
|
||||
const auto snd_sizes = snd_sizes_node->cast_vector<int64_t>();
|
||||
std::vector<std::pair<int64_t, int64_t>> axes_and_sizes;
|
||||
for (size_t i = 0; i < fst_axes.size(); ++i) {
|
||||
axes_and_sizes.emplace_back(std::make_pair(fst_axes[i], fst_sizes[i]));
|
||||
}
|
||||
for (size_t i = 0; i < snd_axes.size(); ++i) {
|
||||
axes_and_sizes.emplace_back(std::make_pair(snd_axes[i], snd_sizes[i]));
|
||||
}
|
||||
std::sort(axes_and_sizes.begin(),
|
||||
axes_and_sizes.end(),
|
||||
[](const std::pair<int64_t, int64_t>& a, const std::pair<int64_t, int64_t>& b) {
|
||||
return a.first < b.first;
|
||||
});
|
||||
std::vector<int64_t> new_axes;
|
||||
std::vector<int64_t> new_sizes;
|
||||
for (const auto& as : axes_and_sizes) {
|
||||
new_axes.emplace_back(as.first);
|
||||
new_sizes.emplace_back(as.second);
|
||||
}
|
||||
|
||||
auto new_sizes_node = opset8::Constant::create(element::i64, {new_sizes.size()}, new_sizes);
|
||||
auto new_axes_node = opset8::Constant::create(element::i64, {new_axes.size()}, new_axes);
|
||||
auto new_sizes_cast = std::make_shared<opset8::Convert>(new_sizes_node, element::f32);
|
||||
auto shape_node = std::make_shared<opset8::ShapeOf>(fst->input_value(0));
|
||||
|
||||
auto gather_axis_node = opset8::Constant::create(element::i64, {1}, std::vector<int64_t>{0});
|
||||
auto gather_node = std::make_shared<opset8::Gather>(shape_node, new_axes_node, gather_axis_node);
|
||||
auto cast_shape_to_float = std::make_shared<opset8::Convert>(gather_node, element::f32);
|
||||
|
||||
auto div_node = std::make_shared<opset8::Divide>(new_sizes_cast, cast_shape_to_float);
|
||||
|
||||
const auto new_interpolate = ov::as_type_ptr<opset8::Interpolate>(fst->clone_with_new_inputs({fst->input_value(0), new_sizes_node, div_node,
|
||||
new_axes_node}));
|
||||
matcherPass->register_new_node(new_interpolate);
|
||||
|
||||
return {new_sizes_node, new_axes_node, new_sizes_cast, shape_node, gather_axis_node, gather_node, cast_shape_to_float, div_node, new_interpolate};
|
||||
}
|
||||
|
||||
ngraph::NodeVector subgraph_for_scales_calculation_mode(const std::shared_ptr<opset8::Interpolate>& fst, const std::shared_ptr<opset8::Interpolate>& snd,
|
||||
pass::MatcherPass* matcherPass) {
|
||||
const auto fst_axes = get_interpolated_axes(fst);
|
||||
const auto snd_axes = get_interpolated_axes(snd);
|
||||
const auto fst_scales_node = std::dynamic_pointer_cast<opset8::Constant>(fst->input_value(2).get_node_shared_ptr());
|
||||
const auto snd_scales_node = std::dynamic_pointer_cast<opset8::Constant>(snd->input_value(2).get_node_shared_ptr());
|
||||
if (!fst_scales_node || !snd_scales_node) return {};
|
||||
|
||||
const auto fst_scales = fst_scales_node->cast_vector<float>();
|
||||
const auto snd_scales = snd_scales_node->cast_vector<float>();
|
||||
std::vector<std::pair<int64_t, float>> axes_and_scales;
|
||||
for (size_t i = 0; i < fst_axes.size(); ++i) {
|
||||
axes_and_scales.emplace_back(std::make_pair(fst_axes[i], fst_scales[i]));
|
||||
}
|
||||
for (size_t i = 0; i < snd_axes.size(); ++i) {
|
||||
axes_and_scales.emplace_back(std::make_pair(snd_axes[i], snd_scales[i]));
|
||||
}
|
||||
std::sort(axes_and_scales.begin(),
|
||||
axes_and_scales.end(),
|
||||
[](const std::pair<int64_t, float>& a, const std::pair<int64_t, float>& b) {
|
||||
return a.first < b.first;
|
||||
});
|
||||
std::vector<int64_t> new_axes;
|
||||
std::vector<float> new_scales;
|
||||
for (const auto& as : axes_and_scales) {
|
||||
new_axes.emplace_back(as.first);
|
||||
new_scales.emplace_back(as.second);
|
||||
}
|
||||
|
||||
auto new_scales_node = opset8::Constant::create(element::f32, {new_scales.size()}, new_scales);
|
||||
auto new_axes_node = opset8::Constant::create(element::i64, {new_axes.size()}, new_axes);
|
||||
auto shape_node = std::make_shared<opset8::ShapeOf>(fst->input_value(0));
|
||||
|
||||
auto gather_axis_node = opset8::Constant::create(element::i64, {1}, std::vector<int64_t>{0});
|
||||
auto gather_node = std::make_shared<opset8::Gather>(shape_node, new_axes_node, gather_axis_node);
|
||||
auto cast_shape_to_float = std::make_shared<opset8::Convert>(gather_node, element::f32);
|
||||
|
||||
auto mul_node = std::make_shared<opset8::Multiply>(cast_shape_to_float, new_scales_node);
|
||||
auto eps_node = opset8::Constant::create(element::f32, {}, std::vector<float>{1.0e-5f});
|
||||
auto add_node = std::make_shared<opset8::Multiply>(mul_node, eps_node);
|
||||
auto floor_node = std::make_shared<opset8::Floor>(add_node);
|
||||
auto cast_mul_result_to_int = std::make_shared<opset8::Convert>(floor_node, element::i64);
|
||||
|
||||
const auto new_interpolate = ov::as_type_ptr<opset8::Interpolate>(fst->clone_with_new_inputs({fst->input_value(0), cast_mul_result_to_int,
|
||||
new_scales_node, new_axes_node}));
|
||||
matcherPass->register_new_node(new_interpolate);
|
||||
|
||||
return {new_scales_node, new_axes_node, shape_node, gather_axis_node, gather_node, cast_shape_to_float, mul_node, eps_node,
|
||||
add_node, floor_node, cast_mul_result_to_int, new_interpolate};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::InterpolateSequenceFusion, "InterpolateSequenceFusion", 0);
|
||||
|
||||
ngraph::pass::InterpolateSequenceFusion::InterpolateSequenceFusion() {
|
||||
MATCHER_SCOPE(InterpolateSequenceFusion);
|
||||
auto interpolate_pattern = ngraph::pattern::wrap_type<ngraph::opset8::Interpolate>();
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto snd_interpolate = std::dynamic_pointer_cast<opset8::Interpolate>(m.get_match_root());
|
||||
if (!snd_interpolate) return false;
|
||||
|
||||
auto fst_interpolate = std::dynamic_pointer_cast<opset8::Interpolate>(snd_interpolate->input_value(0).get_node_shared_ptr());
|
||||
if (!fst_interpolate) return false;
|
||||
|
||||
if (!can_be_fused(fst_interpolate, snd_interpolate)) return false;
|
||||
|
||||
NodeVector new_subgraph;
|
||||
if (fst_interpolate->get_attrs().shape_calculation_mode == ngraph::opset8::Interpolate::ShapeCalcMode::SIZES) {
|
||||
new_subgraph = subgraph_for_sizes_calculation_mode(fst_interpolate, snd_interpolate, this);
|
||||
} else {
|
||||
new_subgraph = subgraph_for_scales_calculation_mode(fst_interpolate, snd_interpolate, this);
|
||||
}
|
||||
if (new_subgraph.empty()) return false;
|
||||
|
||||
auto& new_interpolate = new_subgraph.back();
|
||||
new_interpolate->set_friendly_name(snd_interpolate->get_friendly_name());
|
||||
copy_runtime_info({fst_interpolate, snd_interpolate}, new_subgraph);
|
||||
replace_node(snd_interpolate, new_interpolate);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(interpolate_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
@@ -0,0 +1,330 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/common_optimizations/interpolate_sequence_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
using Attrs = ngraph::opset8::Interpolate::InterpolateAttrs;
|
||||
using ShapeCalcMode = ngraph::opset8::Interpolate::ShapeCalcMode;
|
||||
using InterpolateMode = ngraph::opset8::Interpolate::InterpolateMode;
|
||||
using CoordinateTransformMode = ngraph::opset8::Interpolate::CoordinateTransformMode;
|
||||
using NearestMode = ngraph::opset8::Interpolate::NearestMode;
|
||||
|
||||
TEST_F(TransformationTestsF, InterpolateSequenceFusion4D1) {
|
||||
ngraph::Shape input_shape { 1, 4, 220, 350 };
|
||||
std::vector<Attrs> attributes = {
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f},
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> sizes_vector = {
|
||||
{660}, {700}
|
||||
};
|
||||
std::vector<std::vector<float>> scales_vector = {
|
||||
{3.0f}, {2.0f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> axes_vector = {
|
||||
{2}, {3}
|
||||
};
|
||||
Attrs ref_attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f};
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto fst_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[0].size()}, sizes_vector[0]);
|
||||
auto fst_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[0].size()}, scales_vector[0]);
|
||||
auto fst_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[0].size()}, axes_vector[0]);
|
||||
auto fst_interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, fst_sizes_node, fst_scales_node, fst_axis_node, attributes[0]);
|
||||
|
||||
auto snd_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[1].size()}, sizes_vector[1]);
|
||||
auto snd_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[1].size()}, scales_vector[1]);
|
||||
auto snd_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[1].size()}, axes_vector[1]);
|
||||
auto snd_interpolate = std::make_shared<ngraph::opset8::Interpolate>(fst_interpolate, snd_sizes_node, snd_scales_node, snd_axis_node, attributes[1]);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ snd_interpolate }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::InterpolateSequenceFusion>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{2}, std::vector<float>{3.0f, 2.0f});
|
||||
auto axes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector<int64_t>{2, 3});
|
||||
|
||||
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
|
||||
auto gather_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{0});
|
||||
auto gather_node = std::make_shared<ngraph::opset8::Gather>(shape_node, axes_node, gather_axis_node);
|
||||
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(gather_node, ngraph::element::f32);
|
||||
|
||||
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
|
||||
auto eps_node = ngraph::opset8::Constant::create(ngraph::element::f32, {}, std::vector<float>{1.0e-5f});
|
||||
auto add_node = std::make_shared<ngraph::opset8::Multiply>(mul_node, eps_node);
|
||||
auto floor_node = std::make_shared<ngraph::opset8::Floor>(add_node);
|
||||
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
|
||||
|
||||
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axes_node, ref_attrs);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, InterpolateSequenceFusion4D2) {
|
||||
ngraph::Shape input_shape { 1, 4, 220, 350 };
|
||||
std::vector<Attrs> attributes = {
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f},
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f},
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> sizes_vector = {
|
||||
{660}, {700}, {1320}
|
||||
};
|
||||
std::vector<std::vector<float>> scales_vector = {
|
||||
{3.0f}, {2.0f}, {2.0f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> axes_vector = {
|
||||
{2}, {3}, {2}
|
||||
};
|
||||
Attrs ref_attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f};
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto fst_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[0].size()}, sizes_vector[0]);
|
||||
auto fst_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[0].size()}, scales_vector[0]);
|
||||
auto fst_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[0].size()}, axes_vector[0]);
|
||||
auto fst_interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, fst_sizes_node, fst_scales_node, fst_axis_node, attributes[0]);
|
||||
|
||||
auto snd_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[1].size()}, sizes_vector[1]);
|
||||
auto snd_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[1].size()}, scales_vector[1]);
|
||||
auto snd_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[1].size()}, axes_vector[1]);
|
||||
auto snd_interpolate = std::make_shared<ngraph::opset8::Interpolate>(fst_interpolate, snd_sizes_node, snd_scales_node, snd_axis_node, attributes[1]);
|
||||
|
||||
auto third_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[2].size()}, sizes_vector[2]);
|
||||
auto third_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[2].size()}, scales_vector[2]);
|
||||
auto third_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[2].size()}, axes_vector[2]);
|
||||
auto third_interpolate = std::make_shared<ngraph::opset8::Interpolate>(snd_interpolate, third_sizes_node, third_scales_node, third_axis_node,
|
||||
attributes[2]);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ third_interpolate }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::InterpolateSequenceFusion>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{2}, std::vector<float>{3.0f, 2.0f});
|
||||
auto axes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector<int64_t>{2, 3});
|
||||
|
||||
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
|
||||
auto gather_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{0});
|
||||
auto gather_node = std::make_shared<ngraph::opset8::Gather>(shape_node, axes_node, gather_axis_node);
|
||||
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(gather_node, ngraph::element::f32);
|
||||
|
||||
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
|
||||
auto eps_node = ngraph::opset8::Constant::create(ngraph::element::f32, {}, std::vector<float>{1.0e-5f});
|
||||
auto add_node = std::make_shared<ngraph::opset8::Multiply>(mul_node, eps_node);
|
||||
auto floor_node = std::make_shared<ngraph::opset8::Floor>(add_node);
|
||||
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
|
||||
|
||||
auto fst_interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axes_node, ref_attrs);
|
||||
|
||||
auto snd_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[2].size()}, sizes_vector[2]);
|
||||
auto snd_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[2].size()}, scales_vector[2]);
|
||||
auto snd_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[2].size()}, axes_vector[2]);
|
||||
|
||||
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(fst_interpolate, snd_sizes_node, snd_scales_node, snd_axis_node, ref_attrs);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, InterpolateSequenceFusion4D3) {
|
||||
ngraph::Shape input_shape { 1, 4, 220, 350 };
|
||||
std::vector<Attrs> attributes = {
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SIZES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f},
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SIZES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> sizes_vector = {
|
||||
{700}, {660}
|
||||
};
|
||||
std::vector<std::vector<float>> scales_vector = {
|
||||
{2.0f}, {3.0f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> axes_vector = {
|
||||
{3}, {2}
|
||||
};
|
||||
Attrs ref_attrs{InterpolateMode::NEAREST, ShapeCalcMode::SIZES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f};
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto fst_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[0].size()}, sizes_vector[0]);
|
||||
auto fst_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[0].size()}, scales_vector[0]);
|
||||
auto fst_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[0].size()}, axes_vector[0]);
|
||||
auto fst_interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, fst_sizes_node, fst_scales_node, fst_axis_node, attributes[0]);
|
||||
|
||||
auto snd_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[1].size()}, sizes_vector[1]);
|
||||
auto snd_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[1].size()}, scales_vector[1]);
|
||||
auto snd_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[1].size()}, axes_vector[1]);
|
||||
auto snd_interpolate = std::make_shared<ngraph::opset8::Interpolate>(fst_interpolate, snd_sizes_node, snd_scales_node, snd_axis_node, attributes[1]);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ snd_interpolate }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::InterpolateSequenceFusion>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, {2}, std::vector<int64_t>{660, 700});
|
||||
auto axes_node = ngraph::opset8::Constant::create(ngraph::element::i64, {2}, std::vector<int64_t>{2, 3});
|
||||
auto sizes_cast = std::make_shared<ngraph::opset8::Convert>(sizes_node, ngraph::element::f32);
|
||||
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
|
||||
|
||||
auto gather_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{0});
|
||||
auto gather_node = std::make_shared<ngraph::opset8::Gather>(shape_node, axes_node, gather_axis_node);
|
||||
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(gather_node, ngraph::element::f32);
|
||||
auto div_node = std::make_shared<ngraph::opset8::Divide>(sizes_cast, cast_shape_to_float);
|
||||
|
||||
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, sizes_node, div_node, axes_node, ref_attrs);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, InterpolateSequenceFusion5D1) {
|
||||
ngraph::Shape input_shape { 1, 5, 417, 256, 800 };
|
||||
std::vector<Attrs> attributes = {
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f},
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f},
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> sizes_vector = {
|
||||
{600}, {100}, {834}
|
||||
};
|
||||
std::vector<std::vector<float>> scales_vector = {
|
||||
{0.75f}, {20.0f}, {2.0f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> axes_vector = {
|
||||
{4}, {1}, {2}
|
||||
};
|
||||
Attrs ref_attrs{InterpolateMode::NEAREST, ShapeCalcMode::SCALES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f};
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto fst_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[0].size()}, sizes_vector[0]);
|
||||
auto fst_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[0].size()}, scales_vector[0]);
|
||||
auto fst_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[0].size()}, axes_vector[0]);
|
||||
auto fst_interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, fst_sizes_node, fst_scales_node, fst_axis_node, attributes[0]);
|
||||
|
||||
auto snd_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[1].size()}, sizes_vector[1]);
|
||||
auto snd_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[1].size()}, scales_vector[1]);
|
||||
auto snd_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[1].size()}, axes_vector[1]);
|
||||
auto snd_interpolate = std::make_shared<ngraph::opset8::Interpolate>(fst_interpolate, snd_sizes_node, snd_scales_node, snd_axis_node, attributes[1]);
|
||||
|
||||
auto third_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[2].size()}, sizes_vector[2]);
|
||||
auto third_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[2].size()}, scales_vector[2]);
|
||||
auto third_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[2].size()}, axes_vector[2]);
|
||||
auto third_interpolate = std::make_shared<ngraph::opset8::Interpolate>(snd_interpolate, third_sizes_node, third_scales_node, third_axis_node,
|
||||
attributes[2]);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ third_interpolate }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::InterpolateSequenceFusion>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{3}, std::vector<float>{20.0f, 2.0f, 0.75f});
|
||||
auto axes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{1, 2, 4});
|
||||
|
||||
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
|
||||
auto gather_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{0});
|
||||
auto gather_node = std::make_shared<ngraph::opset8::Gather>(shape_node, axes_node, gather_axis_node);
|
||||
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(gather_node, ngraph::element::f32);
|
||||
|
||||
auto mul_node = std::make_shared<ngraph::opset8::Multiply>(cast_shape_to_float, scales_node);
|
||||
auto eps_node = ngraph::opset8::Constant::create(ngraph::element::f32, {}, std::vector<float>{1.0e-5f});
|
||||
auto add_node = std::make_shared<ngraph::opset8::Multiply>(mul_node, eps_node);
|
||||
auto floor_node = std::make_shared<ngraph::opset8::Floor>(add_node);
|
||||
auto cast_mul_result_to_int = std::make_shared<ngraph::opset8::Convert>(floor_node, ngraph::element::i64);
|
||||
|
||||
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, cast_mul_result_to_int, scales_node, axes_node, ref_attrs);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, InterpolateSequenceFusion5D2) {
|
||||
ngraph::Shape input_shape { 1, 5, 417, 256, 800 };
|
||||
std::vector<Attrs> attributes = {
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SIZES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f},
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SIZES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f},
|
||||
Attrs{InterpolateMode::NEAREST, ShapeCalcMode::SIZES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> sizes_vector = {
|
||||
{600}, {100}, {834}
|
||||
};
|
||||
std::vector<std::vector<float>> scales_vector = {
|
||||
{0.75f}, {20.0f}, {2.0f}
|
||||
};
|
||||
std::vector<std::vector<int64_t>> axes_vector = {
|
||||
{4}, {1}, {2}
|
||||
};
|
||||
Attrs ref_attrs{InterpolateMode::NEAREST, ShapeCalcMode::SIZES, std::vector<size_t>{0}, std::vector<size_t>{0}, CoordinateTransformMode::HALF_PIXEL,
|
||||
NearestMode::ROUND_PREFER_FLOOR, false, -0.75f};
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto fst_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[0].size()}, sizes_vector[0]);
|
||||
auto fst_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[0].size()}, scales_vector[0]);
|
||||
auto fst_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[0].size()}, axes_vector[0]);
|
||||
auto fst_interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, fst_sizes_node, fst_scales_node, fst_axis_node, attributes[0]);
|
||||
|
||||
auto snd_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[1].size()}, sizes_vector[1]);
|
||||
auto snd_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[1].size()}, scales_vector[1]);
|
||||
auto snd_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[1].size()}, axes_vector[1]);
|
||||
auto snd_interpolate = std::make_shared<ngraph::opset8::Interpolate>(fst_interpolate, snd_sizes_node, snd_scales_node, snd_axis_node, attributes[1]);
|
||||
|
||||
auto third_sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{sizes_vector[2].size()}, sizes_vector[2]);
|
||||
auto third_scales_node = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{scales_vector[2].size()}, scales_vector[2]);
|
||||
auto third_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{axes_vector[2].size()}, axes_vector[2]);
|
||||
auto third_interpolate = std::make_shared<ngraph::opset8::Interpolate>(snd_interpolate, third_sizes_node, third_scales_node, third_axis_node,
|
||||
attributes[2]);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ third_interpolate }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::InterpolateSequenceFusion>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto sizes_node = ngraph::opset8::Constant::create(ngraph::element::i64, {3}, std::vector<int64_t>{100, 834, 600});
|
||||
auto axes_node = ngraph::opset8::Constant::create(ngraph::element::i64, {3}, std::vector<int64_t>{1, 2, 4});
|
||||
auto sizes_cast = std::make_shared<ngraph::opset8::Convert>(sizes_node, ngraph::element::f32);
|
||||
auto shape_node = std::make_shared<ngraph::opset8::ShapeOf>(input);
|
||||
|
||||
auto gather_axis_node = ngraph::opset8::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{0});
|
||||
auto gather_node = std::make_shared<ngraph::opset8::Gather>(shape_node, axes_node, gather_axis_node);
|
||||
auto cast_shape_to_float = std::make_shared<ngraph::opset8::Convert>(gather_node, ngraph::element::f32);
|
||||
auto div_node = std::make_shared<ngraph::opset8::Divide>(sizes_cast, cast_shape_to_float);
|
||||
|
||||
auto interpolate = std::make_shared<ngraph::opset8::Interpolate>(input, sizes_node, div_node, axes_node, ref_attrs);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ interpolate }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user