Shape inference 2nd batch (#8781)

This commit is contained in:
Evgenya Stepyreva 2021-12-02 10:47:41 +03:00 committed by GitHub
parent e23e6f3628
commit 84ebe77f62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1052 additions and 859 deletions

View File

@ -5,13 +5,18 @@
#include <openvino/core/node.hpp>
#include <ngraph/runtime/host_tensor.hpp>
#include <openvino/opsets/opset1.hpp>
#include <openvino/opsets/opset2.hpp>
#include <openvino/opsets/opset4.hpp>
#include <openvino/opsets/opset5.hpp>
#include <openvino/opsets/opset6.hpp>
#include <openvino/opsets/opset8.hpp>
#include "static_shape.hpp"
#include "utils.hpp"
#include "shape_inference.hpp"
#include "convolution_shape_inference.hpp"
#include "reduce_shape_inference.hpp"
#include "shape_nodes.hpp"
#include "fake_quantize.hpp"
#include "experimental_detectron_detection_output_shape_inference.hpp"
@ -24,10 +29,45 @@ void shape_inference(ov::Node* op,
bool status = resolve_auto_pad_for_shape(node, pads_begin, pads_end, input_shapes, 2, 2);
OPENVINO_ASSERT(status, "Convolution shape inference doesn't have enough information to calculate static shapes");
shape_infer(node, pads_begin, pads_end, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset8::GroupConvolution>(op)) {
ov::CoordinateDiff pads_begin, pads_end;
bool status = resolve_auto_pad_for_shape(node, pads_begin, pads_end, input_shapes, 2, 3);
OPENVINO_ASSERT(status, "GroupConvolution shape inference doesn't have enough information to calculate static shapes");
shape_infer(node, pads_begin, pads_end, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset8::ConvolutionBackpropData>(op)) {
ov::CoordinateDiff pads_begin, pads_end;
ov::StaticShape output_shape_input;
if (node->get_input_size() == 3)
get_data_as_shape<ov::StaticShape>(2, op, output_shape_input, constant_data);
bool status = resolve_auto_pad_for_shape_back_prop(node, pads_begin, pads_end, input_shapes, output_shape_input, 2, 2);
OPENVINO_ASSERT(status, "ConvolutionBackpropData shape inference doesn't have enough information to calculate static shapes");
shape_infer(node, pads_begin, pads_end, output_shape_input, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset8::GroupConvolutionBackpropData>(op)) {
ov::CoordinateDiff pads_begin, pads_end;
ov::StaticShape output_shape_input;
if (node->get_input_size() == 3)
get_data_as_shape<ov::StaticShape>(2, op, output_shape_input, constant_data);
bool status = resolve_auto_pad_for_shape_back_prop(node, pads_begin, pads_end, input_shapes, output_shape_input, 2, 3);
OPENVINO_ASSERT(status, "GroupConvolutionBackpropData shape inference doesn't have enough information to calculate static shapes");
shape_infer(node, pads_begin, pads_end, output_shape_input, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::op::util::ArithmeticReductionKeepDims>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::op::util::LogicalReductionKeepDims>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (ov::is_type<ov::op::util::UnaryElementwiseArithmetic>(op) ||
ov::is_type<ov::opset1::Convert>(op) || ov::is_type<ov::opset1::Clamp>(op) ||
ov::is_type<ov::opset1::GRN>(op) || ov::is_type<ov::opset1::LRN>(op) ||
ov::is_type<ov::opset1::LogicalNot>(op) || ov::is_type<ov::opset4::Mish>(op) ||
ov::is_type<ov::opset2::MVN>(op) || ov::is_type<ov::opset6::MVN>(op) ||
ov::is_type<ov::opset1::PRelu>(op) || ov::is_type<ov::opset1::Relu>(op) ||
ov::is_type<ov::opset4::Swish>(op) || ov::is_type<ov::opset1::Softmax>(op) ||
ov::is_type<ov::opset1::Elu>(op) || ov::is_type<ov::opset5::Round>(op)) {
copy_shape_infer(node, input_shapes, output_shapes);
} else if (ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(op) ||
ov::is_type<ov::op::util::BinaryElementwiseComparison>(op) || ov::is_type<ov::op::util::BinaryElementwiseLogical>(op)) {
eltwise_shape_infer(op, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset1::FakeQuantize>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset1::Reshape>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset1::Squeeze>(op)) {

View File

@ -6,8 +6,10 @@
#include <openvino/core/coordinate_diff.hpp>
#include <openvino/op/convolution.hpp>
#include <openvino/op/group_conv.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/op/relu.hpp>
#include <openvino/op/constant.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
@ -32,6 +34,89 @@ TEST(StaticShapeInferenceTest, ConvolutionTest) {
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 7, 5, 5}));
}
TEST(StaticShapeInferenceTest, GroupConvolutionTest) {
Strides strides{1, 1};
CoordinateDiff pads_begin{0, 0};
CoordinateDiff pads_end{0, 0};
Strides dilations{1, 1};
const auto auto_pad = op::PadType::SAME_LOWER;
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto filters = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1, -1});
auto conv =
std::make_shared<op::v1::GroupConvolution>(data, filters, strides, pads_begin, pads_end, dilations, auto_pad);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 4, 5, 5}, StaticShape{2, 1, 2, 3, 3}},
static_output_shapes = {StaticShape{}};
shape_inference(conv.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 2, 5, 5}));
}
TEST(StaticShapeInferenceTest, ConvolutionBackPropDataTest) {
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto filters = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
const Strides strides{2, 2};
const Strides dilations{1, 1};
const CoordinateDiff padding_begin{1, 1};
const CoordinateDiff padding_end{1, 1};
const CoordinateDiff output_padding{1, 1};
const op::PadType auto_pad = op::PadType::SAME_LOWER;
auto output_shape = std::make_shared<op::v0::Constant>(
ov::element::i64, ov::Shape{2}, std::vector<int64_t>({3, 3}));
auto conv = std::make_shared<op::v1::ConvolutionBackpropData>(data,
filters,
output_shape,
strides,
padding_begin,
padding_end,
dilations,
auto_pad,
output_padding);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 16, 2, 2}, StaticShape{16, 6, 3, 3}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
shape_inference(conv.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 6, 3, 3}));
}
TEST(StaticShapeInferenceTest, GroupConvolutionBackPropDataTest) {
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto filters = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1, -1});
const Strides strides{2, 2};
const Strides dilations{1, 1};
const CoordinateDiff padding_begin{1, 1};
const CoordinateDiff padding_end{1, 1};
const CoordinateDiff output_padding{1, 1};
const op::PadType auto_pad = op::PadType::SAME_LOWER;
auto output_shape = std::make_shared<op::v0::Constant>(
ov::element::i64, ov::Shape{2}, std::vector<int64_t>({3, 3}));
auto conv = std::make_shared<op::v1::GroupConvolutionBackpropData>(data,
filters,
output_shape,
strides,
padding_begin,
padding_end,
dilations,
auto_pad,
output_padding);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 16, 2, 2}, StaticShape{4, 4, 6, 3, 3}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
shape_inference(conv.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 24, 3, 3}));
}
#if 0
TEST(StaticShapeInferenceTest, ConvolutionTimeTest) {
Strides strides{1, 1};

View File

@ -0,0 +1,61 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <openvino/core/coordinate_diff.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/op/add.hpp>
#include <openvino/op/relu.hpp>
#include <openvino/op/fake_quantize.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
TEST(StaticShapeInferenceTest, UnaryEltwiseTest) {
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto node = std::make_shared<op::v0::Relu>(data);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 6, 5, 5}},
static_output_shapes = {StaticShape{}};
shape_inference(node.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 6, 5, 5}));
}
TEST(StaticShapeInferenceTest, BinaryEltwiseTest) {
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto data_1 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto node = std::make_shared<op::v1::Add>(data, data_1);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 6, 1, 5}, StaticShape{1, 3, 5}},
static_output_shapes = {StaticShape{}};
shape_inference(node.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 6, 3, 5}));
}
TEST(StaticShapeInferenceTest, FakeQuantizeTest) {
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto il = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto ih = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto ol = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto oh = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto node = std::make_shared<op::v0::FakeQuantize>(data, il, ih, ol, oh, 256);
std::vector<StaticShape> static_input_shapes = {
StaticShape{3, 6, 3, 5},
StaticShape{1, 3, 1},
StaticShape{1},
StaticShape{5},
StaticShape{1, 1, 1, 1}
},
static_output_shapes = {StaticShape{}};
shape_inference(node.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 6, 3, 5}));
}

View File

