Slice-8 to StridedSlice transformation (#8295)

* SliceToStridedSlice transformation

* Slice SLT

* ONNX import

* Disable throw

* Add Slice evaluate, re-enable mkldnn graph throw

* Use ScatterUpdate instead of Gather to adjust indices

* Add CmpValues::CONST_VALUES to Slice transformation tests

* Apply smaller review comments

* Adjust indices lenght type

* Use ov namespace

* Refactor indices alignment function

* Move SliceToStridedSlice transformation to separate file

* Style alignment

* Resolve xfails

* Update tests and remove redundant const folding

* Remove evaluate and onnx changes

* Add use_shapes switch to the Slice transformation

* Style fix
This commit is contained in:
Katarzyna Mitrus 2021-11-16 20:40:34 +01:00 committed by GitHub
parent 3d68ba6480
commit c171be238f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1123 additions and 20 deletions

View File

@ -68,6 +68,11 @@ public:
*/ */
class ngraph::pass::StridedSliceOptimization: public ngraph::pass::FunctionPass { class ngraph::pass::StridedSliceOptimization: public ngraph::pass::FunctionPass {
public: public:
StridedSliceOptimization(bool use_shapes = true);
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override; bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
private:
bool m_use_shapes = true;
}; };

View File

@ -0,0 +1,27 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SliceToStridedSlice;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SliceToStridedSlice transformation convert v8::Slice to v1::StridedSlice
*/
class ngraph::pass::SliceToStridedSlice: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SliceToStridedSlice(bool use_shapes);
};

View File

@ -91,9 +91,7 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
// workaround until dynamism in NMS is not supported // workaround until dynamism in NMS is not supported
manager.register_pass<ngraph::pass::ConvertNmsGatherPathToUnsigned>(); manager.register_pass<ngraph::pass::ConvertNmsGatherPathToUnsigned>();
if (m_use_shapes) { manager.register_pass<ngraph::pass::StridedSliceOptimization>(m_use_shapes);
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>(); manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();

View File

@ -6,11 +6,15 @@
#include <vector> #include <vector>
#include "itt.hpp" #include "itt.hpp"
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include <transformations/common_optimizations/optimize_strided_slice.hpp> #include <transformations/common_optimizations/optimize_strided_slice.hpp>
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp> #include <ngraph/opsets/opset3.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::StridedSliceOptimization, "StridedSliceOptimization", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::StridedSliceOptimization, "StridedSliceOptimization", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::UselessStridedSliceEraser, "UselessStridedSliceEraser", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::UselessStridedSliceEraser, "UselessStridedSliceEraser", 0);
@ -245,10 +249,21 @@ bool ngraph::pass::GroupedStridedSliceOptimizer::run_on_function(std::shared_ptr
return graph_rewritten; return graph_rewritten;
} }
ngraph::pass::StridedSliceOptimization::StridedSliceOptimization(bool use_shapes) {
m_use_shapes = use_shapes;
}
bool ngraph::pass::StridedSliceOptimization::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::StridedSliceOptimization::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(StridedSliceOptimization); RUN_ON_FUNCTION_SCOPE(StridedSliceOptimization);
bool rewritten = UselessStridedSliceEraser().run_on_function(f); ngraph::pass::Manager manager(get_pass_config());
rewritten |= SharedStridedSliceEraser().run_on_function(f); manager.register_pass<ngraph::pass::SliceToStridedSlice>(m_use_shapes);
rewritten |= GroupedStridedSliceOptimizer().run_on_function(f); manager.run_passes(f);
bool rewritten = false;
if (m_use_shapes) {
rewritten = UselessStridedSliceEraser().run_on_function(f);
rewritten |= SharedStridedSliceEraser().run_on_function(f);
rewritten |= GroupedStridedSliceOptimizer().run_on_function(f);
}
return rewritten; return rewritten;
} }

View File

