Convolution: fast shape inference (#7523)

* Convolution: fast shape inference

* StaticShape and StaticDimension + static shape infer time test in comparison to Convolution

* Review comments
This commit is contained in:
Evgenya Stepyreva 2021-09-27 00:14:50 +03:00 committed by GitHub
parent c92988c8e9
commit ef028a567e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 681 additions and 30 deletions

View File

@ -42,7 +42,8 @@ endif()
target_link_libraries(${TARGET_NAME} PRIVATE mkldnn
inference_engine
inference_engine_transformations
inference_engine_lp_transformations)
inference_engine_lp_transformations
ov_shape_inference)
target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_EXTENSION_API)
@ -73,6 +74,7 @@ target_include_directories(${TARGET_NAME}_obj PRIVATE $<TARGET_PROPERTY:inferenc
$<TARGET_PROPERTY:inference_engine_transformations,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:openvino::itt,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:inference_engine_lp_transformations,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:ov_shape_inference,INTERFACE_INCLUDE_DIRECTORIES>
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}
$<TARGET_PROPERTY:openvino::conditional_compilation,INTERFACE_INCLUDE_DIRECTORIES>)

View File

@ -0,0 +1,91 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "static_dimension.hpp"
using namespace ov;
std::ostream& ov::operator<<(std::ostream& str, const StaticDimension& dimension) {
return str << dimension.get_length();
}
StaticDimension::StaticDimension(value_type dimension)
: m_dimension(dimension) {}
bool StaticDimension::operator==(const StaticDimension& dim) const {
return m_dimension == dim.m_dimension;
}
bool StaticDimension::operator!=(const StaticDimension& dim) const {
return m_dimension != dim.m_dimension;
}
StaticDimension StaticDimension::operator+(const StaticDimension& dim) const {
return StaticDimension(m_dimension + dim.m_dimension);
}
StaticDimension& StaticDimension::operator+=(const StaticDimension& dim) {
return (*this = *this + dim);
}
StaticDimension StaticDimension::operator-(const StaticDimension& dim) const {
return StaticDimension(m_dimension - dim.m_dimension);
}
StaticDimension StaticDimension::operator*(const StaticDimension& dim) const {
return StaticDimension(m_dimension * dim.m_dimension);
}
StaticDimension& StaticDimension::operator*=(const StaticDimension& dim) {
return (*this = *this * dim);
}
StaticDimension StaticDimension::operator&(const StaticDimension& dim) const {
return (*this == dim) ? dim : 0;
}
StaticDimension& StaticDimension::operator&=(const StaticDimension& dim) {
if (*this != dim)
m_dimension = 0;
return *this;
}
bool StaticDimension::compatible(const StaticDimension& dim) const {
return m_dimension == dim.m_dimension;
}
bool StaticDimension::same_scheme(const StaticDimension& dim) const {
return m_dimension == dim.m_dimension;
}
bool StaticDimension::merge(StaticDimension& dst, const StaticDimension& d1, const StaticDimension& d2) {
if (d1 != d2)
return false;
dst = d1;
return true;
}
bool StaticDimension::broadcast_merge(StaticDimension& dst, const StaticDimension& d1, const StaticDimension& d2) {
if (d1 == 1) {
dst = d2;
return true;
}
if (d2 == 1) {
dst = d1;
return true;
}
return merge(dst, d1, d2);
}
StaticDimension::value_type StaticDimension::get_length() const {
return m_dimension;
}
StaticDimension::value_type StaticDimension::get_max_length() const {
return m_dimension;
}
StaticDimension::value_type StaticDimension::get_min_length() const {
return m_dimension;
}

View File

