[Shape inference] Pad_1/Topk_3/Split_1/VariadicSplit_1/ExperimentalDetectronROIFeatureExtractor_6/Bucketize_3/EmbeddingBagOffsetsSum_3/EmbeddingSegmentsSum_3/Range_4/RegionYolo_0/ReorgYolo_0 (#8413)

* [shape_infer] add shape_infer for ExperimentalDetectronROIFeatureExtractor op

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* add test

* Use compatible & merge for intersection checks

* Update

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* Add perf_test

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* Initial commit

* fix compile issue

* Add test

* fix clang format issue

* support for pads_begin/pads_end with different sizes

* fix bug in EDGE mode checking

* fix padding mode checks

* fix according to jane's review comment

* fix const reference

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* Initial commit

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* fix bugs

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* Switch to use single generic code with small helper template

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* Initial commit on Split

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* Convolution update

* Adds pragma once

* Reductions shape infer

* Shape nodes

* style

* Update

* add exp detectron roi feature

* Update

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>

* Use get_data_as_int64 + constant_data

* Add test

* Add utils.hpp into cpuUnit shape inference test

* avoid using friend template function

* fix topk axis bug

* Add bucketize

* Add embeddingbag offsets sum

* Add embedding segments sum

* fix code style issue

* Add Range_4

* Update tests

* Add range

* Add region Yolo

* Add reorg

* fix according to Globev's comment

* call shape_infer in evaluate_variadic_split()

* fix CI issue

* fix CI issue

* fix CI issue, topk change revert

* fix flake8 E302

* fix myriad smoke test issue

* fix according to Vladislav's second round review

* fix format

* Add StridedSlice & Einsum

* fix pad_test.cpp build issue

* fix according to review comment

* insert directly into output shape

* revert infer_slice_shape() change since vpux compiler uses this function

* move tests

Co-authored-by: Stepyreva, Evgenya <evgenya.stepyreva@intel.com>
This commit is contained in:
Tingqian Li 2021-12-22 11:54:12 +08:00 committed by GitHub
parent 7bcca1b82d
commit b8e6b6368c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
52 changed files with 2097 additions and 714 deletions

View File

@ -10,6 +10,7 @@
#include <openvino/opsets/opset4.hpp>
#include <openvino/opsets/opset5.hpp>
#include <openvino/opsets/opset6.hpp>
#include <openvino/opsets/opset7.hpp>
#include <openvino/opsets/opset8.hpp>
#include "assign_shape_inference.hpp"
@ -30,6 +31,21 @@
#include "scatter_nd_base_shape_inference.hpp"
#include "shape_inference.hpp"
#include "shape_nodes.hpp"
#include "fake_quantize.hpp"
#include "experimental_detectron_detection_output_shape_inference.hpp"
#include "bucketize_shape_inference.hpp"
#include "embedding_segments_sum_shape_inference.hpp"
#include "embeddingbag_offsets_shape_inference.hpp"
#include "experimental_detectron_roi_feature_shape_inference.hpp"
#include "pad_shape_inference.hpp"
#include "range_shape_inference.hpp"
#include "region_yolo_shape_inference.hpp"
#include "reorg_yolo_shape_inference.hpp"
#include "split_shape_inference.hpp"
#include "topk_shape_inference.hpp"
#include "variadic_split_shape_inference.hpp"
#include "einsum_shape_inference.hpp"
#include "strided_slice_shape_inference.hpp"
#include "static_shape.hpp"
#include "tile_shape_inference.hpp"
#include "utils.hpp"
@ -104,6 +120,34 @@ void shape_inference(ov::Node* op,
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset6::ExperimentalDetectronDetectionOutput>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset3::TopK>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset3::Bucketize>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset3::EmbeddingSegmentsSum>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset3::EmbeddingBagOffsetsSum>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset6::ExperimentalDetectronROIFeatureExtractor>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset1::Pad>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset4::Range>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset1::Range>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset1::RegionYolo>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset2::ReorgYolo>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset1::Split>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset1::VariadicSplit>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset7::Einsum>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset1::StridedSlice>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset3::Assign>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset6::Assign>(op)) {

View File

@ -77,7 +77,6 @@ xfail_issue_38735 = xfail_test(reason="RuntimeError: nGraph does not support the
xfail_issue_48052 = xfail_test(reason="Dropout op is not supported in traning mode")
xfail_issue_45180 = xfail_test(reason="RuntimeError: Unsupported dynamic op: ReduceSum")
xfail_issue_44851 = xfail_test(reason="Expected: Unsupported dynamic op: Broadcast")
xfail_issue_44854 = xfail_test(reason="Expected: Unsupported dynamic op: VariadicSplit")
xfail_issue_44858 = xfail_test(reason="Expected: Unsupported dynamic op: Unsqueeze")
xfail_issue_44956 = xfail_test(reason="Expected: Unsupported dynamic op: Loop")
xfail_issue_44957 = xfail_test(reason="Expected: Unsupported dynamic op: NonZero")

View File

@ -27,7 +27,6 @@ from tests import (
xfail_issue_38735,
xfail_issue_39658,
xfail_issue_39662,
xfail_issue_44854,
xfail_issue_44858,
xfail_issue_44956,
xfail_issue_44965,
@ -264,12 +263,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_example_cpu",
"OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_random_cpu",
),
(
xfail_issue_44854,
"OnnxBackendNodeModelTest.test_split_variable_parts_1d_cpu",
"OnnxBackendNodeModelTest.test_split_variable_parts_2d_cpu",
"OnnxBackendNodeModelTest.test_split_variable_parts_default_axis_cpu",
),
(
xfail_issue_44858,
"OnnxBackendNodeModelTest.test_unsqueeze_axis_0_cpu",

View File

@ -15,7 +15,6 @@ from tests.test_onnx.utils import (
run_node,
)
from tests import (xfail_issue_35927,
xfail_issue_44854,
xfail_issue_44858,
xfail_issue_44968)
@ -306,7 +305,6 @@ def test_split_2d(node, expected_output):
assert all_arrays_equal(ng_results, expected_output)
@xfail_issue_44854
def test_split_2d_splits_input():
data = np.arange(8, dtype=np.int32).reshape(2, 4)
splits = np.array([3, 1]).astype(np.int64)
@ -321,7 +319,6 @@ def test_split_2d_splits_input():
assert all_arrays_equal(ng_results, expected_outputs)
@xfail_issue_44854
def test_split_1d():
# 1D
data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)
@ -360,7 +357,7 @@ def test_split_1d():
splits = np.array([2, 4]).astype(np.int64)
node = onnx.helper.make_node(
"Split", inputs=["input", "splits"], outputs=["y", "z"], split=[2, 4]
"Split", inputs=["input", "splits"], outputs=["y", "z"]
)
expected_outputs = [
np.array([1.0, 2.0]).astype(np.float32),

View File

@ -88,7 +88,6 @@ xfail_issue_38735 = xfail_test(reason="RuntimeError: nGraph does not support the
xfail_issue_48052 = xfail_test(reason="Dropout op is not supported in traning mode")
xfail_issue_45180 = xfail_test(reason="RuntimeError: Unsupported dynamic op: ReduceSum")
xfail_issue_44851 = xfail_test(reason="Expected: Unsupported dynamic op: Broadcast")
xfail_issue_44854 = xfail_test(reason="Expected: Unsupported dynamic op: VariadicSplit")
xfail_issue_44858 = xfail_test(reason="Expected: Unsupported dynamic op: Unsqueeze")
xfail_issue_44956 = xfail_test(reason="Expected: Unsupported dynamic op: Loop")
xfail_issue_44957 = xfail_test(reason="Expected: Unsupported dynamic op: NonZero")

View File

@ -26,7 +26,6 @@ from tests_compatibility import (
xfail_issue_38735,
xfail_issue_39658,
xfail_issue_39662,
xfail_issue_44854,
xfail_issue_44858,
xfail_issue_44956,
xfail_issue_44965,
@ -256,12 +255,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_example_cpu",
"OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_random_cpu",
),
(
xfail_issue_44854,
"OnnxBackendNodeModelTest.test_split_variable_parts_1d_cpu",
"OnnxBackendNodeModelTest.test_split_variable_parts_2d_cpu",
"OnnxBackendNodeModelTest.test_split_variable_parts_default_axis_cpu",
),
(
xfail_issue_44858,
"OnnxBackendNodeModelTest.test_unsqueeze_axis_0_cpu",

View File

@ -15,7 +15,6 @@ from tests_compatibility.test_onnx.utils import (
run_node,
)
from tests_compatibility import (xfail_issue_35927,
xfail_issue_44854,
xfail_issue_44858,
xfail_issue_44968)
@ -306,7 +305,6 @@ def test_split_2d(node, expected_output):
assert all_arrays_equal(ng_results, expected_output)
@xfail_issue_44854
def test_split_2d_splits_input():
data = np.arange(8, dtype=np.int32).reshape(2, 4)
splits = np.array([3, 1]).astype(np.int64)
@ -321,7 +319,6 @@ def test_split_2d_splits_input():
assert all_arrays_equal(ng_results, expected_outputs)
@xfail_issue_44854
def test_split_1d():
# 1D
data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)
@ -360,7 +357,7 @@ def test_split_1d():
splits = np.array([2, 4]).astype(np.int64)
node = onnx.helper.make_node(
"Split", inputs=["input", "splits"], outputs=["y", "z"], split=[2, 4]
"Split", inputs=["input", "splits"], outputs=["y", "z"]
)
expected_outputs = [
np.array([1.0, 2.0]).astype(np.float32),

View File

@ -37,7 +37,7 @@ public:
///
/// \return Einsum equation
///
std::string get_equation() const {
const std::string& get_equation() const {
return m_equation;
}

View File

@ -54,6 +54,11 @@ public:
private:
Attributes m_attrs;
template <class T>
friend void shape_infer(const ExperimentalDetectronROIFeatureExtractor* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes);
};
} // namespace v6
} // namespace op

View File

@ -39,6 +39,11 @@ public:
void set_output_type(element::Type output_type) {
m_output_type = output_type;
}
const element::Type& get_output_type() const {
return m_output_type;
}
// Overload collision with method on Node
using Node::set_output_type;

View File

@ -79,6 +79,9 @@ private:
std::vector<float> m_anchors{};
int m_axis;
int m_end_axis;
template <class T>
friend void shape_infer(const RegionYolo* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes);
};
} // namespace v0
} // namespace op

View File

@ -12,6 +12,7 @@
namespace ov {
namespace op {
namespace v1 {
/// \brief Splits the input tensor into a list of equal sized tensors
class OPENVINO_API Split : public Op {
public:

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/bucketize.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v3 {
template <class T>
void shape_infer(const Bucketize* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2) && output_shapes.size() == 1);
const auto& data_shape = input_shapes[0];
const auto& buckets_shape = input_shapes[1];
NODE_VALIDATION_CHECK(op,
buckets_shape.rank().compatible(1),
"Buckets input must be a 1D tensor. Got: ",
buckets_shape);
output_shapes[0] = data_shape;
}
} // namespace v3
} // namespace op
} // namespace ov

View File

@ -0,0 +1,112 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/einsum.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v7 {
template <class T>
void shape_infer(const Einsum* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
// check that equation has correct format and extract input and output subscripts
std::vector<std::string> input_subscripts;
std::string output_subscript;
Einsum::parse_equation(op->get_equation(), input_subscripts, output_subscript);
// a number of input subscripts must match with a number of input tensors
NODE_VALIDATION_CHECK(op,
input_subscripts.size() == input_shapes.size(),
"Equation must contain a number of subscripts equal to a number of Einsum inputs.");
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1);
// create a dictionary with dimension sizes (or ranges in case of dynamic shapes) for each label
// and check their compatibility in case of repeating labels
std::unordered_map<std::string, T> label_to_shape;
for (size_t input_idx = 0; input_idx < input_shapes.size(); ++input_idx) {
const auto& pshape = input_shapes[input_idx];
std::vector<std::string> labels;
labels = Einsum::extract_labels(input_subscripts[input_idx]);
if (pshape.rank().is_static()) {
size_t input_rank = pshape.size();
// check that a rank is greater or equal to a number of labels
// these numbers are always equal if there is no ellipsis in the subscript
NODE_VALIDATION_CHECK(op,
input_rank >= labels.size(),
"Input rank must be greater or equal to a number of labels in the "
"corresponding input subscript.");
for (size_t label_ind = 0, dim_ind = 0; label_ind < labels.size() && dim_ind < input_rank; ++label_ind) {
auto const& label = labels[label_ind];
if (label.compare("...") == 0) {
size_t num_broadcasted_dims = input_rank - labels.size() + 1;
auto current_sub_pshape = T(std::vector<DimType>(pshape.begin() + dim_ind,
pshape.begin() + dim_ind + num_broadcasted_dims));
if (label_to_shape.find(label) == label_to_shape.end()) {
label_to_shape[label] = current_sub_pshape;
} else {
bool is_broadcast_success = T::broadcast_merge_into(label_to_shape[label],
current_sub_pshape,
op::AutoBroadcastType::NUMPY);
NODE_VALIDATION_CHECK(op,
is_broadcast_success,
"Input dimensions labeled with ellipsis for Einsum "
"must be broadcastable.");
}
dim_ind += num_broadcasted_dims;
} else {
if (label_to_shape.find(label) == label_to_shape.end()) {
label_to_shape[label] = T{pshape[dim_ind]};
} else {
NODE_VALIDATION_CHECK(op,
label_to_shape[label].compatible(T{pshape[label_ind]}),
"Different input dimensions indicated by the same labels for Einsum "
"must be compatible.");
T::merge_into(label_to_shape[label], T{pshape[dim_ind]});
}
++dim_ind;
}
}
} else {
for (auto const& label : labels) {
NODE_VALIDATION_CHECK(op,
label != "...",
"The subscript corresponding to a dynamic rank input must "
"not contain ellipsis.");
if (label_to_shape.find(label) == label_to_shape.end()) {
label_to_shape[label] = ov::PartialShape{Dimension::dynamic()};
}
}
}
}
// compute the output shape
std::vector<std::string> output_labels;
output_labels = Einsum::extract_labels(output_subscript);
auto& output_shape = output_shapes[0];
output_shape.resize(0);
for (auto const& output_label : output_labels) {
NODE_VALIDATION_CHECK(op,
label_to_shape.find(output_label) != label_to_shape.end(),
"Label in output subscript of Einsum equation must enter at least "
"one input subscript.");
output_shape.insert(output_shape.end(),
label_to_shape[output_label].begin(),
label_to_shape[output_label].end());
}
}
} // namespace v7
} // namespace op
} // namespace ov