@ -0,0 +1,149 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <memory>
#include <vector>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include "transformations/utils/utils.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/validation_util.hpp"
#include "itt.hpp"
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::SliceToStridedSlice, "SliceToStridedSlice", 0);
namespace {
Output<ngraph::Node> align_indices(const Output<ngraph::Node>& indices,
const Output<ngraph::Node>& slice_axes,
const Output<ngraph::Node>& scatter_axis,
size_t slice_indices_length,
int64_t fill_in_value,
NodeVector& new_ops) {
// Handle a case when starts/ends/steps lengths are less than provided axes
// in order to ensure compatibility with `StridedSlice:v1` interface
// Example:
// data_shape: {3, 3, 3, 3}
// starts: [1, 1] - after extending --> [0, 0, 1, 1]
// ends: [2, 2] - after extending --> [0, 0, 2, 2]
// steps : [1, 1] - after extending --> [1, 1, 1, 1]
// axes: [2, 3] - apply slice values to 2 and 3 dimension of input data
// expected_output_shape: {3, 3, 1, 1}
const auto default_indices = ngraph::opset8::Constant::create(indices.get_element_type(), Shape{slice_indices_length}, {fill_in_value});
std::shared_ptr<ngraph::Node> adjusted_indices = ngraph::op::util::make_try_fold<ngraph::opset8::ScatterUpdate>(
default_indices,
slice_axes,
indices, // updates
scatter_axis);
if (!ngraph::op::is_constant(adjusted_indices)) {
new_ops.push_back(default_indices);
}
return adjusted_indices;
}
std::vector<int64_t> axes_to_mask(const std::vector<int64_t>& axes, size_t slice_indices_length) {
std::vector<int64_t> mask(slice_indices_length, 1);
for (auto axis : axes) {
mask[axis] = 0;
}
return mask;
}
} // namespace
ngraph::pass::SliceToStridedSlice::SliceToStridedSlice(bool use_shapes) {
MATCHER_SCOPE(SliceToStridedSlice);
auto slice = pattern::wrap_type<opset8::Slice>();
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto slice_node = std::dynamic_pointer_cast<opset8::Slice>(m.get_match_root());
if (!slice_node)
return false;
if (slice_node->get_input_size() < 4)
return false;
auto arg = slice_node->input_value(0);
std::shared_ptr<opset8::Constant> start_const;
std::shared_ptr<opset8::Constant> stop_const;
std::shared_ptr<opset8::Constant> step_const;
if (use_shapes) {
start_const = get_constant_from_source(slice_node->input_value(1));
stop_const = get_constant_from_source(slice_node->input_value(2));
step_const = get_constant_from_source(slice_node->input_value(3));
} else {
start_const = std::dynamic_pointer_cast<opset8::Constant>(slice_node->input_value(1).get_node_shared_ptr());
stop_const = std::dynamic_pointer_cast<opset8::Constant>(slice_node->input_value(2).get_node_shared_ptr());
step_const = std::dynamic_pointer_cast<opset8::Constant>(slice_node->input_value(3).get_node_shared_ptr());
}
auto start_input = start_const ? start_const : slice_node->input_value(1);
auto stop_input = stop_const ? stop_const : slice_node->input_value(2);
auto step_input = step_const ? step_const : slice_node->input_value(3);
std::shared_ptr<opset8::Constant> axes_const;
if (slice_node->get_input_size() > 4) {
axes_const = use_shapes ? get_constant_from_source(slice_node->input_value(4))
: std::dynamic_pointer_cast<opset8::Constant>(slice_node->input_value(4).get_node_shared_ptr());
} else {
axes_const = slice_node->get_default_const_axes(start_input);
}
if (!axes_const)
return false;
const auto& data_shape = slice_node->get_input_partial_shape(0);
auto axes_vec = axes_const->cast_vector<int64_t>();
if (data_shape.rank().is_static()) {
auto norm_axes_vec = normalize_axes(slice_node->get_friendly_name(), axes_vec, data_shape.rank());
axes_vec = std::vector<int64_t>(norm_axes_vec.begin(), norm_axes_vec.end());
} else {
const bool need_normalization = std::any_of(axes_vec.begin(),
axes_vec.end(),
[](int64_t axis) {
return axis < 0;
});
if (need_normalization)
return false;
}
const size_t slice_indices_length = *std::max_element(std::begin(axes_vec), std::end(axes_vec)) + 1;
const auto begin_end_mask = axes_to_mask(axes_vec, slice_indices_length);
const bool are_axes_sorted = std::is_sorted(axes_vec.begin(), axes_vec.end());
const bool are_indices_aligned = are_axes_sorted && (axes_vec.size() == slice_indices_length);
NodeVector new_ops;
if (!are_indices_aligned) {
const auto scatter_axis = opset8::Constant::create(element::i32, Shape{1}, {0});
const auto slice_axes = opset8::Constant::create(element::i64, Shape{axes_vec.size()}, axes_vec);
new_ops.insert(new_ops.end(), {scatter_axis, slice_axes});
start_input = align_indices(start_input, slice_axes, scatter_axis, slice_indices_length, 0, new_ops);
stop_input = align_indices(stop_input, slice_axes, scatter_axis, slice_indices_length, 0, new_ops);
step_input = align_indices(step_input, slice_axes, scatter_axis, slice_indices_length, 1, new_ops);
}
new_ops.insert(new_ops.end(), {start_input.get_node_shared_ptr(), stop_input.get_node_shared_ptr(), step_input.get_node_shared_ptr()});
const auto strided_slice = std::make_shared<opset8::StridedSlice>(arg, start_input, stop_input, step_input, begin_end_mask, begin_end_mask);
new_ops.push_back(strided_slice);
strided_slice->set_friendly_name(slice_node->get_friendly_name());
ngraph::copy_runtime_info(slice_node, new_ops);
ngraph::replace_node(slice_node, strided_slice);
return true;
};
auto m = std::make_shared<pattern::Matcher>(slice, matcher_name);
register_matcher(m, callback);
}

View File