@ -0,0 +1,56 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cstddef>
#include <limits>
#include <stdexcept>
#include <ostream>
namespace ov {
/// \brief Class representing a dimension, which must be static,
/// in a shape or shape-like object.
///
/// Provides similar API to the public Dimension class.
class StaticDimension {
public:
using value_type = size_t;
/// \brief Construct a static dimension.
/// \param dimension Value of the dimension.
StaticDimension(value_type dimension);
/// \brief Construct a zero dimension
StaticDimension() = default;
bool operator==(const StaticDimension& dimension) const;
bool operator!=(const StaticDimension& dimension) const;
static bool is_static() { return true; }
static bool is_dynamic() { return false; }
value_type get_length() const;
value_type get_min_length() const;
value_type get_max_length() const;
bool same_scheme(const StaticDimension& dim) const;
bool compatible(const StaticDimension& d) const;
static bool merge(StaticDimension& dst, const StaticDimension& d1, const StaticDimension& d2);
static bool broadcast_merge(StaticDimension& dst, const StaticDimension& d1, const StaticDimension& d2);
StaticDimension operator+(const StaticDimension& dim) const;
StaticDimension operator-(const StaticDimension& dim) const;
StaticDimension operator*(const StaticDimension& dim) const;
StaticDimension operator&(const StaticDimension& dim) const;
StaticDimension& operator+=(const StaticDimension& dim);
StaticDimension& operator*=(const StaticDimension& dim);
StaticDimension& operator&=(const StaticDimension& dim);
private:
value_type m_dimension = 0;
};
std::ostream& operator<<(std::ostream& str, const StaticDimension& dimension);
} // namespace ov

View File