View File

@ -0,0 +1,71 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <bitset>
#include <openvino/core/validation_util.hpp>
#include <openvino/op/embeddingbag_offsets_sum.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v3 {
template <class T>
void shape_infer(const EmbeddingSegmentsSum* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
const auto input_size = input_shapes.size();
NODE_VALIDATION_CHECK(op, (input_size >= 4 && input_size <= 6) && output_shapes.size() == 1);
static constexpr int EMB_TABLE = 0;
static constexpr int INDICES = 1;
static constexpr int SEGMENT_IDS = 2;
static constexpr int NUM_SEGMENTS = 3;
static constexpr int DEFAULT_INDEX = 4;
static constexpr int PER_SAMPLE_WEIGHTS = 5;
NODE_VALIDATION_CHECK(op, input_shapes[INDICES].rank().compatible(1), "INDICES must be 1D");
NODE_VALIDATION_CHECK(op, input_shapes[SEGMENT_IDS].rank().compatible(1), "SEGMENT_IDS must be 1D");
NODE_VALIDATION_CHECK(op,
input_shapes[INDICES].compatible(input_shapes[SEGMENT_IDS]),
"INDICES and SEGMENT_IDS shape must be same");
NODE_VALIDATION_CHECK(op, input_shapes[NUM_SEGMENTS].compatible(T{}), "NUM_SEGMENTS must be a scalar");
if (input_size >= 5) {
NODE_VALIDATION_CHECK(op, input_shapes[DEFAULT_INDEX].compatible(T{}), "DEFAULT_INDEX must be a scalar");
}
if (input_size == 6) {
NODE_VALIDATION_CHECK(op,
input_shapes[PER_SAMPLE_WEIGHTS].rank().compatible(1),
"PER_SAMPLE_WEIGHTS must be 1D");
NODE_VALIDATION_CHECK(op,
input_shapes[INDICES].compatible(input_shapes[PER_SAMPLE_WEIGHTS]),
"INDICES and PER_SAMPLE_WEIGHTS shape must be same");
}
const auto& emb_table_shape = input_shapes[EMB_TABLE];
auto& result_shape = output_shapes[0];
if (emb_table_shape.rank().is_static()) {
result_shape = emb_table_shape;
std::vector<int64_t> segments_value;
if (get_data_as_int64<T>(NUM_SEGMENTS, op, segments_value, constant_data)) {
result_shape[0] = segments_value[0];
} else {
result_shape[0] = Dimension::dynamic();
}
} else {
result_shape = ov::PartialShape::dynamic();
}
}
} // namespace v3
} // namespace op
} // namespace ov

View File

@ -0,0 +1,58 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/embeddingbag_offsets_sum.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace util {
template <class T>
void shape_infer(const ov::op::util::EmbeddingBagOffsetsBase* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes) {
const auto input_size = input_shapes.size();
NODE_VALIDATION_CHECK(op, (input_size >= 3 && input_size <= 5) && output_shapes.size() == 1);
static constexpr int EMB_TABLE = 0;
static constexpr int INDICES = 1;
static constexpr int OFFSETS = 2;
static constexpr int DEFAULT_INDEX = 3;
static constexpr int PER_SAMPLE_WEIGHTS = 4;
NODE_VALIDATION_CHECK(op, input_shapes[INDICES].rank().compatible(1), "INDICES must be 1D");
NODE_VALIDATION_CHECK(op, input_shapes[OFFSETS].rank().compatible(1), "OFFSETS must be 1D");
if (input_size >= 4) {
NODE_VALIDATION_CHECK(op, input_shapes[DEFAULT_INDEX].rank().compatible(0), "DEFAULT_INDEX must be a scalar");
}
if (input_size == 5) {
NODE_VALIDATION_CHECK(op,
input_shapes[PER_SAMPLE_WEIGHTS].rank().compatible(1),
"PER_SAMPLE_WEIGHTS must be 1D");
NODE_VALIDATION_CHECK(op,
input_shapes[INDICES].compatible(input_shapes[PER_SAMPLE_WEIGHTS]),
"INDICES and PER_SAMPLE_WEIGHTS shape must be same");
}
const auto& emb_table_shape = input_shapes[EMB_TABLE];
const auto& offsets_shape = input_shapes[OFFSETS];
if (emb_table_shape.rank().is_static()) {
output_shapes[0] = emb_table_shape;
output_shapes[0][0] = offsets_shape.rank().is_static() ? offsets_shape[0] : Dimension::dynamic();
} else {
output_shapes[0] = PartialShape::dynamic();
}
}
} // namespace util
} // namespace op
} // namespace ov

View File

@ -0,0 +1,89 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/op/experimental_detectron_roi_feature.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v6 {
// by definition:
// inputs:
// 1. [number_of_ROIs, 4]
// 2..L [1, number_of_channels, layer_size[l], layer_size[l]]
// outputs:
// 1. out_shape = [number_of_ROIs, number_of_channels, output_size, output_size]
// 2. out_rois_shape = [number_of_ROIs, 4]
template <class T>
void shape_infer(const ExperimentalDetectronROIFeatureExtractor* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, input_shapes.size() >= 2 && output_shapes.size() == 2);
const auto& rois_shape = input_shapes[0];
auto& out_shape = output_shapes[0];
auto& out_rois_shape = output_shapes[1];
// all dimensions is initialized by-default as dynamic
out_shape.resize(4);
out_rois_shape.resize(2);
// infer static dimensions
out_shape[2] = op->get_attrs().output_size;
out_shape[3] = op->get_attrs().output_size;
out_rois_shape[1] = 4;
// infer number_of_ROIs (which may be dynamic/static)
auto rois_shape_rank = rois_shape.rank();
NODE_VALIDATION_CHECK(op, rois_shape_rank.compatible(2), "Input rois rank must be equal to 2.");
if (rois_shape_rank.is_static()) {
NODE_VALIDATION_CHECK(op,
rois_shape[1].compatible(4),
"The last dimension of the 'input_rois' input must be equal to 4. "
"Got: ",
rois_shape[1]);
out_shape[0] = rois_shape[0];
out_rois_shape[0] = rois_shape[0];
}
// infer number_of_channels;
// by definition, all shapes starting from input 2 must have same number_of_channels
DimType channels_intersection;
bool channels_intersection_initialized = false;
for (size_t i = 1; i < input_shapes.size(); i++) {
const auto& current_shape = input_shapes[i];
auto current_rank = current_shape.rank();
NODE_VALIDATION_CHECK(op,
current_rank.compatible(4),
"Rank of each element of the pyramid must be equal to 4. Got: ",
current_rank);
if (current_rank.is_static()) {
NODE_VALIDATION_CHECK(op,
current_shape[0].compatible(1),
"The first dimension of each pyramid element must be equal to 1. "
"Got: ",
current_shape[0]);
if (channels_intersection_initialized) {
NODE_VALIDATION_CHECK(op,
DimType::merge(channels_intersection, channels_intersection, current_shape[1]),
"The number of channels must be the same for all layers of the pyramid.");
} else {
channels_intersection = current_shape[1];
channels_intersection_initialized = true;
}
}
}
out_shape[1] = channels_intersection;
}
} // namespace v6
} // namespace op
} // namespace ov

View File

@ -0,0 +1,148 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/pad.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v1 {
template <class T>
void shape_infer(const Pad* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
constexpr bool is_dynamic_shape = std::is_base_of<ov::PartialShape, T>::value;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 3 || input_shapes.size() == 4) && output_shapes.size() == 1);
auto& output_shape = output_shapes[0];
auto pad_mode = op->get_pad_mode();
// Check the shape of pad_value
if (pad_mode == PadMode::CONSTANT && input_shapes.size() == 4) {
const auto& pad_value_shape = input_shapes[3];
NODE_VALIDATION_CHECK(op,
pad_value_shape.rank().compatible(0),
"Argument for padding value is not a scalar (shape: ",
pad_value_shape,
").");
}
const auto& pads_begin_shape = input_shapes[1];
const auto& pads_begin_rank = pads_begin_shape.rank();
NODE_VALIDATION_CHECK(op,
pads_begin_rank.compatible(1),
"Argument for pads_begin is not 1D (shape: ",
pads_begin_rank,
").");
const auto& pads_end_shape = input_shapes[2];
const auto& pads_end_rank = pads_end_shape.rank();
NODE_VALIDATION_CHECK(op,
pads_end_rank.compatible(1),
"Argument for pads_end is not 1D (shape: ",
pads_end_rank,
").");
const auto& arg_shape = input_shapes[0];
const auto& arg_shape_rank = arg_shape.rank();
if (arg_shape_rank.is_static()) {
if (pads_begin_shape.is_static()) {
NODE_VALIDATION_CHECK(op,
pads_begin_shape[0].get_length() <= arg_shape_rank.get_length(),
"Number of elements of pads_begin must be >= 0 and <= arg rank "
"(pads_begin_shape[0]: ",
pads_begin_shape[0],
").");
}
if (pads_end_shape.is_static()) {
NODE_VALIDATION_CHECK(op,
pads_end_shape[0].get_length() <= arg_shape_rank.get_length(),
"Number of elements of pads_end must be >= 0 and <= arg rank (pads_end_shape[0]: ",
pads_end_shape[0],
").");
}
output_shape.resize(arg_shape_rank.get_length());
std::vector<int64_t> pads_begin_coord;
std::vector<int64_t> pads_end_coord;
get_data_as_int64<T>(1, op, pads_begin_coord, constant_data);
get_data_as_int64<T>(2, op, pads_end_coord, constant_data);
// special check for static shape inference
NODE_VALIDATION_CHECK(op,
is_dynamic_shape || (!pads_begin_coord.empty()),
"Cannot determined static output shape when pads_begin is not determined.");
NODE_VALIDATION_CHECK(op,
is_dynamic_shape || (!pads_end_coord.empty()),
"Cannot determined static output shape when pads_begin is not determined.");
if (!pads_begin_coord.empty() && !pads_end_coord.empty()) {
NODE_VALIDATION_CHECK(op,
(output_shape.size() == pads_begin_coord.size()),
"length of pads_begin mismatches with rank of input, expect ",
output_shape.size(),
", but got ",
pads_begin_coord.size());
NODE_VALIDATION_CHECK(op,
(output_shape.size() == pads_end_coord.size()),
"length of pads_end mismatches with rank of input, expect ",
output_shape.size(),
", but got ",
pads_end_coord.size());
for (size_t i = 0; i < output_shape.size(); i++) {
ptrdiff_t begin = pads_begin_coord[i];
ptrdiff_t end = pads_end_coord[i];
if (arg_shape[i].is_static()) {
const auto& dim = arg_shape[i].get_length();
output_shape[i] = static_cast<size_t>(begin + dim + end);
if (i > 1) {
NODE_VALIDATION_CHECK(op,
pad_mode != op::PadMode::EDGE || arg_shape[i].get_length() >= 1,
"EDGE padding mode requires an input of dimension of "
"at least 1 at each "
"spatial axis.");
NODE_VALIDATION_CHECK(op,
pad_mode != op::PadMode::REFLECT || arg_shape[i].get_length() >= 2,
"REFLECT padding mode requires an input of dimension "
"of at least 2 at each "
"spatial axis.");
}
NODE_VALIDATION_CHECK(op,
pad_mode != op::PadMode::REFLECT || (begin < dim && end < dim),
"REFLECT padding mode requires that 'pads_begin[D]' and 'pads_end[D]' "
"must be not greater than 'data_shape[D] - 1'.");
NODE_VALIDATION_CHECK(op,
pad_mode != op::PadMode::SYMMETRIC || (begin <= dim && end <= dim),
"SYMMETRIC padding mode requires that 'pads_begin[D]' and 'pads_end[D]' "
"must be not greater than 'data_shape[D]'.");
} else {
output_shape[i] = arg_shape[i] + (begin + end);
}
}
} else {
output_shape = ov::PartialShape::dynamic(arg_shape_rank);
}
} else {
output_shape = ov::PartialShape::dynamic(arg_shape_rank);
}
}
} // namespace v1
} // namespace op
} // namespace ov

View File