@ -15,9 +15,11 @@
#include <ngraph/function.hpp> #include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp> #include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/constant_folding.hpp> #include <ngraph/pass/constant_folding.hpp>
#include <transformations/common_optimizations/optimize_strided_slice.hpp> #include <transformations/common_optimizations/optimize_strided_slice.hpp>
#include <transformations/utils/utils.hpp> #include <transformations/utils/utils.hpp>
#include "openvino/core/partial_shape.hpp"
#include "common_test_utils/ngraph_test_utils.hpp" #include "common_test_utils/ngraph_test_utils.hpp"
@ -277,3 +279,707 @@ TEST_F(TransformationTestsF, OptimizeSS_Groupped_Test) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{concat}, ngraph::ParameterVector{source}); function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{concat}, ngraph::ParameterVector{source});
} }
} }
TEST_F(TransformationTestsF, OptimizeSS_UselessDeletion_use_shapes_false) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 5, 5, 5});
auto relu = std::make_shared<ngraph::opset1::Relu>(data);
auto begin = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
auto stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
std::vector<int64_t> begin_mask = {0, 0, 0, 0};
std::vector<int64_t> end_mask = {1, 1, 1, 1}; // ignoring end -- slicing to the end
auto ss = std::make_shared<ngraph::opset1::StridedSlice>(relu, begin, end, stride, begin_mask, end_mask);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::StridedSliceOptimization>(false);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
// No UselessStridedSliceEraser transformation if use_shapes == false
}
TEST_F(TransformationTestsF, OptimizeSS_Shared_Test_use_shapes_false) {
{
auto source = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 5, 5, 5});
auto begin1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
auto stride1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
std::vector<int64_t> begin_mask1 = {0, 0, 0, 0};
std::vector<int64_t> end_mask1 = {0, 0, 0, 0};
auto ss1 = std::make_shared<ngraph::opset1::StridedSlice>(source, begin1, end1, stride1, begin_mask1, end_mask1);
auto begin2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
auto stride2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
std::vector<int64_t> begin_mask2 = {0, 0, 0, 0};
std::vector<int64_t> end_mask2 = {0, 0, 0, 0};
auto ss2 = std::make_shared<ngraph::opset1::StridedSlice>(source, begin2, end2, stride2, begin_mask2, end_mask2);
auto concat = std::make_shared<ngraph::opset1::Concat>(ngraph::NodeVector{ss1, ss2}, 0);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{concat}, ngraph::ParameterVector{source});
manager.register_pass<ngraph::pass::StridedSliceOptimization>(false);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
// No SharedStridedSliceEraser transformation if use_shapes == false
}
TEST_F(TransformationTestsF, OptimizeSS_Groupped_Test_use_shapes_false) {
{
auto source = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 5, 5, 5});
auto begin1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {5, 3, 5, 5});
auto stride1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
std::vector<int64_t> begin_mask1 = {0, 0, 0, 0};
std::vector<int64_t> end_mask1 = {0, 0, 0, 0};
auto ss1 = std::make_shared<ngraph::opset1::StridedSlice>(source, begin1, end1, stride1, begin_mask1, end_mask1);
auto begin2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 3, 0, 0});
auto end2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {5, 5, 5, 5});
auto stride2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
std::vector<int64_t> begin_mask2 = {0, 0, 0, 0};
std::vector<int64_t> end_mask2 = {0, 0, 0, 0};
auto ss2 = std::make_shared<ngraph::opset1::StridedSlice>(source, begin2, end2, stride2, begin_mask2, end_mask2);
auto concat = std::make_shared<ngraph::opset1::Concat>(ngraph::NodeVector{ss1, ss2}, 1);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{concat}, ngraph::ParameterVector{source});
manager.register_pass<ngraph::pass::StridedSliceOptimization>(false);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
// No GroupedStridedSliceOptimizer transformation if use_shapes == false
}
TEST_F(TransformationTestsF, SliceToStridedSlice_default_axes) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
auto step = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 1});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
auto stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 1});
std::vector<int64_t> begin_mask = {0, 0, 0, 0};
std::vector<int64_t> end_mask = {0, 0, 0, 0};
auto strided_slice = std::make_shared<ngraph::opset8::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_axes_const_sorted_full) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
auto step = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 1});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 1, 2, 3});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
auto stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 1});
std::vector<int64_t> begin_mask = {0, 0, 0, 0};
std::vector<int64_t> end_mask = {0, 0, 0, 0};
auto strided_slice = std::make_shared<ngraph::opset1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_all_const) {
{
auto data = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{4}, {2, 3, 4, 5});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {-1});
auto step = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {-1});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{4}, {2, 3, 4, 5});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {-1});
auto stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
std::vector<int64_t> begin_end_mask = {0};
auto strided_slice = std::make_shared<ngraph::opset1::StridedSlice>(data, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_all_const_fold) {
{
auto data = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{4}, {2, 3, 4, 5});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {-1});
auto step = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {-1});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
manager.register_pass<ngraph::pass::ConstantFolding>();
}
{
auto sliced_const = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{2}, {3, 4});
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{sliced_const}, ngraph::ParameterVector{});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_sss_params_axes_const_sorted_less) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 2});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin, end, step});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto start = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto stop = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 2});
auto zero = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
const auto default_begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0});
const auto begin = std::make_shared<ngraph::opset8::ScatterUpdate>(default_begin,
axes,
start,
zero);
const auto default_end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0});
const auto end = std::make_shared<ngraph::opset8::ScatterUpdate>(default_end,
axes,
stop,
zero);
const auto default_stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1});
const auto stride = std::make_shared<ngraph::opset8::ScatterUpdate>(default_stride,
axes,
step,
zero);
std::vector<int64_t> begin_end_mask = {1, 0, 0};
auto strided_slice = std::make_shared<ngraph::opset1::StridedSlice>(data, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data, start, stop, step});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_sss_params_axes_const_unsorted) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {3, 1});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin, end, step});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto start = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto stop = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto zero = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {3, 1});
const auto default_begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0});
const auto begin = std::make_shared<ngraph::opset8::ScatterUpdate>(default_begin,
axes,
start,
zero);
const auto default_end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0});
const auto end = std::make_shared<ngraph::opset8::ScatterUpdate>(default_end,
axes,
stop,
zero);
const auto default_stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
const auto stride = std::make_shared<ngraph::opset8::ScatterUpdate>(default_stride,
axes,
step,
zero);
std::vector<int64_t> begin_end_mask = {1, 0, 1, 0};
auto strided_slice = std::make_shared<ngraph::opset8::StridedSlice>(data, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data, start, stop, step});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_sss_params_axes_const_negative_sorted) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, -3, 2, -1});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin, end, step});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto stride = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
std::vector<int64_t> begin_mask = {0, 0, 0, 0};
std::vector<int64_t> end_mask = {0, 0, 0, 0};
auto strided_slice = std::make_shared<ngraph::opset8::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data, begin, end, stride});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_sss_params_dyn_shape_axes_const_negative_unsorted) {
{
auto data_shape = ov::PartialShape{ov::Dimension(-1), ov::Dimension(2, 6), 4, ov::Dimension(-1)};
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, data_shape);
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::PartialShape{ov::Dimension(-1)});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::PartialShape{ov::Dimension(-1)});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::PartialShape{ov::Dimension(-1)});
auto axes = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {-1, -3});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin, end, step});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data_shape = ov::PartialShape{ov::Dimension(-1), ov::Dimension(2, 6), 4, ov::Dimension(-1)};
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, data_shape);
auto start = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::PartialShape{ov::Dimension(-1)});
auto stop = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::PartialShape{ov::Dimension(-1)});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::PartialShape{ov::Dimension(-1)});
auto zero = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {3, 1});
const auto default_begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0});
const auto begin = std::make_shared<ngraph::opset8::ScatterUpdate>(default_begin,
axes,
start,
zero);
const auto default_end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0});
const auto end = std::make_shared<ngraph::opset8::ScatterUpdate>(default_end,
axes,
stop,
zero);
const auto default_stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
const auto stride = std::make_shared<ngraph::opset8::ScatterUpdate>(default_stride,
axes,
step,
zero);
std::vector<int64_t> begin_end_mask = {1, 0, 1, 0};
auto strided_slice = std::make_shared<ngraph::opset8::StridedSlice>(data, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data, start, stop, step});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_sss_params_static_shape_axes_const_negative_unsorted) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto axes = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {-1, -3});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin, end, step});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto start = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto stop = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{2});
auto zero = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {3, 1});
const auto default_begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0});
const auto begin = std::make_shared<ngraph::opset8::ScatterUpdate>(default_begin,
axes,
start,
zero);
const auto default_end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0});
const auto end = std::make_shared<ngraph::opset8::ScatterUpdate>(default_end,
axes,
stop,
zero);
const auto default_stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
const auto stride = std::make_shared<ngraph::opset8::ScatterUpdate>(default_stride,
axes,
step,
zero);
std::vector<int64_t> begin_end_mask = {1, 0, 1, 0};
auto strided_slice = std::make_shared<ngraph::opset8::StridedSlice>(data, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data, start, stop, step});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_dyn_rank_axes_const_positive) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ov::PartialShape::dynamic());
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 1, 2, 3});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin, end, step});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ov::PartialShape::dynamic());
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto stride = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
std::vector<int64_t> begin_mask = {0, 0, 0, 0};
std::vector<int64_t> end_mask = {0, 0, 0, 0};
auto strided_slice = std::make_shared<ngraph::opset8::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data, begin, end, stride});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_dyn_rank_axes_const_negative) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ov::PartialShape::dynamic());
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto end = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto step = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto axes = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, -3, 2, -1});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin, end, step});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
manager.register_pass<ngraph::pass::ConstantFolding>();
}
// No transformation for negative axes and dynamic data rank
}
TEST_F(TransformationTestsF, SliceToStridedSlice_axes_param) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 4, 3, 5});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
auto step = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 1});
auto axes = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{4});
auto slice = std::make_shared<ngraph::opset8::Slice>(data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, axes});
manager.register_pass<ngraph::pass::StridedSliceOptimization>();
manager.register_pass<ngraph::pass::ConstantFolding>();
}
// No transformation for non-const axes input
}
TEST_F(TransformationTestsF, SliceToStridedSlice_begin_param_shape_of_use_shapes_true) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3, 4, 5});
auto shape_of_data = std::make_shared<ngraph::opset8::ShapeOf>(data, ngraph::element::i64);
auto data_rank = std::make_shared<ngraph::opset8::ShapeOf>(shape_of_data, ngraph::element::i64);
auto zero_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
auto one_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto three_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {3});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{1});
auto end = std::make_shared<ngraph::opset8::Broadcast>(three_const, begin);
auto step = std::make_shared<ngraph::opset8::Broadcast>(one_const, begin);
auto axes = std::make_shared<ngraph::opset8::Range>(zero_const, one_const, one_const, ngraph::element::i64);
auto slice = std::make_shared<ngraph::opset8::Slice>(shape_of_data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin});
manager.register_pass<ngraph::pass::StridedSliceOptimization>(true);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
{
auto one_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto three_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {3});
auto data = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {2, 3, 4, 5});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{1});
auto end = std::make_shared<ngraph::opset8::Broadcast>(three_const, begin);
auto stride = std::make_shared<ngraph::opset8::Broadcast>(one_const, begin);
std::vector<int64_t> begin_end_mask = {0};
auto strided_slice = std::make_shared<ngraph::opset1::StridedSlice>(data, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{begin});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_begin_param_shape_of_use_shapes_false) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3, 4, 5});
auto shape_of_data = std::make_shared<ngraph::opset8::ShapeOf>(data, ngraph::element::i64);
auto data_rank = std::make_shared<ngraph::opset8::ShapeOf>(shape_of_data, ngraph::element::i64);
auto zero_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
auto one_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto three_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {3});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{1});
auto end = std::make_shared<ngraph::opset8::Broadcast>(three_const, begin);
auto step = std::make_shared<ngraph::opset8::Broadcast>(one_const, begin);
auto axes = std::make_shared<ngraph::opset8::Range>(zero_const, one_const, one_const, ngraph::element::i64);
auto slice = std::make_shared<ngraph::opset8::Slice>(shape_of_data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data, begin});
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(false);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
{
auto one_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto three_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {3});
auto data = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {2, 3, 4, 5});
auto begin = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, ngraph::Shape{1});
auto end = std::make_shared<ngraph::opset8::Broadcast>(three_const, begin);
auto stride = std::make_shared<ngraph::opset8::Broadcast>(one_const, begin);
std::vector<int64_t> begin_end_mask = {0};
auto strided_slice = std::make_shared<ngraph::opset1::StridedSlice>(data, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{begin});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_const_fold_params_slice_shape_of_use_shapes_true) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3, 4, 5});
auto shape_of_data = std::make_shared<ngraph::opset8::ShapeOf>(data, ngraph::element::i64);
auto data_rank = std::make_shared<ngraph::opset8::ShapeOf>(shape_of_data, ngraph::element::i64);
auto zero_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
auto one_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto three_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {3});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto end = std::make_shared<ngraph::opset8::Broadcast>(three_const, begin);
auto step = std::make_shared<ngraph::opset8::Broadcast>(one_const, begin);
auto axes = std::make_shared<ngraph::opset8::Range>(zero_const, one_const, one_const, ngraph::element::i64);
auto slice = std::make_shared<ngraph::opset8::Slice>(shape_of_data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(true);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
{
auto sliced_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {3, 4});
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{sliced_const}, ngraph::ParameterVector{});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_const_fold_params_slice_shape_of_use_shapes_false) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3, 4, 5});
auto shape_of_data = std::make_shared<ngraph::opset8::ShapeOf>(data, ngraph::element::i64);
auto data_rank = std::make_shared<ngraph::opset8::ShapeOf>(shape_of_data, ngraph::element::i64);
auto zero_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
auto one_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto three_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {3});
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto end = std::make_shared<ngraph::opset8::Broadcast>(three_const, begin);
auto step = std::make_shared<ngraph::opset8::Broadcast>(one_const, begin);
auto axes = std::make_shared<ngraph::opset8::Range>(zero_const, one_const, one_const, ngraph::element::i64);
auto slice = std::make_shared<ngraph::opset8::Slice>(shape_of_data, begin, end, step, axes);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(false);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
{
auto sliced_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {3, 4});
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{sliced_const}, ngraph::ParameterVector{});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_slice_all_use_shapes_true) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3, 4, 5});
auto relu = std::make_shared<ngraph::opset8::Relu>(data);
auto shape_of_data = std::make_shared<ngraph::opset8::ShapeOf>(relu, ngraph::element::i64);
auto data_rank = std::make_shared<ngraph::opset8::ShapeOf>(shape_of_data, ngraph::element::i64);
auto zero_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
auto one_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto begin = std::make_shared<ngraph::opset8::Broadcast>(zero_const, data_rank);
auto end = std::make_shared<ngraph::opset8::Broadcast>(data_rank, data_rank);
auto step = std::make_shared<ngraph::opset8::Broadcast>(one_const, data_rank);
auto slice = std::make_shared<ngraph::opset8::Slice>(relu, begin, end, step);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::StridedSliceOptimization>(true);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3, 4, 5});
auto relu = std::make_shared<ngraph::opset8::Relu>(data);
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {4});
auto stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
std::vector<int64_t> begin_end_mask = {0, 0, 0, 0};
auto strided_slice = std::make_shared<ngraph::opset8::StridedSlice>(relu, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, SliceToStridedSlice_slice_all_use_shapes_false) {
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3, 4, 5});
auto relu = std::make_shared<ngraph::opset8::Relu>(data);
auto shape_of_data = std::make_shared<ngraph::opset8::ShapeOf>(relu, ngraph::element::i64);
auto data_rank = std::make_shared<ngraph::opset8::ShapeOf>(shape_of_data, ngraph::element::i64);
auto zero_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
auto one_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto begin = std::make_shared<ngraph::opset8::Broadcast>(zero_const, data_rank);
auto end = std::make_shared<ngraph::opset8::Broadcast>(data_rank, data_rank);
auto step = std::make_shared<ngraph::opset8::Broadcast>(one_const, data_rank);
auto slice = std::make_shared<ngraph::opset8::Slice>(relu, begin, end, step);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{slice}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::StridedSliceOptimization>(false);
manager.register_pass<ngraph::pass::ConstantFolding>();
}
{
auto data = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3, 4, 5});
auto relu = std::make_shared<ngraph::opset8::Relu>(data);
auto begin = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0});
auto end = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {4});
auto stride = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
std::vector<int64_t> begin_end_mask = {0, 0, 0, 0};
auto strided_slice = std::make_shared<ngraph::opset8::StridedSlice>(relu, begin, end, stride, begin_end_mask, begin_end_mask);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{strided_slice}, ngraph::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}