@ -101,19 +101,21 @@ protected:
int64_t m_num_spatial = -1;
private:
friend int64_t calculate_num_spatial(const Convolution* op,
template <class ConvType>
friend int64_t calculate_num_spatial(const ConvType* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
friend void update_and_validate_attributes(Convolution* op);
template <class ConvType>
friend void update_and_validate_attributes(ConvType* op);
template <class T>
friend bool resolve_auto_pad_for_shape(const Convolution* op,
template <class ConvType, class ShapeType>
friend bool resolve_auto_pad_for_shape(const ConvType* op,
CoordinateDiff& pads_begin,
CoordinateDiff& pads_end,
const std::vector<T>& input_shapes,
const std::vector<ShapeType>& input_shapes,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class T>
@ -266,6 +268,47 @@ protected:
CoordinateDiff m_pads_end;
PadType m_auto_pad;
CoordinateDiff m_output_padding;
int64_t m_num_spatial = -1;
private:
template <class ConvType>
friend int64_t calculate_num_spatial(const ConvType* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const PartialShape& output_shapes_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class ConvType>
friend int64_t calculate_num_spatial(const ConvType* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class ConvType>
friend void update_and_validate_attributes(ConvType* op);
template <class ConvType>
friend void update_and_validate_attributes_back_prop(ConvType* op);
template <class ConvType, class ShapeType>
friend bool resolve_auto_pad_for_shape_back_prop(const ConvType* op,
CoordinateDiff& pads_begin,
CoordinateDiff& pads_end,
const std::vector<ShapeType>& input_shapes,
ShapeType& output_spatial_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class T>
friend void shape_infer(const ConvolutionBackpropData* op,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const T& output_shape_from_input,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes);
};
} // namespace v1
} // namespace op

View File

@ -96,6 +96,33 @@ protected:
CoordinateDiff m_pads_begin;
CoordinateDiff m_pads_end;
PadType m_auto_pad;
int64_t m_num_spatial = -1;
private:
template <class ConvType>
friend int64_t calculate_num_spatial(const ConvType* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class ConvType>
friend void update_and_validate_attributes(ConvType* op);
template <class ConvType, class ShapeType>
friend bool resolve_auto_pad_for_shape(const ConvType* op,
CoordinateDiff& pads_begin,
CoordinateDiff& pads_end,
const std::vector<ShapeType>& input_shapes,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class T>
friend void shape_infer(const GroupConvolution* op,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes);
};
/// \brief Data batch backprop for batched convolution operation.
@ -269,6 +296,47 @@ protected:
CoordinateDiff m_pads_end;
PadType m_auto_pad;
CoordinateDiff m_output_padding;
int64_t m_num_spatial = -1;
private:
template <class ConvType>
friend int64_t calculate_num_spatial(const ConvType* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const PartialShape& output_shapes_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class ConvType>
friend int64_t calculate_num_spatial(const ConvType* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class ConvType>
friend void update_and_validate_attributes(ConvType* op);
template <class ConvType>
friend void update_and_validate_attributes_back_prop(ConvType* op);
template <class ConvType, class ShapeType>
friend bool resolve_auto_pad_for_shape_back_prop(const ConvType* op,
CoordinateDiff& pads_begin,
CoordinateDiff& pads_end,
const std::vector<ShapeType>& input_shapes,
ShapeType& output_spatial_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
template <class T>
friend void shape_infer(const GroupConvolutionBackpropData* op,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const T& output_shape_from_input,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes);
};
} // namespace v1
} // namespace op

View File

@ -37,9 +37,6 @@ namespace util {
// clang-format on
class OPENVINO_API BinaryElementwiseLogical : public Op {
protected:
OPENVINO_OP("BinaryElementwiseLogical", "util");
BWDCMP_RTTI_DECLARATION;
BinaryElementwiseLogical();
/// \brief Constructs a binary elementwise logical operation.
@ -51,6 +48,9 @@ protected:
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
public:
OPENVINO_OP("BinaryElementwiseLogical", "util");
BWDCMP_RTTI_DECLARATION;
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const override {

View File

@ -4,17 +4,102 @@
#pragma once
#include <openvino/op/convolution.hpp>
#include <openvino/op/group_conv.hpp>
#include "utils.hpp"
namespace ov {
namespace op {
namespace v1 {
int64_t calculate_num_spatial(const Convolution* op,
template<class ConvType>
int64_t calculate_num_spatial(const ConvType* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims);
void update_and_validate_attributes(Convolution* op);
const int64_t& num_non_spatial_filter_dims) {
int64_t num_spatial = op->m_num_spatial;
if (num_spatial == -1) {
const auto &input_rank = input_shape.rank();
const auto &filters_rank = filters_shape.rank();
if (input_rank.is_static())
num_spatial = input_rank.get_length() - num_non_spatial_data_dims;
if (filters_rank.is_static())
num_spatial = filters_rank.get_length() - num_non_spatial_filter_dims;
if (const auto &size = op->m_dilations.size()) {
NODE_VALIDATION_CHECK(op, num_spatial == -1 || num_spatial == size,
"Dilations should be defined for all and only spatial dimensions.");
num_spatial = static_cast<int64_t>(size);
}
if (const auto &size = op->m_strides.size()) {
NODE_VALIDATION_CHECK(op, num_spatial == -1 || num_spatial == size,
"Strides should be defined for all and only spatial dimensions.");
num_spatial = static_cast<int64_t>(size);
}
if (const auto &size = op->m_pads_begin.size()) {
NODE_VALIDATION_CHECK(op, num_spatial == -1 || num_spatial == size,
"Pads begin should be defined for all and only spatial dimensions.");
num_spatial = static_cast<int64_t>(size);
}
if (const auto &size = op->m_pads_end.size()) {
NODE_VALIDATION_CHECK(op, num_spatial == -1 || num_spatial == size,
"Pads begin should be defined for all and only spatial dimensions.");
num_spatial = static_cast<int64_t>(size);
}
}
return num_spatial;
}
template<class ConvType>
void update_and_validate_attributes(ConvType* op) {
const auto& num_spatial = op->m_num_spatial;
if (num_spatial != -1) {
auto& strides = op->m_strides;
auto& dilations = op->m_dilations;
auto& pad_begin = op->m_pads_begin;
auto& pad_end = op->m_pads_end;
auto& auto_pad = op->m_auto_pad;
if (strides.empty())
strides = Strides(num_spatial, 1);
if (dilations.empty())
dilations = Strides(num_spatial, 1);
if (pad_begin.empty() || auto_pad == op::PadType::VALID)
pad_begin = CoordinateDiff(num_spatial, 0);
if (pad_end.empty() || auto_pad == op::PadType::VALID)
pad_end = CoordinateDiff(num_spatial, 0);
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(strides.size()) == num_spatial,
"Strides should be defined for all and only spatial dimensions..");
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(dilations.size()) == num_spatial,
"Dilations should be defined for all and only spatial dimensions..");
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(pad_begin.size()) == num_spatial &&
static_cast<int64_t>(pad_end.size()) == num_spatial,
"Pads should be defined for all and only spatial dimensions..");
NODE_VALIDATION_CHECK(op,
std::all_of(dilations.begin(),
dilations.end(),
[](const size_t &i) {
return i > 0;
}),
"Filter dilation (",
dilations,
") has zero dimension.");
NODE_VALIDATION_CHECK(op,
std::all_of(strides.begin(),
strides.end(),
[](const size_t &i) {
return i > 0;
}),
"Filter strides (",
strides,
") has zero dimension.");
}
}
template <class T>
inline bool dynamic_check(const int64_t& num_spatial) {
@ -28,14 +113,14 @@ inline bool dynamic_check<PartialShape>(const int64_t& num_spatial) {
return num_spatial != -1;
}
template<class T>
bool resolve_auto_pad_for_shape(const Convolution* op,
template<class ConvType, class ShapeType>
bool resolve_auto_pad_for_shape(const ConvType* op,
CoordinateDiff& pads_begin,
CoordinateDiff& pads_end,
const std::vector<T> &input_shapes,
const std::vector<ShapeType> &input_shapes,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims) {
const auto& auto_pad = op->get_auto_pad();
const auto& auto_pad = op->m_auto_pad;
if (auto_pad != op::PadType::SAME_UPPER && auto_pad != op::PadType::SAME_LOWER) {
pads_begin = op->m_pads_begin;
pads_end = op->m_pads_end;
@ -43,7 +128,7 @@ bool resolve_auto_pad_for_shape(const Convolution* op,
}
auto& num_spatial = op->m_num_spatial;
if (!dynamic_check<T>(num_spatial))
if (!dynamic_check<ShapeType>(num_spatial))
return false;
auto input_shape = input_shapes[0];
@ -61,8 +146,8 @@ bool resolve_auto_pad_for_shape(const Convolution* op,
bool status = true;
for (int64_t i = 0; i < num_spatial; ++i) {
const auto& input_dim = input_shape[i + 2];
const auto& filters_dim = filters_shape[i + 2];
const auto& input_dim = input_shape[i + num_non_spatial_data_dims];
const auto& filters_dim = filters_shape[i + num_non_spatial_filter_dims];
if (input_dim.is_static() && filters_dim.is_static()) {
const int64_t& window_dilated_dim = (filters_dim.get_length() - 1) * dilations[i] + 1;
NODE_VALIDATION_CHECK(op,
@ -93,6 +178,49 @@ bool resolve_auto_pad_for_shape(const Convolution* op,
}
template<class ConvType, class ShapeType>
void calculate_output_spatial_dims_for_convolution(
const ConvType* op,
const ShapeType& input_shape,
const ShapeType& filters_shape,
ShapeType& output_shape,
const int64_t& num_spatial,
const Strides& strides,
const Strides& dilations,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims) {
for (int64_t i = 0; i < num_spatial; ++i) {
const auto& input_dim = input_shape[i + num_non_spatial_data_dims];
const auto& filters_dim = filters_shape[i + num_non_spatial_filter_dims];
if (input_dim.is_static() && filters_dim.is_static()) {
const int64_t& window_dilated_dim = (filters_dim.get_length() - 1) * dilations[i] + 1;
NODE_VALIDATION_CHECK(op,
window_dilated_dim > 0,
"Window after dilation has dimension less than 1 (dim: ",
window_dilated_dim,
") at axis ",
i,
".");
const int64_t& data_padded_dilated_dim = input_dim.get_length() + pads_begin[i] + pads_end[i];
NODE_VALIDATION_CHECK(op,
window_dilated_dim <= data_padded_dilated_dim,
"Window after dilation has dimension (dim: ",
window_dilated_dim,
") larger than the data shape after padding (dim: ",
data_padded_dilated_dim,
") at axis ",
i,
".");
output_shape[i + num_non_spatial_data_dims] = (data_padded_dilated_dim - window_dilated_dim) / strides[i] + 1;
}
}
}
template<class T>
void shape_infer(const Convolution* op,
const CoordinateDiff& pads_begin,
@ -136,33 +264,349 @@ void shape_infer(const Convolution* op,
filters_shape[1],
").");
const auto& dilations = op->m_dilations;
const auto& strides = op->m_strides;
calculate_output_spatial_dims_for_convolution(
op, input_shape, filters_shape, output_shape,
num_spatial, op->m_strides, op->m_dilations, pads_begin, pads_end, 2, 2
);
}
for (int64_t i = 0; i < num_spatial; ++i) {
const auto& input_dim = input_shape[i + 2];
const auto& filters_dim = filters_shape[i + 2];
if (input_dim.is_static() && filters_dim.is_static()) {
const int64_t& window_dilated_dim = (filters_dim.get_length() - 1) * dilations[i] + 1;
NODE_VALIDATION_CHECK(op,
window_dilated_dim > 0,
"Window after dilation has dimension less than 1 (dim: ",
window_dilated_dim,
") at axis ",
i,
".");
template<class T>
void shape_infer(const GroupConvolution* op,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const std::vector<T> &input_shapes,
std::vector<T> &output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
auto input_shape = input_shapes[0], filters_shape = input_shapes[1];
const int64_t& data_padded_dilated_dim = input_dim.get_length() + pads_begin[i] + pads_end[i];
NODE_VALIDATION_CHECK(op,
window_dilated_dim <= data_padded_dilated_dim,
"Window after dilation has dimension (dim: ",
window_dilated_dim,
") larger than the data shape after padding (dim: ",
data_padded_dilated_dim,
") at axis ",
i,
".");
output_shape[i + 2] = (data_padded_dilated_dim - window_dilated_dim) / strides[i] + 1;
const auto& num_spatial = op->m_num_spatial;
NODE_VALIDATION_CHECK(op, num_spatial != -1,
"GroupConvolution shape_infer should be provided with correct num_spatial attribute");
if (input_shape.rank().is_dynamic())
input_shape.resize(num_spatial + 2);
if (filters_shape.rank().is_dynamic())
filters_shape.resize(num_spatial + 3);
NODE_VALIDATION_CHECK(op,
(static_cast<int64_t>(input_shape.size()) == (num_spatial + 2)) &&
(static_cast<int64_t>(filters_shape.size()) == (num_spatial + 3)),
"Data batch and filters rank do not match (data batch shape: ",
input_shape,
", filters shape: ",
filters_shape,
").");
// ranks are originally static or aligned with num_spatial, attributes assumed to be valid
auto& output_shape = output_shapes[0];
output_shape.resize(num_spatial + 2);
output_shape[0] = input_shape[0];
auto groups = filters_shape[0];
if (groups.is_dynamic()) {
// [N, GROUPS * C_IN, ...] x [GROUPS, C_OUT, C_IN, ...] = [N, GROUPS * C_OUT, ...]
if (input_shape[1].is_static() && filters_shape[2].is_static()) {
using DimensionType = typename std::iterator_traits<typename T::iterator>::value_type;
auto n_data_channels = input_shape[1].get_length();
auto input_channels = filters_shape[2].get_length();
NODE_VALIDATION_CHECK(op, (n_data_channels % input_channels) == 0);
groups = DimensionType(n_data_channels / input_channels);
}
}
if (input_shape[1].is_static()) {
// GROUPS and C_IN consistency checks
if (groups.is_static() && filters_shape[2].is_static()) {
NODE_VALIDATION_CHECK(
op,
input_shape[1].get_length() / groups.get_length() == filters_shape[2].get_length(),
"Input channels dimension of data batch has incompatible value "
"with filter shape.");
} else if (groups.is_static()) {
NODE_VALIDATION_CHECK(
op,
input_shape[1].get_length() % groups.get_length() == 0,
"Input channels dimension of data batch not a multiple of group size.");
}
}
output_shape[1] = groups * filters_shape[1];
calculate_output_spatial_dims_for_convolution(
op, input_shape, filters_shape, output_shape,
num_spatial, op->m_strides, op->m_dilations, pads_begin, pads_end, 2, 3
);
}
template<class ConvType>
int64_t calculate_num_spatial(const ConvType* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const PartialShape& output_shapes_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims) {
auto num_spatial = op->m_num_spatial;
if (num_spatial == -1) {
num_spatial = calculate_num_spatial(
op, input_shape, filters_shape, num_non_spatial_data_dims, num_non_spatial_filter_dims);
if (const auto &size = op->m_output_padding.size()) {
NODE_VALIDATION_CHECK(op, num_spatial == -1 || num_spatial == size,
"Output padding should be defined for all and only spatial dimensions.");
num_spatial = static_cast<int64_t>(size);
}
if (output_shapes_shape.is_static()) {
NODE_VALIDATION_CHECK(op, output_shapes_shape.size() == 1, "Input delivering output shape must have rank 1");
NODE_VALIDATION_CHECK(op, num_spatial == -1 || num_spatial == output_shapes_shape[0].get_length(),
"Output shape should be specified only and for all spatial dimensions.");
num_spatial = static_cast<int64_t>(output_shapes_shape[0].get_length());
}
}
return num_spatial;
}
template<class ConvType>
void update_and_validate_attributes_back_prop(ConvType* op) {
const auto& num_spatial = op->m_num_spatial;
if (num_spatial != -1) {
update_and_validate_attributes(op);
auto& output_padding = op->m_output_padding;
if (output_padding.empty())
output_padding = CoordinateDiff(num_spatial, 0);
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(output_padding.size()) == num_spatial,
"Output padding should be defined for all and only "
"spatial dimensions..");
}
}
template<class ConvType, class ShapeType>
bool resolve_auto_pad_for_shape_back_prop(const ConvType* op,
CoordinateDiff& pads_begin,
CoordinateDiff& pads_end,
const std::vector<ShapeType> &input_shapes,
ShapeType& output_spatial_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims) {
const auto& auto_pad = op->m_auto_pad;
if (auto_pad != PadType::SAME_UPPER && auto_pad != PadType::SAME_LOWER) {
pads_begin = op->m_pads_begin;
pads_end = op->m_pads_end;
return true;
}
const auto& num_spatial = op->m_num_spatial;
if (!dynamic_check<ShapeType>(num_spatial))
return false;
if (input_shapes.size() != 3) {
pads_begin = CoordinateDiff(num_spatial, 0);
pads_end = CoordinateDiff(num_spatial, 0);
return true;
}
OPENVINO_ASSERT(input_shapes.size() == 3 && (auto_pad == PadType::SAME_UPPER || auto_pad == PadType::SAME_LOWER));
pads_begin = CoordinateDiff(num_spatial, 0);
pads_end = CoordinateDiff(num_spatial, 0);
if (output_spatial_shape.rank().is_dynamic())
output_spatial_shape.resize(num_spatial);
auto input_shape = input_shapes[0];
auto filters_shape = input_shapes[1];
if (input_shape.rank().is_dynamic())
input_shape.resize(num_spatial + num_non_spatial_data_dims);
if (filters_shape.rank().is_dynamic())
filters_shape.resize(num_spatial + num_non_spatial_filter_dims);
bool status = true;
for (auto i = 0; i < num_spatial; ++i) {
const auto& data_dim = input_shape[i + num_non_spatial_data_dims];
const auto& filter_dim = filters_shape[i + num_non_spatial_filter_dims];
const auto& output_dim = output_spatial_shape[i];
const auto& output_padding = op->m_output_padding[i];
if (data_dim.is_static() && filter_dim.is_static() && output_dim.is_static()) {
const auto& strides = op->m_strides[i];
const auto& dilations = op->m_dilations[i];
int total_padding = std::max<int>(
strides * (data_dim.get_length() - 1) + dilations * (filter_dim.get_length() - 1) + 1 - output_dim.get_length() + output_padding, 0);
if (auto_pad != op::PadType::SAME_UPPER) {
pads_begin[i] = total_padding / 2;
pads_end[i] = total_padding - pads_begin[i];
} else {
pads_end[i] = total_padding / 2;
pads_begin[i] = total_padding - pads_end[i];
}
} else {
status = false;
}
}
return status;
}
template<class ConvType, class ShapeType>
void calculate_output_spatial_dims_for_convolution_back_prop_data(
const ConvType* op,
const ShapeType& input_shape,
const ShapeType& filters_shape,
const ShapeType& output_rank,
ShapeType& output_shape,
const int64_t& num_spatial,
const Strides& strides,
const Strides& dilations,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const CoordinateDiff& output_padding,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims) {
}
template <class T>
void shape_infer(const ConvolutionBackpropData* op,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const T& output_shape_from_input,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes) {
size_t input_size = input_shapes.size();
NODE_VALIDATION_CHECK(op, (input_size == 2 || input_size == 3) && output_shapes.size() == 1);
auto input_shape = input_shapes[0], filters_shape = input_shapes[1];
const auto& num_spatial = op->m_num_spatial;
NODE_VALIDATION_CHECK(op, num_spatial != -1,
"ConvolutionBackpropData shape_infer should be provided with correct num_spatial attribute");
NODE_VALIDATION_CHECK(op, num_spatial == 1 || num_spatial == 2 || num_spatial == 3,
"Data and filters inputs must have rank 3, 4 or 5");
if (input_shape.rank().is_dynamic())
input_shape.resize(num_spatial + 2);
if (filters_shape.rank().is_dynamic())
filters_shape.resize(num_spatial + 2);
NODE_VALIDATION_CHECK(op,
(static_cast<int64_t>(input_shape.size()) == (num_spatial + 2)) &&
(static_cast<int64_t>(filters_shape.size()) == (num_spatial + 2)),
"Data and filters rank do not match (data batch shape: ",
input_shape,
", filters shape: ",
filters_shape,
").");
// ranks are originally static or aligned with num_spatial, attributes assumed to be valid
auto& output_shape = output_shapes[0];
output_shape.resize(num_spatial + 2);
output_shape[0] = input_shape[0];
output_shape[1] = filters_shape[1];
NODE_VALIDATION_CHECK(op, input_shape[1].compatible(filters_shape[0]),
"Input channels dimension of data and filters inputs must be equal");
if (input_size == 3) {
if (output_shape_from_input.rank().is_static()) {
NODE_VALIDATION_CHECK(op, output_shape_from_input.size() == num_spatial,
"Output shape should be specified only and for all spatial dimensions.");
for (size_t i = 0; i < num_spatial; ++i)
output_shape[i + 2] = output_shape_from_input[i];
}
} else {
const auto& strides = op->m_strides;
const auto& dilations = op->m_dilations;
const auto& output_padding = op->m_output_padding;
for (size_t i = 0; i < num_spatial; ++i) {
if (filters_shape[i + 2].is_static() && input_shape[i + 2].is_static())
output_shape[i + 2] = strides[i] * (input_shape[i + 2].get_length() - 1) +
dilations[i] * (filters_shape[i + 2].get_length() - 1) + 1 - pads_begin[i] - pads_end[i] +
output_padding[i];
}
}
}
template <class T>
void shape_infer(const GroupConvolutionBackpropData* op,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const T& output_shape_from_input,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes) {
size_t input_size = input_shapes.size();
NODE_VALIDATION_CHECK(op, (input_size == 2 || input_size == 3) && output_shapes.size() == 1);
auto input_shape = input_shapes[0], filters_shape = input_shapes[1];
const auto& num_spatial = op->m_num_spatial;
NODE_VALIDATION_CHECK(op, num_spatial != -1,
"GroupConvolutionBackpropData shape_infer should be provided with correct num_spatial attribute");
NODE_VALIDATION_CHECK(op, num_spatial == 1 || num_spatial == 2 || num_spatial == 3,
"Data and filters inputs must have rank 3, 4 or 5");
if (input_shape.rank().is_dynamic())
input_shape.resize(num_spatial + 2);
if (filters_shape.rank().is_dynamic())
filters_shape.resize(num_spatial + 3);
NODE_VALIDATION_CHECK(op,
(static_cast<int64_t>(input_shape.size()) == (num_spatial + 2)) &&
(static_cast<int64_t>(filters_shape.size()) == (num_spatial + 3)),
"Data and filters rank do not match (data batch shape: ",
input_shape,
", filters shape: ",
filters_shape,
").");
// ranks are originally static or aligned with num_spatial, attributes assumed to be valid
auto& output_shape = output_shapes[0];
output_shape.resize(num_spatial + 2);
output_shape[0] = input_shape[0];
auto groups = filters_shape[0];
if (groups.is_dynamic()) {
// [N, GROUPS * C_IN, ...] x [GROUPS, C_IN, C_OUT, ...] = [N, GROUPS * C_OUT, ...]
if (input_shape[1].is_static() && filters_shape[1].is_static()) {
using DimensionType = typename std::iterator_traits<typename T::iterator>::value_type;
auto n_data_channels = input_shape[1].get_length();
auto input_channels = filters_shape[1].get_length();
NODE_VALIDATION_CHECK(op, (n_data_channels % input_channels) == 0);
groups = DimensionType(n_data_channels / input_channels);
}
}
if (input_shape[1].is_static()) {
// GROUPS and C_IN consistency checks
if (groups.is_static() && filters_shape[1].is_static()) {
NODE_VALIDATION_CHECK(
op,
input_shape[1].get_length() / groups.get_length() == filters_shape[1].get_length(),
"Input channels dimension of data batch has incompatible value "
"with filter shape.");
} else if (groups.is_static()) {
NODE_VALIDATION_CHECK(
op,
input_shape[1].get_length() % groups.get_length() == 0,
"Input channels dimension of data batch not a multiple of group size.");
}
}
output_shape[1] = filters_shape[2] * groups;
if (input_size == 3) {
if (output_shape_from_input.rank().is_static()) {
NODE_VALIDATION_CHECK(op, output_shape_from_input.size() == num_spatial,
"Output shape should be specified only and for all spatial dimensions.");
for (size_t i = 0; i < num_spatial; ++i)
output_shape[i + 2] = output_shape_from_input[i];
}
} else {
const auto& strides = op->m_strides;
const auto& dilations = op->m_dilations;
const auto& output_padding = op->m_output_padding;
for (size_t i = 0; i < num_spatial; ++i) {
if (filters_shape[i + 3].is_static() && input_shape[i + 2].is_static())
output_shape[i + 2] = strides[i] * (input_shape[i + 2].get_length() - 1) +
dilations[i] * (filters_shape[i + 3].get_length() - 1) + 1 - pads_begin[i] - pads_end[i] +
output_padding[i];
}
}
}

View File

@ -0,0 +1,41 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/op/fake_quantize.hpp>
#include "utils.hpp"
template <class T>
void shape_infer(const ov::op::v0::FakeQuantize* op,
const std::vector<T> &input_shapes,
std::vector<T> &output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 5 && output_shapes.size() == 1);
T data_pshape = input_shapes[0];
ov::op::AutoBroadcastSpec auto_broadcast = op->get_auto_broadcast();
for (size_t i = 1; i <= 4; ++i) {
if (auto_broadcast.m_type == ov::op::AutoBroadcastType::NONE) {
NODE_VALIDATION_CHECK(op,
T::merge_into(data_pshape, input_shapes[i]),
"Argument shapes are inconsistent.");
} else if (auto_broadcast.m_type == ov::op::AutoBroadcastType::NUMPY ||
auto_broadcast.m_type == ov::op::AutoBroadcastType::PDPD) {
NODE_VALIDATION_CHECK(
op,
T::broadcast_merge_into(data_pshape, input_shapes[i], auto_broadcast),
"Argument shapes are inconsistent.");
} else {
NODE_VALIDATION_CHECK(op, false, "Unsupported auto broadcast specification");
}
}
// NOTE: kept as first shape passthrough as by spec we have uni-directional broadcasting
// meaning that limit inputs do not affect output shape
// BUT: will not fail in the case of
// input[0].shape = [1, 3, 1, 1]
// input[1].shape = [1, 3, 4, 5]
// This controversial behavior is kept here due to backward-compatibility and the fact that
// frameworks do not allow such behavior too -- so the chance to have such FQ configuration is minimal
first_input_passthrough_infer(op, input_shapes, output_shapes);
}

View File

@ -6,6 +6,39 @@
#include <openvino/opsets/opset1.hpp>
#include <openvino/core/validation_util.hpp>
template <class OpType, class T>
void copy_shape_infer(const OpType* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 1 && output_shapes.size() == 1,
"Incorrect number of input/output shapes");
output_shapes[0] = input_shapes[0];
}
template <class OpType, class T>
void first_input_passthrough_infer(const OpType* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1, "Incorrect number of output shapes");
output_shapes[0] = input_shapes[0];
}
template <class OpType, class T>
void eltwise_shape_infer(const OpType* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1,
"Incorrect number of input/output shapes");
T output_shape = input_shapes[0];
ov::op::AutoBroadcastSpec autob = op->get_autob();
if (autob.m_type == ov::op::AutoBroadcastType::NONE) {
NODE_VALIDATION_CHECK(op, T::merge_into(output_shape, input_shapes[1]),
"Argument shapes are inconsistent.");
} else if (autob.m_type == ov::op::AutoBroadcastType::NUMPY || autob.m_type == ov::op::AutoBroadcastType::PDPD) {
NODE_VALIDATION_CHECK(op, T::broadcast_merge_into(output_shape, input_shapes[1], autob),
"Argument shapes are inconsistent.");
} else {
NODE_VALIDATION_CHECK(op, false, "Unsupported auto broadcast specification");
}
output_shapes[0] = output_shape;
}
template <class T>
inline bool get_data_as_int64(
size_t idx, const ov::Node* op, std::vector<int64_t>& axes_value,
@ -33,3 +66,29 @@ inline bool get_data_as_int64<ov::PartialShape>(
}
return true;
}
template <class T>
inline bool get_data_as_shape(
size_t idx, const ov::Node* op, T& shape,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
if (constant_data.count(idx)) {
shape = T(ov::opset1::Constant(constant_data.at(idx)).cast_vector<size_t>());
} else {
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
shape = T(constant->cast_vector<size_t>());
}
return true;
}
template <>
inline bool get_data_as_shape<ov::PartialShape>(
size_t idx, const ov::Node* op, ov::PartialShape& shape,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
if (constant_data.count(idx)) {
shape = ov::PartialShape(ov::opset1::Constant(constant_data.at(idx)).cast_vector<int64_t>());
return true;
} else {
return ov::evaluate_as_partial_shape(op->input_value(idx), shape);
}
}

View File

@ -1,85 +0,0 @@
#include <convolution_shape_inference.hpp>
namespace ov {
namespace op {
namespace v1 {
int64_t calculate_num_spatial(const Convolution* op,
const PartialShape& input_shape,
const PartialShape& filters_shape,
const int64_t& num_non_spatial_data_dims,
const int64_t& num_non_spatial_filter_dims) {
int64_t num_spatial = op->m_num_spatial;
if (num_spatial == -1) {
const auto &input_rank = input_shape.rank();
const auto &filters_rank = filters_shape.rank();
if (const auto &size = op->m_dilations.size())
num_spatial = static_cast<int64_t>(size);
if (const auto &size = op->m_strides.size())
num_spatial = static_cast<int64_t>(size);
if (const auto &size = op->m_pads_begin.size())
num_spatial = static_cast<int64_t>(size);
if (const auto &size = op->m_pads_end.size())
num_spatial = static_cast<int64_t>(size);
if (input_rank.is_static())
num_spatial = input_rank.get_length() - num_non_spatial_data_dims;
if (filters_rank.is_static())
num_spatial = filters_rank.get_length() - num_non_spatial_filter_dims;
}
return num_spatial;
}
void update_and_validate_attributes(Convolution* op) {
const auto& num_spatial = op->m_num_spatial;
if (num_spatial != -1) {
auto& strides = op->m_strides;
auto& dilations = op->m_dilations;
auto& pad_begin = op->m_pads_begin;
auto& pad_end = op->m_pads_end;
auto& auto_pad = op->m_auto_pad;
if (strides.empty())
strides = Strides(num_spatial, 1);
if (dilations.empty())
dilations = Strides(num_spatial, 1);
if (pad_begin.empty() || auto_pad == op::PadType::VALID)
pad_begin = CoordinateDiff(num_spatial, 0);
if (pad_end.empty() || auto_pad == op::PadType::VALID)
pad_end = CoordinateDiff(num_spatial, 0);
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(strides.size()) == num_spatial,
"Strides should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(dilations.size()) == num_spatial,
"Dilations should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(pad_begin.size()) == num_spatial &&
static_cast<int64_t>(pad_end.size()) == num_spatial,
"Pads should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(op,
std::all_of(dilations.begin(),
dilations.end(),
[](const size_t &i) {
return i > 0;
}),
"Filter dilation (",
dilations,
") has zero dimension.");
NODE_VALIDATION_CHECK(op,
std::all_of(strides.begin(),
strides.end(),
[](const size_t &i) {
return i > 0;
}),
"Filter strides (",
strides,
") has zero dimension.");
}
}
}
}
}

View File

@ -71,7 +71,7 @@ void op::v1::Convolution::validate_and_infer_types() {
update_and_validate_attributes(this);
std::vector<ov::PartialShape> input_shapes = {data_shape, filter_shape};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
if (m_num_spatial != -1) {
resolve_auto_pad_for_shape(this, m_pads_begin, m_pads_end, input_shapes, 2, 2);
@ -160,24 +160,19 @@ bool op::v1::ConvolutionBackpropData::is_dynamic() const {
}
const ov::PartialShape op::v1::ConvolutionBackpropData::get_output_shape() const {
ov::PartialShape shape;
if (get_input_size() == 3 && evaluate_as_partial_shape(input_value(2), shape))
return shape;
auto data_pshape = get_input_partial_shape(0);
auto filter_pshape = get_input_partial_shape(1);
ov::PartialShape shape;
bool is_output_shape_present = inputs().size() == 3;
if (is_output_shape_present) {
if (const auto& const_op = get_constant_from_source(input_value(2))) {
return ov::PartialShape{const_op->get_shape_val()};
}
}
if (data_pshape.rank().is_static()) {
shape = ov::PartialShape{vector<Dimension>(data_pshape.rank().get_length() - 2)};
} else if (filter_pshape.rank().is_static()) {
shape = ov::PartialShape{vector<Dimension>(filter_pshape.rank().get_length() - 2)};
} else {
if (data_pshape.rank().is_static())
shape = ov::PartialShape::dynamic(data_pshape.rank().get_length() - 2);
else if (filter_pshape.rank().is_static())
shape = ov::PartialShape::dynamic(filter_pshape.rank().get_length() - 2);
else
shape = ov::PartialShape::dynamic();
}
return shape;
}
@ -215,9 +210,7 @@ void op::v1::ConvolutionBackpropData::infer_conv_backprop_output_spatial_shape(
void op::v1::ConvolutionBackpropData::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v1_ConvolutionBackpropData_validate_and_infer_types);
const ov::PartialShape& data_pshape = get_input_partial_shape(0);
element::Type delta_et = get_input_element_type(0);
const ov::PartialShape& filters_pshape = get_input_partial_shape(1);
element::Type filters_et = get_input_element_type(1);
element::Type result_et;
@ -234,186 +227,40 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types() {
"Element type of inputs must be numeric. Got: ",
result_et);
Rank result_ps_rank;
NODE_VALIDATION_CHECK(this,
Rank::merge(result_ps_rank, data_pshape.rank(), filters_pshape.rank()),
"Data and filters inputs must have same rank. Got: ",
data_pshape,
" and ",
filters_pshape);
NODE_VALIDATION_CHECK(this,
result_ps_rank.compatible(3) || result_ps_rank.compatible(4) || result_ps_rank.compatible(5),
"Data and filters inputs must have rank 3, 4 or 5. Got: ",
result_ps_rank);
if (data_pshape.rank().is_static() && filters_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
data_pshape[1].compatible(filters_pshape[0]),
"Input channels dimension of data and filters inputs must be equal. Got: ",
data_pshape,
" and ",
filters_pshape);
}
bool is_output_shape_present = inputs().size() == 3;
if (is_output_shape_present) {
const ov::PartialShape& output_shape_pshape = get_input_partial_shape(2);
const element::Type output_shape_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
output_shape_et.is_integral_number(),
"Element type for output shape should be of integer type ",
"(output_shape element type: ",
output_shape_et,
").");
NODE_VALIDATION_CHECK(this,
output_shape_pshape.rank().compatible(1),
"Spatial shape of output input must be of rank 1 ",
"(output_shape shape: ",
output_shape_pshape,
").");
}
ov::PartialShape output_spatial_pshape = get_output_shape();
if (result_ps_rank.is_static()) {
const auto num_spatial_dims = result_ps_rank.get_length() - 2;
if (m_strides.size() == 0) {
m_strides = Strides(num_spatial_dims, 1);
}
if (m_dilations.size() == 0) {
m_dilations = Strides(num_spatial_dims, 1);
}
if (m_pads_begin.size() == 0 || m_auto_pad == PadType::VALID) {
m_pads_begin = CoordinateDiff(num_spatial_dims, 0);
}
if (m_pads_end.size() == 0 || m_auto_pad == PadType::VALID) {
m_pads_end = CoordinateDiff(num_spatial_dims, 0);
}
if (m_output_padding.size() == 0) {
m_output_padding = CoordinateDiff(num_spatial_dims, 0);
}
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(m_strides.size()) == num_spatial_dims,
"Strides should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(m_dilations.size()) == num_spatial_dims,
"Dilations should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(m_pads_begin.size()) == num_spatial_dims &&
static_cast<int64_t>(m_pads_end.size()) == num_spatial_dims,
"Pads should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(m_output_padding.size()) == num_spatial_dims,
"Output padding should be defined for all and only "
"spatial features.");
if (is_output_shape_present && output_spatial_pshape.is_static()) {
ov::Shape output_shape = output_spatial_pshape.to_shape();
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(output_shape.size()) == num_spatial_dims,
"Output shape should be specified only and for "
"all spatial dimensions.");
}
}
ov::PartialShape result_pshape{ov::PartialShape::dynamic()};
// If output shape is provided, ignore current values for padding begin/end
// and infer them.
if (is_output_shape_present) {
if (output_spatial_pshape.rank().is_static()) {
if (data_pshape.rank().is_static() && filters_pshape.rank().is_static()) {
const ov::PartialShape data_spatial_shape = [data_pshape]() {
vector<Dimension> data_dims{data_pshape};
data_dims.erase(data_dims.begin(), data_dims.begin() + 2); // remove {N, C_IN}
return ov::PartialShape{data_dims};
}();
bool output_shape_input_present = get_input_size() == 3;
const ov::PartialShape filters_spatial_shape = [filters_pshape]() {
vector<Dimension> filters_dims{filters_pshape};
filters_dims.erase(filters_dims.begin(),
filters_dims.begin() + 2); // remove {C_IN, C_OUT}
return ov::PartialShape{filters_dims};
}();
const auto& data_shape = get_input_partial_shape(0);
const auto& filter_shape = get_input_partial_shape(1);
// If auto_pad has one of following mode we infer paddings. Otherwise in
// EXPLICIT auto_pad mode we use what is provided.
if ((m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER) &&
(data_spatial_shape.is_static() && filters_spatial_shape.is_static() &&
output_spatial_pshape.is_static())) {
opset1::infer_conv_backprop_auto_padding(data_spatial_shape.to_shape(),
filters_spatial_shape.to_shape(),
output_spatial_pshape.to_shape(),
m_strides,
m_dilations,
m_auto_pad,
m_output_padding,
m_pads_begin,
m_pads_end);
}
}
vector<Dimension> output_pshape{output_spatial_pshape};
// C_OUT
auto n_out_channels = filters_pshape.rank().is_static() ? filters_pshape[1] : Dimension::dynamic();
output_pshape.insert(output_pshape.begin(), n_out_channels);
// N
auto batches = data_pshape.rank().is_static() ? data_pshape[0] : Dimension::dynamic();
output_pshape.insert(output_pshape.begin(), batches);
result_pshape = ov::PartialShape{output_pshape};
}
set_input_is_relevant_to_shape(2);
auto& output_shapes_shape = output_shape_input_present ? get_input_partial_shape(2) : PartialShape::dynamic();
m_num_spatial = calculate_num_spatial(this, data_shape, filter_shape, output_shapes_shape, 2, 2);
update_and_validate_attributes_back_prop(this);
std::vector<ov::PartialShape> input_shapes = {data_shape, filter_shape};
if (output_shape_input_present)
input_shapes.push_back(get_input_partial_shape(2));
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
if (m_num_spatial != -1) {
ov::PartialShape output_spatial_shape = get_output_shape();
resolve_auto_pad_for_shape_back_prop(this, m_pads_begin, m_pads_end, input_shapes, output_spatial_shape, 2, 2);
shape_infer(this, m_pads_begin, m_pads_end, output_spatial_shape, input_shapes, output_shapes);
}
// Deduce output shape from input spatial shape, strides, dilations, output padding
// and padding values.
else {
if (m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER || m_auto_pad == PadType::VALID) {
m_pads_begin.assign(m_pads_begin.size(), 0);
m_pads_end.assign(m_pads_end.size(), 0);
}
set_output_type(0, result_et, output_shapes[0]);
vector<Dimension> output_pshape;
if (data_pshape.rank().is_static() && filters_pshape.rank().is_static()) {
auto data_spatial_shape = [data_pshape]() {
vector<Dimension> data_dims{data_pshape};
return vector<Dimension>{std::next(data_dims.begin(), 2), std::end(data_dims)}; // remove {N, C_IN}
}();
auto filters_spatial_shape = [filters_pshape]() {
vector<Dimension> filters_dims{filters_pshape};
return vector<Dimension>{std::next(filters_dims.begin(), 2), // remove {C_IN, C_OUT}
std::end(filters_dims)};
}();
infer_conv_backprop_output_spatial_shape(data_spatial_shape,
filters_spatial_shape,
m_strides,
m_dilations,
m_pads_begin,
m_pads_end,
m_output_padding,
output_pshape);
} else {
output_pshape = vector<Dimension>{output_spatial_pshape};
}
if (output_pshape.size()) {
// C_OUT
auto n_out_channels = filters_pshape.rank().is_static() ? filters_pshape[1] : Dimension::dynamic();
output_pshape.insert(output_pshape.begin(), n_out_channels);
// N
auto batches = data_pshape.rank().is_static() ? data_pshape[0] : Dimension::dynamic();
output_pshape.insert(output_pshape.begin(), batches);
result_pshape = ov::PartialShape{output_pshape};
}
}
set_input_is_relevant_to_shape(0);
set_input_is_relevant_to_shape(1);
set_output_type(0, result_et, result_pshape);
}
shared_ptr<Node> op::v1::ConvolutionBackpropData::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -4,15 +4,10 @@
#include "ngraph/op/group_conv.hpp"
#include <numeric>
#include <convolution_shape_inference.hpp>
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/op/util/precision_sensitive_attribute.hpp"
@ -72,8 +67,6 @@ static Dimension infer_group_from_input_shapes(const ov::PartialShape& data_psha
void op::v1::GroupConvolution::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v1_GroupConvolution_validate_and_infer_types);
ov::PartialShape data_batch_pshape = get_input_partial_shape(0);
ov::PartialShape filters_pshape = get_input_partial_shape(1);
element::Type data_batch_et = get_input_element_type(0);
element::Type filters_et = get_input_element_type(1);
@ -91,148 +84,21 @@ void op::v1::GroupConvolution::validate_and_infer_types() {
"Element type of inputs must be numeric. Got: ",
result_et);
NODE_VALIDATION_CHECK(this,
(data_batch_pshape.rank().compatible(5) && filters_pshape.rank().compatible(6)) ||
(data_batch_pshape.rank().compatible(4) && filters_pshape.rank().compatible(5)) ||
(data_batch_pshape.rank().compatible(3) && filters_pshape.rank().compatible(4)),
"Shapes for data batch and filters do not match. (data batch shape: ",
data_batch_pshape,
", filters shape: ",
filters_pshape,
").");
auto& data_shape = get_input_partial_shape(0);
auto& filter_shape = get_input_partial_shape(1);
ov::PartialShape result_shape{ov::PartialShape::dynamic()};
if (data_batch_pshape.rank().is_static() || filters_pshape.rank().is_static()) {
const bool is_data_batch_ps_static = data_batch_pshape.rank().is_static();
const auto output_ps_rank =
is_data_batch_ps_static ? data_batch_pshape.rank().get_length() : filters_pshape.rank().get_length() - 1;
const size_t num_spatial_dims = output_ps_rank - 2;
m_num_spatial = calculate_num_spatial(this, data_shape, filter_shape, 2, 3);
update_and_validate_attributes(this);
if (m_strides.size() == 0) {
m_strides = Strides(num_spatial_dims, 1);
}
std::vector<ov::PartialShape> input_shapes = {data_shape, filter_shape};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
if (m_dilations.size() == 0) {
m_dilations = Strides(num_spatial_dims, 1);
}
if (m_pads_begin.size() == 0 || m_auto_pad == PadType::VALID) {
m_pads_begin = CoordinateDiff(num_spatial_dims, 0);
}
if (m_pads_end.size() == 0 || m_auto_pad == PadType::VALID) {
m_pads_end = CoordinateDiff(num_spatial_dims, 0);
}
NODE_VALIDATION_CHECK(this,
m_strides.size() == num_spatial_dims,
"Strides should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(this,
m_dilations.size() == num_spatial_dims,
"Dilations should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(this,
m_pads_begin.size() == num_spatial_dims && m_pads_end.size() == num_spatial_dims,
"Pads should be defined for all and only spatial features.");
if (data_batch_pshape.rank().is_static() && filters_pshape.rank().is_static()) {
auto data_in_channels_dim = data_batch_pshape[1];
if (data_in_channels_dim.is_static()) {
auto groups_dim = filters_pshape[0];
if (groups_dim.is_static() && filters_pshape[2].is_static()) {
NODE_VALIDATION_CHECK(
this,
data_in_channels_dim.get_length() / groups_dim.get_length() == filters_pshape[2].get_length(),
"Input channels dimension of data batch has incompatible value "
"with filter shape.");
} else if (groups_dim.is_static()) {
NODE_VALIDATION_CHECK(this,
data_in_channels_dim.get_length() % groups_dim.get_length() == 0,
"Input channels dimension of data batch not a multiple of group size.");
}
}
}
result_shape = std::vector<Dimension>(output_ps_rank, Dimension::dynamic());
if (data_batch_pshape.rank().is_static()) {
result_shape[0] = data_batch_pshape[0]; // batch size
}
if (filters_pshape.rank().is_static() && filters_pshape.rank().get_length() > 2) {
result_shape[1] = filters_pshape[0] * filters_pshape[1];
}
if (m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER) {
bool auto_padding_applied = false;
if (filters_pshape.rank().is_static() && filters_pshape.rank().get_length() > 2) {
m_pads_begin.clear();
m_pads_end.clear();
const ov::PartialShape filter_spatial_shape = [filters_pshape]() {
vector<Dimension> filter_dims{filters_pshape};
filter_dims.erase(filter_dims.begin(),
filter_dims.begin() + 3); // Remove {GROUP, C_OUT, C_IN}
return ov::PartialShape{filter_dims};
}();
if (filter_spatial_shape.is_static()) {
auto_padding_applied = try_apply_auto_padding(data_batch_pshape,
filter_spatial_shape.to_shape(),
m_strides,
m_dilations,
m_auto_pad,
m_pads_end,
m_pads_begin);
}
}
if (!auto_padding_applied) {
set_output_type(0, result_et, result_shape);
return;
}
}
// we need to adjust channels input dim to reuse helpers for regular convolution
ov::PartialShape data_batch_ps = [&]() {
auto shape = ov::PartialShape{data_batch_pshape};
auto groups = filters_pshape.rank().is_static() ? filters_pshape[0] : Dimension();
if (groups.is_dynamic()) {
groups = infer_group_from_input_shapes(data_batch_pshape, filters_pshape);
}
if (data_batch_pshape.rank().is_static() && data_batch_pshape.rank().get_length()) {
if (data_batch_pshape[1].is_static() && groups.is_static()) {
shape[1] = Dimension(data_batch_pshape[1].get_length() / groups.get_length());
} else {
shape[1] = Dimension();
}
}
return shape;
}();
// we need to adjust filters shape to reuse helpers for regular convolution
ov::PartialShape filters_ps = [&]() {
auto shape = ov::PartialShape{filters_pshape};
if (shape.rank().is_static() && shape.rank().get_length() > 2) {
auto groups = filters_pshape.rank().is_static() ? filters_pshape[0] : Dimension();
if (groups.is_dynamic()) {
groups = infer_group_from_input_shapes(data_batch_pshape, filters_pshape);
}
shape[1] = groups * shape[1];
vector<Dimension> dim_vec{shape};
dim_vec.erase(dim_vec.begin());
shape = ov::PartialShape{dim_vec};
}
return shape;
}();
result_shape = infer_convolution_forward(this,
data_batch_ps,
Strides(m_strides.size(), 1), // dummy data dilations
m_pads_begin,
m_pads_end,
filters_ps,
m_strides,
m_dilations);
if (m_num_spatial != -1) {
resolve_auto_pad_for_shape(this, m_pads_begin, m_pads_end, input_shapes, 2, 3);
shape_infer(this, m_pads_begin, m_pads_end, input_shapes, output_shapes);
}
set_output_type(0, result_et, result_shape);
set_output_type(0, result_et, output_shapes[0]);
}
shared_ptr<Node> op::v1::GroupConvolution::clone_with_new_inputs(const OutputVector& new_args) const {
@ -352,23 +218,19 @@ static Dimension infer_backprop_group_from_input_shapes(const ov::PartialShape&
}
const ov::PartialShape op::v1::GroupConvolutionBackpropData::get_convolution_output_shape() const {
ov::PartialShape shape;
if (get_input_size() == 3 && evaluate_as_partial_shape(input_value(2), shape))
return shape;
auto data_pshape = get_input_partial_shape(0);
auto filter_pshape = get_input_partial_shape(1);
ov::PartialShape shape;
if (inputs().size() == 3) {
if (const auto& const_op = get_constant_from_source(input_value(2))) {
return ov::PartialShape{const_op->get_shape_val()};
}
}
if (data_pshape.rank().is_static()) {
shape = ov::PartialShape{vector<Dimension>(data_pshape.rank().get_length() - 2)};
} else if (filter_pshape.rank().is_static()) {
shape = ov::PartialShape{vector<Dimension>(filter_pshape.rank().get_length() - 3)};
} else {
if (data_pshape.rank().is_static())
shape = ov::PartialShape::dynamic(data_pshape.rank().get_length() - 2);
else if (filter_pshape.rank().is_static())
shape = ov::PartialShape::dynamic(filter_pshape.rank().get_length() - 2);
else
shape = ov::PartialShape::dynamic();
}
return shape;
}
@ -405,9 +267,7 @@ void op::v1::GroupConvolutionBackpropData::infer_conv_backprop_output_spatial_sh
void op::v1::GroupConvolutionBackpropData::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v1_GroupConvolutionBackpropData_validate_and_infer_types);
const ov::PartialShape& data_pshape = get_input_partial_shape(0);
element::Type data_et = get_input_element_type(0);
const ov::PartialShape& filters_pshape = get_input_partial_shape(1);
element::Type filters_et = get_input_element_type(1);
element::Type result_et;
@ -424,207 +284,39 @@ void op::v1::GroupConvolutionBackpropData::validate_and_infer_types() {
"Element type of inputs must be numeric. Got: ",
result_et);
NODE_VALIDATION_CHECK(this,
(data_pshape.rank().compatible(5) && filters_pshape.rank().compatible(6)) ||
(data_pshape.rank().compatible(4) && filters_pshape.rank().compatible(5)) ||
(data_pshape.rank().compatible(3) && filters_pshape.rank().compatible(4)),
"Shapes for data batch and filters do not match. (data batch shape: ",
data_pshape,
", filters shape: ",
filters_pshape,
").");
bool is_output_shape_present = inputs().size() == 3;
if (is_output_shape_present) {
const ov::PartialShape& output_shape_pshape = get_input_partial_shape(2);
bool output_shape_input_present = get_input_size() == 3;
if (output_shape_input_present) {
const element::Type output_shape_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
output_shape_et.is_integral_number(),
"Element type for output shape should be of integer type ",
"(output_shape element type: ",
output_shape_et,
").");
NODE_VALIDATION_CHECK(this,
output_shape_pshape.rank().compatible(1),
"Spatial shape of output input must be of rank 1 ",
"(output_shape shape: ",
output_shape_pshape,
").");
}
ov::PartialShape output_spatial_pshape = get_convolution_output_shape();
if (data_pshape.rank().is_static() || filters_pshape.rank().is_static()) {
const bool is_data_ps_static = data_pshape.rank().is_static();
const auto output_ps_rank =
is_data_ps_static ? data_pshape.rank().get_length() : filters_pshape.rank().get_length() - 1;
const size_t num_spatial_dims = output_ps_rank - 2;
if (m_strides.size() == 0) {
m_strides = Strides(num_spatial_dims, 1);
}
if (m_dilations.size() == 0) {
m_dilations = Strides(num_spatial_dims, 1);
}
if (m_pads_begin.size() == 0 || m_auto_pad == PadType::VALID) {
m_pads_begin = CoordinateDiff(num_spatial_dims, 0);
}
if (m_pads_end.size() == 0 || m_auto_pad == PadType::VALID) {
m_pads_end = CoordinateDiff(num_spatial_dims, 0);
}
if (m_output_padding.size() == 0) {
m_output_padding = CoordinateDiff(num_spatial_dims, 0);
}
NODE_VALIDATION_CHECK(this,
m_strides.size() == num_spatial_dims,
"Strides should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(this,
m_dilations.size() == num_spatial_dims,
"Dilations should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(this,
m_pads_begin.size() == num_spatial_dims && m_pads_end.size() == num_spatial_dims,
"Pads should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(this,
m_output_padding.size() == num_spatial_dims,
"Output padding should be defined for all and only "
"spatial features.");
if (data_pshape.rank().is_static() && filters_pshape.rank().is_static()) {
if (filters_pshape[0].is_static() && filters_pshape[1].is_static() && data_pshape[1].is_static()) {
auto groups = filters_pshape[0].get_length();
auto input_channels = filters_pshape[1].get_length();
auto n_data_channels = data_pshape[1].get_length();
NODE_VALIDATION_CHECK(this,
n_data_channels % groups == 0,
"Number of data channels not a multiple of group size.");
NODE_VALIDATION_CHECK(this,
n_data_channels / groups == input_channels,
"Data second dimension has incompatible value "
"with number of input channels.");
}
}
if (is_output_shape_present && output_spatial_pshape.is_static()) {
ov::Shape output_shape = output_spatial_pshape.to_shape();
NODE_VALIDATION_CHECK(this,
output_shape.size() == num_spatial_dims,
"Output shape should be specified only and for "
"all spatial dimensions.");
}
}
ov::PartialShape result_pshape{ov::PartialShape::dynamic()};
// If output shape is provided, ignore current values for padding begin/end
// and infer them.
if (is_output_shape_present) {
if (output_spatial_pshape.rank().is_static()) {
if (data_pshape.rank().is_static() && filters_pshape.rank().is_static()) {
const ov::PartialShape data_spatial_shape = [data_pshape]() {
vector<Dimension> data_dims{data_pshape};
data_dims.erase(data_dims.begin(), data_dims.begin() + 2); // remove {N, C_IN}
return ov::PartialShape{data_dims};
}();
const auto& data_shape = get_input_partial_shape(0);
const auto& filter_shape = get_input_partial_shape(1);
const ov::PartialShape filters_spatial_shape = [filters_pshape]() {
vector<Dimension> filters_dims{filters_pshape};
filters_dims.erase(filters_dims.begin(),
filters_dims.begin() + 3); // remove {GROUPS, C_OUT, C_IN}
return ov::PartialShape{filters_dims};
}();
auto& output_shapes_shape = output_shape_input_present ? get_input_partial_shape(2) : PartialShape::dynamic();
// If auto_pad has one of following mode we infer paddings. Otherwise in
// EXPLICIT auto_pad mode we use what is provided.
if ((m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER) &&
(data_spatial_shape.is_static() && filters_spatial_shape.is_static() &&
output_spatial_pshape.is_static())) {
opset1::infer_conv_backprop_auto_padding(data_spatial_shape.to_shape(),
filters_spatial_shape.to_shape(),
output_spatial_pshape.to_shape(),
m_strides,
m_dilations,
m_auto_pad,
m_output_padding,
m_pads_begin,
m_pads_end);
}
}
m_num_spatial = calculate_num_spatial(this, data_shape, filter_shape, output_shapes_shape, 2, 3);
update_and_validate_attributes_back_prop(this);
vector<Dimension> output_pshape{output_spatial_pshape};
// GROUPS * C_OUT
auto n_out_channels = Dimension::dynamic();
if (filters_pshape.rank().is_static()) {
auto group_dim = filters_pshape[0];
if (!group_dim.is_static()) {
group_dim = infer_backprop_group_from_input_shapes(data_pshape, filters_pshape);
}
n_out_channels = group_dim * filters_pshape[2];
}
output_pshape.insert(output_pshape.begin(), n_out_channels);
// N
auto batches = data_pshape.rank().is_static() ? data_pshape[0] : Dimension::dynamic();
output_pshape.insert(output_pshape.begin(), batches);
result_pshape = ov::PartialShape{output_pshape};
}
set_input_is_relevant_to_shape(2);
std::vector<ov::PartialShape> input_shapes = {data_shape, filter_shape};
if (output_shape_input_present)
input_shapes.push_back(get_input_partial_shape(2));
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
if (m_num_spatial != -1) {
ov::PartialShape output_spatial_shape = get_convolution_output_shape();
resolve_auto_pad_for_shape_back_prop(this, m_pads_begin, m_pads_end, input_shapes, output_spatial_shape, 2, 3);
shape_infer(this, m_pads_begin, m_pads_end, output_spatial_shape, input_shapes, output_shapes);
}
// Deduce output shape from input spatial shape, strides, dilations, output padding
// and padding values.
else {
if (m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER || m_auto_pad == PadType::VALID) {
m_pads_begin.assign(m_pads_begin.size(), 0);
m_pads_end.assign(m_pads_end.size(), 0);
}
set_output_type(0, result_et, output_shapes[0]);
vector<Dimension> output_pshape;
if (data_pshape.rank().is_static() && filters_pshape.rank().is_static()) {
auto data_spatial_shape = [data_pshape]() {
vector<Dimension> data_dims{data_pshape};
return vector<Dimension>{std::next(data_dims.begin(), 2), std::end(data_dims)};
}();
auto filters_spatial_shape = [filters_pshape]() {
vector<Dimension> filters_dims{filters_pshape};
return vector<Dimension>{std::next(filters_dims.begin(), 3), std::end(filters_dims)};
}();
infer_conv_backprop_output_spatial_shape(data_spatial_shape,
filters_spatial_shape,
m_strides,
m_dilations,
m_pads_begin,
m_pads_end,
m_output_padding,
output_pshape);
} else {
output_pshape = vector<Dimension>{output_spatial_pshape};
}
if (output_pshape.size()) {
// GROUPS * C_OUT
auto n_out_channels = Dimension::dynamic();
if (filters_pshape.rank().is_static()) {
auto group_dim = filters_pshape[0];
if (!group_dim.is_static()) {
group_dim = infer_backprop_group_from_input_shapes(data_pshape, filters_pshape);
}
n_out_channels = group_dim * filters_pshape[2];
}
output_pshape.insert(output_pshape.begin(), n_out_channels);
// N
auto batches = data_pshape.rank().is_static() ? data_pshape[0] : Dimension::dynamic();
output_pshape.insert(output_pshape.begin(), batches);
result_pshape = ov::PartialShape{output_pshape};
}
}
set_input_is_relevant_to_shape(0);
set_input_is_relevant_to_shape(1);
set_output_type(0, result_et, result_pshape);
}
shared_ptr<Node> op::v1::GroupConvolutionBackpropData::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -93,7 +93,7 @@ TEST(type_prop, convolution_backprop_data_auto_pad_explicit_with_output_padding)
auto_pad,
output_padding);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{1, 6, 4, 4}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape{1, 6, 4, 4}));
ASSERT_EQ(conv_backprop->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(conv_backprop->get_pads_end(), (CoordinateDiff{1, 1}));
ASSERT_EQ(conv_backprop->get_output_padding(), (CoordinateDiff{1, 1}));
@ -124,7 +124,7 @@ TEST(type_prop, convolution_backprop_data_auto_pad_same_with_output_padding_and_
auto_pad,
output_padding);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{1, 6, 3, 3}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape{1, 6, 3, 3}));
ASSERT_EQ(conv_backprop->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(conv_backprop->get_pads_end(), (CoordinateDiff{2, 2}));
ASSERT_EQ(conv_backprop->get_output_padding(), (CoordinateDiff{1, 1}));
@ -148,7 +148,7 @@ TEST(type_prop, convolution_backprop_data_output_shape_as_const) {
op::PadType::SAME_UPPER);
EXPECT_EQ(conv_backprop->get_element_type(), element::f32);
EXPECT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, 3, 3}));
EXPECT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape{1, 2, 3, 3}));
EXPECT_EQ(conv_backprop->get_strides(), (Strides{1, 1}));
EXPECT_EQ(conv_backprop->get_dilations(), (Strides{1, 1}));
EXPECT_EQ(conv_backprop->get_pads_begin(), (CoordinateDiff{2, 2}));
@ -176,8 +176,8 @@ TEST(type_prop, convolution_backprop_data_output_shape_as_param) {
EXPECT_EQ(conv_backprop->get_element_type(), element::f32);
EXPECT_EQ(conv_backprop->get_auto_pad(), op::PadType::SAME_UPPER);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(
PartialShape{1, 2, Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{1, 2, Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_static_ranks_data_nc_dyn) {
@ -197,10 +197,7 @@ TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_static_ranks_dat
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 3, 3}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape{Dimension::dynamic(), 2, 3, 3}));
}
TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_static_ranks_filters_cin_dyn) {
@ -220,10 +217,7 @@ TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_static_ranks_fil
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 6, 3, 3}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape{Dimension::dynamic(), 6, 3, 3}));
}
TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_static_ranks_filters_cin_cout_dyn) {
@ -243,11 +237,8 @@ TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_static_ranks_fil
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3}));
}
TEST(type_prop, convolution_backprop_data_dyn_static_ranks_data_nc_dyn) {
@ -265,11 +256,8 @@ TEST(type_prop, convolution_backprop_data_dyn_static_ranks_data_nc_dyn) {
auto conv_backprop =
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, padding_begin, padding_end, dilations);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 447, 447}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{Dimension::dynamic(), 2, 447, 447}));
}
TEST(type_prop, convolution_backprop_data_dyn_static_ranks_filters_cin_dyn) {
@ -287,11 +275,8 @@ TEST(type_prop, convolution_backprop_data_dyn_static_ranks_filters_cin_dyn) {
auto conv_backprop =
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, padding_begin, padding_end, dilations);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 447, 447}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{Dimension::dynamic(), 2, 447, 447}));
}
TEST(type_prop, convolution_backprop_data_dyn_static_ranks_filters_cin_cout_dyn) {
@ -309,11 +294,8 @@ TEST(type_prop, convolution_backprop_data_dyn_static_ranks_filters_cin_cout_dyn)
auto conv_backprop =
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, padding_begin, padding_end, dilations);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 447, 447}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 447, 447}));
}
TEST(type_prop, convolution_backprop_data_dyn_static_ranks_data_spatial_dims_dyn) {
@ -331,11 +313,8 @@ TEST(type_prop, convolution_backprop_data_dyn_static_ranks_data_spatial_dims_dyn
auto conv_backprop =
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, padding_begin, padding_end, dilations);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 16, Dimension::dynamic(), 447}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{Dimension::dynamic(), 16, Dimension::dynamic(), 447}));
}
TEST(type_prop, convolution_backprop_data_dyn_static_ranks_filters_spatial_dims_dyn) {
@ -353,11 +332,8 @@ TEST(type_prop, convolution_backprop_data_dyn_static_ranks_filters_spatial_dims_
auto conv_backprop =
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, padding_begin, padding_end, dilations);
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 16, 447, Dimension::dynamic()}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{Dimension::dynamic(), 16, 447, Dimension::dynamic()}));
}
TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_data_batch) {
@ -376,10 +352,7 @@ TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_data_batch) {
CoordinateDiff{},
Strides{});
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 3, 3}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape{Dimension::dynamic(), 2, 3, 3}));
}
TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_filters) {
@ -398,10 +371,7 @@ TEST(type_prop, convolution_backprop_data_with_output_shape_dyn_filters) {
CoordinateDiff{},
Strides{});
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3, 3}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape{1, Dimension::dynamic(), 3, 3}));
}
TEST(type_prop, convolution_backprop_data_with_output_shape_as_const_dyn_data_and_filters) {
@ -420,11 +390,8 @@ TEST(type_prop, convolution_backprop_data_with_output_shape_as_const_dyn_data_an
CoordinateDiff{},
Strides{});
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{5}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3, 3}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3, 3}));
}
TEST(type_prop, convolution_backprop_data_with_output_shape_as_param_dyn_data_and_filters) {
@ -443,9 +410,7 @@ TEST(type_prop, convolution_backprop_data_with_output_shape_as_param_dyn_data_an
CoordinateDiff{},
Strides{});
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape::dynamic(5)));
}
TEST(type_prop, convolution_backprop_data_shape_dyn_data) {
@ -462,11 +427,8 @@ TEST(type_prop, convolution_backprop_data_shape_dyn_data) {
CoordinateDiff{},
Strides{});
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, convolution_backprop_data_shape_dyn_filters) {
@ -483,11 +445,8 @@ TEST(type_prop, convolution_backprop_data_shape_dyn_filters) {
CoordinateDiff{},
Strides{});
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(
PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0),
PartialShape(PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, convolution_backprop_data_dyn_data_and_filters) {
@ -504,9 +463,7 @@ TEST(type_prop, convolution_backprop_data_dyn_data_and_filters) {
CoordinateDiff{},
Strides{});
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(conv_backprop->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
ASSERT_EQ(conv_backprop->get_output_partial_shape(0), PartialShape(PartialShape::dynamic()));
}
TEST(type_prop, convolution_backprop_data_invalid_et_inputs) {
@ -590,7 +547,7 @@ TEST(type_prop, convolution_backprop_data_invalid_input_ranks) {
Strides{});
FAIL() << "Incompatible input ranks not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Data and filters inputs must have same rank");
EXPECT_HAS_SUBSTRING(error.what(), "Data and filters rank do not match");
} catch (...) {
FAIL() << "Rank validation check of inputs failed for unexpected reason";
}
@ -652,7 +609,7 @@ TEST(type_prop, convolution_backprop_data_invalid_input_ranks) {
// output_shape has rank 2, should be rank 1
FAIL() << "Incompatible rank of output shape optional input not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Spatial shape of output input must be of rank 1"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input delivering output shape must have rank 1"));
} catch (...) {
FAIL() << "Output shape rank validation check failed for unexpected reason.";
}
@ -726,7 +683,7 @@ TEST(type_prop, convolution_backprop_data_invalid_conv_param_spatial_dims) {
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid strides spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Strides spatial dimensions validation check failed for unexpected reason";
}
@ -742,7 +699,7 @@ TEST(type_prop, convolution_backprop_data_invalid_conv_param_spatial_dims) {
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid strides spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Strides spatial dimensions validation check failed for unexpected reason";
}
@ -760,7 +717,7 @@ TEST(type_prop, convolution_backprop_data_invalid_conv_param_spatial_dims) {
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid dilations spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Dilations spatial dimensions validation check failed for unexpected reason";
}
@ -776,7 +733,7 @@ TEST(type_prop, convolution_backprop_data_invalid_conv_param_spatial_dims) {
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid dilations spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Dilations spatial dimensions validation check failed for unexpected reason";
}
@ -794,7 +751,7 @@ TEST(type_prop, convolution_backprop_data_invalid_conv_param_spatial_dims) {
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Pads should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Pads begin should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Padding spatial dimensions validation check failed for unexpected reason";
}
@ -810,7 +767,7 @@ TEST(type_prop, convolution_backprop_data_invalid_conv_param_spatial_dims) {
make_shared<op::v1::ConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Pads should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Pads begin should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Padding spatial dimensions validation check failed for unexpected reason";
}
@ -836,7 +793,7 @@ TEST(type_prop, convolution_backprop_data_invalid_conv_param_spatial_dims) {
output_padding);
FAIL() << "Invalid output padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Output padding should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Output padding should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Output padding spatial dimensions validation check failed for unexpected reason";
}
@ -860,7 +817,7 @@ TEST(type_prop, convolution_backprop_data_invalid_conv_param_spatial_dims) {
output_padding);
FAIL() << "Invalid output padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Output padding should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Output padding should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Output padding spatial dimensions validation check failed for unexpected reason";
}