@ -0,0 +1,152 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/range.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace ShapeInferRange {
template <class T>
inline bool get_data_as_double(
size_t idx,
const ov::Node* op,
std::vector<double>& axes_value,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
if (constant_data.count(idx)) {
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<double>();
} else {
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
axes_value = constant->cast_vector<double>();
}
return true;
}
template <>
inline bool get_data_as_double<ov::PartialShape>(
size_t idx,
const ov::Node* op,
std::vector<double>& axes_value,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
if (constant_data.count(idx)) {
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<double>();
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
axes_value = constant->cast_vector<double>();
} else {
return false;
}
return true;
}
template <class T>
void range_shape_infer(const Node* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
bool output_is_integral,
bool step_allows_zero,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 3) && output_shapes.size() == 1);
NODE_VALIDATION_CHECK(op, input_shapes[0].rank().compatible(0), "'start' input is not a scalar");
NODE_VALIDATION_CHECK(op, input_shapes[1].rank().compatible(0), "'stop' input is not a scalar");
NODE_VALIDATION_CHECK(op, input_shapes[2].rank().compatible(0), "'step' input is not a scalar");
std::vector<double> start_val;
std::vector<double> stop_val;
std::vector<double> step_val;
double start = 0;
double stop = 0;
double step = 0;
if (get_data_as_double<T>(0, op, start_val, constant_data)) {
NODE_VALIDATION_CHECK(op, start_val.size() == 1);
start = start_val[0];
NODE_VALIDATION_CHECK(op, std::isfinite(start) && !std::isnan(start), "'start' cannot be nan or infinite.");
}
if (get_data_as_double<T>(1, op, stop_val, constant_data)) {
NODE_VALIDATION_CHECK(op, stop_val.size() == 1);
stop = stop_val[0];
NODE_VALIDATION_CHECK(op, std::isfinite(stop) && !std::isnan(stop), "'stop' cannot be nan or infinite.");
}
if (get_data_as_double<T>(2, op, step_val, constant_data)) {
NODE_VALIDATION_CHECK(op, step_val.size() == 1);
step = step_val[0];
if (step_allows_zero)
NODE_VALIDATION_CHECK(op, std::isfinite(step) && !std::isnan(step), "'step' cannot be nan or infinite.");
else
NODE_VALIDATION_CHECK(op,
std::isfinite(step) && !std::isnan(step) && step != 0,
"'step' cannot be zero, nan, or infinite.");
}
if (start_val.size() == 1 && stop_val.size() == 1 && step_val.size() == 1) {
// all inputs must be casted to output_type before
// the rounding for casting values are done towards zero
if (output_is_integral) {
start = std::trunc(start);
stop = std::trunc(stop);
step = std::trunc(step);
}
// the number of elements is: max(ceil((stop start) / step), 0)
double span;
if ((step > 0 && start >= stop) || (step < 0 && start <= stop)) {
span = 0;
} else {
span = stop - start;
}
double strided = ceil(fabs(span) / fabs(step));
output_shapes[0] = T{static_cast<uint32_t>(strided)};
} else {
output_shapes[0] = ov::PartialShape::dynamic(1);
}
}
} // namespace ShapeInferRange
namespace v0 {
template <class T>
void shape_infer(const Range* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
ShapeInferRange::range_shape_infer(op,
input_shapes,
output_shapes,
op->get_input_element_type(0).is_integral_number(),
false,
constant_data);
}
} // namespace v0
namespace v4 {
template <class T>
void shape_infer(const Range* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
ShapeInferRange::range_shape_infer(op,
input_shapes,
output_shapes,
op->get_output_type().is_integral_number(),
true,
constant_data);
}
} // namespace v4
} // namespace op
} // namespace ov

View File

@ -0,0 +1,59 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/region_yolo.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v0 {
template <class T>
void shape_infer(const RegionYolo* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 1) && output_shapes.size() == 1);
const auto& input_shape = input_shapes[0];
const auto& input_rank = input_shape.rank();
auto& output_shape = output_shapes[0];
NODE_VALIDATION_CHECK(op, input_rank.compatible(4), "Input must be a tensor of rank 4, but got ", input_rank);
if (input_rank.is_static()) {
int end_axis = op->m_end_axis;
if (end_axis < 0) {
end_axis += input_shape.size();
}
if (op->m_do_softmax) {
output_shape.resize(0);
auto axis = ov::normalize_axis(op, op->m_axis, input_rank);
DimType flat_dim = 1;
for (int64_t i = 0; i < axis; i++) {
output_shape.push_back(input_shape[i]);
}
for (int64_t i = axis; i < end_axis + 1; i++) {
flat_dim *= input_shape[i];
}
output_shape.push_back(flat_dim);
for (size_t i = end_axis + 1; i < input_shape.size(); i++) {
output_shape.push_back(input_shape[i]);
}
} else {
output_shape = T({input_shape[0],
static_cast<typename DimType::value_type>(
(op->get_num_classes() + op->get_num_coords() + 1) * op->get_mask().size()),
input_shape[2],
input_shape[3]});
}
} else {
output_shape = ov::PartialShape::dynamic(ov::Rank(1, 4));
}
}
} // namespace v0
} // namespace op
} // namespace ov

View File

@ -0,0 +1,58 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/reorg_yolo.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v0 {
template <class T>
void shape_infer(const ReorgYolo* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 1) && output_shapes.size() == 1);
const auto& input_shape = input_shapes[0];
auto& output_shape = output_shapes[0];
const auto & strides = op->get_strides();
if (input_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(op, input_shape.size() == 4, "[N, C, H, W] input shape is required.");
NODE_VALIDATION_CHECK(op,
input_shape[2].is_dynamic() || (input_shape[2].get_length() % strides[0]) == 0,
"For [N, C, H, W] input shape, H should be divisible by stride.");
NODE_VALIDATION_CHECK(op,
input_shape[3].is_dynamic() || (input_shape[3].get_length() % strides[0]) == 0,
"For [N, C, H, W] input shape, W should be divisible by stride.");
NODE_VALIDATION_CHECK(op,
input_shape[1].is_dynamic() || input_shape[1].get_length() >= (strides[0] * strides[0]),
"For [N, C, H, W] input shape, C >= (stride*stride) is required.");
output_shape = T({input_shape[0], input_shape[1]});
for (size_t i = 2; i < input_shape.size(); i++) {
if (input_shape[i].is_static())
output_shape.push_back(input_shape[i].get_length() / strides[0]);
else {
const auto& interval = input_shape[i].get_interval();
if (interval.has_upper_bound()) {
output_shape.push_back(
ov::Dimension(interval.get_max_val() / strides[0], interval.get_min_val() / strides[0]));
} else {
output_shape.push_back(ov::Dimension::dynamic());
}
}
output_shape[1] *= strides[0];
}
} else {
output_shape = ov::PartialShape::dynamic(input_shape.rank());
}
}
} // namespace v0
} // namespace op
} // namespace ov

View File

@ -0,0 +1,85 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/split.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v1 {
template <typename T>
void shape_infer(const Split* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2));
output_shapes.clear();
const auto& data_ps = input_shapes[0];
const auto& axis_ps = input_shapes[1];
NODE_VALIDATION_CHECK(op, axis_ps.rank().compatible(0), "'axis' input must be a scalar. Got: ", axis_ps);
auto each_output_shape = data_ps;
const auto data_rank = data_ps.rank();
std::vector<int64_t> axes_values;
const auto & num_splits = op->get_num_splits();
if (get_data_as_int64<T>(1, op, axes_values, constant_data) && data_rank.is_static()) {
NODE_VALIDATION_CHECK(op,
axes_values.size() == 1,
"a scalar axis value is expected. Got: ",
axes_values.size(),
" axes");
auto axis = ov::normalize_axis(op, axes_values[0], data_rank);
if (data_ps[axis].is_static()) {
const auto dimension_at_axis = data_ps[axis].get_length();
NODE_VALIDATION_CHECK(op,
dimension_at_axis % num_splits == 0,
"Dimension of data input shape along 'axis': ",
dimension_at_axis,
" must be evenly divisible by 'num_splits' attribute value: ",
num_splits);
each_output_shape[axis] = dimension_at_axis / num_splits;
} else {
const auto dim_interval_at_axis = data_ps[axis].get_interval();
NODE_VALIDATION_CHECK(op,
dim_interval_at_axis.get_max_val() >= static_cast<int64_t>(num_splits),
"The interval maximum of the dimension for data "
"input shape along 'axis' must be "
"greater or equal to 'num_splits' attribute. Got: ",
dim_interval_at_axis,
" and ",
num_splits);
auto dim_interval_at_axis_min =
static_cast<int64_t>(dim_interval_at_axis.get_min_val() * (1.0f / num_splits));
auto dim_interval_at_axis_max = dim_interval_at_axis.get_max_val();
if (dim_interval_at_axis.has_upper_bound()) {
dim_interval_at_axis_max = static_cast<int64_t>(dim_interval_at_axis_max * (1.0f / num_splits));
}
each_output_shape[axis] = Dimension(dim_interval_at_axis_min, dim_interval_at_axis_max);
}
} else {
each_output_shape = ov::PartialShape::dynamic(data_ps.rank());
}
for (size_t i = 0; i < num_splits; ++i)
output_shapes.push_back(each_output_shape);
}
} // namespace v1
} // namespace op
} // namespace ov

View File

@ -0,0 +1,236 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/strided_slice.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v1 {
template <class T>
void shape_infer(const StridedSlice* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 3 || input_shapes.size() == 4) && output_shapes.size() == 1);
const auto& input_shape = input_shapes[0];
const auto& begin_shape = input_shapes[1];
NODE_VALIDATION_CHECK(op,
begin_shape.rank().compatible(1),
"Begin input must be 1D (begin rank: ",
begin_shape.rank(),
").");
const auto& end_shape = input_shapes[2];
NODE_VALIDATION_CHECK(op,
end_shape.rank().compatible(1),
"End input must be 1D (end rank: ",
end_shape.rank(),
").");
std::vector<int64_t> begin;
std::vector<int64_t> end;
std::vector<int64_t> strides;
bool got_begin = get_data_as_int64<T>(1, op, begin, constant_data);
bool got_end = get_data_as_int64<T>(2, op, end, constant_data);
bool got_strides = false;
if (input_shapes.size() > 3) {
got_strides = get_data_as_int64<T>(3, op, strides, constant_data);
} else if (got_begin) {
// generate default strides
strides.resize(begin.size(), 1);
got_strides = true;
}
if (got_begin && got_end && got_strides) {
if (begin.size() && end.size()) {
NODE_VALIDATION_CHECK(op,
begin.size() == end.size(),
"Lower bounds and Upper bounds needs to have same number of values");
}
if (begin.size() && strides.size()) {
NODE_VALIDATION_CHECK(op,
begin.size() == strides.size(),
"Lower bounds and strides needs to have same number of values");
}
if (end.size() && strides.size()) {
NODE_VALIDATION_CHECK(op,
end.size() == strides.size(),
"Upper bounds and strides needs to have same number of values");
}
auto convert_mask_to_axis_set = [](const std::vector<int64_t>& mask) {
AxisSet axis_set{};
for (size_t i = 0; i < mask.size(); ++i) {
if (mask[i] == 1) {
axis_set.emplace(i);
}
}
return axis_set;
};
AxisSet ellipsis_mask = convert_mask_to_axis_set(op->get_ellipsis_mask());
NODE_VALIDATION_CHECK(op, ellipsis_mask.size() <= 1, "At most one ellipsis is allowed.");
if (input_shape.rank().is_dynamic()) {
output_shapes[0] = ov::PartialShape::dynamic();
return;
}
auto input_rank = input_shape.size();
AxisSet new_axis_mask = convert_mask_to_axis_set(op->get_new_axis_mask());
NODE_VALIDATION_CHECK(op,
input_rank + new_axis_mask.size() >= begin.size(),
"Input rank plus number of new axis has to be at least the size of Lower "
"and Upper bounds vector.");
AxisSet begin_mask = convert_mask_to_axis_set(op->get_begin_mask());
AxisSet end_mask = convert_mask_to_axis_set(op->get_end_mask());
AxisSet shrink_axis_mask = convert_mask_to_axis_set(op->get_shrink_axis_mask());
std::vector<DimType> dim;
int64_t input_shape_idx = 0;
for (size_t axis = 0; axis < begin.size(); ++axis) {
// add all dimensions hidden under the ellipsis mask if ellipsis mask is set
if (ellipsis_mask.count(axis)) {
// only one bit in ellipsis mask is allowed
int num_new_axis_after_ellipses = 0;
int num_input_axis_before_ellipses = 0;
for (size_t i = 0; i < axis; ++i) {
if (!new_axis_mask.count(i)) {
num_input_axis_before_ellipses++;
}
}
for (size_t i = axis + 1; i < begin.size(); ++i) {
if (new_axis_mask.count(i)) {
num_new_axis_after_ellipses++;
}
}
int64_t num_input_axis_after_ellipses =
(begin.size() - axis - num_new_axis_after_ellipses - 1); // -1 because it's a position of ellipses
int64_t num_of_hidden_dims =
input_rank - num_input_axis_after_ellipses - num_input_axis_before_ellipses;
for (int64_t i = 0; i < num_of_hidden_dims; ++i) {
dim.emplace_back(input_shape[input_shape_idx]);
input_shape_idx++;
}
} else {
// add new single dimension if new_axis_mask is set
if (new_axis_mask.count(axis)) {
dim.emplace_back(1);
}
// skip this dimension if shrink_axis_mask is set
else if (shrink_axis_mask.count(axis)) {
input_shape_idx++;
}
// calculating dimension (begin, end, begin_mask, end_mask, stride)
else {
const int64_t lb0 = begin[axis];
const int64_t ub0 = end[axis];
// set default value for stride or use given value
int64_t stride = 1;
if (strides.size() > axis) {
stride = strides[axis];
}
NODE_VALIDATION_CHECK(op, stride != 0, "Stride must be non-zero");
auto get_output_dim = [&](int64_t input_dim) {
// make a mutable copy
auto lb = lb0;
auto ub = ub0;
// convert negative indexes to positive
// take max for this case: if abs(lb) > input_shape[input_shape_idx],then after
// conversion lb < 0
// so according to tensorflow and numpy we just get 0
if (lb < 0) {
lb = std::max(input_dim + lb, int64_t(0));
}
if (ub < 0) {
ub = std::max(input_dim + ub, stride > 0 ? int64_t(0) : int64_t(-1));
}
// apply restrictions when begin or end values more than max possible values.
lb = std::min(input_dim, lb);
ub = std::min(input_dim, ub);
int64_t dimension = 0;
if (stride < 0) {
// apply masks
if (begin_mask.count(axis)) {
lb = input_dim - 1;
}
if (end_mask.count(axis)) {
ub = -1;
}
lb = std::min(lb, input_dim - 1);
lb -= 1; // we always get 1st element, so we need decrease range
if (ub <= lb) {
dimension = (ub - lb) / stride + 1;
}
} else {
// apply masks
if (begin_mask.count(axis)) {
lb = 0;
}
if (end_mask.count(axis)) {
ub = input_dim;
}
lb += 1; // we always get 1st element, so we need decrease range
if (ub >= lb) {
dimension = (ub - lb) / stride + 1;
}
}
return dimension;
};
if (input_shape[input_shape_idx].is_dynamic()) {
// the relationship between input and output length is monotonically increasing
// so we repeat the dimension inference twice to infer dynamic dimension
const Interval& interval = input_shape[input_shape_idx].get_interval();
int64_t odim_min = get_output_dim(interval.get_min_val());
int64_t odim_max;
if (interval.has_upper_bound())
odim_max = get_output_dim(interval.get_max_val());
else
odim_max = -1;
dim.emplace_back(ov::Dimension(odim_min, odim_max));
} else {
int64_t dimension = get_output_dim(input_shape[input_shape_idx].get_length());
dim.emplace_back(dimension);
}
input_shape_idx++;
}
}
}
// get remaining values
for (; input_shape_idx < input_shape.rank().get_length(); ++input_shape_idx) {
dim.emplace_back(input_shape[input_shape_idx]);
}
output_shapes[0] = T(dim);
} else {
output_shapes[0] = ov::PartialShape::dynamic(input_shape.rank());
}
}
} // namespace v1
} // namespace op
} // namespace ov