View File

@ -0,0 +1,82 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/slice.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
const std::vector<InferenceEngine::Precision> inputPrecision = {
InferenceEngine::Precision::I8,
InferenceEngine::Precision::U8,
InferenceEngine::Precision::I16,
InferenceEngine::Precision::I32,
InferenceEngine::Precision::FP32
};
std::vector<SliceSpecificParams> test_cases = {
SliceSpecificParams{ { 16 }, { 4 }, { 12 }, { 1 }, { 0 } },
SliceSpecificParams{ { 16 }, { 0 }, { 8 }, { 2 }, { 0 } },
SliceSpecificParams{ { 20, 10, 5 }, { 0, 0}, { 10, 20}, { 1, 1 }, { 1, 0 } },
SliceSpecificParams{ { 1, 2, 12, 100 }, { 0, 1, 0, 1 }, { 1, 2, 5, 100 }, { 1, 1, 1, 10 }, {} },
SliceSpecificParams{ { 1, 12, 100 }, { 0, 9, 0 }, { 1, 11, 1 }, { 1, 1, 1 }, {} },
SliceSpecificParams{ { 1, 12, 100 }, { 0, 1, 0 }, { 10, -1, 10 }, { 1, 1, 1 }, {} },
SliceSpecificParams{ { 2, 12, 100 }, { 1, 12, 100 }, { 0, 7, 0 }, { -1, -1, -1 }, {} },
SliceSpecificParams{ { 2, 12, 100 }, { 1, 4, 99 }, { 0, 9, 0 }, { -1, 2, -1 }, {} },
SliceSpecificParams{ { 2, 12, 100 }, { -1, -1, -1 }, { 0, 4, 0 }, { -1, -2, -1 }, {} },
SliceSpecificParams{ { 2, 12, 100 }, { -1, -1, -1 }, { 0, 0, 4 }, { -1, -1, -1 }, {2, 0, 1} },
SliceSpecificParams{ { 2, 12, 100 }, { 0, 0, 4 }, { -5, -1, -1 }, { 1, 2, 1 }, {2, 0, 1} },
SliceSpecificParams{ { 2, 2, 2, 2 }, { 0, 0, 0, 0 }, { 2, 2, 2, 2 }, { 1, 1, 1, 1 }, {} },
SliceSpecificParams{ { 2, 2, 2, 2 }, { 1, 1, 1, 1 }, { 2, 2, 2, 2 }, { 1, 1, 1, 1 }, {} },
SliceSpecificParams{ { 2, 2, 4, 3 }, { 0, 0, 0, 0 }, { 2, 2, 4, 3 }, { 1, 1, 2, 1 }, {} },
SliceSpecificParams{ { 2, 2, 4, 2 }, { 1, 0, 0, 1 }, { 2, 2, 4, 2 }, { 1, 1, 2, 1 }, {} },
SliceSpecificParams{ { 1, 2, 4, 2 }, { 0, 1, 0, 1 }, { 10, 2, 4, 2 }, { 1, 1, 2, 1 }, {} },
SliceSpecificParams{ { 10, 2, 4, 2 }, { 9, 1, 3, 0 }, { 0, 0, 0, 1 }, { -1, -1, -1, 1 }, {} },
SliceSpecificParams{ { 10, 2, 4, 2 }, { 19, 1, -1, 0 }, { -10, 0, 0, -1 }, { -1, -1, -1, 1 }, {} },
SliceSpecificParams{ { 3, 2, 4, 200 }, { 0, 1, -1, -1 }, { 3, 2, 0, 0 }, { 1, 1, -2, -1 }, {} },
SliceSpecificParams{ { 2, 4, 5, 5, 68 }, { 0, 1, 0, 0, 0 }, {
std::numeric_limits<std::int64_t>::max(),
std::numeric_limits<std::int64_t>::max(),
std::numeric_limits<std::int64_t>::max(),
std::numeric_limits<std::int64_t>::max(),
std::numeric_limits<std::int64_t>::max() }, { 1, 1, 1, 1, 16 }, {} },
SliceSpecificParams{ { 10, 12 }, { -1, 1 }, { -9999, 10 }, { -1, 1 }, {} },
SliceSpecificParams{ { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, -1, -60, -1 }, { -1, 1, -1, 1 }, {} },
SliceSpecificParams{ { 1, 5, 32, 32 }, { 0, 2, 5, 4 }, { 1, 4, 28, 27 }, { 1, 1, 1, 1 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 1, 5, 32, 20 }, { 0, 1, 0, 0 }, { 1, 3, 32, 20 }, { 1, 1, 1, 1 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 2, 5, 32, 20 }, { 0, 0, 10, 0 }, { 1, 3, 20, 20 }, { 1, 1, 1, 1 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 1, 5, 32, 32 }, { 0, 0, 20, 20 }, { 1, 5, 25, 26 }, { 1, 1, 1, 2 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 2, 5, 32, 32 }, { 0, 0, 0, 20 }, { 1, 2, 30, 30 }, { 1, 1, 2, 1 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 1, 5, 32, 20 }, { 0, 0, 2, 10 }, { 1, 3, 32, 20 }, { 1, 1, 1, 1 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 2, 5, 32, 32 }, { 0, 1, 0, 10 }, { 1, 5, 32, 30 }, { 1, 1, 1, 1 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 1, 5, 32, 20 }, { 0, 1, 2, 10 }, { 1, 5, 32, 18 }, { 1, 1, 1, 2 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 2, 8, 32, 20 }, { 0, 0, 2, 10 }, { 1, 8, 32, 18 }, { 1, 2, 1, 2 }, { 0, 1, 2, 3 } },
SliceSpecificParams{ { 2, 8, 32, 20 }, { 0, -20, -15 }, { 2, -5, 3 }, { 1, 1, 1 }, { 0, 2, 1 } },
// Plugin Error: Slice has zero dimension which is not allowed
// SliceSpecificParams{ { 2, 8, 32, 20 }, { 0, 0, 10 }, { 0, 32, 18 }, { 1, 1, 1 }, { 0, 1, 2 } },
// SliceSpecificParams{ { 2, 8, 32, 20 }, { 0, 0, 10 }, { 1, 0, 20 }, { 1, 1, 1 }, { 0, 1, 2 } },
// SliceSpecificParams{ { 2, 8, 32, 20 }, { 0, 4, 10 }, { 2, 8, 0 }, { 1, 1, 1 }, { 0, 1, 2 } },
// SliceSpecificParams{ { 2, 8, 32, 20 }, { 0, 4, 10 }, { 2, 8, 0 }, { 1, 1, 1 }, { 0, 2, 1 } },
// SliceSpecificParams{ { 2, 8, 32, 20 }, { 0, 4, 10 }, { 2, 8, 0 }, { 1, 1, 1 }, { 0, -2, -1 } },
};
INSTANTIATE_TEST_SUITE_P(
smoke_MKLDNN, SliceLayerTest,
::testing::Combine(
::testing::ValuesIn(test_cases),
::testing::ValuesIn(inputPrecision),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>())),
SliceLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/single_layer/slice.hpp"
namespace LayerTestsDefinitions {
TEST_P(SliceLayerTest, CompareWithRefs) {
Run();
}
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,43 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "shared_test_classes/base/layer_test_utils.hpp"
namespace LayerTestsDefinitions {
struct SliceSpecificParams {
InferenceEngine::SizeVector inputShape;
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
std::vector<int64_t> axes;
};
using SliceParams = std::tuple<
SliceSpecificParams,
InferenceEngine::Precision, // Net precision
InferenceEngine::Precision, // Input precision
InferenceEngine::Precision, // Output precision
InferenceEngine::Layout, // Input layout
InferenceEngine::Layout, // Output layout
std::string, // Device name
std::map<std::string, std::string> // Additional network configuration
>;
class SliceLayerTest : public testing::WithParamInterface<SliceParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<SliceParams> &obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,61 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_functions/builders.hpp"
#include "ngraph/ngraph.hpp"
#include "shared_test_classes/single_layer/slice.hpp"
using namespace ngraph;
namespace LayerTestsDefinitions {
std::string SliceLayerTest::getTestCaseName(const testing::TestParamInfo<SliceParams> &obj) {
SliceSpecificParams params;
InferenceEngine::Precision netPrc;
InferenceEngine::Precision inPrc, outPrc;
InferenceEngine::Layout inLayout, outLayout;
std::string targetName;
std::map<std::string, std::string> additionalConfig;
std::tie(params, netPrc, inPrc, outPrc, inLayout, outLayout, targetName, additionalConfig) = obj.param;
std::ostringstream result;
result << "inShape=" << CommonTestUtils::vec2str(params.inputShape) << "_";
result << "netPRC=" << netPrc.name() << "_";
result << "start=" << CommonTestUtils::vec2str(params.start) << "_";
result << "stop=" << CommonTestUtils::vec2str(params.stop) << "_";
result << "step=" << CommonTestUtils::vec2str(params.step) << "_";
result << "axes=" << CommonTestUtils::vec2str(params.axes) << "_";
result << "trgDev=" << targetName;
return result.str();
}
void SliceLayerTest::SetUp() {
SliceSpecificParams sliceParams;
InferenceEngine::Precision netPrecision;
std::map<std::string, std::string> additionalConfig;
std::tie(sliceParams, netPrecision, inPrc, outPrc, inLayout, outLayout, targetDevice, additionalConfig) = this->GetParam();
configuration.insert(additionalConfig.begin(), additionalConfig.end());
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
element::Type_t et = element::i32;
const auto data = std::make_shared<opset8::Parameter>(ngPrc, Shape(sliceParams.inputShape));
const auto start = std::make_shared<opset8::Constant>(et, Shape{sliceParams.start.size()}, sliceParams.start);
const auto stop = std::make_shared<opset8::Constant>(et, Shape{sliceParams.stop.size()}, sliceParams.stop);
const auto step = std::make_shared<opset8::Constant>(et, Shape{sliceParams.step.size()}, sliceParams.step);
Output<Node> slice;
if (sliceParams.axes.empty()) {
slice = std::make_shared<opset8::Slice>(data, start, stop, step);
} else {
const auto axes = std::make_shared<opset8::Constant>(et, Shape{sliceParams.axes.size()}, sliceParams.axes);
slice = std::make_shared<opset8::Slice>(data, start, stop, step, axes);
}
ResultVector results{std::make_shared<opset8::Result>(slice)};
function = std::make_shared<Function>(results, ov::ParameterVector{data}, "Slice");
}
} // namespace LayerTestsDefinitions

View File

@ -32,6 +32,7 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override; std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<ngraph::op::v0::Constant> get_default_const_axes(const Output<Node>& start) const;
PartialShape calculate_output_shape(const std::vector<int64_t>& starts, PartialShape calculate_output_shape(const std::vector<int64_t>& starts,
const std::vector<int64_t>& stops, const std::vector<int64_t>& stops,
const std::vector<int64_t>& steps, const std::vector<int64_t>& steps,

View File

@ -10,6 +10,7 @@
#include "ngraph/attribute_visitor.hpp" #include "ngraph/attribute_visitor.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
@ -34,19 +35,6 @@ op::v8::Slice::Slice(const Output<Node>& data,
namespace { namespace {
std::shared_ptr<ngraph::op::v0::Constant> get_default_const_axes(const Output<Node>& start) {
const auto start_pshape = start.get_partial_shape();
// Static case
if (start_pshape.rank().is_static() && start_pshape.rank().get_length() == 1 && start_pshape[0].is_static()) {
size_t axes_length = start_pshape[0].get_length();
std::vector<int64_t> axes(axes_length);
std::iota(axes.begin(), axes.end(), 0);
return op::v0::Constant::create(element::i64, Shape{axes_length}, axes);
}
// Dynamic case
return nullptr;
}
int64_t get_sliced_dim_size(int64_t start, int64_t stop, int64_t step, int64_t dim_size) { int64_t get_sliced_dim_size(int64_t start, int64_t stop, int64_t step, int64_t dim_size) {
// Normalize index // Normalize index
start = start < 0 ? dim_size + start : start; start = start < 0 ? dim_size + start : start;
@ -64,7 +52,9 @@ int64_t get_sliced_dim_size(int64_t start, int64_t stop, int64_t step, int64_t d
// Clip max stop index (last element exclusively) // Clip max stop index (last element exclusively)
elements_in_range = std::max(int64_t(0), std::min(dim_size, stop) - start); elements_in_range = std::max(int64_t(0), std::min(dim_size, stop) - start);
} }
const int64_t sliced_dim_size = std::ceil(elements_in_range / std::fabs(step)); const int64_t rest = elements_in_range % std::abs(step);
const int64_t integer_div = elements_in_range / std::abs(step);
const int64_t sliced_dim_size = !rest ? integer_div : integer_div + 1;
return sliced_dim_size; return sliced_dim_size;
} }
@ -75,6 +65,19 @@ bool op::v8::Slice::visit_attributes(AttributeVisitor& visitor) {
return true; return true;
} }
std::shared_ptr<ngraph::op::v0::Constant> op::v8::Slice::get_default_const_axes(const Output<Node>& start) const {
const auto start_pshape = start.get_partial_shape();
// Static case
if (start_pshape.rank().is_static() && start_pshape.rank().get_length() == 1 && start_pshape[0].is_static()) {
size_t axes_length = start_pshape[0].get_length();
std::vector<int64_t> axes(axes_length);
std::iota(axes.begin(), axes.end(), 0);
return op::v0::Constant::create(element::i64, Shape{axes_length}, axes);
}
// Dynamic case
return nullptr;
}
void op::v8::Slice::validate_and_infer_types() { void op::v8::Slice::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v8_Slice_validate_and_infer_types); NGRAPH_OP_SCOPE(v8_Slice_validate_and_infer_types);