StridedSlice default shape inference (#11292)

This commit is contained in:
Evgenya Stepyreva 2022-03-29 16:42:52 +03:00 committed by GitHub
parent 5b0a1fe7bb
commit ed030e113e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 0 deletions

View File

@ -79,18 +79,34 @@ public:
const std::vector<int64_t>& get_begin_mask() const {
return m_begin_mask;
}
void set_begin_mask(const std::vector<int64_t>& vec) {
m_begin_mask = vec;
}
const std::vector<int64_t>& get_end_mask() const {
return m_end_mask;
}
void set_end_mask(const std::vector<int64_t>& vec) {
m_end_mask = vec;
}
const std::vector<int64_t>& get_new_axis_mask() const {
return m_new_axis_mask;
}
void set_new_axis_mask(const std::vector<int64_t>& vec) {
m_new_axis_mask = vec;
}
const std::vector<int64_t>& get_shrink_axis_mask() const {
return m_shrink_axis_mask;
}
void set_shrink_axis_mask(const std::vector<int64_t>& vec) {
m_shrink_axis_mask = vec;
}
const std::vector<int64_t>& get_ellipsis_mask() const {
return m_ellipsis_mask;
}
void set_ellipsis_mask_mask(const std::vector<int64_t>& vec) {
m_ellipsis_mask = vec;
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
OPENVINO_SUPPRESS_DEPRECATED_START

View File

@ -585,6 +585,7 @@ target_link_libraries(ov_core_unit_tests PRIVATE ngraph_test_util
ngraph_reference
ngraph::builder
openvino::util
ov_shape_inference
pugixml::static
${CMAKE_DL_LIBS}
Threads::Threads

View File

@ -4,6 +4,7 @@
#include <dimension_tracker.hpp>
#include <memory>
#include <strided_slice_shape_inference.hpp>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
@ -191,3 +192,20 @@ TEST(type_prop, strided_slice_dynamic_value_and_label_propagation) {
const auto& output_shape = bc->get_output_partial_shape(0);
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
}
TEST(type_prop, default_strided_slice_shape_inference) {
auto slice = new op::v1::StridedSlice;
slice->set_begin_mask({0, 0, 0});
slice->set_end_mask({0, 0, 0});
slice->set_new_axis_mask({1, 0, 0});
slice->set_shrink_axis_mask({0, 0, 0, 1});
slice->set_ellipsis_mask_mask({0, 0, 0});
std::vector<ov::PartialShape> in = {{10, 11, 12}, {3}, {3}, {3}}, out = {PartialShape()};
int64_t begin_data[] = {0, 0, 0, 0}, end_data[] = {1, 1, 5, 1}, stride_data[] = {1, 1, 1, 1};
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> const_data = {
{1, std::make_shared<ov::HostTensor>(element::i64, Shape{4}, begin_data)},
{2, std::make_shared<ov::HostTensor>(element::i64, Shape{4}, end_data)},
{3, std::make_shared<ov::HostTensor>(element::i64, Shape{4}, stride_data)}};
ov::op::v1::shape_infer(slice, in, out, const_data);
ASSERT_EQ(out[0], PartialShape({1, 1, 5}));
}