@ -0,0 +1,160 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "static_shape.hpp"
using namespace ov;
ov::StaticShape::StaticShape(std::vector<StaticDimension> dimensions)
: std::vector<StaticDimension>(std::move(dimensions)) {}
ov::StaticShape::StaticShape(const std::vector<StaticDimension::value_type>& dimensions)
: std::vector<StaticDimension>(dimensions.begin(), dimensions.end()) {}
ov::StaticShape::StaticShape(std::initializer_list<StaticDimension> init)
: std::vector<StaticDimension>(init.begin(), init.end()) {}
ov::Shape ov::StaticShape::get_max_shape() const {
return (*this).to_shape();
}
ov::Shape ov::StaticShape::get_min_shape() const {
return (*this).to_shape();
}
ov::Shape ov::StaticShape::get_shape() const {
return (*this).to_shape();
}
ov::StaticShape ov::operator+(const StaticShape& s1, const StaticShape& s2) {
if (s1.size() != s2.size()) {
throw std::invalid_argument("rank mismatch");
}
std::vector<StaticDimension> result(s1.size());
for (size_t i = 0; i < s1.size(); ++i)
result[i] = (s1[i] + s2[i]);
return result;
}
std::ostream& ov::operator<<(std::ostream& str, const StaticShape& shape) {
str << "{";
bool first = true;
for (const auto& d : shape) {
if (!first) str << ",";
str << d;
first = false;
}
return (str << "}");
}
bool ov::StaticShape::compatible(const StaticShape& s) const {
if (size() != s.size())
return false;
for (size_t i = 0; i < size(); ++i)
if (!((*this)[i]).compatible(s[i]))
return false;
return true;
}
bool ov::StaticShape::same_scheme(const StaticShape& s) const {
if (size() != s.size())
return false;
for (size_t i = 0; i < size(); ++i)
if (!((*this)[i]).same_scheme(s[i]))
return false;
return true;
}
bool ov::StaticShape::merge_rank(Rank r) {
if (r.is_dynamic()) {
return true;
} else {
return (static_cast<int64_t>(size()) == r.get_length());
}
}
ov::Shape ov::StaticShape::to_shape() const {
std::vector<size_t> shape_dimensions(size());
std::transform(begin(), end(), shape_dimensions.begin(), [](const StaticDimension& d) {
return d.get_length();
});
return shape_dimensions;
}
bool ov::StaticShape::merge_into(StaticShape& dst, const StaticShape& src) {
if (dst.size() != src.size())
return false;
bool success = true;
for (size_t i = 0; i < dst.size(); ++i)
success &= StaticDimension::merge(dst[i], dst[i], src[i]);
return success;
}
bool ov::StaticShape::broadcast_merge_into(StaticShape& dst,
const StaticShape& src,
const ngraph::op::AutoBroadcastSpec& autob) {
switch (autob.m_type) {
case ngraph::op::AutoBroadcastType::NONE:
return true;
case ngraph::op::AutoBroadcastType::NUMPY: {
auto dst_rank = dst.size();
auto src_rank = src.size();
auto new_rank = std::max(dst_rank, src_rank);
std::vector<StaticDimension> dims(new_rank);
bool success = true;
for (int64_t i = 0; i < new_rank; i++) {
auto dsti = i < (new_rank - dst_rank) ? StaticDimension(1) : dst[i - (new_rank - dst_rank)];
auto srci = i < (new_rank - src_rank) ? StaticDimension(1) : src[i - (new_rank - src_rank)];
success &= StaticDimension::broadcast_merge(dims[i], dsti, srci);
}
dst = StaticShape(std::move(dims));
return success;
}
case ngraph::op::AutoBroadcastType::PDPD: {
// Ranks are both static.
auto dst_rank = dst.rank().get_length();
auto src_rank = src.rank().get_length();
if (dst_rank == src_rank && dst.compatible(src))
return true;
int64_t axis = autob.m_axis;
if (axis < -1) {
return false;
}
if (axis == -1) {
axis = dst_rank - src_rank;
}
size_t len = src_rank;
while (len > 0 && src[len - 1].is_static() && src[len - 1].get_length() == 1) {
--len;
}
for (size_t i = axis; i < axis + len; ++i) {
if (!(dst[i].compatible(src[i - axis]))) {
return false;
}
}
return true;
}
default:
NGRAPH_CHECK(false, "Unsupported auto broadcast type: ", autob.m_type);
}
return false;
}
bool ov::StaticShape::operator==(const StaticShape& shape) const {
if (size() != shape.size())
return false;
for (auto i = 0; i < size(); ++i)
if ((*this)[i] != shape[i])
return false;
return true;
}
bool ov::StaticShape::operator!=(const StaticShape& partial_shape) const {
return !(*this == partial_shape);
}

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cstddef>
#include "ngraph/op/util/attr_types.hpp"
#include "openvino/core/attribute_adapter.hpp"
#include "static_dimension.hpp"
#include "openvino/core/rank.hpp"
#include "openvino/core/shape.hpp"
namespace ov {
namespace op {
struct AutoBroadcastSpec;
}
/// \brief Class representing a shape that must be totally static.
class StaticShape : public std::vector<StaticDimension> {
public:
StaticShape(std::initializer_list<StaticDimension> init);
StaticShape(const std::vector<StaticDimension::value_type>& dimensions);
StaticShape(std::vector<StaticDimension> dimensions);
static bool is_static() { return true; }
static bool is_dynamic() { return false; }
Rank rank() const { return Rank(size()); }
bool compatible(const StaticShape& s) const;
bool same_scheme(const StaticShape& s) const;
bool refines(const StaticShape& s) const;
bool merge_rank(Rank r);
Shape to_shape() const;
friend std::ostream& operator<<(std::ostream& str, const StaticShape& shape);
friend StaticShape operator+(const StaticShape& s1, const StaticShape& s2);
bool operator==(const StaticShape& shape) const;
bool operator!=(const StaticShape& shape) const;
/// Get the max bounding shape
Shape get_max_shape() const;
/// Get the min bounding shape
Shape get_min_shape() const;
/// Get the unique shape
Shape get_shape() const;
static bool merge_into(StaticShape& dst, const StaticShape& src);
static bool broadcast_merge_into(StaticShape& dst,
const StaticShape& src,
const ngraph::op::AutoBroadcastSpec& autob);
};
StaticShape operator+(const StaticShape& s1, const StaticShape& s2);
std::ostream& operator<<(std::ostream& str, const StaticShape& shape);
} // namespace ov

