StridedSlice default shape inference (#11292)
This commit is contained in:
parent
5b0a1fe7bb
commit
ed030e113e
@ -79,18 +79,34 @@ public:
|
||||
const std::vector<int64_t>& get_begin_mask() const {
|
||||
return m_begin_mask;
|
||||
}
|
||||
void set_begin_mask(const std::vector<int64_t>& vec) {
|
||||
m_begin_mask = vec;
|
||||
}
|
||||
const std::vector<int64_t>& get_end_mask() const {
|
||||
return m_end_mask;
|
||||
}
|
||||
void set_end_mask(const std::vector<int64_t>& vec) {
|
||||
m_end_mask = vec;
|
||||
}
|
||||
const std::vector<int64_t>& get_new_axis_mask() const {
|
||||
return m_new_axis_mask;
|
||||
}
|
||||
void set_new_axis_mask(const std::vector<int64_t>& vec) {
|
||||
m_new_axis_mask = vec;
|
||||
}
|
||||
const std::vector<int64_t>& get_shrink_axis_mask() const {
|
||||
return m_shrink_axis_mask;
|
||||
}
|
||||
void set_shrink_axis_mask(const std::vector<int64_t>& vec) {
|
||||
m_shrink_axis_mask = vec;
|
||||
}
|
||||
const std::vector<int64_t>& get_ellipsis_mask() const {
|
||||
return m_ellipsis_mask;
|
||||
}
|
||||
void set_ellipsis_mask_mask(const std::vector<int64_t>& vec) {
|
||||
m_ellipsis_mask = vec;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
|
@ -585,6 +585,7 @@ target_link_libraries(ov_core_unit_tests PRIVATE ngraph_test_util
|
||||
ngraph_reference
|
||||
ngraph::builder
|
||||
openvino::util
|
||||
ov_shape_inference
|
||||
pugixml::static
|
||||
${CMAKE_DL_LIBS}
|
||||
Threads::Threads
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include <dimension_tracker.hpp>
|
||||
#include <memory>
|
||||
#include <strided_slice_shape_inference.hpp>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
@ -191,3 +192,20 @@ TEST(type_prop, strided_slice_dynamic_value_and_label_propagation) {
|
||||
const auto& output_shape = bc->get_output_partial_shape(0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
|
||||
}
|
||||
|
||||
TEST(type_prop, default_strided_slice_shape_inference) {
|
||||
auto slice = new op::v1::StridedSlice;
|
||||
slice->set_begin_mask({0, 0, 0});
|
||||
slice->set_end_mask({0, 0, 0});
|
||||
slice->set_new_axis_mask({1, 0, 0});
|
||||
slice->set_shrink_axis_mask({0, 0, 0, 1});
|
||||
slice->set_ellipsis_mask_mask({0, 0, 0});
|
||||
std::vector<ov::PartialShape> in = {{10, 11, 12}, {3}, {3}, {3}}, out = {PartialShape()};
|
||||
int64_t begin_data[] = {0, 0, 0, 0}, end_data[] = {1, 1, 5, 1}, stride_data[] = {1, 1, 1, 1};
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> const_data = {
|
||||
{1, std::make_shared<ov::HostTensor>(element::i64, Shape{4}, begin_data)},
|
||||
{2, std::make_shared<ov::HostTensor>(element::i64, Shape{4}, end_data)},
|
||||
{3, std::make_shared<ov::HostTensor>(element::i64, Shape{4}, stride_data)}};
|
||||
ov::op::v1::shape_infer(slice, in, out, const_data);
|
||||
ASSERT_EQ(out[0], PartialShape({1, 1, 5}));
|
||||
}
|
Loading…
Reference in New Issue
Block a user