View File

@ -0,0 +1,93 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/topk.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v1 {
template <typename T>
void shape_infer(const TopK* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
constexpr bool is_dynamic_shape = std::is_base_of<ov::PartialShape, T>::value;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 && output_shapes.size() == 2));
const auto& input_shape = input_shapes[0];
const auto input_rank = input_shape.rank();
NODE_VALIDATION_CHECK(op,
input_rank.is_dynamic() || input_rank.get_length() > 0,
"Input rank must be greater than 0.");
const auto& k_shape = input_shapes[1];
NODE_VALIDATION_CHECK(op, k_shape.rank().compatible(0), "The 'K' input must be a scalar.");
auto output_shape = input_shape;
if (input_shape.rank().is_static()) {
ov::PartialShape k_as_shape;
auto input_rank = static_cast<int64_t>(input_shape.size());
auto normalized_axis = ov::normalize_axis(op, op->get_provided_axis(), input_rank, -input_rank, input_rank - 1);
auto& dim_axis = output_shape[normalized_axis];
if (!is_dynamic_shape) {
std::vector<int64_t> k_val;
NODE_VALIDATION_CHECK(op,
get_data_as_int64<T>(1, op, k_val, constant_data),
"determined k is required to infer static shape");
NODE_VALIDATION_CHECK(op,
k_val.size() == 1,
"Only one value (scalar) should be provided as the 'K' input to TopK",
" (got ",
k_val.size(),
" elements).");
dim_axis = k_val[0];
} else if (ov::evaluate_as_partial_shape(op->input_value(1), k_as_shape)) {
NODE_VALIDATION_CHECK(op,
k_as_shape.size() == 1,
"Only one value (scalar) should be provided as the 'K' input to TopK",
" (got ",
k_as_shape.size(),
" elements).");
if (k_as_shape[0].is_static()) {
NODE_VALIDATION_CHECK(op,
k_as_shape[0].get_max_length() >= 0,
"The value of 'K' must not be a negative number.",
" (got ",
k_as_shape[0].get_max_length(),
").");
dim_axis = k_as_shape[0].get_length();
} else {
// in this dynamic branch we are sure of dim_axis's type
const auto in_min = dim_axis.get_min_length();
const auto in_max = dim_axis.get_max_length();
const auto k_min = k_as_shape[0].get_min_length();
const auto k_max = k_as_shape[0].get_max_length();
const auto lower = std::min<Dimension::value_type>(in_min, k_min);
const auto upper =
in_max < 0 ? Dimension::dynamic().get_max_length() : std::max<Dimension::value_type>(in_max, k_max);
dim_axis = Dimension(lower, upper);
}
} else {
dim_axis = Dimension(0, dim_axis.get_max_length());
}
}
output_shapes[0] = output_shape;
output_shapes[1] = output_shape;
} // namespace
} // namespace v1
} // namespace op
} // namespace ov

View File

@ -0,0 +1,131 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/variadic_split.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v1 {
template <typename T>
void shape_infer(const VariadicSplit* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
constexpr bool is_dynamic_shape = std::is_base_of<ov::PartialShape, T>::value;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 3));
output_shapes.clear();
auto axis_pshape = input_shapes[1];
auto split_lengths_pshape = input_shapes[2];
NODE_VALIDATION_CHECK(op,
axis_pshape.rank().compatible(0) || axis_pshape.compatible({1}),
"Axis should be a scalar or of shape [1]. Got ",
axis_pshape,
" instead.");
if (split_lengths_pshape.is_static()) {
NODE_VALIDATION_CHECK(op,
split_lengths_pshape.size() == 1,
"Split lengths should be a 1-D tensor. Got ",
split_lengths_pshape.size(),
" instead.");
const auto num_outputs = split_lengths_pshape[0].get_length();
const auto& data_shape = input_shapes[0];
std::vector<int64_t> axis_values;
std::vector<int64_t> split_lengths;
if (data_shape.rank().is_static() && get_data_as_int64<T>(1, op, axis_values, constant_data)) {
NODE_VALIDATION_CHECK(op,
axis_values.size() == 1,
"a scalar axis value is expected. Got: ",
axis_values.size(),
" axes");
const auto axis_val = axis_values[0];
// Adjust split axis in case of negatives
const int64_t axis = ov::normalize_axis(op, axis_val, data_shape.rank());
if (get_data_as_int64<T>(2, op, split_lengths, constant_data)) {
// Adjust split lengths in case of negatives
int64_t sum_of_splits = 0;
int64_t negative_one_idx = -1;
for (size_t i = 0; i < split_lengths.size(); i++) {
NODE_VALIDATION_CHECK(op,
split_lengths[i] >= -1,
"Invalid value ",
split_lengths[i],
" in split lengths input. Should be >= -1.");
if (split_lengths[i] == -1) {
NODE_VALIDATION_CHECK(op,
negative_one_idx == -1,
"Cannot infer split with multiple -1 values at ",
negative_one_idx,
" and ",
i);
negative_one_idx = i;
} else {
sum_of_splits += split_lengths[i];
}
}
const auto dimension_at_axis = data_shape[axis];
if (negative_one_idx >= 0 && dimension_at_axis.is_static()) {
split_lengths[negative_one_idx] = dimension_at_axis.get_length() - sum_of_splits;
sum_of_splits += split_lengths[negative_one_idx];
}
if (data_shape[axis].is_static()) {
NODE_VALIDATION_CHECK(op,
sum_of_splits == data_shape[axis].get_length(),
"Total length of splits: ",
sum_of_splits,
" must match the length of the chosen axis: ",
data_shape[axis]);
}
for (int64_t output = 0; output < num_outputs; ++output) {
if (split_lengths.at(output) == -1) {
auto out_shape = data_shape;
out_shape[axis] = Dimension::dynamic();
output_shapes.push_back(out_shape);
} else {
auto out_shape = data_shape;
out_shape[axis] = split_lengths.at(output);
output_shapes.push_back(out_shape);
}
}
} else {
// we know num_outputs & axis but split_lengths, pass other dimensions besides axis in dynamic shape
// case
NODE_VALIDATION_CHECK(op, is_dynamic_shape, "Cannot infer static shape due to lack of split_lengths.");
auto out_shape = data_shape;
out_shape[axis] = Dimension::dynamic();
for (int64_t output = 0; output < num_outputs; ++output)
output_shapes.push_back(out_shape);
}
} else {
// we only know num_outputs, only predict the rank
auto out_shape = ov::PartialShape::dynamic(data_shape.rank());
for (int64_t output = 0; output < num_outputs; ++output)
output_shapes.push_back(out_shape);
}
} else {
// we don't even known the number of outputs in this case.
// just leave output_shapes as empty.
}
}
} // namespace v1
} // namespace op
} // namespace ov

View File