View File

@ -18,6 +18,7 @@ addIeTargetTest(
mkldnn
inference_engine_transformations
inference_engine_lp_transformations
ov_shape_inference
inference_engine_s
ADD_CPPLINT
LABELS

View File

@ -0,0 +1,85 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <openvino/core/coordinate_diff.hpp>
#include <openvino/op/convolution.hpp>
#include <openvino/op/parameter.hpp>
#include <convolution_shape_inference.hpp>
#include <openvino/op/ops.hpp>
#include "utils/shape_inference/static_shape.hpp"
using namespace ov;
TEST(StaticShapeInferenceTest, ConvolutionTest) {
Strides strides{1, 1};
CoordinateDiff pads_begin{0, 0};
CoordinateDiff pads_end{0, 0};
Strides dilations{1, 1};
const auto auto_pad = op::PadType::SAME_LOWER;
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto filters = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto conv =
std::make_shared<op::v1::Convolution>(data, filters, strides, pads_begin, pads_end, dilations, auto_pad);
std::vector<PartialShape> input_shapes = {PartialShape{3, 6, 5, 5}, PartialShape{7, 6, 3, 3}}, output_shapes = {PartialShape{}};
shape_infer(conv.get(), input_shapes, output_shapes);
ASSERT_EQ(output_shapes[0], PartialShape({3, 7, 5, 5}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1}));
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 6, 5, 5}, StaticShape{7, 6, 3, 3}}, static_output_shapes = {StaticShape{}};
shape_infer(conv.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 7, 5, 5}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1}));
}
#if 0
TEST(StaticShapeInferenceTest, ConvolutionTimeTest) {
Strides strides{1, 1};
CoordinateDiff pads_begin{0, 0};
CoordinateDiff pads_end{0, 0};
Strides dilations{1, 1};
const auto auto_pad = op::PadType::SAME_LOWER;
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{3, 6, 5, 5});
auto filters = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{7, 6, 3, 3});
auto conv =
std::make_shared<op::v1::Convolution>(data, filters, strides, pads_begin, pads_end, dilations, auto_pad);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 6, 5, 5}, StaticShape{7, 6, 3, 3}}, static_output_shapes = {StaticShape{}};
auto before = std::chrono::high_resolution_clock::now();
auto after = std::chrono::high_resolution_clock::now();
std::cout << conv << std::endl;
auto convolution_time_sum = 0;
for (size_t i = 0; i < 10; ++i) {
before = std::chrono::high_resolution_clock::now();
shape_infer(conv.get(), static_input_shapes, static_output_shapes);
after = std::chrono::high_resolution_clock::now();
auto diff = std::chrono::duration_cast<std::chrono::nanoseconds>(after - before).count();
std::cout << diff << " ns" << std::endl;
convolution_time_sum += diff;
}
// other operation creation and time measurements: ReLU is an example
auto relu = std::make_shared<op::v0::Relu>(data);
std::cout << relu << std::endl;
auto other_op_time_sum = 0;
for (size_t i = 0; i < 10; ++i) {
before = std::chrono::high_resolution_clock::now();
relu->validate_and_infer_types();
after = std::chrono::high_resolution_clock::now();
auto diff = std::chrono::duration_cast<std::chrono::nanoseconds>(after - before).count();
std::cout << diff << " ns" << std::endl;
other_op_time_sum += diff;
}
std::cout << (convolution_time_sum >= other_op_time_sum ? "ON PAR WITH CONVOLUTION: " : "LONGER THAN CONVOLUTION ")
<< 1. * other_op_time_sum / convolution_time_sum << std::endl;
}
#endif

View File