View File

@ -25,7 +25,7 @@ TEST(type_prop, group_convolution_auto_padding_same_lower) {
auto groupConv =
make_shared<op::v1::GroupConvolution>(data_batch, filters, strides, pads_begin, pads_end, dilations, auto_pad);
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, 5, 5}));
ASSERT_EQ(groupConv->get_output_partial_shape(0), PartialShape({1, 2, 5, 5}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -46,7 +46,7 @@ TEST(type_prop, group_convolution_auto_padding_same_upper) {
auto conv =
make_shared<op::v1::GroupConvolution>(data_batch, filters, strides, pads_begin, pads_end, dilations, auto_pad);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, 5, 5}));
ASSERT_EQ(conv->get_output_partial_shape(0), PartialShape({1, 2, 5, 5}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{0, 0}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -67,7 +67,7 @@ TEST(type_prop, group_convolution_auto_padding_same_lower_spatial_dims_static) {
Strides{},
auto_pad);
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme({Dimension::dynamic(), Dimension::dynamic(), 5, 5}));
ASSERT_EQ(groupConv->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), Dimension::dynamic(), 5, 5}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -88,7 +88,7 @@ TEST(type_prop, group_convolution_auto_padding_same_upper_spatial_dims_static) {
Strides{},
auto_pad);
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 5, 5}));
ASSERT_EQ(groupConv->get_output_partial_shape(0), PartialShape({1, Dimension::dynamic(), 5, 5}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{0, 0}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -109,7 +109,7 @@ TEST(type_prop, group_convolution_static_ranks_filters_groups_dyn) {
Strides{},
auto_pad);
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme({Dimension::dynamic(), 2, 5, 5}));
ASSERT_EQ(groupConv->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 2, 5, 5}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -130,7 +130,7 @@ TEST(type_prop, group_convolution_static_ranks_filters_groups_cout_dyn) {
Strides{},
auto_pad);
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme({Dimension::dynamic(), Dimension::dynamic(), 5, 5}));
ASSERT_EQ(groupConv->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), Dimension::dynamic(), 5, 5}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -151,7 +151,7 @@ TEST(type_prop, group_convolution_static_ranks_data_cin_filters_group_dyn) {
Strides{},
auto_pad);
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme({Dimension::dynamic(), Dimension::dynamic(), 5, 5}));
ASSERT_EQ(groupConv->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), Dimension::dynamic(), 5, 5}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -172,7 +172,7 @@ TEST(type_prop, group_convolution_auto_padding_same_spatial_dims_dynamic) {
Strides{},
auto_pad);
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme({1, 2, Dimension::dynamic(), 5}));
ASSERT_EQ(groupConv->get_output_partial_shape(0), PartialShape({1, 2, Dimension::dynamic(), 5}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{0, 1}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{0, 1}));
}
@ -196,8 +196,8 @@ TEST(type_prop, group_convolution_data_batch_dynamic) {
ASSERT_EQ(groupConv->get_dilations(), (Strides{1, 1}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{0, 0}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{0, 0}));
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(groupConv->get_output_partial_shape(0),
PartialShape({Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, group_convolution_filters_dynamic_auto_pad_explicit) {
@ -219,8 +219,8 @@ TEST(type_prop, group_convolution_filters_dynamic_auto_pad_explicit) {
ASSERT_EQ(groupConv->get_dilations(), (Strides{1, 1}));
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{0, 0}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{0, 0}));
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme(
PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(groupConv->get_output_partial_shape(0),
PartialShape({1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, group_convolution_filters_dynamic_auto_pad_same) {
@ -243,8 +243,8 @@ TEST(type_prop, group_convolution_filters_dynamic_auto_pad_same) {
// pads should be as default since filters shape is dynamic
ASSERT_EQ(groupConv->get_pads_begin(), (CoordinateDiff{0, 0}));
ASSERT_EQ(groupConv->get_pads_end(), (CoordinateDiff{0, 0}));
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme(
PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(groupConv->get_output_partial_shape(0),
PartialShape({1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, group_convolution_data_batch_and_filters_dynamic) {
@ -260,7 +260,7 @@ TEST(type_prop, group_convolution_data_batch_and_filters_dynamic) {
CoordinateDiff{},
Strides{});
ASSERT_TRUE(groupConv->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
ASSERT_EQ(groupConv->get_output_partial_shape(0), PartialShape::dynamic());
}
TEST(type_prop, group_convolution_invalid_et_inputs) {
@ -322,7 +322,7 @@ TEST(type_prop, group_convolution_invalid_input_ranks) {
// data and weight have incompatible ranks
FAIL() << "Incompatible input ranks not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for data batch and filters do not match."));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match"));
} catch (...) {
FAIL() << "Rank validation check of inputs failed for unexpected reason";
}
@ -341,7 +341,7 @@ TEST(type_prop, group_convolution_invalid_input_ranks) {
// data and weight have incompatible ranks
FAIL() << "Incompatible input ranks not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for data batch and filters do not match."));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match"));
} catch (...) {
FAIL() << "Rank validation check of inputs failed for unexpected reason";
}
@ -413,7 +413,7 @@ TEST(type_prop, group_convolution_invalid_conv_param_spatial_dims) {
make_shared<op::v1::GroupConvolution>(data_batch, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid strides spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Strides spatial dimensions validation check failed for unexpected reason";
}
@ -429,7 +429,7 @@ TEST(type_prop, group_convolution_invalid_conv_param_spatial_dims) {
make_shared<op::v1::GroupConvolution>(data_batch, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid strides spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Strides spatial dimensions validation check failed for unexpected reason";
}
@ -447,7 +447,7 @@ TEST(type_prop, group_convolution_invalid_conv_param_spatial_dims) {
make_shared<op::v1::GroupConvolution>(data_batch, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid dilations spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Dilations spatial dimensions validation check failed for unexpected reason";
}
@ -463,7 +463,7 @@ TEST(type_prop, group_convolution_invalid_conv_param_spatial_dims) {
make_shared<op::v1::GroupConvolution>(data_batch, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid dilations spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Dilations spatial dimensions validation check failed for unexpected reason";
}
@ -481,7 +481,7 @@ TEST(type_prop, group_convolution_invalid_conv_param_spatial_dims) {
make_shared<op::v1::GroupConvolution>(data_batch, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Pads should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Pads begin should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Padding spatial dimensions validation check failed for unexpected reason";
}
@ -497,7 +497,7 @@ TEST(type_prop, group_convolution_invalid_conv_param_spatial_dims) {
make_shared<op::v1::GroupConvolution>(data_batch, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Pads should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Pads begin should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Padding spatial dimensions validation check failed for unexpected reason";
}

View File

@ -75,8 +75,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_as
ASSERT_EQ(gcbd->get_element_type(), element::f32);
ASSERT_EQ(gcbd->get_auto_pad(), op::PadType::SAME_UPPER);
ASSERT_TRUE(
gcbd->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{1, 2, Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_static_ranks_data_nc_dyn) {
@ -94,10 +93,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_st
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 3, 3}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 2, 3, 3}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_static_ranks_filters_group_dyn) {
@ -115,10 +111,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_st
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 3, 3}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 2, 3, 3}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_static_ranks_filters_group_cin_dyn) {
@ -140,11 +133,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_st
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
gcbd->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_static_ranks_data_cin_filters_group_dyn) {
@ -162,10 +151,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_st
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3, 3}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 3, 3}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_static_ranks_filters_group_cout_dyn) {
@ -187,11 +173,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_st
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
gcbd->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_data_nc_dyn) {
@ -213,10 +195,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_data_nc
padding_end,
dilations);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 8, 447, 447}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 8, 447, 447}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_filters_group_dyn) {
@ -238,10 +217,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_filters
padding_end,
dilations);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{1, 8, 447, 447}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{1, 8, 447, 447}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_filters_group_cin_dyn) {
@ -267,11 +243,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_filters
padding_end,
dilations);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 447, 447}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), Dimension::dynamic(), 447, 447}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_data_cin_filters_group_dyn) {
@ -293,10 +265,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_data_ci
padding_end,
dilations);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 447, 447}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 447, 447}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_filters_group_cout_dyn) {
@ -322,10 +291,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_filters
padding_end,
dilations);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 447, 447}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 447, 447}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_data_spatial_dim_dyn) {
@ -347,10 +313,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_data_sp
padding_end,
dilations);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{1, 8, Dimension::dynamic(), 447}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{1, 8, Dimension::dynamic(), 447}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_filters_spatial_dim_dyn) {
@ -372,11 +335,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_static_ranks_filters
padding_end,
dilations);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 8, 447, Dimension::dynamic()}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 8, 447, Dimension::dynamic()}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_data_dyn) {
@ -394,10 +353,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_da
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 3, 3}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 2, 3, 3}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_data_dyn) {
@ -414,11 +370,8 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_data_dyn) {
CoordinateDiff{},
Strides{});
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 8, Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(gcbd->get_output_partial_shape(0),
(PartialShape{Dimension::dynamic(), 8, Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_filters_dyn) {
@ -436,10 +389,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_fi
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3, 3}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 3, 3}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_filters_dyn) {
@ -456,11 +406,8 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_filters_dyn) {
CoordinateDiff{},
Strides{});
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(
PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(gcbd->get_output_partial_shape(0),
(PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_as_const_data_and_filters_dyn) {
@ -478,11 +425,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_as
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().same_scheme(Rank{5}));
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3, 3}));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3, 3, 3}));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_as_param_data_and_filters_dyn) {
@ -500,9 +443,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_with_output_shape_as
Strides{},
op::PadType::SAME_UPPER);
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape::dynamic(5)));
}
TEST(type_prop, group_convolution_backprop_data_shape_infer_data_and_filters_dyn) {
@ -519,9 +460,7 @@ TEST(type_prop, group_convolution_backprop_data_shape_infer_data_and_filters_dyn
CoordinateDiff{},
Strides{});
ASSERT_TRUE(gcbd->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(gcbd->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
ASSERT_EQ(gcbd->get_output_partial_shape(0), (PartialShape::dynamic()));
}
TEST(type_prop, group_convolution_backprop_data_invalid_et_inputs) {
@ -615,7 +554,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_input_ranks) {
// data and filters have incompatible ranks
FAIL() << "Incompatible input ranks not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for data batch and filters do not match."));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data and filters rank do not match"));
} catch (...) {
FAIL() << "Rank validation check of inputs failed for unexpected reason";
}
@ -639,7 +578,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_input_ranks) {
// data and weight have incompatible ranks
FAIL() << "Incompatible input ranks not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for data batch and filters do not match."));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data and filters rank do not match"));
} catch (...) {
FAIL() << "Rank validation check of inputs failed for unexpected reason";
}
@ -662,7 +601,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_input_ranks) {
// Output shape optional input must be of rank 1
FAIL() << "Incompatible output shape input rank not detected.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Spatial shape of output input must be of rank 1"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input delivering output shape must have rank 1"));
} catch (...) {
FAIL() << "Rank validation check of inputs failed for unexpected reason";
}
@ -685,7 +624,9 @@ TEST(type_prop, group_convolution_backprop_data_invalid_input_channel_dims) {
// data batch shape does not have correct dimension C_IN * GROUPS
FAIL() << "Incompatibile input shapes not detected.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Number of data channels not a multiple of group size."));
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Input channels dimension of data batch has incompatible value with filter shape."));
} catch (...) {
FAIL() << "Input shapes validation check failed for unexpected reason.";
}
@ -702,9 +643,9 @@ TEST(type_prop, group_convolution_backprop_data_invalid_input_channel_dims) {
// dimension C_IN * GROUPS = 16
FAIL() << "Incompatibile input shapes not detected.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data second dimension has incompatible value "
"with number of input channels."));
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Input channels dimension of data batch has incompatible value with filter shape."));
} catch (...) {
FAIL() << "Input shapes validation check failed for unexpected reason.";
}
@ -753,7 +694,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_conv_param_spatial_dims)
make_shared<op::v1::GroupConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid strides spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Strides spatial dimensions validation check failed for unexpected reason";
}
@ -769,7 +710,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_conv_param_spatial_dims)
make_shared<op::v1::GroupConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid strides spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Strides should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Strides spatial dimensions validation check failed for unexpected reason";
}
@ -787,7 +728,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_conv_param_spatial_dims)
make_shared<op::v1::GroupConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid dilations spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Dilations spatial dimensions validation check failed for unexpected reason";
}
@ -803,7 +744,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_conv_param_spatial_dims)
make_shared<op::v1::GroupConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid dilations spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Dilations should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Dilations spatial dimensions validation check failed for unexpected reason";
}
@ -821,7 +762,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_conv_param_spatial_dims)
make_shared<op::v1::GroupConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Pads should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Pads begin should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Padding spatial dimensions validation check failed for unexpected reason";
}
@ -837,7 +778,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_conv_param_spatial_dims)
make_shared<op::v1::GroupConvolutionBackpropData>(data, filters, strides, pads_begin, pads_end, dilations);
FAIL() << "Invalid padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Pads should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "Pads begin should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Padding spatial dimensions validation check failed for unexpected reason";
}
@ -863,7 +804,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_conv_param_spatial_dims)
output_padding);
FAIL() << "Invalid output padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Output padding should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Output padding spatial dimensions validation check failed for unexpected reason";
}
@ -887,7 +828,7 @@ TEST(type_prop, group_convolution_backprop_data_invalid_conv_param_spatial_dims)
output_padding);
FAIL() << "Invalid output padding spatial dimensions not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Output padding should be defined for all and only spatial features.");
EXPECT_HAS_SUBSTRING(error.what(), "should be defined for all and only spatial dimensions.");
} catch (...) {
FAIL() << "Output padding spatial dimensions validation check failed for unexpected reason";
}

View File

@ -49,7 +49,7 @@ TEST(attributes, convolution_backprop_output_shape_output_padding) {
NodeBuilder::get_ops().register_factory<opset1::ConvolutionBackpropData>();
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 16, 124, 124});
const auto filter = make_shared<op::Parameter>(element::f32, Shape{16, 2, 3, 3});
const auto output_shape = make_shared<op::Parameter>(element::i32, Shape{1});
const auto output_shape = make_shared<op::Parameter>(element::i32, Shape{2});
const auto strides = Strides{2, 1};
const auto pads_begin = CoordinateDiff{3, 4};

View File

@ -44,7 +44,7 @@ TEST(attributes, group_conv_backprop_data_op) {
NodeBuilder::get_ops().register_factory<opset1::GroupConvolutionBackpropData>();
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 20, 224, 224});
const auto filter = make_shared<op::Parameter>(element::f32, Shape{4, 5, 2, 3, 3});
const auto output_shape = make_shared<op::Parameter>(element::i32, Shape{1});
const auto output_shape = make_shared<op::Parameter>(element::i32, Shape{2});
const auto strides = Strides{2, 1};
const auto pads_begin = CoordinateDiff{3, 4};