@ -4,6 +4,7 @@
#include "ngraph/op/bucketize.hpp"
#include "bucketize_shape_inference.hpp"
#include "itt.hpp"
using namespace ngraph;
@ -51,17 +52,16 @@ void op::v3::Bucketize::validate_and_infer_types() {
"Output type must be i32 or i64. Got: ",
m_output_type);
NODE_VALIDATION_CHECK(this,
buckets_pshape.rank().compatible(1),
"Buckets input must be a 1D tensor. Got: ",
buckets_pshape);
std::vector<ov::PartialShape> input_shapes = {data_pshape, buckets_pshape};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
shape_infer(this, input_shapes, output_shapes);
if (data_pshape.is_dynamic()) {
set_input_is_relevant_to_shape(0);
}
set_output_size(1);
set_output_type(0, m_output_type, data_pshape);
set_output_type(0, m_output_type, output_shapes[0]);
}
shared_ptr<Node> op::v3::Bucketize::clone_with_new_inputs(const OutputVector& inputs) const {

View File

@ -10,6 +10,7 @@
#include <string>
#include <unordered_map>
#include "einsum_shape_inference.hpp"
#include "itt.hpp"
using namespace std;
@ -180,97 +181,15 @@ void op::v7::Einsum::validate_and_infer_types() {
"Inputs to Einsum operation must have the same type.");
}
// check that equation has correct format and extract input and output subscripts
std::vector<std::string> input_subscripts;
std::string output_subscript;
parse_equation(m_equation, input_subscripts, output_subscript);
// a number of input subscripts must match with a number of input tensors
NODE_VALIDATION_CHECK(this,
input_subscripts.size() == num_inputs,
"Equation must contain a number of subscripts equal to a number of Einsum inputs.");
// create a dictionary with dimension sizes (or ranges in case dynamic shapes) for each label
// and check their compatibility in case repeating labels
unordered_map<string, ov::PartialShape> label_to_shape;
label_to_shape.clear();
std::vector<ov::PartialShape> input_shapes;
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
for (size_t input_idx = 0; input_idx < num_inputs; ++input_idx) {
const auto& pshape = get_input_partial_shape(input_idx);
std::vector<std::string> labels;
labels = extract_labels(input_subscripts[input_idx]);
if (pshape.rank().is_static()) {
size_t input_rank = pshape.rank().get_length();
// check that a rank is greater or equal to a number of labels
// these numbers are always equal if there is no ellipsis in the subscript
NODE_VALIDATION_CHECK(this,
input_rank >= labels.size(),
"Input rank must be greater or equal to a number of labels in the "
"corresponding input subscript.");
for (size_t label_ind = 0, dim_ind = 0; label_ind < labels.size() && dim_ind < input_rank; ++label_ind) {
auto const& label = labels[label_ind];
if (label.compare("...") == 0) {
size_t num_broadcasted_dims = input_rank - labels.size() + 1;
auto current_sub_pshape =
ov::PartialShape(std::vector<Dimension>(pshape.begin() + dim_ind,
pshape.begin() + dim_ind + num_broadcasted_dims));
if (label_to_shape.find(label) == label_to_shape.end()) {
label_to_shape[label] = current_sub_pshape;
} else {
bool is_broadcast_success =
ov::PartialShape::broadcast_merge_into(label_to_shape[label],
current_sub_pshape,
op::AutoBroadcastType::NUMPY);
NODE_VALIDATION_CHECK(this,
is_broadcast_success,
"Input dimensions labeled with ellipsis for Einsum "
"must be broadcastable.");
}
dim_ind += num_broadcasted_dims;
} else {
if (label_to_shape.find(label) == label_to_shape.end()) {
label_to_shape[label] = ov::PartialShape{pshape[dim_ind]};
} else {
NODE_VALIDATION_CHECK(this,
label_to_shape[label].compatible(ov::PartialShape{pshape[label_ind]}),
"Different input dimensions indicated by the same labels for Einsum "
"must be compatible.");
ov::PartialShape::merge_into(label_to_shape[label], ov::PartialShape{pshape[dim_ind]});
}
++dim_ind;
}
}
} else {
for (auto const& label : labels) {
NODE_VALIDATION_CHECK(this,
label != "...",
"The subscript corresponding to a dynamic rank input must "
"not contain ellipsis.");
if (label_to_shape.find(label) == label_to_shape.end()) {
label_to_shape[label] = ov::PartialShape{Dimension::dynamic()};
}
}
}
input_shapes.push_back(get_input_partial_shape(input_idx));
}
// compute the output shape
std::vector<std::string> output_labels;
output_labels = extract_labels(output_subscript);
std::vector<Dimension> output_pshape_vector;
shape_infer(this, input_shapes, output_shapes);
for (auto const& output_label : output_labels) {
NODE_VALIDATION_CHECK(this,
label_to_shape.find(output_label) != label_to_shape.end(),
"Label in output subscript of Einsum equation must enter at least "
"one input subscript.");
output_pshape_vector.insert(output_pshape_vector.end(),
label_to_shape[output_label].begin(),
label_to_shape[output_label].end());
}
set_output_type(0, input_type_0, ov::PartialShape(output_pshape_vector));
set_output_type(0, input_type_0, output_shapes[0]);
}
bool op::v7::Einsum::visit_attributes(AttributeVisitor& visitor) {

View File

@ -6,6 +6,7 @@
#include <ngraph/validation_util.hpp>
#include "embedding_segments_sum_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/opsets/opset3.hpp"
@ -75,24 +76,6 @@ void op::v3::EmbeddingSegmentsSum::validate_and_infer_types() {
get_input_element_type(SEGMENT_IDS),
")");
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(INDICES).is_dynamic() || get_input_partial_shape(INDICES).to_shape().size() == 1,
"INDICES must be 1D");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(SEGMENT_IDS).is_dynamic() ||
get_input_partial_shape(SEGMENT_IDS).to_shape().size() == 1,
"SEGMENT_IDS must be 1D");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(INDICES).compatible(get_input_partial_shape(SEGMENT_IDS)),
"INDICES and SEGMENT_IDS shape must be same");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(NUM_SEGMENTS).compatible(ov::PartialShape{}),
"NUM_SEGMENTS must be a scalar");
if (get_input_size() >= 5) {
NODE_VALIDATION_CHECK(this,
get_input_element_type(DEFAULT_INDEX) == element::i64 ||
@ -106,10 +89,6 @@ void op::v3::EmbeddingSegmentsSum::validate_and_infer_types() {
") must match indices element type (",
get_input_element_type(INDICES),
")");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(DEFAULT_INDEX).compatible(ov::PartialShape{}),
"DEFAULT_INDEX must be a scalar");
}
if (get_input_size() == 6) {
@ -120,36 +99,21 @@ void op::v3::EmbeddingSegmentsSum::validate_and_infer_types() {
") must match embedding table element type (",
get_input_element_type(EMB_TABLE),
")");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(PER_SAMPLE_WEIGHTS).is_dynamic() ||
get_input_partial_shape(PER_SAMPLE_WEIGHTS).to_shape().size() == 1,
"PER_SAMPLE_WEIGHTS must be 1D");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(INDICES).compatible(get_input_partial_shape(PER_SAMPLE_WEIGHTS)),
"INDICES and PER_SAMPLE_WEIGHTS shape must be same");
}
element::Type result_et = get_input_element_type(EMB_TABLE);
const ov::PartialShape& emb_table_shape = get_input_partial_shape(EMB_TABLE);
std::vector<PartialShape> result_shapes = {PartialShape::dynamic()};
std::vector<PartialShape> input_shapes;
for (int i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
ov::PartialShape result_shape;
if (emb_table_shape.rank().is_static()) {
result_shape = emb_table_shape;
if (const auto& num_segments_const = get_constant_from_source(input_value(NUM_SEGMENTS))) {
result_shape[0] = num_segments_const->cast_vector<int64_t>()[0];
} else {
result_shape[0] = Dimension::dynamic();
set_input_is_relevant_to_shape(NUM_SEGMENTS);
}
} else {
result_shape = ov::PartialShape::dynamic();
set_input_is_relevant_to_shape(NUM_SEGMENTS);
shape_infer(this, input_shapes, result_shapes);
if (result_shapes[EMB_TABLE].rank().is_dynamic() || result_shapes[EMB_TABLE][0].is_dynamic()) {
set_input_is_relevant_to_shape(NUM_SEGMENTS, true);
}
set_output_type(0, result_et, result_shape);
set_output_type(0, result_et, result_shapes[0]);
}
shared_ptr<Node> op::v3::EmbeddingSegmentsSum::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -5,6 +5,7 @@
#include "ngraph/op/experimental_detectron_roi_feature.hpp"
#include <algorithm>
#include <experimental_detectron_roi_feature_shape_inference.hpp>
#include <memory>
#include <utility>
@ -40,59 +41,18 @@ void op::v6::ExperimentalDetectronROIFeatureExtractor::validate_and_infer_types(
NGRAPH_OP_SCOPE(v6_ExperimentalDetectronROIFeatureExtractor_validate_and_infer_types);
NODE_VALIDATION_CHECK(this, get_input_size() >= 2, "At least two argument required.");
auto rois_shape = get_input_partial_shape(0);
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}};
std::vector<ov::PartialShape> input_shapes;
for (size_t i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
shape_infer(this, input_shapes, output_shapes);
auto input_et = get_input_element_type(0);
ov::PartialShape out_shape = {Dimension::dynamic(), Dimension::dynamic(), m_attrs.output_size, m_attrs.output_size};
ov::PartialShape out_rois_shape = {Dimension::dynamic(), 4};
if (rois_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this, rois_shape.rank().get_length() == 2, "Input rois rank must be equal to 2.");
auto input_rois_last_dim_intersection_with_4 = rois_shape[1] & Dimension(4);
NODE_VALIDATION_CHECK(this,
!input_rois_last_dim_intersection_with_4.get_interval().empty(),
"The last dimension of the 'input_rois' input must be equal to 4. "
"Got: ",
rois_shape[1]);
out_shape[0] = rois_shape[0];
out_rois_shape[0] = rois_shape[0];
}
size_t num_of_inputs = get_input_size();
Dimension channels_intersection;
for (size_t i = 1; i < num_of_inputs; i++) {
auto current_shape = get_input_partial_shape(i);
auto current_rank = current_shape.rank();
if (current_rank.is_static()) {
NODE_VALIDATION_CHECK(this,
current_rank.get_length() == 4,
"Rank of each element of the pyramid must be equal to 4. Got: ",
current_rank);
auto first_dim_intersection_with_1 = current_shape[0] & Dimension(1);
NODE_VALIDATION_CHECK(this,
!first_dim_intersection_with_1.get_interval().empty(),
"The first dimension of each pyramid element must be equal to 1. "
"Got: ",
current_shape[0]);
channels_intersection &= current_shape[1];
}
}
NODE_VALIDATION_CHECK(this,
!channels_intersection.get_interval().empty(),
"The number of channels must be the same for all layers of the pyramid.");
out_shape[1] = channels_intersection;
set_output_size(2);
set_output_type(0, input_et, out_shape);
set_output_type(1, input_et, out_rois_shape);
set_output_size(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); i++)
set_output_type(i, input_et, output_shapes[i]);
}
shared_ptr<Node> op::v6::ExperimentalDetectronROIFeatureExtractor::clone_with_new_inputs(

View File

@ -14,6 +14,7 @@
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "openvino/op/util/precision_sensitive_attribute.hpp"
#include "pad_shape_inference.hpp"
using namespace std;
using namespace ngraph;
@ -75,7 +76,6 @@ void op::v1::Pad::validate_and_infer_types() {
if (m_pad_mode == PadMode::CONSTANT && get_input_size() == 4) {
const auto& arg_pad_element_type = get_input_element_type(3);
const auto& arg_pad_shape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, arg_element_type, arg_pad_element_type),
"Argument element types do not match (input arg element type: ",
@ -83,12 +83,6 @@ void op::v1::Pad::validate_and_infer_types() {
", arg_pad element type: ",
arg_pad_element_type,
").");
NODE_VALIDATION_CHECK(this,
arg_pad_shape.compatible(ov::PartialShape{}),
"Argument for padding value is not a scalar (shape: ",
arg_pad_shape,
").");
}
NODE_VALIDATION_CHECK(this,
@ -103,80 +97,12 @@ void op::v1::Pad::validate_and_infer_types() {
pads_end_element_type,
").");
const auto& pads_begin_shape = get_input_partial_shape(1);
const auto& pads_begin_rank = pads_begin_shape.rank();
NODE_VALIDATION_CHECK(this,
pads_begin_rank.compatible(1),
"Argument for pads_begin is not 1D (shape: ",
pads_begin_rank,
").");
const auto& pads_end_shape = get_input_partial_shape(2);
const auto& pads_end_rank = pads_end_shape.rank();
NODE_VALIDATION_CHECK(this,
pads_end_rank.compatible(1),
"Argument for pads_end is not 1D (shape: ",
pads_end_rank,
").");
const auto& arg_shape = get_input_partial_shape(0);
const auto& arg_shape_rank = arg_shape.rank();
if (arg_shape_rank.is_static() && pads_begin_shape.is_static()) {
NODE_VALIDATION_CHECK(this,
pads_begin_shape[0].get_length() <= arg_shape_rank.get_length(),
"Number of elements of pads_begin must be >= 0 and <= arg rank "
"(pads_begin_shape[0]: ",
pads_begin_shape[0],
").");
}
if (arg_shape_rank.is_static() && pads_end_shape.is_static()) {
NODE_VALIDATION_CHECK(this,
pads_end_shape[0].get_length() <= arg_shape_rank.get_length(),
"Number of elements of pads_end must be >= 0 and <= arg rank (pads_end_shape[0]: ",
pads_end_shape[0],
").");
}
const auto& pads_begin_coord = get_pads_begin();
const auto& pads_end_coord = get_pads_end();
if (arg_shape_rank.is_static() && !pads_begin_coord.empty() && !pads_end_coord.empty()) {
const auto implied_rank = pads_begin_coord.size();
std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic());
for (size_t i = 0; i < implied_rank; i++) {
if (arg_shape[i].is_static()) {
ptrdiff_t result_dim = pads_begin_coord[i] + arg_shape[i].get_length() + pads_end_coord[i];
result_dims[i] = static_cast<size_t>(result_dim);
if (i > 1) {
NODE_VALIDATION_CHECK(this,
m_pad_mode != op::PadMode::EDGE || arg_shape[i].get_length() >= 1,
"EDGE padding mode requires an input of dimension of "
"at least 1 at each "
"spatial axis.");
NODE_VALIDATION_CHECK(this,
m_pad_mode != op::PadMode::REFLECT || arg_shape[i].get_length() >= 2,
"REFLECT padding mode requires an input of dimension "
"of at least 2 at each "
"spatial axis.");
}
NODE_VALIDATION_CHECK(
this,
m_pad_mode != op::PadMode::REFLECT || (pads_begin_coord[i] < arg_shape[i].get_length() &&
pads_end_coord[i] < arg_shape[i].get_length()),
"REFLECT padding mode requires that 'pads_begin[D]' and 'pads_end[D]' "
"must be not greater than 'data_shape[D] - 1'.");
NODE_VALIDATION_CHECK(
this,
m_pad_mode != op::PadMode::SYMMETRIC || (pads_begin_coord[i] <= arg_shape[i].get_length() &&
pads_end_coord[i] <= arg_shape[i].get_length()),
"SYMMETRIC padding mode requires that 'pads_begin[D]' and 'pads_end[D]' "
"must be not greater than 'data_shape[D]'.");
}
}
set_output_type(0, get_input_element_type(0), result_dims);
} else {
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic(arg_shape_rank));
}
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
std::vector<ov::PartialShape> input_shapes;
for (size_t i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}
shared_ptr<Node> op::v1::Pad::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -12,6 +12,7 @@
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/range.hpp"
#include "ngraph/type/element_type_traits.hpp"
#include "range_shape_inference.hpp"
using namespace std;
using namespace ngraph;
@ -71,10 +72,6 @@ void op::v4::Range::validate_and_infer_types() {
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).compatible(ov::Shape{}), "'start' input is not a scalar");
NODE_VALIDATION_CHECK(this, get_input_partial_shape(1).compatible(ov::Shape{}), "'stop' input is not a scalar");
NODE_VALIDATION_CHECK(this, get_input_partial_shape(2).compatible(ov::Shape{}), "'step' input is not a scalar");
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_integral_number() || get_input_element_type(0).is_real(),
"'start' input scalar should be a numeric type. Got: ",
@ -88,63 +85,14 @@ void op::v4::Range::validate_and_infer_types() {
"'step' input scalar should be a numeric type. Got: ",
get_input_element_type(2));
auto const_start = get_constant_from_source(input_value(0));
auto const_stop = get_constant_from_source(input_value(1));
auto const_step = get_constant_from_source(input_value(2));
std::vector<PartialShape> result_shapes = {PartialShape::dynamic()};
std::vector<PartialShape> input_shapes;
for (int i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
double start = 0;
double stop = 0;
double step = 0;
op::v4::shape_infer(this, input_shapes, result_shapes);
if (const_start != nullptr) {
std::vector<double> start_val = const_start->cast_vector<double>();
NODE_VALIDATION_CHECK(this, start_val.size() == 1);
start = start_val[0];
NODE_VALIDATION_CHECK(this, std::isfinite(start) && !std::isnan(start), "'start' cannot be nan or infinite.");
}
if (const_stop != nullptr) {
std::vector<double> stop_val = const_stop->cast_vector<double>();
NODE_VALIDATION_CHECK(this, stop_val.size() == 1);
stop = stop_val[0];
NODE_VALIDATION_CHECK(this, std::isfinite(stop) && !std::isnan(stop), "'stop' cannot be nan or infinite.");
}
if (const_step != nullptr) {
std::vector<double> step_val = const_step->cast_vector<double>();
NODE_VALIDATION_CHECK(this, step_val.size() == 1);
step = step_val[0];
NODE_VALIDATION_CHECK(this, std::isfinite(step) && !std::isnan(step), "'step' cannot be nan or infinite.");
}
ov::PartialShape result{ov::PartialShape::dynamic(1)};
if (const_start != nullptr && const_stop != nullptr && const_step != nullptr) {
// all inputs must be casted to output_type before
// the rounding for casting values are done towards zero
if (m_output_type.is_integral_number() && get_input_element_type(0).is_real()) {
start = std::trunc(start);
}
if (m_output_type.is_integral_number() && get_input_element_type(1).is_real()) {
stop = std::trunc(stop);
}
if (m_output_type.is_integral_number() && get_input_element_type(2).is_real()) {
step = std::trunc(step);
}
// the number of elements is: max(ceil((stop start) / step), 0)
double span;
if ((step > 0 && start >= stop) || (step < 0 && start <= stop)) {
span = 0;
} else {
span = stop - start;
}
double strided = ceil(fabs(span) / fabs(step));
result = ov::PartialShape{Dimension(static_cast<int64_t>(strided))};
}
set_output_type(0, m_output_type, result);
set_output_type(0, m_output_type, result_shapes[0]);
}
shared_ptr<Node> op::v4::Range::clone_with_new_inputs(const OutputVector& new_args) const {
@ -401,70 +349,24 @@ void op::v0::Range::validate_and_infer_types() {
result_et != element::boolean,
"Element type for start, stop, and step, must not be boolean.");
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).compatible(ov::Shape{}), "'start' input is not a scalar");
NODE_VALIDATION_CHECK(this, get_input_partial_shape(1).compatible(ov::Shape{}), "'stop' input is not a scalar");
NODE_VALIDATION_CHECK(this, get_input_partial_shape(2).compatible(ov::Shape{}), "'step' input is not a scalar");
NODE_VALIDATION_CHECK(this,
result_et != element::Type_t::u1 && result_et != element::Type_t::i4 &&
result_et != element::Type_t::u4 && result_et != element::Type_t::undefined,
"Internal nGraph error: unsupported element type: ",
result_et);
ov::PartialShape result_shape;
if (result_et == element::Type_t::dynamic) {
set_output_type(0, result_et, ov::PartialShape::dynamic(1));
} else {
std::vector<PartialShape> result_shapes = {PartialShape::dynamic()};
std::vector<PartialShape> input_shapes;
for (int i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
# pragma GCC diagnostic push
# pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (result_et) {
case element::Type_t::bf16:
result_shape = infer_output_shape<bfloat16>(this, result_et);
break;
case element::Type_t::f16:
result_shape = infer_output_shape<float16>(this, result_et);
break;
case element::Type_t::f32:
result_shape = infer_output_shape<float>(this, result_et);
break;
case element::Type_t::f64:
result_shape = infer_output_shape<double>(this, result_et);
break;
case element::Type_t::i8:
result_shape = infer_output_shape<int8_t>(this, result_et);
break;
case element::Type_t::i16:
result_shape = infer_output_shape<int16_t>(this, result_et);
break;
case element::Type_t::i32:
result_shape = infer_output_shape<int32_t>(this, result_et);
break;
case element::Type_t::i64:
result_shape = infer_output_shape<int64_t>(this, result_et);
break;
case element::Type_t::u8:
result_shape = infer_output_shape<uint8_t>(this, result_et);
break;
case element::Type_t::u16:
result_shape = infer_output_shape<uint16_t>(this, result_et);
break;
case element::Type_t::u32:
result_shape = infer_output_shape<uint32_t>(this, result_et);
break;
case element::Type_t::u64:
result_shape = infer_output_shape<uint64_t>(this, result_et);
break;
case element::Type_t::dynamic:
result_shape = ov::PartialShape::dynamic(1);
break;
case element::Type_t::u1:
case element::Type_t::i4:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::boolean:
NODE_VALIDATION_CHECK(this, false, "Internal nGraph error: unsupported element type: ", result_et);
break;
op::v0::shape_infer(this, input_shapes, result_shapes);
set_output_type(0, result_et, result_shapes[0]);
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
# pragma GCC diagnostic pop
#endif
set_output_type(0, result_et, result_shape);
}
shared_ptr<Node> op::v0::Range::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -6,6 +6,7 @@
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "region_yolo_shape_inference.hpp"
using namespace std;
using namespace ngraph;
@ -54,42 +55,11 @@ void op::RegionYolo::validate_and_infer_types() {
input_et.is_real(),
"Type of input is expected to be a floating point type. Got: ",
input_et);
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0)};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
shape_infer(this, input_shapes, output_shapes);
const auto& input_partial_shape = get_input_partial_shape(0);
if (input_partial_shape.rank().is_static()) {
ov::PartialShape input_shape = get_input_partial_shape(0);
ov::PartialShape output_shape;
int end_axis = m_end_axis;
if (m_end_axis < 0) {
m_end_axis += input_shape.size();
}
if (m_do_softmax) {
size_t flat_dim = 1;
for (int64_t i = 0; i < m_axis; i++) {
output_shape.push_back(input_shape[i]);
}
for (int64_t i = m_axis; i < end_axis + 1; i++) {
if (input_shape[i].is_dynamic()) {
flat_dim = -1;
break;
}
flat_dim *= input_shape[i].get_length();
}
output_shape.push_back(flat_dim);
for (size_t i = end_axis + 1; i < input_shape.size(); i++) {
output_shape.push_back(input_shape[i]);
}
} else {
output_shape = {input_shape[0],
ov::Dimension((m_num_classes + m_num_coords + 1) * m_mask.size()),
input_shape[2],
input_shape[3]};
}
set_output_type(0, input_et, output_shape);
} else {
set_output_type(0, input_et, ov::PartialShape::dynamic());
}
set_output_type(0, input_et, output_shapes[0]);
}
shared_ptr<Node> op::RegionYolo::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -6,6 +6,7 @@
#include "itt.hpp"
#include "ngraph/runtime/reference/reorg_yolo.hpp"
#include "reorg_yolo_shape_inference.hpp"
using namespace std;
using namespace ngraph;
@ -27,32 +28,12 @@ void op::ReorgYolo::validate_and_infer_types() {
NODE_VALIDATION_CHECK(this, !m_strides.empty(), "Stride attribute is required.");
auto input_et = get_input_element_type(0);
if (get_input_partial_shape(0).is_static()) {
auto input_shape = get_input_partial_shape(0).to_shape();
NODE_VALIDATION_CHECK(this, input_shape.size() == 4, "[N, C, H, W] input shape is required.");
NODE_VALIDATION_CHECK(this,
(input_shape[2] % m_strides[0]) == 0,
"For [N, C, H, W] input shape, H should be divisible by stride.");
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0)};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
shape_infer(this, input_shapes, output_shapes);
NODE_VALIDATION_CHECK(this,
(input_shape[3] % m_strides[0]) == 0,
"For [N, C, H, W] input shape, W should be divisible by stride.");
NODE_VALIDATION_CHECK(this,
input_shape[1] >= (m_strides[0] * m_strides[0]),
"For [N, C, H, W] input shape, C >= (stride*stride) is required.");
ov::Shape output_shape{input_shape[0], input_shape[1]};
for (size_t i = 2; i < input_shape.size(); i++) {
output_shape.push_back(input_shape[i] / m_strides[0]);
output_shape[1] *= m_strides[0];
}
set_output_type(0, input_et, output_shape);
} else {
auto input_shape = get_input_partial_shape(0);
set_output_type(0, input_et, ov::PartialShape::dynamic(input_shape.rank()));
}
set_output_type(0, input_et, output_shapes[0]);
}
shared_ptr<Node> op::ReorgYolo::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -5,6 +5,7 @@
#include "ngraph/runtime/reference/split.hpp"
#include <numeric>
#include <split_shape_inference.hpp>
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
@ -34,12 +35,8 @@ bool ngraph::op::v1::Split::visit_attributes(AttributeVisitor& visitor) {
void op::v1::Split::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v1_Split_validate_and_infer_types);
const ov::PartialShape& data_ps = get_input_partial_shape(0);
const ov::PartialShape& axis_ps = get_input_partial_shape(1);
const element::Type& axis_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this, axis_ps.rank().compatible(0), "'axis' input must be a scalar. Got: ", axis_ps);
NODE_VALIDATION_CHECK(this,
axis_et.is_integral_number(),
"Element type of 'axis' input must be integer. Got: ",
@ -50,48 +47,13 @@ void op::v1::Split::validate_and_infer_types() {
"Attribute 'num_splits' must be greater than zero. Got: ",
m_num_splits);
ov::PartialShape each_output_shape{data_ps};
const Rank data_rank = data_ps.rank();
const auto axis_input = get_constant_from_source(input_value(1));
if (axis_input && data_rank.is_static()) {
auto axis = axis_input->cast_vector<int64_t>()[0];
axis = ngraph::normalize_axis(this, axis, data_rank);
if (data_ps[axis].is_static()) {
const auto dimension_at_axis = data_ps[axis].get_length();
NODE_VALIDATION_CHECK(this,
dimension_at_axis % m_num_splits == 0,
"Dimension of data input shape along 'axis': ",
dimension_at_axis,
" must be evenly divisible by 'num_splits' attribute value: ",
m_num_splits);
each_output_shape[axis] = dimension_at_axis / m_num_splits;
} else {
const auto dim_interval_at_axis = data_ps[axis].get_interval();
NODE_VALIDATION_CHECK(this,
dim_interval_at_axis.get_max_val() >= static_cast<int64_t>(m_num_splits),
"The interval maximum of the dimension for data input shape along 'axis' must be "
"greater or equal to 'num_splits' attribute. Got: ",
dim_interval_at_axis,
" and ",
m_num_splits);
auto dim_interval_at_axis_min =
static_cast<int64_t>(dim_interval_at_axis.get_min_val() * (1.0f / m_num_splits));
auto dim_interval_at_axis_max = dim_interval_at_axis.get_max_val();
if (dim_interval_at_axis.has_upper_bound()) {
dim_interval_at_axis_max = static_cast<int64_t>(dim_interval_at_axis_max * (1.0f / m_num_splits));
}
each_output_shape[axis] = Dimension(dim_interval_at_axis_min, dim_interval_at_axis_max);
}
} else {
each_output_shape = ov::PartialShape::dynamic(data_ps.rank());
}
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0), get_input_partial_shape(1)};
std::vector<ov::PartialShape> output_shapes;
shape_infer(this, input_shapes, output_shapes);
set_output_size(m_num_splits);
for (size_t i = 0; i < m_num_splits; ++i) {
set_output_type(i, get_input_element_type(0), each_output_shape);
set_output_type(i, get_input_element_type(0), output_shapes[i]);
}
set_input_is_relevant_to_shape(0);