@ -10,6 +10,7 @@ file(GLOB_RECURSE PUBLIC_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
add_subdirectory(builder)
add_subdirectory(reference)
add_subdirectory(shape_inference)
# Create named folders for the sources within the .vcproj
# Empty name lists them directly under the .vcproj
@ -39,7 +40,7 @@ endif()
addVersionDefines(src/version.cpp CI_BUILD_NUMBER)
target_link_libraries(ngraph PRIVATE ngraph::builder ngraph::reference openvino::util)
target_link_libraries(ngraph PRIVATE ngraph::builder ngraph::reference openvino::util ov_shape_inference)
ie_mark_target_as_cc(ngraph)

View File

@ -293,6 +293,18 @@ public:
return m_dimensions.crend();
}
/// \brief Resizes dimensions container to contain count elements
void resize(size_t count) {
m_dimensions.resize(count);
m_rank_is_static = true;
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
}
/// \brief Returns size of dimension vector. Requires rank to be static
size_t size() const {
OPENVINO_ASSERT(rank().is_static());
return m_dimensions.size();
}
private:
// Private constructor for PartialShape::dynamic().
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions);

View File

@ -98,6 +98,11 @@ protected:
CoordinateDiff m_pads_begin;
CoordinateDiff m_pads_end;
PadType m_auto_pad;
int64_t m_num_spatial = -1;
private:
template <class T>
friend void shape_infer(Convolution* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes);
};
/// \brief Data batch backprop for batched convolution operation.

View File

