From 978121e91ed1a974b4576e78b7fea89186f39438 Mon Sep 17 00:00:00 2001 From: Szymon Durawa Date: Mon, 14 Jun 2021 13:28:32 +0200 Subject: [PATCH] Ref implementation transposed convolution revise (#4999) --- .../ngraph/runtime/reference/convolution.hpp | 23 +- .../reference/convolution_backprop_data.hpp | 603 ++++---- .../group_convolution_backprop_data.hpp | 25 +- ngraph/test/CMakeLists.txt | 1 + .../test/backend/convolution_backprop.in.cpp | 1221 +++++++++++++++++ .../runtime/interpreter/evaluates_map.cpp | 3 +- 6 files changed, 1622 insertions(+), 254 deletions(-) create mode 100644 ngraph/test/backend/convolution_backprop.in.cpp diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/convolution.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/convolution.hpp index e5e12bac9ac..adee512d975 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/convolution.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/convolution.hpp @@ -18,8 +18,6 @@ #include "ngraph/runtime/reference/split.hpp" #include "ngraph/util.hpp" -// can't be removed currently due to arm-plugin dependency -#include "ngraph/runtime/reference/convolution_backprop_data.hpp" namespace ngraph { namespace runtime @@ -42,15 +40,18 @@ namespace ngraph std::vector dilation; std::vector pads_begin; std::vector pads_end; + std::vector output_padding; ConvolutionParams(const Strides& strides_, const Strides& dilation_, const CoordinateDiff& pads_begin_, - const CoordinateDiff& pads_end_) + const CoordinateDiff& pads_end_, + const CoordinateDiff& output_padding_ = {0, 0, 0}) : strides{strides_.begin(), strides_.end()} , dilation{dilation_.begin(), dilation_.end()} , pads_begin{pads_begin_.begin(), pads_begin_.end()} - , pads_end{pads_end_.begin(), pads_end_.end()} {}; + , pads_end{pads_end_.begin(), pads_end_.end()} + , output_padding{output_padding_.begin(), output_padding_.end()} {}; }; template @@ -86,15 +87,18 @@ namespace ngraph const size_t filter_channel_size = shape_size(filter_channel_shape); for (int i_z = -p.pads_begin[0]; - i_z <= (p.pads_end[0] + input_size_z - dilated_filter_size_z); + i_z <= (p.pads_end[0] + input_size_z - dilated_filter_size_z + + p.output_padding[0]); i_z += p.strides[0]) { for (int i_y = -p.pads_begin[1]; - i_y <= (p.pads_end[1] + input_size_y - dilated_filter_size_y); + i_y <= (p.pads_end[1] + input_size_y - dilated_filter_size_y + + p.output_padding[1]); i_y += p.strides[1]) { for (int i_x = -p.pads_begin[2]; - i_x <= (p.pads_end[2] + input_size_x - dilated_filter_size_x); + i_x <= (p.pads_end[2] + input_size_x - dilated_filter_size_x + + p.output_padding[2]); i_x += p.strides[2]) { auto input_channel = batch; @@ -154,6 +158,8 @@ namespace ngraph std::prev(p.pads_begin.end(), spatial_rank), missing_dims, 0); p.pads_end.insert( std::prev(p.pads_end.end(), spatial_rank), missing_dims, 0); + p.output_padding.insert( + std::prev(p.output_padding.end(), spatial_rank), missing_dims, 0); in_shape.insert(std::next(in_shape.end(), -spatial_rank), missing_dims, 1); filter_shape.insert( std::prev(filter_shape.end(), spatial_rank), missing_dims, 1); @@ -324,3 +330,6 @@ namespace ngraph } // namespace reference } // namespace runtime } // namespace ngraph + +// can't be removed currently due to arm-plugin dependency +#include "ngraph/runtime/reference/convolution_backprop_data.hpp" diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/convolution_backprop_data.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/convolution_backprop_data.hpp index 3fa325f7726..1c755198163 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/convolution_backprop_data.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/convolution_backprop_data.hpp @@ -10,11 +10,7 @@ #include #include "ngraph/axis_vector.hpp" -#include "ngraph/coordinate_transform.hpp" -#include "ngraph/runtime/reference/concat.hpp" -#include "ngraph/runtime/reference/helpers.hpp" -#include "ngraph/runtime/reference/reverse.hpp" -#include "ngraph/runtime/reference/split.hpp" +#include "ngraph/runtime/reference/convolution.hpp" #include "ngraph/util.hpp" namespace ngraph @@ -23,217 +19,302 @@ namespace ngraph { namespace reference { - // in: NC_I... - // filter: C_OC_I... - // out: NC_O... - template ::type> - void convolution_backprop_impl(const INPUT* in, - const FILTER* filter, - OUTPUT* out, - const Shape& in_shape, - const Shape& filter_shape, - const Shape& out_shape, - const Strides& stride, - const Strides& filter_dilation, - const CoordinateDiff& in_pad_below, - const CoordinateDiff& in_pad_above, - const Strides& in_dilation, - size_t in_batch_axis, - size_t in_channel_axis, - size_t filter_out_channel_axis, - size_t filter_in_channel_axis, - size_t out_batch_axis, - size_t out_channel_axis) + namespace { - auto old_mode = std::fegetround(); - std::fesetround(FE_TONEAREST); - // Comments throughout assume without loss of generality that: - // - // * batch axes for both in and out are 0 - // * in channel axes for both in and filter are 1 - // * out channel axes for filter is 0 - // * out channel axis for out is 1 + constexpr size_t filter_input_ch_axis = 0; - // At the outermost level we will walk over every out coordinate O. - CoordinateTransform out_transform(out_shape); - - for (const Coordinate& out_coord : out_transform) + template + void extend_with_zeros(const Strides& strides, + const Shape& input_shape, + const T* in, + Shape& output_shape, + std::vector& input_zeros) { - // Our out coordinate O will have the form: - // - // (N,chan_out,i_1,...,i_n) + std::vector input_3d(3, 1); + std::vector strides_3d(3, 1); + std::vector output_3d(3, 1); - size_t batch_index = out_coord[out_batch_axis]; - size_t out_channel = out_coord[out_channel_axis]; - - // For the in we need to iterate the coordinate: - // - // I: - // - // over the range (noninclusive on the right): - // - // (N,0,s_1*i_1,s_2*i_2,...,s_n*i_n) -> - // - // (N+1, - // chans_in_count, - // s_1*i_1+ l_1*filter_dims_1, - /// ..., - /// s_n*i_n +l_n*filter_dims_n) - // - // with strides: - // - // (1,l_1,...,l_n). - // - // Note that we are iterating within the *padded* and *dilated* in batch, so - // further down we must check the current coordinate is in the pad or dilation - // gap. - - size_t n_spatial_dimensions = in_shape.size() - 2; - size_t n_in_channels = in_shape[in_channel_axis]; - - Coordinate in_transform_start(2 + n_spatial_dimensions); - Coordinate in_transform_end(2 + n_spatial_dimensions); - Strides in_transform_movement_strides(2 + n_spatial_dimensions, 1); - CoordinateDiff in_transform_pad_below(2 + n_spatial_dimensions, 0); - CoordinateDiff in_transform_pad_above(2 + n_spatial_dimensions, 0); - Strides in_transform_dilation_strides(2 + n_spatial_dimensions, 1); - - in_transform_start[in_batch_axis] = batch_index; - in_transform_end[in_batch_axis] = batch_index + 1; - in_transform_start[in_channel_axis] = 0; - in_transform_end[in_channel_axis] = 1; - - for (size_t i = 2; i < n_spatial_dimensions + 2; i++) + for (size_t i = 0; i < strides.size(); ++i) { - size_t filter_dilation_stride = filter_dilation[i - 2]; - size_t filter_movement_stride = stride[i - 2]; - std::ptrdiff_t below_pad = in_pad_below[i - 2]; - std::ptrdiff_t above_pad = in_pad_above[i - 2]; - size_t in_dilation_stride = in_dilation[i - 2]; - - in_transform_start[i] = filter_movement_stride * out_coord[i]; - in_transform_end[i] = in_transform_start[i] + - (filter_shape[i] - 1) * filter_dilation_stride + 1; - in_transform_movement_strides[i] = filter_dilation_stride; - in_transform_pad_below[i] = below_pad; - in_transform_pad_above[i] = above_pad; - in_transform_dilation_strides[i] = in_dilation_stride; + output_shape[i + 2] = + input_shape[i + 2] + (strides[i] - 1) * (input_shape[i + 2] - 1); + input_3d[input_3d.size() - strides.size() + i] = input_shape[i + 2]; + strides_3d[strides_3d.size() - strides.size() + i] = strides[i]; + output_3d[output_3d.size() - strides.size() + i] = output_shape[i + 2]; } - AxisVector in_transform_axis_order(2 + n_spatial_dimensions); - for (size_t i = 0; i < in_transform_axis_order.size(); i++) + const size_t input_size = shape_size(input_3d); + if (input_size == 1) { - in_transform_axis_order[i] = i; - } - CoordinateTransform in_transform(in_shape, - in_transform_start, - in_transform_end, - in_transform_movement_strides, - in_transform_axis_order, - in_transform_pad_below, - in_transform_pad_above, - in_transform_dilation_strides); - - // Simultaneously with iterating I, for the filter we need to iterate the - // coordinate: - // - // F - // - // over the range (noninclusive on the right): - // - // (chan_out,0,0,...,0) -> - // (chan_out+1, - // chans_in_count, - // filter_dims_1, - // ..., - // filter_dims_n) - // - // with unit stride. - - Shape filter_transform_start(2 + n_spatial_dimensions); - Shape filter_transform_end(2 + n_spatial_dimensions); - - filter_transform_start[filter_out_channel_axis] = out_channel; - filter_transform_end[filter_out_channel_axis] = out_channel + 1; - filter_transform_start[filter_in_channel_axis] = 0; - filter_transform_end[filter_in_channel_axis] = 1; - - for (size_t i = 2; i < n_spatial_dimensions + 2; i++) - { - filter_transform_start[i] = 0; - filter_transform_end[i] = filter_shape[i]; - } - - CoordinateTransform filter_transform( - filter_shape, filter_transform_start, filter_transform_end); - - // As we go, we sum up: - // - // out[O] += in[I] * filter[F]. - - ACCUMULATION result = 0; - - CoordinateTransform::Iterator in_it = in_transform.begin(); - CoordinateTransform::Iterator filter_it = filter_transform.begin(); - CoordinateTransform::Iterator in_it_end = in_transform.end(); - CoordinateTransform::Iterator filter_it_end = filter_transform.end(); - - size_t in_channel_stride = row_major_strides(in_shape).at(in_channel_axis); - size_t filter_in_channel_stride = - row_major_strides(filter_shape).at(filter_in_channel_axis); - - while (in_it != in_it_end && filter_it != filter_it_end) - { - const Coordinate& in_coord = *in_it; - if (in_transform.has_source_coordinate(in_coord)) + for (size_t i = 0; i < shape_size(input_shape); ++i) { - size_t in_idx = in_transform.index(in_coord); - const Coordinate& filter_coord = *filter_it; - size_t filter_idx = filter_transform.index(filter_coord); - for (size_t in_channel = 0; in_channel < n_in_channels; ++in_channel) + input_zeros.push_back(in[i]); + } + } + else + { + for (size_t batch = 0; batch < input_shape[0]; ++batch) + { + const auto offset_batch = batch * input_size * input_shape[1]; + for (size_t channel = 0; channel < input_shape[1]; ++channel) { - ACCUMULATION in_v = static_cast(in[in_idx]); - ACCUMULATION f_v = static_cast(filter[filter_idx]); + const auto offset_channel = offset_batch + channel * input_size; + for (int i_z = 0; i_z < input_3d[0]; ++i_z) + { + const auto offset_i_z = i_z * input_3d[2] * input_3d[1]; + for (int i_y = 0; i_y < input_3d[1]; ++i_y) + { + const auto offset_i_y = i_y * input_3d[2]; + for (int i_x = 0; i_x < input_3d[2]; ++i_x) + { + input_zeros.push_back( + in[offset_channel + i_x + offset_i_y + offset_i_z]); - result += in_v * f_v; - in_idx += in_channel_stride; - filter_idx += filter_in_channel_stride; + if (i_x < input_3d[2] - 1) + { + for (int k = 0; k < strides_3d[2] - 1; k++) + { + input_zeros.push_back(0); + } + } + } + + if (i_y < input_3d[1] - 1) + { + const auto new_size = + output_3d[2] * (strides_3d[1] - 1); + input_zeros.insert(input_zeros.begin() + + input_zeros.size(), + new_size, + 0); + } + } + + if (i_z < input_3d[0] - 1) + { + const auto new_size = + output_3d[1] * output_3d[2] * (strides_3d[0] - 1); + input_zeros.insert( + input_zeros.begin() + input_zeros.size(), new_size, 0); + } + } } } - ++in_it; - ++filter_it; } - - out[out_transform.index(out_coord)] = result; } - std::fesetround(old_mode); + + void infer_forward_convbackprop_output_shape(const Shape& in_spatial_shape, + const Shape& f_spatial_shape, + const Shape& out_spatial_shape, + Shape& infer_spatial_shape, + const Strides& strides, + const Strides& dilations, + const CoordinateDiff& output_padding) + { + for (size_t idx = 0; idx < in_spatial_shape.size(); idx++) + { + int total_padding = strides[idx] * (in_spatial_shape[idx] - 1) + + dilations[idx] * (f_spatial_shape[idx] - 1) + 1 - + out_spatial_shape[idx] + output_padding[idx]; + size_t padded_dim = std::max(total_padding, 0); + size_t filter_dilated_dim = dilations[idx] * (f_spatial_shape[idx] - 1) + 1; + size_t out_spatial_dim = (in_spatial_shape[idx] - 1) * strides[idx] + + filter_dilated_dim - padded_dim + + output_padding[idx]; + infer_spatial_shape.push_back(out_spatial_dim); + } + } + + void validate_convolution_backprop_parameters(const Shape& in_shape, + const Shape& f_shape, + const Shape& out_shape, + const Strides& strides, + const Strides& dilations, + const CoordinateDiff& pads_begin, + const CoordinateDiff& pads_end, + const CoordinateDiff& output_padding) + { + // this implementation supports 1D, 2D and 3D convolutions + NGRAPH_CHECK(in_shape.size() >= 3 && in_shape.size() <= 5, + "Unsupported input rank: ", + in_shape); + + NGRAPH_CHECK(in_shape.size() == f_shape.size(), + "Incompatible input ranks: ", + in_shape.size(), + " and ", + f_shape.size()); + + NGRAPH_CHECK(in_shape[in_channel_axis] == f_shape[filter_input_ch_axis], + "Incompatible input channels in data batch and filters shapes: ", + in_shape[in_channel_axis], + " and ", + f_shape[filter_input_ch_axis]); + + NGRAPH_CHECK(in_shape.size() == out_shape.size(), + "Incompatible input and output ranks: ", + in_shape.size(), + " and ", + out_shape.size()); + + const auto spatial_dims = in_shape.size() - 2; + NGRAPH_CHECK(strides.size() == spatial_dims, + "Strides not definied for all and only spatial dimensions."); + + NGRAPH_CHECK(dilations.size() == spatial_dims, + "Dilations not defined for all and only spatial dimensions."); + + NGRAPH_CHECK((pads_begin.size() == pads_end.size()) && + (pads_begin.size() == spatial_dims), + "Pads not defined for all and only spatial dimensions."); + + NGRAPH_CHECK(!output_padding.empty() && output_padding.size() == spatial_dims, + "Output padding not defined for all and only spatial dimensions."); + + Shape out_spatial_shape{std::next(out_shape.begin(), 2), std::end(out_shape)}; + Shape infered_out_spatial_shape{}; + infer_forward_convbackprop_output_shape( + Shape{std::next(in_shape.begin(), 2), std::end(in_shape)}, + Shape{std::next(f_shape.begin(), 2), std::end(f_shape)}, + Shape{std::next(out_shape.begin(), 2), std::end(out_shape)}, + infered_out_spatial_shape, + strides, + dilations, + output_padding); + NGRAPH_CHECK(out_spatial_shape == infered_out_spatial_shape, + "Incorrect output shape provided"); + } + } // namespace + + template + void convolution_backprop_impl(const T* in, + const T* f, + T* out, + const Shape& in_shape, + const Shape& f_shape, + const Shape& out_shape, + const Strides& strides, + const Strides& dilation, + const CoordinateDiff& pads_begin, + const CoordinateDiff& pads_end, + const CoordinateDiff& output_padding) + + { + // here we are converting all param types to int's to avoid arithmetic issues + // (e.g signed + unsigned) in indexes calculation later + ConvolutionParams params{strides, dilation, pads_begin, pads_end, output_padding}; + + // here we are extending spatial dimensions to 3D, because we are going to use 3D + // convolution implementation to convolve also in 1D & 2D case + Shape input_shape{in_shape}; + Shape filters_shape{f_shape}; + if (in_shape.size() < 5) + { + extend_to_3D(params, input_shape, filters_shape); + } + + for (size_t i = 0; i < input_shape.size() - 2; ++i) + { + if (input_shape[i + 2] > 1 || filters_shape[i + 2] > 1) + { + params.pads_begin[i] = filters_shape[i + 2] - params.pads_begin[i] - 1; + params.pads_end[i] = filters_shape[i + 2] - params.pads_end[i] - 1; + } + else + { + params.pads_begin[i] = 0; + params.pads_end[i] = 0; + } + } + + // convert output shape to 3D, contains only dimensions + Shape out_shape_3d{out_shape.begin() + 2, out_shape.end()}; + + int out_shape_rank = out_shape.size() - 2; + if (out_shape_rank < 3) + { + int missing_dims = 3 - out_shape_rank; + out_shape_3d.insert( + std::prev(out_shape_3d.end(), out_shape_rank), missing_dims, 1); + } + + // modify params.pads_end when output_shape was provided in ctor in order to + // calculate expected number of output elements + for (size_t i = 0; i < out_shape_3d.size(); i++) + { + if (out_shape_3d[i] > 1) + { + // expected_dim = (in - 1)* strides + filter - 2*padding + out_padding + // strides is already applied (through 0's extension in input) + // padding = pads_begin + pads_end, formula below is using + // params.pad_begin/params.pads_end: + const size_t expected_dim = + out_shape_3d[i] - ((input_shape[i + 2] - 1) - filters_shape[i + 2] + + params.pads_begin[i] + params.pads_end[i] + 2 + + params.output_padding[i]); + params.pads_end[i] += expected_dim; + } + } + + const size_t filters_count = filters_shape[filter_out_ch_axis]; + const Shape filter_shape(++filters_shape.begin(), filters_shape.end()); + const size_t filter_size = shape_size(filter_shape); + + const size_t batches_count = input_shape[in_batch_axis]; + Shape batch_shape(++input_shape.begin(), input_shape.end()); + const size_t batch_size = shape_size(batch_shape); + + auto batch = in; + + for (size_t batch_idx = 0; batch_idx < batches_count; ++batch_idx) + { + auto filter = f; + for (size_t f_idx = 0; f_idx < filters_count; ++f_idx) + { + convolve_3D_channels(params, batch, batch_shape, filter, filter_shape, out); + filter += filter_size; + } + batch += batch_size; + } } - template ::type> - void convolution_backprop_in(const OUTPUT* delta_out, - const FILTER* filter, - INPUT* delta_in, - const Shape& out_shape, - const Shape& filter_shape, + template + void convolution_backprop_in(const T* delta_in, + const T* filter, + T* delta_out, const Shape& in_shape, + const Shape& filter_shape, + const Shape& out_shape, const Strides& in_dilation, const Strides& filter_dilation, const CoordinateDiff& forward_in_pad_bellow, const CoordinateDiff& forward_in_pad_above, - const Strides& stride) + const Strides& stride, + const CoordinateDiff& output_padding) { + std::vector extended_input; + std::vector extended_filter; + AxisSet reverse_axes; + + Shape conv_input_shape = in_shape; + Shape conv_filter_shape = filter_shape; + Strides conv_stride = stride; + Strides conv_filter_dilation = filter_dilation; + auto conv_input_data = delta_in; + + validate_convolution_backprop_parameters(in_shape, + filter_shape, + out_shape, + stride, + filter_dilation, + forward_in_pad_bellow, + forward_in_pad_above, + output_padding); + // Note that we only reverse the spatial dimensions here (loop // starts at 2) - std::vector reversed(shape_size(filter_shape)); - AxisSet reverse_axes; - size_t reverse_axes_start = 2; - for (size_t i = reverse_axes_start; i < filter_shape.size(); ++i) + std::vector reversed(shape_size(filter_shape)); + for (size_t i = 2; i < filter_shape.size(); ++i) { reverse_axes.insert(i); } @@ -242,55 +323,109 @@ namespace ngraph filter_shape, filter_shape, reverse_axes, - sizeof(FILTER)); - size_t filter_out_channel_axis = 1; - size_t filter_in_channel_axis = 0; + sizeof(T)); - // Compute backward pad out pad bellow - size_t spatial_dim_count = in_shape.size() - 2; + auto conv_filter_data = &reversed[0]; - CoordinateDiff backward_delta_out_pad_below; - backward_delta_out_pad_below.resize(spatial_dim_count); - - for (size_t i = 0; i < spatial_dim_count; i++) + // if channel number for output is > 1 then reverse layout of filter coefficients as + // it is required by convolve_3D_channels() function. + // Current layout: + // batch0_ch0|batch0_ch1|...|batch0_chN|...|batch1_ch0|batch1_ch1|...|batch1_chN|... + // Expected layout: + // batch0_ch0|batch1_ch0|...|batchN_ch0|...|batch0_ch1|batch1_ch1|...|batch1_chN|... + if (filter_shape[1] > 1) { - backward_delta_out_pad_below[i] = - (static_cast(filter_shape[i + 2]) - 1) * filter_dilation[i] - - forward_in_pad_bellow[i]; - } - // Compute backward pad out pad above - CoordinateDiff backward_delta_out_pad_above; - backward_delta_out_pad_above.resize(spatial_dim_count); + std::vector temp_reversed(reversed); + const Shape filter_dim_shape(filter_shape.begin() + 2, filter_shape.end()); + const size_t filter_size = shape_size(filter_dim_shape); - for (size_t i = 0; i < spatial_dim_count; i++) - { - backward_delta_out_pad_above[i] = - (static_cast(filter_shape[i + 2]) - 1) * filter_dilation[i] + - ((forward_in_pad_bellow[i] + ((in_shape[i + 2]) - 1) * in_dilation[i] + - forward_in_pad_above[i] - - (static_cast(filter_shape[i + 2]) - 1) * filter_dilation[i]) % - stride[i]) - - forward_in_pad_above[i]; + for (size_t i = 0; i < filter_shape[1]; i++) + { + for (size_t j = 0; j < filter_shape[0]; j++) + { + const auto delta = temp_reversed.begin() + + j * filter_shape[1] * filter_size + i * filter_size; + const auto out = reversed.begin() + i * filter_shape[0] * filter_size + + j * filter_size; + std::copy(delta, delta + filter_size, out); + } + } } - convolution_backprop_impl( - delta_out, - &reversed[0], - delta_in, - out_shape, - filter_shape, - in_shape, - in_dilation, - filter_dilation, - backward_delta_out_pad_below, - backward_delta_out_pad_above, - stride, - 0, - 1, - filter_out_channel_axis, - filter_in_channel_axis, - 0, - 1); + // swap filter batch and channels + std::iter_swap(conv_filter_shape.begin(), conv_filter_shape.begin() + 1); + + // extend stride and filter inputs with zero padding for stride and filter_dilation + // > 1, after that set stride and filter params to 1. + const size_t stride_dim = + std::accumulate(stride.begin(), stride.end(), 1, std::multiplies()); + if (stride_dim >= 2) + { + extend_with_zeros(stride, in_shape, delta_in, conv_input_shape, extended_input); + std::fill(conv_stride.begin(), conv_stride.end(), 1); + conv_input_data = &extended_input[0]; + } + + const size_t dilation_dim = std::accumulate( + filter_dilation.begin(), filter_dilation.end(), 1, std::multiplies()); + if (dilation_dim >= 2) + { + extend_with_zeros(filter_dilation, + filter_shape, + reinterpret_cast(&reversed[0]), + conv_filter_shape, + extended_filter); + std::fill(conv_filter_dilation.begin(), conv_filter_dilation.end(), 1); + conv_filter_data = &extended_filter[0]; + } + + convolution_backprop_impl(conv_input_data, + conv_filter_data, + delta_out, + conv_input_shape, + conv_filter_shape, + out_shape, + conv_stride, + conv_filter_dilation, + forward_in_pad_bellow, + forward_in_pad_above, + output_padding); + } + + // DEPRECATED, can't be removed currently due to arm-plugin dependency + template ::type> + NGRAPH_DEPRECATED( + "convolution_backprop_in function with 4 template types is deprecated, use " + "function with 1 template and output_padding parameter.") + void convolution_backprop_in(const INPUT* delta_in, + const FILTER* filter, + OUTPUT* delta_out, + const Shape& in_shape, + const Shape& filter_shape, + const Shape& out_shape, + const Strides& in_dilation, + const Strides& filter_dilation, + const CoordinateDiff& forward_in_pad_bellow, + const CoordinateDiff& forward_in_pad_above, + const Strides& stride) + { + const ngraph::CoordinateDiff output_padding(in_shape.size() - 2, 0); + + convolution_backprop_in(delta_in, + filter, + delta_out, + in_shape, + filter_shape, + out_shape, + in_dilation, + filter_dilation, + forward_in_pad_bellow, + forward_in_pad_above, + stride, + output_padding); } } // namespace reference } // namespace runtime diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/group_convolution_backprop_data.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/group_convolution_backprop_data.hpp index b70c0d3ed9a..306ddf047bf 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/group_convolution_backprop_data.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/group_convolution_backprop_data.hpp @@ -178,23 +178,24 @@ namespace ngraph const size_t group_out_size = shape_size(group_out_shape); Strides in_dilation(in_shape.size(), 1); + const ngraph::CoordinateDiff output_padding(in_shape.size() - 2, 0); for (size_t batch_idx = 0; batch_idx < in_shape[in_batch_axis]; ++batch_idx) { group_filter = f; for (size_t group_idx = 0; group_idx < group_count; ++group_idx) { - runtime::reference::convolution_backprop_in( - group_batch, - group_filter, - group_out, - group_batch_shape, - group_filter_shape, - group_out_shape, - in_dilation, - dilation, - pads_begin, - pads_end, - strides); + runtime::reference::convolution_backprop_in(group_batch, + group_filter, + group_out, + group_batch_shape, + group_filter_shape, + group_out_shape, + in_dilation, + dilation, + pads_begin, + pads_end, + strides, + output_padding); group_batch += group_batch_size; group_filter += group_filter_size; group_out += group_out_size; diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 5ca3049b6e2..c3c5ab80405 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -356,6 +356,7 @@ set(MULTI_TEST_SRC backend/constant.in.cpp backend/convert.in.cpp backend/convert_like.in.cpp + backend/convolution_backprop.in.cpp backend/convolution.in.cpp backend/binary_convolution.in.cpp backend/clamp.in.cpp diff --git a/ngraph/test/backend/convolution_backprop.in.cpp b/ngraph/test/backend/convolution_backprop.in.cpp new file mode 100644 index 00000000000..04269995956 --- /dev/null +++ b/ngraph/test/backend/convolution_backprop.in.cpp @@ -0,0 +1,1221 @@ +//***************************************************************************** +// Copyright 2017-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "ngraph/runtime/tensor.hpp" +#include "runtime/backend.hpp" +#include "util/all_close.hpp" +#include "util/all_close_f.hpp" +#include "util/engine/test_engines.hpp" +#include "util/known_element_types.hpp" +#include "util/ndarray.hpp" +#include "util/test_case.hpp" +#include "util/test_control.hpp" +#include "util/test_tools.hpp" + +using namespace std; +using namespace ngraph; + +static string s_manifest = "${MANIFEST}"; +using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME}); + +static void ConvolutionBackpropTest(const std::vector& inputs, + const Shape inputs_shape, + const std::vector& filters, + const Shape filter_shape, + const std::vector& outputs, + const Shape outputs_shape, + const Strides& strides, + const CoordinateDiff& padding, + const Strides& dilations, + const CoordinateDiff& output_padding) +{ + const CoordinateDiff pads_begin{padding}; + const CoordinateDiff pads_end{padding}; + const op::PadType auto_pad{op::PadType::EXPLICIT}; + const CoordinateDiff out_padding{output_padding}; + + auto inputs_param = make_shared(element::f32, inputs_shape); + auto filters_param = make_shared(element::f32, filter_shape); + auto conv = make_shared(inputs_param, + filters_param, + strides, + pads_begin, + pads_end, + dilations, + auto_pad, + out_padding); + auto f = make_shared(conv, ParameterVector{inputs_param, filters_param}); + + auto test_case = test::TestCase(f); + test_case.add_input(inputs); + test_case.add_input(filters); + test_case.add_expected_output(outputs_shape, outputs); + test_case.run(); +} + +// --------------------- 1D convolution ------------------------------------------ +// clang-format off +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_1batch_1channel) +{ + const Strides strides{1}; + const CoordinateDiff padding{0}; + const Strides dilations{1}; + const CoordinateDiff output_padding{0}; + + const Shape inputs_shape{1, 1, 4}; + const std::vector inputs{5.0f, 6.0f, 7.0f, 2.0f}; + + const Shape filter_shape{1, 1, 3}; + const std::vector filters{2.0f, 0.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 6}; + const std::vector outputs{10.0f, 12.0f, 19.0f, 10.0f, 7.0f, 2.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_1batch_1channel_padding) +{ + const Strides strides{1}; + const CoordinateDiff padding{1}; + const Strides dilations{1}; + const CoordinateDiff output_padding{0}; + + const Shape inputs_shape{1, 1, 4}; + const std::vector inputs{5.0f, 6.0f, 7.0f, 2.0f}; + + const Shape filter_shape{1, 1, 3}; + const std::vector filters{2.0f, 0.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 4}; + const std::vector outputs{12.0f, 19.0f, 10.0f, 7.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_1batch_1channel_stride) +{ + const Strides strides{2}; + const CoordinateDiff padding{0}; + const Strides dilations{1}; + const CoordinateDiff output_padding{0}; + + const Shape inputs_shape{1, 1, 2}; + const std::vector inputs{5.0f, 7.0f}; + + const Shape filter_shape{1, 1, 3}; + const std::vector filters{2.0f, 0.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 5}; + const std::vector outputs{10.0f, 0.0f, 19.0f, 0.0f, 7.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_1batch_1channel_output_padding) +{ + const Strides strides{1}; + const CoordinateDiff padding{1}; + const Strides dilations{1}; + const CoordinateDiff output_padding{1}; + + const Shape inputs_shape{1, 1, 4}; + const std::vector inputs{5.0f, 6.0f, 7.0f, 2.0f}; + + const Shape filter_shape{1, 1, 3}; + const std::vector filters{2.0f, 0.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 5}; + const std::vector outputs{12.0f, 19.0f, 10.0f, 7.0f, 2.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_1batch_1channel_dilation) +{ + const Strides strides{1}; + const CoordinateDiff padding{0}; + const Strides dilations{2}; + const CoordinateDiff output_padding{0}; + + const Shape inputs_shape{1, 1, 3}; + const std::vector inputs{8.0f, 5.0f, 1.0f}; + + const Shape filter_shape{1, 1, 3}; + const std::vector filters{2.0f, 0.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 7}; + const std::vector outputs{16.0f, 10.0f, 2.0f, 0.0f, 8.0f, 5.0f, 1.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_1batch_1channel_padding_stride_dilation) +{ + const Strides strides{2}; + const CoordinateDiff padding{2}; + const Strides dilations{2}; + const CoordinateDiff output_padding{0}; + + const Shape inputs_shape{1, 1, 4}; + const std::vector inputs{3.0f, 9.0f, 1.0f, 2.0f}; + + const Shape filter_shape{1, 1, 3}; + const std::vector filters{2.0f, 0.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 7}; + const std::vector outputs{18.0f, 0.0f, 5.0f, 0.0f, 13.0f, 0.0f, 1.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_1batch_2channel) +{ + const Strides strides{1}; + const CoordinateDiff padding{0}; + const Strides dilations{1}; + const CoordinateDiff output_padding{0}; + + const Shape inputs_shape{1, 1, 2}; + const std::vector inputs{10.0f, 3.0f}; + + const Shape filter_shape{1, 2, 3}; + const std::vector filters{ + // channel 1 + 2.0f, 0.0f, 1.0f, + // channel 2 + 1.0f, 0.0f, 2.0f}; + + const Shape outputs_shape{1, 2, 4}; + const std::vector outputs{ + // channel 1 + 20.0f, 6.0f, 10.0f, 3.0f, + // channel 2 + 10.0f, 3.0f, 20.0f, 6.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_1batch_2filter) +{ + const Strides strides{1}; + const CoordinateDiff padding{0}; + const Strides dilations{1}; + const CoordinateDiff output_padding{0}; + + const Shape inputs_shape{1, 2, 2}; + const std::vector inputs{ + // channel 1 + 4.0f, 7.0f, + // channel 2 + 5.0f, 5.0f}; + + const Shape filter_shape{2, 1, 3}; + const std::vector filters{ + // filter 1 + 2.0f, 0.0f, 1.0f, + // filter 2 + 1.0f, 0.0f, 2.0f}; + + const Shape outputs_shape{1, 1, 4}; + const std::vector outputs{13.0f, 19.0f, 14.0f, 17.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_1D_2batch_1channel) +{ + const Strides strides{1}; + const CoordinateDiff padding{0}; + const Strides dilations{1}; + const CoordinateDiff output_padding{0}; + + const Shape inputs_shape{2, 1, 2}; + const std::vector inputs{ + // batch 1 + 1.0f, 3.0f, + // batch 2 + 2.0f, 2.0f}; + + const Shape filter_shape{1, 1, 3}; + const std::vector filters{2.0f, 0.0f, 1.0f}; + + const Shape outputs_shape{2, 1, 4}; + const std::vector outputs{ + // batch 1 + 2.0f, 6.0f, 1.0f, 3.0f, + // batch 2 + 4.0f, 4.0f, 2.0f, 2.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +// --------------------- 2D convolution ------------------------------------------ +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_1batch_1channel) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{1, 1, 2, 2}; + const std::vector inputs{1.0f, 3.0f, + 7.0f, 5.0f}; + + const Shape filter_shape{1, 1, 3, 3}; + const std::vector filters{1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 3.0f, 2.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 4, 4}; + const std::vector outputs{1.0f, 5.0f, 9.0f, 9.0f, + 7.0f, 20.0f, 34.0f, 15.0f, + 3.0f, 18.0f, 12.0f, 3.0f, + 21.0f, 29.0f, 17.0f, 5.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_1batch_1channel_output_padding) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{1, 1}; + const Strides dilations{1, 1}; + const CoordinateDiff output_padding{1, 1}; + + const Shape inputs_shape{1, 1, 2, 2}; + const std::vector inputs{1.0f, 3.0f, + 7.0f, 5.0f}; + + const Shape filter_shape{1, 1, 3, 3}; + const std::vector filters{1.0f, 2.0f, 3.0f, + 1.0f, 1.0f, 1.0f, + 3.0f, 2.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 3, 3}; + const std::vector outputs{23.0f, 35.0f, 18.0f, + 23.0f, 19.0f, 8.0f, + 29.0f, 17.0f, 5.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_1batch_1channel_padding) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{1, 1}; + const Strides dilations{1, 1}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{1, 1, 4, 4}; + const std::vector inputs{1.0f, 3.0f, 5.0f, 7.0f, + 7.0f, 5.0f, 3.0f, 1.0f, + 2.0f, 4.0f, 6.0f, 8.0f, + 8.0f, 6.0f, 4.0f, 2.0f}; + + const Shape filter_shape{1, 1, 3, 3}; + const std::vector filters{1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f}; + + const Shape outputs_shape{1, 1, 4, 4}; + const std::vector outputs{20.0f, 37.0f, 27.0f, 18.0f, + 22.0f, 40.0f, 60.0f, 52.0f, + 41.0f, 69.0f, 49.0f, 31.0f, + 18.0f, 26.0f, 34.0f, 22.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_1batch_1channel_stride) +{ + const Strides strides{2, 2}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{1, 1, 2, 2}; + const std::vector inputs{2.0f, 5.0f, + 4.0f, 3.0f}; + + const Shape filter_shape{1, 1, 3, 3}; + const std::vector filters{1.0f, 2.0f, 3.0f, + 1.0f, 1.0f, 1.0f, + 3.0f, 2.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 5, 5}; + const std::vector outputs{2.0f, 4.0f, 11.0f, 10.0f, 15.0f, + 2.0f, 2.0f, 7.0f, 5.0f, 5.0f, + 10.0f, 12.0f, 32.0f, 16.0f, 14.0f, + 4.0f, 4.0f, 7.0f, 3.0f, 3.0f, + 12.0f, 8.0f, 13.0f, 6.0f, 3.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_1batch_1channel_dilation) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{2, 2}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{1, 1, 2, 2}; + const std::vector inputs{2.0f, 3.0f, + 4.0f, 3.0f}; + + const Shape filter_shape{1, 1, 3, 3}; + const std::vector filters{1.0f, 2.0f, 3.0f, + 1.0f, 1.0f, 1.0f, + 3.0f, 2.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 6, 6}; + const std::vector outputs{2.f, 3.f, 4.f, 6.f, 6.f, 9.f, + 4.f, 3.f, 8.f, 6.f, 12.f, 9.f, + 2.f, 3.f, 2.f, 3.f, 2.f, 3.f, + 4.f, 3.f, 4.f, 3.f, 4.f, 3.f, + 6.f, 9.f, 4.f, 6.f, 2.f, 3.f, + 12.f, 9.f, 8.f, 6.f, 4.f, 3.f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_1batch_1channel_padding_strides_dilation) +{ + const Strides strides{2, 2}; + const CoordinateDiff padding{2, 2}; + const Strides dilations{2, 2}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{1, 1, 3, 3}; + const std::vector inputs{1.0f, 3.0f, 5.0f, + 7.0f, 5.0f, 3.0f, + 2.0f, 4.0f, 6.0f}; + + const Shape filter_shape{1, 1, 3, 3}; + const std::vector filters{1.0f, 2.0f, 3.0f, + 1.0f, 1.0f, 1.0f, + 3.0f, 2.0f, 1.0f}; + + const Shape outputs_shape{1, 1, 5, 5}; + const std::vector outputs{23.0f, 0.0f, 43.0f, 0.0f, 29.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 31.0f, 0.0f, 57.0f, 0.0f, 45.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 35.0f, 0.0f, 38.0f, 0.0f, 21.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_1batch_2channel) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{1, 1, 2, 2}; + const std::vector inputs{1.0f, 3.0f, + 7.0f, 5.0f}; + + const Shape filter_shape{1, 2, 3, 3}; + const std::vector filters{ + // channel 1 + 5.0f, 3.0f, 5.0f, + 1.0f, 3.0f, 1.0f, + 4.0f, 2.0f, 4.0f, + // channel 2 + -5.0f, 3.0f, 5.0f, + 1.0f, -3.0f, 1.0f, + 4.0f, 2.0f, -4.0f}; + + const Shape outputs_shape{1, 2, 4, 4}; + const std::vector outputs{ + // channel 1 + 5.0f, 18.0f, 14.0f, 15.0f, + 36.0f, 52.0f, 60.0f, 28.0f, + 11.0f, 40.0f, 32.0f, 17.0f, + 28.0f, 34.0f, 38.0f, 20.0f, + // channel 2 + -5.0f, -12.0f, 14.0f, 15.0f, + -34.0f, -4.0f, 42.0f, 28.0f, + 11.0f, -2.0f, -6.0f, -7.0f, + 28.0f, 34.0f, -18.0f, -20.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_1batch_2filter) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{1, 2, 2, 2}; + const std::vector inputs{ + // channel 1 + 1.0f, 3.0f, + 7.0f, 5.0f, + // channel 2 + 2.0f, 4.0f, + 8.0f, 6.0f}; + + const Shape filter_shape{2, 1, 3, 3}; + const std::vector filters{ + // channel 1 + 5.0f, 3.0f, 5.0f, + 1.0f, 3.0f, 1.0f, + 4.0f, 2.0f, 4.0f, + // channel 2 + -5.0f, 3.0f, 5.0f, + 1.0f, -3.0f, 1.0f, + 4.0f, 2.0f, -4.0f}; + + const Shape outputs_shape{1, 1, 4, 4}; + const std::vector outputs{ + -5.0f, 4.0f, 36.0f, 35.0f, + -2.0f, 44.0f, 108.0f, 62.0f, + 27.0f, 42.0f, 22.0f, 7.0f, + 60.0f, 74.0f, 18.0f, -4.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_2batch_2filter) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{1, 2, 1, 1}; + const std::vector inputs{ + // channel 1 + 2.0f, + // channel 2 + 3.0f}; + + const Shape filter_shape{2, 2, 2, 2}; + const std::vector filters{ + // batch 0 + // channel 1 + 5.0f, 3.0f, + 1.0f, 3.0f, + // channel 2 + -5.0f, 3.0f, + 1.0f, -3.0f, + // batch 1 + // channel 1 + 5.0f, 3.0f, + 1.0f, 3.0f, + // channel 2 + -5.0f, 3.0f, + 1.0f, -3.0f}; + + const Shape outputs_shape{1, 2, 2, 2}; + const std::vector outputs{ + 25.0f, 15.0f, 5.0f, 15.0f, -25.0f, 15.0f, 5.0f, -15.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_2D_2batch_1channel) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + const CoordinateDiff output_padding{0, 0}; + + const Shape inputs_shape{2, 1, 2, 2}; + const std::vector inputs{ + // batch 1 + 1.0f, 3.0f, + 1.0f, 3.0f, + // batch 2 + -1.0f, 3.0f, + 1.0f, 3.0f}; + + const Shape filter_shape{1, 1, 3, 3}; + const std::vector filters{-5.0f, 3.0f, 5.0f, + 1.0f, -3.0f, 1.0f, + 4.0f, 2.0f, -4.0f}; + + const Shape outputs_shape{2, 1, 4, 4}; + const std::vector outputs{ + // batch 1 + -5.0f, -12.0f, 14.0f, 15.0f, + -4.0f, -12.0f, 6.0f, 18.0f, + 5.0f, 14.0f, -6.0f, -9.0f, + 4.0f, 14.0f, 2.0f, -12.0f, + // batch 2 + 5.0f, -18.0f, 4.0f, 15.0f, + -6.0f, -6.0f, 4.0f, 18.0f, + -3.0f, 10.0f, 2.0f, -9.0f, + 4.0f, 14.0f, 2.0f, -12.0f}; + + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +// --------------------- 3D convolution ------------------------------------------ +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_3D_1batch_1channel) +{ + const Strides strides{1, 1, 1}; + const CoordinateDiff padding{0, 0, 0}; + const Strides dilations{1, 1, 1}; + const CoordinateDiff output_padding{0, 0, 0}; + + const Shape inputs_shape{1, 1, 2, 2, 2}; + const std::vector inputs{ + // depth: 1 + 15.0f, 3.0f, + 21.0f, 10.0f, + // depth: 2 + 10.0f, 13.0f, + 11.0f, 17.0f}; + + const Shape filter_shape{1, 1, 3, 3, 3}; + const std::vector filters{ + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f}; + + const Shape outputs_shape{1, 1, 4, 4, 4}; + const std::vector outputs{ + // depth: 1 + 15.0f, 33.0f, 51.0f, 9.0f, + 21.0f, 67.0f, 86.0f, 30.0f, + 30.0f, 42.0f, 43.0f, 6.0f, + 42.0f, 41.0f, 52.0f, 20.0f, + // depth: 2 + 25.0f, 66.0f, 107.0f, 48.0f, + 32.0f, 116.0f, 166.0f, 81.0f, + 50.0f, 89.0f, 93.0f, 32.0f, + 64.0f, 86.0f, 91.0f, 54.0f, + // depth: 3 + 25.0f, 66.0f, 107.0f, 48.0f, + 32.0f, 116.0f, 166.0f, 81.0f, + 50.0f, 89.0f, 93.0f, 32.0f, + 64.0f, 86.0f, 91.0f, 54.0f, + // depth: 4 + 10.0f, 33.0f, 56.0f, 39.0f, + 11.0f, 49.0f, 80.0f, 51.0f, + 20.0f, 47.0f, 50.0f, 26.0f, + 22.0f, 45.0f, 39.0f, 34.0f + }; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_3D_1batch_1channel_output_padding) +{ + const Strides strides{1, 1, 1}; + const CoordinateDiff padding{1, 1, 1}; + const Strides dilations{1, 1, 1}; + const CoordinateDiff output_padding{1, 1, 1}; + + const Shape inputs_shape{1, 1, 2, 2, 2}; + const std::vector inputs{ + // depth: 1 + 15.0f, 3.0f, + 21.0f, 10.0f, + // depth: 2 + 10.0f, 13.0f, + 11.0f, 17.0f}; + + const Shape filter_shape{1, 1, 3, 3, 3}; + const std::vector filters{ + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f}; + + const Shape outputs_shape{1, 1, 3, 3, 3}; + const std::vector outputs{ + // depth: 1 + 116.0f, 166.0f, 81.0f, + 89.0f, 93.0f, 32.0f, + 86.0f, 91.0f, 54.0f, + // depth: 2 + 116.0f, 166.0f, 81.0f, + 89.0f, 93.0f, 32.0f, + 86.0f, 91.0f, 54.0f, + // depth: 3 + 49.0f, 80.0f, 51.0f, + 47.0f, 50.0f, 26.0f, + 45.0f, 39.0f, 34.0f + }; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_3D_1batch_1channel_padding) +{ + const Strides strides{1, 1, 1}; + const CoordinateDiff padding{1, 1, 1}; + const Strides dilations{1, 1, 1}; + const CoordinateDiff output_padding{0, 0, 0}; + + const Shape inputs_shape{1, 1, 4, 4, 4}; + const std::vector inputs{ + // depth: 1 + 1.0f, 3.0f, 2.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 1.0f, + 2.0f, 1.0f, 1.0f, 3.0f, + 3.0f, 2.0f, 3.0f, 3.0f, + // depth: 2 + 1.0f, 3.0f, 2.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 1.0f, + 2.0f, 1.0f, 1.0f, 3.0f, + 3.0f, 2.0f, 3.0f, 3.0f, + // depth: 3 + 1.0f, 3.0f, 2.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 1.0f, + 2.0f, 1.0f, 1.0f, 3.0f, + 3.0f, 2.0f, 3.0f, 3.0f, + // depth: 4 + 1.0f, 3.0f, 2.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 1.0f, + 2.0f, 1.0f, 1.0f, 3.0f, + 3.0f, 2.0f, 3.0f, 3.0f + }; + + const Shape filter_shape{1, 1, 3, 3, 3}; + const std::vector filters{ + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f}; + + const Shape outputs_shape{1, 1, 4, 4, 4}; + const std::vector outputs{ + // depth: 1 + 12.0f, 30.0f, 36.0f, 24.0f, + 26.0f, 42.0f, 42.0f, 30.0f, + 34.0f, 56.0f, 54.0f, 50.0f, + 14.0f, 18.0f, 24.0f, 16.0f, + // depth: 2 + 18.0f, 45.0f, 54.0f, 36.0f, + 39.0f, 63.0f, 63.0f, 45.0f, + 51.0f, 84.0f, 81.0f, 75.0f, + 21.0f, 27.0f, 36.0f, 24.0f, + // depth: 3 + 18.0f, 45.0f, 54.0f, 36.0f, + 39.0f, 63.0f, 63.0f, 45.0f, + 51.0f, 84.0f, 81.0f, 75.0f, + 21.0f, 27.0f, 36.0f, 24.0f, + // depth: 4 + 12.0f, 30.0f, 36.0f, 24.0f, + 26.0f, 42.0f, 42.0f, 30.0f, + 34.0f, 56.0f, 54.0f, 50.0f, + 14.0f, 18.0f, 24.0f, 16.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_3D_1batch_1channel_stride) +{ + const Strides strides{2, 2, 2}; + const CoordinateDiff padding{0, 0, 0}; + const Strides dilations{1, 1, 1}; + const CoordinateDiff output_padding{0, 0, 0}; + + const Shape inputs_shape{1, 1, 2, 2, 2}; + const std::vector inputs{ + // depth: 1 + 15.0f, 3.0f, + 21.0f, 10.0f, + // depth: 2 + 10.0f, 13.0f, + 11.0f, 17.0f}; + + const Shape filter_shape{1, 1, 3, 3, 3}; + const std::vector filters{ + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f}; + + const Shape outputs_shape{1, 1, 5, 5, 5}; + const std::vector outputs{ + // depth: 1 + 15.0f, 30.0f, 48.0f, 6.0f, 9.0f, + 0.0f, 15.0f, 0.0f, 3.0f, 0.0f, + 51.0f, 57.0f, 109.0f, 23.0f, 36.0f, + 0.0f, 21.0f, 0.0f, 10.0f, 0.0f, + 42.0f, 21.0f, 62.0f, 10.0f, 20.0f, + // depth: 2 + 15.0f, 30.0f, 48.0f, 6.0f, 9.0f, + 0.0f, 15.0f, 0.0f, 3.0f, 0.0f, + 51.0f, 57.0f, 109.0f, 23.0f, 36.0f, + 0.0f, 21.0f, 0.0f, 10.0f, 0.0f, + 42.0f, 21.0f, 62.0f, 10.0f, 20.0f, + // depth: 3 + 25.0f, 50.0f, 91.0f, 32.0f, 48.0f, + 0.0f, 25.0f, 0.0f, 16.0f, 0.0f, + 82.0f, 89.0f, 205.0f, 70.0f, 113.0f, + 0.0f, 32.0f, 0.0f, 27.0f, 0.0f, + 64.0f, 32.0f, 118.0f, 27.0f, 54.0f, + // depth: 4 + 10.0f, 20.0f, 43.0f, 26.0f, 39.0f, + 0.0f, 10.0f, 0.0f, 13.0f, 0.0f, + 31.0f, 32.0f, 96.0f, 47.0f, 77.0f, + 0.0f, 11.0f, 0.0f, 17.0f, 0.0f, + 22.0f, 11.0f, 56.0f, 17.0f, 34.0f, + // depth: 5 + 10.0f, 20.0f, 43.0f, 26.0f, 39.0f, + 0.0f, 10.0f, 0.0f, 13.0f, 0.0f, + 31.0f, 32.0f, 96.0f, 47.0f, 77.0f, + 0.0f, 11.0f, 0.0f, 17.0f, 0.0f, + 22.0f, 11.0f, 56.0f, 17.0f, 34.0f + }; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_3D_1batch_1channel_padding_strides_dilation) +{ + const Strides strides{2, 2, 2}; + const CoordinateDiff padding{2, 2, 2}; + const Strides dilations{2, 2, 2}; + const CoordinateDiff output_padding{0, 0, 0}; + + const Shape inputs_shape{1, 1, 4, 4, 4}; + const std::vector inputs{ + // depth: 1 + 1.0f, 3.0f, 2.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 1.0f, + 2.0f, 1.0f, 1.0f, 3.0f, + 3.0f, 2.0f, 3.0f, 3.0f, + // depth: 2 + 1.0f, 3.0f, 2.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 1.0f, + 2.0f, 1.0f, 1.0f, 3.0f, + 3.0f, 2.0f, 3.0f, 3.0f, + // depth: 3 + 1.0f, 3.0f, 2.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 1.0f, + 2.0f, 1.0f, 1.0f, 3.0f, + 3.0f, 2.0f, 3.0f, 3.0f, + // depth: 4 + 1.0f, 3.0f, 2.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 1.0f, + 2.0f, 1.0f, 1.0f, 3.0f, + 3.0f, 2.0f, 3.0f, 3.0f + }; + + const Shape filter_shape{1, 1, 3, 3, 3}; + const std::vector filters{ + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f}; + + const Shape outputs_shape{1, 1, 7, 7, 7}; + const std::vector outputs{ + // depth: 1 + 12.0f, 0.0f, 30.0f, 0.0f, 36.0f, 0.0f, 24.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 26.0f, 0.0f, 42.0f, 0.0f, 42.0f, 0.0f, 30.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 34.0f, 0.0f, 56.0f, 0.0f, 54.0f, 0.0f, 50.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 14.0f, 0.0f, 18.0f, 0.0f, 24.0f, 0.0f, 16.0f, + // depth: 2 + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + // depth: 3 + 18.0f, 0.0f, 45.0f, 0.0f, 54.0f, 0.0f, 36.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 39.0f, 0.0f, 63.0f, 0.0f, 63.0f, 0.0f, 45.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 51.0f, 0.0f, 84.0f, 0.0f, 81.0f, 0.0f, 75.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 21.0f, 0.0f, 27.0f, 0.0f, 36.0f, 0.0f, 24.0f, + // depth: 4 + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + // depth: 5 + 18.0f, 0.0f, 45.0f, 0.0f, 54.0f, 0.0f, 36.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 39.0f, 0.0f, 63.0f, 0.0f, 63.0f, 0.0f, 45.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 51.0f, 0.0f, 84.0f, 0.0f, 81.0f, 0.0f, 75.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 21.0f, 0.0f, 27.0f, 0.0f, 36.0f, 0.0f, 24.0f, + // depth: 6 + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + // depth: 7 + 12.0f, 0.0f, 30.0f, 0.0f, 36.0f, 0.0f, 24.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 26.0f, 0.0f, 42.0f, 0.0f, 42.0f, 0.0f, 30.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 34.0f, 0.0f, 56.0f, 0.0f, 54.0f, 0.0f, 50.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 14.0f, 0.0f, 18.0f, 0.0f, 24.0f, 0.0f, 16.0f + }; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_3D_1batch_2channel) +{ + const Strides strides{1, 1, 1}; + const CoordinateDiff padding{0, 0, 0}; + const Strides dilations{1, 1, 1}; + const CoordinateDiff output_padding{0, 0, 0}; + + const Shape inputs_shape{1, 1, 2, 2, 2}; + const std::vector inputs{ + // depth: 1 + 1.0f, 8.0f, + 1.0f, 3.0f, + // depth: 2 + 1.0f, 7.0f, + 3.0f, 8.0f}; + + const Shape filter_shape{1, 2, 3, 3, 3}; + const std::vector filters{ + // -- channel 1 -- + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // -- channel 2 -- + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f + }; + + const Shape outputs_shape{1, 2, 4, 4, 4}; + const std::vector outputs{ + // -- channel 1 -- + // depth: 1 + 1.0f, 10.0f, 19.0f, 24.0f, + 1.0f, 6.0f, 17.0f, 9.0f, + 2.0f, 18.0f, 13.0f, 16.0f, + 2.0f, 7.0f, 5.0f, 6.0f, + // depth: 2 + 2.0f, 19.0f, 36.0f, 45.0f, + 4.0f, 21.0f, 49.0f, 33.0f, + 4.0f, 36.0f, 30.0f, 30.0f, + 8.0f, 26.0f, 19.0f, 22.0f, + // depth: 3 + 2.0f, 19.0f, 36.0f, 45.0f, + 4.0f, 21.0f, 49.0f, 33.0f, + 4.0f, 36.0f, 30.0f, 30.0f, + 8.0f, 26.0f, 19.0f, 22.0f, + // depth: 4 + 1.0f, 9.0f, 17.0f, 21.0f, + 3.0f, 15.0f, 32.0f, 24.0f, + 2.0f, 18.0f, 17.0f, 14.0f, + 6.0f, 19.0f, 14.0f, 16.0f, + // -- channel 2 -- + // depth: 1 + 1.0f, 10.0f, 19.0f, 24.0f, + 1.0f, 6.0f, 17.0f, 9.0f, + 2.0f, 18.0f, 13.0f, 16.0f, + 2.0f, 7.0f, 5.0f, 6.0f, + // depth: 2 + 2.0f, 19.0f, 36.0f, 45.0f, + 4.0f, 21.0f, 49.0f, 33.0f, + 4.0f, 36.0f, 30.0f, 30.0f, + 8.0f, 26.0f, 19.0f, 22.0f, + // depth: 3 + 2.0f, 19.0f, 36.0f, 45.0f, + 4.0f, 21.0f, 49.0f, 33.0f, + 4.0f, 36.0f, 30.0f, 30.0f, + 8.0f, 26.0f, 19.0f, 22.0f, + // depth: 4 + 1.0f, 9.0f, 17.0f, 21.0f, + 3.0f, 15.0f, 32.0f, 24.0f, + 2.0f, 18.0f, 17.0f, 14.0f, + 6.0f, 19.0f, 14.0f, 16.0f + }; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_3D_1batch_2filter) +{ + const Strides strides{1, 1, 1}; + const CoordinateDiff padding{0, 0, 0}; + const Strides dilations{1, 1, 1}; + const CoordinateDiff output_padding{0, 0, 0}; + + const Shape inputs_shape{1, 2, 2, 2, 2}; + const std::vector inputs{ + // -- in 1 -- + // depth: 1 + 1.0f, 3.0f, + 2.0f, 5.0f, + // depth: 2 + 1.0f, 0.0f, + 3.0f, 6.0f, + // -- in 2 -- + // depth: 1 + 1.0f, 3.0f, + 2.0f, 5.0f, + // depth: 2 + 3.0f, 0.0f, + 1.0f, 8.0f}; + + const Shape filter_shape{2, 1, 3, 3, 3}; + const std::vector filters{ + // -- filter 1 -- + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // -- filter 2 -- + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f + }; + + const Shape outputs_shape{1, 1, 4, 4, 4}; + const std::vector outputs{ + // depth: 1 + 2.0f, 10.0f, 18.0f, 18.0f, + 4.0f, 20.0f, 38.0f, 30.0f, + 4.0f, 18.0f, 20.0f, 12.0f, + 8.0f, 24.0f, 18.0f, 20.0f, + // depth: 2 + 6.0f, 18.0f, 30.0f, 18.0f, + 8.0f, 46.0f, 78.0f, 72.0f, + 12.0f, 26.0f, 42.0f, 12.0f, + 16.0f, 56.0f, 40.0f, 48.0f, + // depth: 3 + 6.0f, 18.0f, 30.0f, 18.0f, + 8.0f, 46.0f, 78.0f, 72.0f, + 12.0f, 26.0f, 42.0f, 12.0f, + 16.0f, 56.0f, 40.0f, 48.0f, + // depth: 4 + 4.0f, 8.0f, 12.0f, 0.0f, + 4.0f, 26.0f, 40.0f, 42.0f, + 8.0f, 8.0f, 22.0f, 0.0f, + 8.0f, 32.0f, 22.0f, 28.0f + }; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} + +NGRAPH_TEST(${BACKEND_NAME}, convolution_backprop_3D_2batch_1channel) +{ + const Strides strides{1, 1, 1}; + const CoordinateDiff padding{0, 0, 0}; + const Strides dilations{1, 1, 1}; + const CoordinateDiff output_padding{0, 0, 0}; + + const Shape inputs_shape{2, 1, 2, 2, 2}; + const std::vector inputs{ + // -- batch 1 -- + // depth: 1 + 1.0f, 3.0f, + 2.0f, 5.0f, + // depth: 2 + 1.0f, 0.0f, + 6.0f, 4.0f, + // -- batch 2 -- + // depth: 1 + 1.0f, 5.0f, + 2.0f, 8.0f, + // depth: 2 + 2.0f, 1.0f, + 0.0f, 5.0f}; + const Shape filter_shape{1, 1, 3, 3, 3}; + const std::vector filters{ + // depth: 1 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 2 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f, + // depth: 3 + 1.0f, 2.0f, 3.0f, + 0.0f, 1.0f, 0.0f, + 2.0f, 1.0f, 2.0f}; + + const Shape outputs_shape{2, 1, 4, 4, 4}; + const std::vector outputs{ + // -- batch 1 -- + // depth: 1 + 1.0f, 5.0f, 9.0f, 9.0f, + 2.0f, 10.0f, 19.0f, 15.0f, + 2.0f, 9.0f, 10.0f, 6.0f, + 4.0f, 12.0f, 9.0f, 10.0f, + // depth: 2 + 2.0f, 7.0f, 12.0f, 9.0f, + 8.0f, 27.0f, 45.0f, 27.0f, + 4.0f, 16.0f, 16.0f, 6.0f, + 16.0f, 26.0f, 25.0f, 18.0f, + // depth: 3 + 2.0f, 7.0f, 12.0f, 9.0f, + 8.0f, 27.0f, 45.0f, 27.0f, + 4.0f, 16.0f, 16.0f, 6.0f, + 16.0f, 26.0f, 25.0f, 18.0f, + // depth: 4 + 1.0f, 2.0f, 3.0f, 0.0f, + 6.0f, 17.0f, 26.0f, 12.0f, + 2.0f, 7.0f, 6.0f, 0.0f, + 12.0f, 14.0f, 16.0f, 8.0f, + // -- batch 2 -- + // depth: 1 + 1.0f, 7.0f, 13.0f, 15.0f, + 2.0f, 13.0f, 27.0f, 24.0f, + 2.0f, 13.0f, 15.0f, 10.0f, + 4.0f, 18.0f, 12.0f, 16.0f, + // depth: 2 + 3.0f, 12.0f, 21.0f, 18.0f, + 2.0f, 20.0f, 38.0f, 39.0f, + 6.0f, 17.0f, 25.0f, 12.0f, + 4.0f, 28.0f, 17.0f, 26.0f, + // depth: 3 + 3.0f, 12.0f, 21.0f, 18.0f, + 2.0f, 20.0f, 38.0f, 39.0f, + 6.0f, 17.0f, 25.0f, 12.0f, + 4.0f, 28.0f, 17.0f, 26.0f, + // depth: 4 + 2.0f, 5.0f, 8.0f, 3.0f, + 0.0f, 7.0f, 11.0f, 15.0f, + 4.0f, 4.0f, 10.0f, 2.0f, + 0.0f, 10.0f, 5.0f, 10.0f}; + + ConvolutionBackpropTest(inputs, inputs_shape, filters, filter_shape, outputs, outputs_shape, + strides, padding, dilations, output_padding); +} diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index 61010117779..4446a4a2434 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -283,7 +283,8 @@ namespace op->get_dilations(), op->get_pads_begin(), op->get_pads_end(), - op->get_strides()); + op->get_strides(), + op->get_output_padding()); return true; }