View File

@ -20,6 +20,7 @@
#include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/op/util/precision_sensitive_attribute.hpp"
#include "strided_slice_shape_inference.hpp"
using namespace std;
using namespace ngraph;
@ -132,24 +133,6 @@ void op::v1::StridedSlice::validate_and_infer_types() {
});
NODE_VALIDATION_CHECK(this, are_attr_sizes_eq, "All masks of StridedSlice must have the same size");
const auto& data_rank = get_input_partial_shape(0).rank();
const auto& begin_shape = get_input_partial_shape(1);
if (begin_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
begin_shape.rank().get_length() == 1,
"Begin input must be 1D (begin rank: ",
begin_shape.rank(),
").");
}
const auto& end_shape = get_input_partial_shape(2);
if (end_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
end_shape.rank().get_length() == 1,
"End input must be 1D (end rank: ",
end_shape.rank(),
").");
}
// Fill up strides input with default strides if not set by this point.
if (get_input_size() < 4) {
set_argument(3, calculate_default_strides(get_input_node_ptr(1)->output(0), get_input_node_ptr(2)->output(0)));
@ -159,26 +142,15 @@ void op::v1::StridedSlice::validate_and_infer_types() {
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
auto begin_const = get_constant_from_source(input_value(1));
auto end_const = get_constant_from_source(input_value(2));
auto strides = get_constant_from_source(input_value(3));
if (begin_const && end_const && strides) {
set_output_type(0,
get_input_element_type(0),
infer_slice_shape(this,
get_input_partial_shape(0),
begin_const->cast_vector<int64_t>(),
end_const->cast_vector<int64_t>(),
strides->cast_vector<int64_t>(),
convert_mask_to_axis_set(get_begin_mask()),
convert_mask_to_axis_set(get_end_mask()),
convert_mask_to_axis_set(get_new_axis_mask()),
convert_mask_to_axis_set(get_shrink_axis_mask()),
convert_mask_to_axis_set(get_ellipsis_mask())));
} else {
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic(data_rank));
std::vector<ov::PartialShape> input_shapes;
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
for (size_t input_idx = 0; input_idx < get_input_size(); ++input_idx) {
input_shapes.push_back(get_input_partial_shape(input_idx));
}
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}
AxisSet op::v1::StridedSlice::convert_mask_to_axis_set(const std::vector<int64_t>& mask) const {

View File

@ -5,6 +5,7 @@
#include "ngraph/op/topk.hpp"
#include <memory>
#include <topk_shape_inference.hpp>
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
@ -186,15 +187,6 @@ bool ngraph::op::v1::TopK::visit_attributes(AttributeVisitor& visitor) {
void op::v1::TopK::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v1_TopK_validate_and_infer_types);
const auto& input_partial_shape = get_input_partial_shape(0);
const auto input_rank = input_partial_shape.rank();
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_rank.get_length() > 0,
"Input rank must be greater than 0.");
const auto& k_partial_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this, k_partial_shape.rank().compatible(0), "The 'K' input must be a scalar.");
NODE_VALIDATION_CHECK(this,
m_index_element_type == element::i32 || m_index_element_type == element::i64,
@ -206,35 +198,15 @@ void op::v1::TopK::validate_and_infer_types() {
read_k_from_constant_node(input_value(1).get_node_shared_ptr(), get_input_element_type(1));
}
ov::PartialShape output_shape{input_partial_shape};
set_axis(get_input_partial_shape(0).rank(), get_provided_axis());
if (output_shape.rank().is_static()) {
m_normalized_axis = ngraph::normalize_axis(this, m_axis, output_shape.rank());
ov::PartialShape k_as_shape;
if (evaluate_as_partial_shape(input_value(1), k_as_shape)) {
if (k_as_shape.is_static()) {
output_shape[m_normalized_axis] = k_as_shape[0];
} else {
const auto in_min = output_shape[m_normalized_axis].get_min_length();
const auto in_max = output_shape[m_normalized_axis].get_max_length();
const auto k_min = k_as_shape[0].get_min_length();
const auto k_max = k_as_shape[0].get_max_length();
const auto lower = std::min<Dimension::value_type>(in_min, k_min);
const auto upper =
in_max < 0 ? Dimension::dynamic().get_max_length() : std::max<Dimension::value_type>(in_max, k_max);
output_shape[m_normalized_axis] = Dimension(lower, upper);
}
} else {
output_shape[m_normalized_axis] = Dimension(0, input_partial_shape[m_normalized_axis].get_max_length());
}
}
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}};
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0), get_input_partial_shape(1)};
shape_infer(this, input_shapes, output_shapes);
set_output_size(2);
set_output_type(0, get_input_element_type(0), output_shape);
set_output_type(1, m_index_element_type, output_shape);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
set_output_type(1, m_index_element_type, output_shapes[1]);
}
ov::Shape op::v1::TopK::compute_output_shape(const std::string& node_description,

View File

@ -4,6 +4,7 @@
#include "ngraph/op/util/embeddingbag_offsets_base.hpp"
#include "embeddingbag_offsets_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/op/constant.hpp"
@ -55,16 +56,6 @@ void ov::op::util::EmbeddingBagOffsetsBase::validate_and_infer_types() {
get_input_element_type(INDICES),
")");
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(INDICES).is_dynamic() || get_input_partial_shape(INDICES).to_shape().size() == 1,
"INDICES must be 1D");
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(OFFSETS).is_dynamic() || get_input_partial_shape(OFFSETS).to_shape().size() == 1,
"OFFSETS must be 1D");
if (get_input_size() >= 4) {
NODE_VALIDATION_CHECK(this,
get_input_element_type(DEFAULT_INDEX) == element::i64 ||
@ -78,10 +69,6 @@ void ov::op::util::EmbeddingBagOffsetsBase::validate_and_infer_types() {
") must match indices element type (",
get_input_element_type(INDICES),
")");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(DEFAULT_INDEX).compatible(PartialShape{}),
"DEFAULT_INDEX must be a scalar");
}
if (get_input_size() == 5) {
@ -92,31 +79,18 @@ void ov::op::util::EmbeddingBagOffsetsBase::validate_and_infer_types() {
") must match embedding table element type (",
get_input_element_type(EMB_TABLE),
")");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(PER_SAMPLE_WEIGHTS).is_dynamic() ||
get_input_partial_shape(PER_SAMPLE_WEIGHTS).to_shape().size() == 1,
"PER_SAMPLE_WEIGHTS must be 1D");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(INDICES).compatible(get_input_partial_shape(PER_SAMPLE_WEIGHTS)),
"INDICES and PER_SAMPLE_WEIGHTS shape must be same");
}
element::Type result_et = get_input_element_type(EMB_TABLE);
const PartialShape& emb_table_shape = get_input_partial_shape(EMB_TABLE);
const PartialShape& offsets_shape = get_input_partial_shape(OFFSETS);
std::vector<PartialShape> result_shapes = {PartialShape::dynamic()};
std::vector<PartialShape> input_shapes;
for (int i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
PartialShape result_shape;
if (emb_table_shape.rank().is_static()) {
result_shape = emb_table_shape;
result_shape[0] = offsets_shape.rank().is_static() ? offsets_shape[0] : Dimension::dynamic();
} else {
result_shape = PartialShape::dynamic();
}
shape_infer(this, input_shapes, result_shapes);
set_output_type(0, result_et, result_shape);
set_output_type(0, result_et, result_shapes[0]);
}
bool ov::op::util::EmbeddingBagOffsetsBase::visit_attributes(AttributeVisitor& visitor) {

View File

@ -9,6 +9,7 @@
#include "itt.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/validation_util.hpp"
#include "variadic_split_shape_inference.hpp"
using namespace std;
using namespace ngraph;
@ -33,81 +34,15 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types() {
set_input_is_relevant_to_value(1);
set_input_is_relevant_to_value(2);
auto split_lengths_pshape = get_input_partial_shape(2);
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0),
get_input_partial_shape(1),
get_input_partial_shape(2)};
std::vector<ov::PartialShape> output_shapes;
shape_infer(this, input_shapes, output_shapes);
if (split_lengths_pshape.is_static()) {
NODE_VALIDATION_CHECK(this,
split_lengths_pshape.rank().get_length() == 1,
"Split lengths should be a 1-D tensor. Got ",
split_lengths_pshape.rank(),
" instead.");
const auto num_outputs = split_lengths_pshape[0].get_length();
const auto data = input_value(0);
const auto axis_source = input_value(1);
const auto split_lengths_source = input_value(2);
const auto data_shape = data.get_partial_shape();
const auto& data_type = data.get_element_type();
set_output_size(num_outputs);
const auto& axis_input_constant = get_constant_from_source(axis_source);
const auto& split_lengths_constant = get_constant_from_source(split_lengths_source);
if (data_shape.rank().is_static() && axis_input_constant && split_lengths_constant) {
const auto axis_val = axis_input_constant->cast_vector<int64_t>()[0];
// Adjust split axis in case of negatives
const int64_t axis = ngraph::normalize_axis(this, axis_val, data_shape.rank());
auto split_lengths = split_lengths_constant->cast_vector<int64_t>();
// Adjust split lengths in case of negatives
int64_t sum_of_splits = 0;
int64_t negative_one = -1;
for (size_t i = 0; i < split_lengths.size(); i++) {
NODE_VALIDATION_CHECK(this,
split_lengths[i] >= -1,
"Invalid value ",
split_lengths[i],
" in split lengths input. Should be >= -1.");
if (split_lengths[i] == -1) {
NODE_VALIDATION_CHECK(this,
negative_one == -1,
"Cannot infer split with multiple -1 values at ",
negative_one,
" and ",
i);
negative_one = i;
} else {
sum_of_splits += split_lengths[i];
}
}
const auto data_shape_dims = vector<Dimension>{data.get_partial_shape()};
const auto dimension_at_axis = data_shape_dims.at(axis);
if (negative_one >= 0 && dimension_at_axis.is_static()) {
split_lengths[negative_one] = dimension_at_axis.get_length() - sum_of_splits;
sum_of_splits += split_lengths[negative_one];
}
if (data_shape[axis].is_static()) {
NODE_VALIDATION_CHECK(this,
sum_of_splits == data_shape[axis].get_length(),
"Total length of splits: ",
sum_of_splits,
" must match the length of the chosen axis: ",
data_shape[axis]);
}
for (int64_t output{0}; output < num_outputs; ++output) {
const auto output_split_dim =
split_lengths.at(output) == -1 ? Dimension::dynamic() : split_lengths.at(output);
auto tmp_shape = data_shape_dims;
tmp_shape.at(axis) = output_split_dim;
set_output_type(output, data_type, ov::PartialShape{tmp_shape});
}
} else {
for (int64_t output{0}; output < num_outputs; ++output) {
set_output_type(output, data_type, ov::PartialShape::dynamic());
}
}
const auto& data_type = get_input_element_type(0);
for (size_t i = 0; i < output_shapes.size(); ++i) {
set_output_type(i, data_type, output_shapes[i]);
}
}
@ -143,37 +78,30 @@ bool op::v1::VariadicSplit::evaluate_variadic_split(const HostTensorVector& inpu
const auto& axis_tensor = inputs[1];
const auto& split_lengths_tensor = inputs[2];
NGRAPH_CHECK(axis_tensor->get_element_type().is_integral_number(), "axis element type is not integral data type");
NGRAPH_CHECK(split_lengths_tensor->get_element_type().is_integral_number(),
"split_lengths element type is not integral data type");
int64_t axis = host_tensor_2_vector<int64_t>(axis_tensor)[0];
axis = ngraph::normalize_axis(this, axis, data_tensor->get_partial_shape().rank());
NGRAPH_CHECK(split_lengths_tensor->get_element_type().is_integral_number(),
"axis element type is not integral data type");
std::vector<int64_t> split_lengths = host_tensor_2_vector<int64_t>(split_lengths_tensor);
std::vector<ov::PartialShape> input_shapes = {data_tensor->get_partial_shape(),
axis_tensor->get_partial_shape(),
split_lengths_tensor->get_partial_shape()};
std::vector<ov::PartialShape> output_shapes;
shape_infer(this, input_shapes, output_shapes, {{1, axis_tensor}, {2, split_lengths_tensor}});
const auto data_shape = data_tensor->get_shape();
const auto neg_one = std::find(std::begin(split_lengths), std::end(split_lengths), -1);
if (neg_one != std::end(split_lengths)) // negative length set
{
const auto sum_of_known_splits = std::accumulate(std::begin(split_lengths), std::end(split_lengths), 0) + 1;
split_lengths[std::distance(std::begin(split_lengths), neg_one)] = data_shape[axis] - sum_of_known_splits;
}
ov::Shape output_shape = data_shape;
std::vector<size_t> lower_bounds(data_shape.size(), 0);
std::vector<size_t> upper_bounds = data_shape;
upper_bounds.at(axis) = split_lengths[0];
upper_bounds[axis] = 0;
size_t split_pos = 0;
for (const auto& output : outputs) {
output_shape.at(axis) = split_lengths[split_pos++];
ov::Shape output_shape = output_shapes[split_pos++].get_shape();
upper_bounds[axis] += output_shape[axis];
output->set_shape(output_shape);
variadic_split::evaluate(data_tensor, output, lower_bounds, upper_bounds);
lower_bounds.at(axis) = upper_bounds.at(axis);
if (split_pos < split_lengths.size())
upper_bounds.at(axis) += split_lengths[split_pos];
}
return true;

View File

@ -0,0 +1,18 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <bucketize_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace std;
TEST(StaticShapeInferenceTest, BucketizeV3) {
auto data = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1, -1, -1});
auto buckets = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1});
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets);
check_static_shape(bucketize.get(), {ov::StaticShape{2, 3, 2}, ov::StaticShape{4}}, {ov::StaticShape{2, 3, 2}});
}