@ -0,0 +1,23 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
set(TARGET_NAME "ov_shape_inference")
file(GLOB_RECURSE PUBLIC_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
set(SHAPE_INFER_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include")
# Create named folders for the sources within the .vcproj
# Empty name lists them directly under the .vcproj
source_group("include" FILES ${PUBLIC_HEADERS})
# Create shared library
add_library(${TARGET_NAME} INTERFACE)
# Defines macro in C++ to load backend plugin
target_include_directories(${TARGET_NAME} INTERFACE ${SHAPE_INFER_INCLUDE_DIR} ${NGRAPH_INCLUDE_PATH})
# developer package
openvino_developer_export_targets(COMPONENT ngraph TARGETS ${TARGET_NAME})

View File

@ -0,0 +1,172 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/op/convolution.hpp>
namespace ov {
namespace op {
namespace v1 {
template<class T>
void shape_infer(Convolution* op, const std::vector<T> &input_shapes, std::vector<T> &output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
auto input_shape = input_shapes[0], filters_shape = input_shapes[1];
auto& dilations = op->m_dilations;
auto& strides = op->m_strides;
auto& num_spatial = op->m_num_spatial;
auto& pad_begin = op->m_pads_begin, &pad_end = op->m_pads_end;
const auto& auto_pad = op->m_auto_pad;
calculate_num_spatial_dims_and_update_attributes(op, input_shape, filters_shape, dilations,
strides, pad_begin, pad_end, auto_pad, num_spatial);
if (num_spatial < 1)
return;
// ranks are originally static or aligned with num_spatials, attributes are valid
auto& output_shape = output_shapes[0];
output_shape.resize(num_spatial + 2);
output_shape[0] = input_shape[0];
output_shape[1] = filters_shape[0];
NODE_VALIDATION_CHECK(
op,
input_shape[1].is_dynamic() || filters_shape[1].is_dynamic() || input_shape[1] == filters_shape[1],
"Data batch channel count (",
input_shape[1],
") does not match filter input ",
"channel count (",
filters_shape[1],
").");
for (int64_t i = 0; i < num_spatial; ++i) {
const auto& input_dim = input_shape[i + 2];
const auto& filters_dim = filters_shape[i + 2];
if (input_dim.is_static() && filters_dim.is_static()) {
const int64_t& window_dilated_dim = (filters_dim.get_length() - 1) * dilations[i] + 1;
NODE_VALIDATION_CHECK(op,
window_dilated_dim > 0,
"Window after dilation has dimension less than 1 (dim: ",
window_dilated_dim,
") at axis ",
i,
".");
if (auto_pad == op::PadType::SAME_UPPER || auto_pad == op::PadType::SAME_LOWER) {
const int64_t& image_size = input_dim.get_length();
const int64_t& filter_stride = strides[i];
const int64_t& output_size = (image_size + filter_stride - 1) / filter_stride;
const int64_t& tmp = (output_size - 1) * filter_stride + window_dilated_dim;
const int64_t& padding_needed = tmp > image_size ? tmp - image_size : 0;
const size_t& padding_lhs = static_cast<size_t>(padding_needed / 2);
const size_t& padding_rhs = static_cast<size_t>(padding_needed - padding_lhs);
pad_begin[i] = auto_pad == op::PadType::SAME_UPPER ? padding_lhs : padding_rhs;
pad_end[i] = auto_pad == op::PadType::SAME_UPPER ? padding_rhs : padding_lhs;
}
const int64_t& data_padded_dilated_dim = input_dim.get_length() + pad_begin[i] + pad_end[i];
NODE_VALIDATION_CHECK(op,
window_dilated_dim <= data_padded_dilated_dim,
"Window after dilation has dimension (dim: ",
window_dilated_dim,
") larger than the data shape after padding (dim: ",
data_padded_dilated_dim,
") at axis ",
i,
".");
output_shape[i + 2] = (data_padded_dilated_dim - window_dilated_dim) / strides[i] + 1;
}
}
}
template <class ShapeType>
void calculate_num_spatial_dims_and_update_attributes(Convolution* op,
ShapeType& input_shape,
ShapeType& filters_shape,
Strides& dilations,
Strides& strides,
CoordinateDiff& pad_begin,
CoordinateDiff& pad_end,
const op::PadType& auto_pad,
int64_t& num_spatial) {
const auto &input_rank = input_shape.rank();
const auto &filters_rank = filters_shape.rank();
if (num_spatial == -1) {
if (const auto &size = dilations.size())
num_spatial = static_cast<int64_t>(size);
if (const auto &size = strides.size())
num_spatial = static_cast<int64_t>(size);
if (const auto &size = pad_begin.size())
num_spatial = static_cast<int64_t>(size);
if (const auto &size = pad_end.size())
num_spatial = static_cast<int64_t>(size);
if (input_rank.is_static())
num_spatial = input_rank.get_length() - 2;
if (filters_rank.is_static())
num_spatial = filters_rank.get_length() - 2;
if (num_spatial == -1)
return; // can not deduce output rank
if (strides.empty())
strides = Strides(num_spatial, 1);
if (dilations.empty())
dilations = Strides(num_spatial, 1);
if (pad_begin.empty() || auto_pad == op::PadType::VALID)
pad_begin = CoordinateDiff(num_spatial, 0);
if (pad_end.empty() || auto_pad == op::PadType::VALID)
pad_end = CoordinateDiff(num_spatial, 0);
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(strides.size()) == num_spatial,
"Strides should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(dilations.size()) == num_spatial,
"Dilations should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(op,
static_cast<int64_t>(pad_begin.size()) == num_spatial &&
static_cast<int64_t>(pad_end.size()) == num_spatial,
"Pads should be defined for all and only spatial features.");
NODE_VALIDATION_CHECK(op,
std::all_of(dilations.begin(),
dilations.end(),
[](const size_t &i) {
return i > 0;
}),
"Filter dilation (",
dilations,
") has zero dimension.");
NODE_VALIDATION_CHECK(op,
std::all_of(strides.begin(),
strides.end(),
[](const size_t &i) {
return i > 0;
}),
"Filter strides (",
strides,
") has zero dimension.");
}
if (input_rank.is_dynamic())
input_shape.resize(num_spatial + 2);
if (filters_rank.is_dynamic())
filters_shape.resize(num_spatial + 2);
NODE_VALIDATION_CHECK(op,
(static_cast<int64_t>(input_shape.size()) == (num_spatial + 2)) &&
(static_cast<int64_t>(filters_shape.size()) == (num_spatial + 2)),
"Data batch and filters rank do not match (data batch shape: ",
input_shape,
", filters shape: ",
filters_shape,
").");
}
}
}
}

View File

@ -4,6 +4,8 @@
#include "ngraph/op/convolution.hpp"
#include <convolution_shape_inference.hpp>
#include "itt.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/coordinate_diff.hpp"
@ -45,9 +47,7 @@ bool op::v1::Convolution::visit_attributes(AttributeVisitor& visitor) {
void op::v1::Convolution::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v1_Convolution_validate_and_infer_types);
const ov::PartialShape& data_batch_pshape = get_input_partial_shape(0);
element::Type data_batch_et = get_input_element_type(0);
const ov::PartialShape& filters_pshape = get_input_partial_shape(1);
element::Type filters_et = get_input_element_type(1);
element::Type result_et;
@ -64,24 +64,10 @@ void op::v1::Convolution::validate_and_infer_types() {
"Element types must be numeric. Got: ",
result_et);
Rank result_ps_rank;
NODE_VALIDATION_CHECK(this,
Rank::merge(result_ps_rank, data_batch_pshape.rank(), filters_pshape.rank()),
"Data batch and filters inputs must have same rank. Got: ",
data_batch_pshape,
" and ",
filters_pshape);
ov::PartialShape result_shape = validate_and_infer_convolution_forward_output_shape(this,
result_ps_rank,
data_batch_pshape,
filters_pshape,
m_auto_pad,
m_strides,
m_dilations,
m_pads_begin,
m_pads_end);
set_output_type(0, result_et, result_shape);
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0), get_input_partial_shape(1)};
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, result_et, output_shapes[0]);
}
shared_ptr<Node> op::v1::Convolution::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -45,7 +45,7 @@ TEST(type_prop, conv_v1_partial_auto_padding_same) {
auto conv =
make_shared<op::v1::Convolution>(data_batch, filters, strides, pads_begin, pads_end, dilations, auto_pad);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(PartialShape{1, 1, 5, 5}));
ASSERT_EQ(conv->get_output_partial_shape(0), (PartialShape{1, 1, 5, 5}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -65,7 +65,7 @@ TEST(type_prop, conv_v1_partial_auto_padding_same_nc_dims_dynamic_same_lower) {
auto conv =
make_shared<op::v1::Convolution>(data_batch, filters, strides, pads_begin, pads_end, dilations, auto_pad);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme({Dimension::dynamic(), 1, 5, 5}));
ASSERT_EQ(conv->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 1, 5, 5}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{1, 1}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -85,7 +85,7 @@ TEST(type_prop, conv_v1_partial_auto_padding_same_nc_dims_dynamic_same_upper) {
auto conv =
make_shared<op::v1::Convolution>(data_batch, filters, strides, pads_begin, pads_end, dilations, auto_pad);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme({Dimension::dynamic(), 1, 5, 5}));
ASSERT_EQ(conv->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 1, 5, 5}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{0, 0}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1}));
}
@ -105,7 +105,7 @@ TEST(type_prop, conv_v1_partial_auto_padding_same_spatial_dims_dynamic) {
auto conv =
make_shared<op::v1::Convolution>(data_batch, filters, strides, pads_begin, pads_end, dilations, auto_pad);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme({1, 1, Dimension::dynamic(), 5}));
ASSERT_EQ(conv->get_output_partial_shape(0), PartialShape({1, 1, Dimension::dynamic(), 5}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{0, 1}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{0, 1}));
}
@ -125,8 +125,8 @@ TEST(type_prop, conv_v1_partial_data_shape_dynamic) {
auto conv =
make_shared<op::v1::Convolution>(data_batch, filters, strides, pads_begin, pads_end, dilations, auto_pad);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
{Dimension::dynamic(), 1, Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{}));
ASSERT_EQ(conv->get_output_partial_shape(0),
PartialShape({Dimension::dynamic(), 1, Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{0, 0}));
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{0, 0}));
}