View File

@ -0,0 +1,50 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <einsum_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
TEST(StaticShapeInferenceTest, Einsum1) {
auto I1 = std::make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto I2 = std::make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto O = std::make_shared<op::v7::Einsum>(OutputVector{I1, I2}, "i,i->");
check_static_shape(O.get(), {ov::StaticShape{3}, ov::StaticShape{3}}, {ov::StaticShape{}});
}
TEST(StaticShapeInferenceTest, Einsum2) {
auto I1 = std::make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto I2 = std::make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto O = std::make_shared<op::v7::Einsum>(OutputVector{I1, I2}, "ab,bc->ac");
check_static_shape(O.get(), {ov::StaticShape{2, 3}, ov::StaticShape{3, 4}}, {ov::StaticShape{2, 4}});
}
TEST(StaticShapeInferenceTest, Einsum3) {
auto I1 = std::make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto O = std::make_shared<op::v7::Einsum>(OutputVector{I1}, "kii->k");
check_static_shape(O.get(), {ov::StaticShape{2, 3, 3}}, {ov::StaticShape{2}});
}
TEST(StaticShapeInferenceTest, Einsum4) {
auto I1 = std::make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto O = std::make_shared<op::v7::Einsum>(OutputVector{I1}, "ijk->kij");
check_static_shape(O.get(), {ov::StaticShape{1, 2, 3}}, {ov::StaticShape{3, 1, 2}});
}
TEST(StaticShapeInferenceTest, Einsum5) {
auto I1 = std::make_shared<op::v0::Parameter>(element::i32, ov::PartialShape::dynamic());
auto I2 = std::make_shared<op::v0::Parameter>(element::i32, ov::PartialShape::dynamic());
auto I3 = std::make_shared<op::v0::Parameter>(element::i32, ov::PartialShape::dynamic());
auto O = std::make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, "ab,bcd,bc->ca");
check_static_shape(O.get(),
{ov::StaticShape{2, 5}, ov::StaticShape{5, 3, 6}, ov::StaticShape{5, 3}},
{ov::StaticShape{3, 2}});
}

View File

@ -0,0 +1,34 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <embedding_segments_sum_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace std;
TEST(StaticShapeInferenceTest, EmbeddingSegmentsSum) {
auto emb_table = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1, -1});
auto indices = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape{-1});
auto segment_ids = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape{-1});
auto num_segments = op::v0::Constant::create(element::i64, ov::Shape{}, {3});
auto default_index = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape{});
auto per_sample_weights = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1});
auto ess = make_shared<op::v3::EmbeddingSegmentsSum>(emb_table,
indices,
segment_ids,
num_segments,
default_index,
per_sample_weights);
check_static_shape(
ess.get(),
{StaticShape{5, 2}, StaticShape{4}, StaticShape{4}, StaticShape{}, StaticShape{}, StaticShape{4}},
{StaticShape{3, 2}});
check_static_shape(ess.get(),
{StaticShape{5, 2}, StaticShape{4}, StaticShape{4}, 8, StaticShape{}, StaticShape{4}},
{StaticShape{8, 2}});
}

View File

@ -0,0 +1,26 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <embeddingbag_offsets_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace std;
TEST(StaticShapeInferenceTest, EmbeddingBagOffsetsSumV3) {
auto emb_table = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto indices = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
auto offsets = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
auto default_index = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
auto per_sample_weights = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto ebos =
make_shared<op::v3::EmbeddingBagOffsetsSum>(emb_table, indices, offsets, default_index, per_sample_weights);
check_static_shape(
ebos.get(),
{ov::StaticShape{5, 2}, ov::StaticShape{4}, ov::StaticShape{3}, ov::StaticShape{}, ov::StaticShape{4}},
{ov::StaticShape{3, 2}});
}

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <convolution_shape_inference.hpp>
#include <experimental_detectron_roi_feature_shape_inference.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include "utils/shape_inference/shape_inference.hpp"
#include "utils/shape_inference/static_shape.hpp"
using namespace ov;
TEST(StaticShapeInferenceTest, ExperimentalDetectronROIFeatureExtractor) {
op::v6::ExperimentalDetectronROIFeatureExtractor::Attributes attrs;
attrs.aligned = false;
attrs.output_size = 14;
attrs.sampling_ratio = 2;
attrs.pyramid_scales = {4, 8, 16, 32};
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto pyramid_layer0 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, -1, -1, -1});
auto pyramid_layer1 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, -1, -1, -1});
auto pyramid_layer2 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, -1, -1, -1});
auto pyramid_layer3 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, -1, -1, -1});
auto roi = std::make_shared<op::v6::ExperimentalDetectronROIFeatureExtractor>(
NodeVector{input, pyramid_layer0, pyramid_layer1, pyramid_layer2, pyramid_layer3},
attrs);
std::vector<StaticShape> input_shapes = {StaticShape{1000, 4},
StaticShape{1, 256, 200, 336},
StaticShape{1, 256, 100, 168},
StaticShape{1, 256, 50, 84},
StaticShape{1, 256, 25, 42}};
std::vector<StaticShape> output_shapes = {StaticShape{}, StaticShape{}};
shape_inference(roi.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], (StaticShape{1000, 256, 14, 14}));
EXPECT_EQ(output_shapes[1], (StaticShape{1000, 4}));
}

View File

@ -0,0 +1,26 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <pad_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
TEST(StaticShapeInferenceTest, Padv1) {
const auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto pads_begin = ov::op::v0::Constant::create(element::i64, ov::Shape{4}, {3, 2, 1, 0});
const auto pads_end = ov::op::v0::Constant::create(element::i64, ov::Shape{4}, {0, 1, 2, 3});
const auto pad_val = ov::op::v0::Constant::create(element::f32, ov::Shape{}, {2112});
const auto pad = std::make_shared<ov::op::v1::Pad>(data, pads_begin, pads_end, pad_val, op::PadMode::CONSTANT);
check_static_shape(pad.get(),
{ov::StaticShape{3, 6, 5, 5},
ov::StaticShape{4},
ov::StaticShape{4},
ov::StaticShape()},
{ov::StaticShape({6, 9, 8, 8})});
}

View File

@ -0,0 +1,36 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <range_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace std;
TEST(StaticShapeInferenceTest, Rangev4_i32) {
auto start = make_shared<op::v0::Parameter>(element::i32, ov::PartialShape{});
auto stop = make_shared<op::v0::Parameter>(element::i32, ov::PartialShape{});
auto step = make_shared<op::v0::Parameter>(element::i32, ov::PartialShape{});
auto range = make_shared<op::v4::Range>(start, stop, step, element::i32);
check_static_shape(range.get(), {2, 0, -2}, {ov::StaticShape{1}});
check_static_shape(range.get(), {2, 0, -1}, {ov::StaticShape{2}});
check_static_shape(range.get(), {-19, 19, 1}, {ov::StaticShape{38}});
check_static_shape(range.get(), {-19, 19, 3}, {ov::StaticShape{13}});
check_static_shape(range.get(), {20, -19, 1}, {ov::StaticShape{0}});
}
TEST(StaticShapeInferenceTest, Rangev4_f32) {
auto start = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{});
auto stop = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{});
auto step = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{});
auto range = make_shared<op::v4::Range>(start, stop, step, element::f32);
check_static_shape(range.get(), {0., 1., 0.25}, {ov::StaticShape{4}});
check_static_shape(range.get(), {-1., 1., 0.25}, {ov::StaticShape{8}});
check_static_shape(range.get(), {-1., 0.875, 0.25}, {ov::StaticShape{8}});
}

View File

@ -0,0 +1,27 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <region_yolo_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace std;
TEST(StaticShapeInferenceTest, RegionYoloV0) {
auto inputs = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1, -1, -1, -1});
auto op = make_shared<op::v0::RegionYolo>(inputs, 0, 0, 0, true, std::vector<int64_t>{}, 0, 1);
check_static_shape(op.get(), {ov::StaticShape{1, 125, 13, 13}}, {ov::StaticShape{1 * 125, 13, 13}});
}
TEST(StaticShapeInferenceTest, RegionYoloV0Dynamic) {
auto inputs = make_shared<op::v0::Parameter>(element::f32,
ov::PartialShape{{1, 11}, {2, 12}, ov::Dimension::dynamic(), {4, 14}});
auto op = make_shared<op::v0::RegionYolo>(inputs, 4, 80, 5, true, std::vector<int64_t>{}, 1, 3);
EXPECT_EQ(op->get_output_partial_shape(0), ov::PartialShape({{1, 11}, ov::Dimension::dynamic()}));
check_static_shape(op.get(), {ov::StaticShape{10, 125, 13, 13}}, {ov::StaticShape{10, 125 * 13 * 13}});
}

View File

@ -0,0 +1,19 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <reorg_yolo_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace std;
TEST(StaticShapeInferenceTest, ReorgYoloV0) {
size_t stride = 2;
auto data_param = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1, -1, -1, -1});
auto op = make_shared<op::v0::ReorgYolo>(data_param, stride);
check_static_shape(op.get(), {ov::StaticShape{1, 64, 26, 26}}, {ov::StaticShape{1, 256, 13, 13}});
}

View File

@ -0,0 +1,41 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <split_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
static std::shared_ptr<op::v1::Split> build_split(PartialShape data_shape,
std::initializer_list<int64_t> axis_value,
size_t num_splits) {
std::shared_ptr<ov::Node> axis;
const auto data = std::make_shared<op::v0::Parameter>(element::f32, data_shape);
if (axis_value.size())
axis = op::v0::Constant::create(element::i64, ov::Shape{}, {*axis_value.begin()});
else
axis = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape{});
return std::make_shared<op::v1::Split>(data, axis, num_splits);
}
TEST(StaticShapeInferenceTest, SplitV1) {
const auto op = build_split(PartialShape{-1, -1, -1}, {}, 3);
check_static_shape(op.get(), {StaticShape{2, 3, 4}, 1}, {{2, 1, 4}, {2, 1, 4}, {2, 1, 4}});
}
TEST(StaticShapeInferenceTest, SplitV1_Dynamic) {
check_output_shape(build_split(PartialShape({2, 8, 4}), {}, 4).get(),
{ov::PartialShape::dynamic(ov::Rank(3)),
ov::PartialShape::dynamic(ov::Rank(3)),
ov::PartialShape::dynamic(ov::Rank(3)),
ov::PartialShape::dynamic(ov::Rank(3))});
}
TEST(StaticShapeInferenceTest, SplitV1_StaticWithConstMap) {
check_static_shape(build_split(PartialShape({-1, -1, -1}), {}, 4).get(),
{StaticShape{2, 8, 4}, 2},
{{2, 8, 1}, {2, 8, 1}, {2, 8, 1}, {2, 8, 1}});
}

View File

@ -0,0 +1,97 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <strided_slice_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
TEST(StaticShapeInferenceTest, StridedSlice1) {
auto data = std::make_shared<op::v0::Parameter>(ngraph::element::f32, ov::PartialShape::dynamic());
auto begin = op::v0::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {100});
auto end = op::v0::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {-100});
auto stride = op::v0::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {-1});
std::vector<int64_t> begin_mask = {0, 0, 0, 0};
std::vector<int64_t> end_mask = {0, 0, 0, 0};
auto ss = std::make_shared<op::v1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
check_static_shape(ss.get(),
{ov::StaticShape{3, 4, 5}, ov::StaticShape{3}, ov::StaticShape{3}, ov::StaticShape{3}},
{ov::StaticShape{3, 4, 5}});
}
TEST(StaticShapeInferenceTest, StridedSlice2) {
auto data = std::make_shared<op::v0::Parameter>(ngraph::element::f32, ov::PartialShape::dynamic());
auto begin = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
auto end = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
auto stride = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
std::vector<int64_t> begin_mask(3, 0);
std::vector<int64_t> end_mask(3, 0);
auto ss = std::make_shared<op::v1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
check_static_shape(ss.get(),
{ov::StaticShape{3, 2, 3}, {1, 0, 0}, {2, 1, 3}, {1, 1, 1}},
{ov::StaticShape{1, 1, 3}});
check_static_shape(ss.get(),
{ov::StaticShape{3, 2, 3}, {1, 0, 0}, {2, 2, 3}, {1, 1, 1}},
{ov::StaticShape{1, 2, 3}});
check_static_shape(ss.get(),
{ov::StaticShape{3, 2, 3}, {2, 0, 0}, {3, 2, 3}, {1, 1, 2}},
{ov::StaticShape{1, 2, 2}});
}
TEST(StaticShapeInferenceTest, StridedSlice3) {
auto data = std::make_shared<op::v0::Parameter>(ngraph::element::f32, ov::PartialShape::dynamic());
auto begin = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
auto end = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
auto stride = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
std::vector<int64_t> begin_mask{0, 1, 1};
std::vector<int64_t> end_mask{1, 1, 1};
auto ss = std::make_shared<op::v1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
check_static_shape(ss.get(),
{ov::StaticShape{3, 2, 3}, {1, 0, 0}, {0, 0, 0}, {1, 1, 1}},
{ov::StaticShape{2, 2, 3}});
}
TEST(StaticShapeInferenceTest, StridedSlice4) {
auto data = std::make_shared<op::v0::Parameter>(ngraph::element::f32, ov::PartialShape::dynamic());
auto begin = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
auto end = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
auto stride = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
std::vector<int64_t> begin_mask{1, 0, 1};
std::vector<int64_t> end_mask{0, 1, 1};
auto ss = std::make_shared<op::v1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
check_static_shape(ss.get(),
{ov::StaticShape{3, 2, 3}, {0, 1, 0}, {2, 0, 0}, {1, 1, 2}},
{ov::StaticShape{2, 1, 2}});
}
TEST(StaticShapeInferenceTest, StridedSlice5) {
auto data = std::make_shared<op::v0::Parameter>(ngraph::element::f32, ov::PartialShape::dynamic());
auto begin = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
auto end = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
auto stride = std::make_shared<op::v0::Parameter>(ngraph::element::i64, ngraph::Shape{3});
std::vector<int64_t> begin_mask{0, 1, 1};
std::vector<int64_t> end_mask{0, 1, 1};
auto ss = std::make_shared<op::v1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
check_static_shape(ss.get(),
{ov::StaticShape{3, 2, 3}, {0, 0, 0}, {1, 0, 0}, {1, 1, -1}},
{ov::StaticShape{1, 2, 3}});
}

View File

@ -0,0 +1,43 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <topk_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
static std::shared_ptr<op::v3::TopK> build_topk(PartialShape data_shape = PartialShape::dynamic(),
int64_t axis = 1,
int k_value = -1) {
std::shared_ptr<ov::Node> k;
const auto data = std::make_shared<op::v0::Parameter>(element::f32, data_shape);
if (k_value >= 0)
k = op::v0::Constant::create(element::i64, ov::Shape{}, {2});
else
k = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape{});
return std::make_shared<op::v3::TopK>(data, k, axis, "max", "value");
}
TEST(StaticShapeInferenceTest, TopKv3) {
const auto topk = build_topk(PartialShape::dynamic(), 1, 2);
check_static_shape(topk.get(),
{StaticShape{1, 10, 100}, StaticShape{}},
{StaticShape({1, 2, 100}), StaticShape({1, 2, 100})});
}
TEST(StaticShapeInferenceTest, TopKv3_StaticNoConstMap) {
const auto topk = build_topk();
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 10, 100}, StaticShape{}};
std::vector<StaticShape> static_output_shapes = {StaticShape{}, StaticShape{}};
EXPECT_THROW(shape_inference(topk.get(), static_input_shapes, static_output_shapes), NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, TopKv3_StaticWithConstMap) {
const auto topk = build_topk();
check_static_shape(topk.get(), {StaticShape{1, 10, 100}, 2}, {StaticShape{1, 2, 100}, StaticShape{1, 2, 100}});
}

View File

@ -0,0 +1,84 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
#pragma once
struct TestTensor {
std::shared_ptr<ngraph::runtime::HostTensor> tensor;
ov::StaticShape static_shape;
template <typename T>
TestTensor(std::initializer_list<T> values) : TestTensor(ov::StaticShape({values.size()}), values) {}
template <typename T>
TestTensor(T scalar) : TestTensor(ov::StaticShape({}), {scalar}) {}
TestTensor(ov::StaticShape shape) : static_shape(shape) {}
template <typename T>
TestTensor(ov::StaticShape shape, std::initializer_list<T> values) {
static_shape = shape;
ov::Shape s;
for (auto dim : shape)
s.push_back(dim.get_length());
if (values.size() > 0) {
tensor = std::make_shared<ngraph::runtime::HostTensor>(ov::element::from<T>(), s);
T* ptr = tensor->get_data_ptr<T>();
int i = 0;
for (auto& v : values)
ptr[i++] = v;
}
}
};
// TestTensor can be constructed from initializer_list<T>/int64_t/Shape/Shape+initializer_list
// so each element of inputs can be:
// {1,2,3,4} tensor of shape [4] and values (1,2,3,4)
// 2 tensor of scalar with value 2
// Shape{2,2} tensor of shape [2,2] and value unknown
// {Shape{2,2}, {1,2,3,4}} tensor of shape [2,2] and values (1,2,3,4)
static void check_static_shape(ov::Node* op,
std::initializer_list<TestTensor> inputs,
std::initializer_list<ov::StaticShape> expect_shapes) {
std::vector<ov::StaticShape> output_shapes;
std::vector<ov::StaticShape> input_shapes;
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constData;
int index = 0;
std::for_each(inputs.begin(), inputs.end(), [&](TestTensor t) {
input_shapes.push_back(t.static_shape);
if (t.tensor)
constData[index] = t.tensor;
index++;
});
output_shapes.resize(expect_shapes.size(), ov::StaticShape{});
shape_inference(op, input_shapes, output_shapes, constData);
EXPECT_EQ(output_shapes.size(), expect_shapes.size());
int id = 0;
for (auto& shape : expect_shapes) {
EXPECT_EQ(output_shapes[id], shape);
id++;
}
}
static void check_output_shape(ov::Node* op, std::initializer_list<ov::PartialShape> expect_shapes) {
int id = 0;
EXPECT_EQ(op->outputs().size(), expect_shapes.size());
for (auto& shape : expect_shapes) {
EXPECT_EQ(op->get_output_partial_shape(id), shape);
id++;
}
}

View File

@ -0,0 +1,50 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <variadic_split_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
static std::shared_ptr<op::v1::VariadicSplit> build_variadic_split(PartialShape data_shape,
std::initializer_list<int64_t> axis_value,
std::initializer_list<int64_t> splits) {
std::shared_ptr<ov::Node> axis;
std::shared_ptr<ov::Node> splits_len;
const auto data = std::make_shared<op::v0::Parameter>(element::i32, data_shape);
if (axis_value.size())
axis = op::v0::Constant::create(element::i64, ov::Shape{}, {*axis_value.begin()});
else
axis = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic(ov::Rank(0)));
if (splits.size())
splits_len = op::v0::Constant::create(element::i64, ov::Shape{splits.size()}, splits);
else
splits_len = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic(ov::Rank(1)));
return std::make_shared<op::v1::VariadicSplit>(data, axis, splits_len);
}
TEST(StaticShapeInferenceTest, VariadicSplitV1) {
const auto split = build_variadic_split(ov::PartialShape::dynamic(), {}, {});
check_static_shape(split.get(),
{StaticShape{12, 6}, {-2}, {7, -1, 2}},
{{7, 6}, {3, 6}, {2, 6}});
check_static_shape(split.get(),
{StaticShape{12, 6}, {-2}, {-1, 7, 2}},
{{3, 6}, {7, 6}, {2, 6}});
check_static_shape(split.get(),
{StaticShape{12, 1, 6}, {2}, {3, 1, 2}},
{{12, 1, 3}, {12, 1, 1}, {12, 1, 2}});
check_static_shape(split.get(), {StaticShape{12, 6}, {1}, {6, 0}}, {{12, 6}, {12, 0}});
}
TEST(StaticShapeInferenceTest, VariadicSplitV1_StaticWithConstMap) {
check_static_shape(build_variadic_split(ov::PartialShape{-1, -1}, {}, {}).get(),
{StaticShape{12, 6}, {-2}, {7, -1, 2}},
{{7, 6}, {3, 6}, {2, 6}});
}