Ref implementation transposed convolution revise (#4999)
This commit is contained in:
parent
826638e523
commit
978121e91e
@ -18,8 +18,6 @@
|
||||
#include "ngraph/runtime/reference/split.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
// can't be removed currently due to arm-plugin dependency
|
||||
#include "ngraph/runtime/reference/convolution_backprop_data.hpp"
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
@ -42,15 +40,18 @@ namespace ngraph
|
||||
std::vector<int> dilation;
|
||||
std::vector<int> pads_begin;
|
||||
std::vector<int> pads_end;
|
||||
std::vector<int> output_padding;
|
||||
|
||||
ConvolutionParams(const Strides& strides_,
|
||||
const Strides& dilation_,
|
||||
const CoordinateDiff& pads_begin_,
|
||||
const CoordinateDiff& pads_end_)
|
||||
const CoordinateDiff& pads_end_,
|
||||
const CoordinateDiff& output_padding_ = {0, 0, 0})
|
||||
: strides{strides_.begin(), strides_.end()}
|
||||
, dilation{dilation_.begin(), dilation_.end()}
|
||||
, pads_begin{pads_begin_.begin(), pads_begin_.end()}
|
||||
, pads_end{pads_end_.begin(), pads_end_.end()} {};
|
||||
, pads_end{pads_end_.begin(), pads_end_.end()}
|
||||
, output_padding{output_padding_.begin(), output_padding_.end()} {};
|
||||
};
|
||||
|
||||
template <typename Int>
|
||||
@ -86,15 +87,18 @@ namespace ngraph
|
||||
const size_t filter_channel_size = shape_size(filter_channel_shape);
|
||||
|
||||
for (int i_z = -p.pads_begin[0];
|
||||
i_z <= (p.pads_end[0] + input_size_z - dilated_filter_size_z);
|
||||
i_z <= (p.pads_end[0] + input_size_z - dilated_filter_size_z +
|
||||
p.output_padding[0]);
|
||||
i_z += p.strides[0])
|
||||
{
|
||||
for (int i_y = -p.pads_begin[1];
|
||||
i_y <= (p.pads_end[1] + input_size_y - dilated_filter_size_y);
|
||||
i_y <= (p.pads_end[1] + input_size_y - dilated_filter_size_y +
|
||||
p.output_padding[1]);
|
||||
i_y += p.strides[1])
|
||||
{
|
||||
for (int i_x = -p.pads_begin[2];
|
||||
i_x <= (p.pads_end[2] + input_size_x - dilated_filter_size_x);
|
||||
i_x <= (p.pads_end[2] + input_size_x - dilated_filter_size_x +
|
||||
p.output_padding[2]);
|
||||
i_x += p.strides[2])
|
||||
{
|
||||
auto input_channel = batch;
|
||||
@ -154,6 +158,8 @@ namespace ngraph
|
||||
std::prev(p.pads_begin.end(), spatial_rank), missing_dims, 0);
|
||||
p.pads_end.insert(
|
||||
std::prev(p.pads_end.end(), spatial_rank), missing_dims, 0);
|
||||
p.output_padding.insert(
|
||||
std::prev(p.output_padding.end(), spatial_rank), missing_dims, 0);
|
||||
in_shape.insert(std::next(in_shape.end(), -spatial_rank), missing_dims, 1);
|
||||
filter_shape.insert(
|
||||
std::prev(filter_shape.end(), spatial_rank), missing_dims, 1);
|
||||
@ -324,3 +330,6 @@ namespace ngraph
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
} // namespace ngraph
|
||||
|
||||
// can't be removed currently due to arm-plugin dependency
|
||||
#include "ngraph/runtime/reference/convolution_backprop_data.hpp"
|
||||
|
@ -10,11 +10,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
#include "ngraph/runtime/reference/concat.hpp"
|
||||
#include "ngraph/runtime/reference/helpers.hpp"
|
||||
#include "ngraph/runtime/reference/reverse.hpp"
|
||||
#include "ngraph/runtime/reference/split.hpp"
|
||||
#include "ngraph/runtime/reference/convolution.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -23,217 +19,302 @@ namespace ngraph
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
// in: NC_I...
|
||||
// filter: C_OC_I...
|
||||
// out: NC_O...
|
||||
template <typename INPUT,
|
||||
typename FILTER,
|
||||
typename OUTPUT,
|
||||
typename ACCUMULATION = typename widen<OUTPUT>::type>
|
||||
void convolution_backprop_impl(const INPUT* in,
|
||||
const FILTER* filter,
|
||||
OUTPUT* out,
|
||||
const Shape& in_shape,
|
||||
const Shape& filter_shape,
|
||||
const Shape& out_shape,
|
||||
const Strides& stride,
|
||||
const Strides& filter_dilation,
|
||||
const CoordinateDiff& in_pad_below,
|
||||
const CoordinateDiff& in_pad_above,
|
||||
const Strides& in_dilation,
|
||||
size_t in_batch_axis,
|
||||
size_t in_channel_axis,
|
||||
size_t filter_out_channel_axis,
|
||||
size_t filter_in_channel_axis,
|
||||
size_t out_batch_axis,
|
||||
size_t out_channel_axis)
|
||||
namespace
|
||||
{
|
||||
auto old_mode = std::fegetround();
|
||||
std::fesetround(FE_TONEAREST);
|
||||
// Comments throughout assume without loss of generality that:
|
||||
//
|
||||
// * batch axes for both in and out are 0
|
||||
// * in channel axes for both in and filter are 1
|
||||
// * out channel axes for filter is 0
|
||||
// * out channel axis for out is 1
|
||||
constexpr size_t filter_input_ch_axis = 0;
|
||||
|
||||
// At the outermost level we will walk over every out coordinate O.
|
||||
CoordinateTransform out_transform(out_shape);
|
||||
|
||||
for (const Coordinate& out_coord : out_transform)
|
||||
template <typename T>
|
||||
void extend_with_zeros(const Strides& strides,
|
||||
const Shape& input_shape,
|
||||
const T* in,
|
||||
Shape& output_shape,
|
||||
std::vector<T>& input_zeros)
|
||||
{
|
||||
// Our out coordinate O will have the form:
|
||||
//
|
||||
// (N,chan_out,i_1,...,i_n)
|
||||
std::vector<int> input_3d(3, 1);
|
||||
std::vector<int> strides_3d(3, 1);
|
||||
std::vector<int> output_3d(3, 1);
|
||||
|
||||
size_t batch_index = out_coord[out_batch_axis];
|
||||
size_t out_channel = out_coord[out_channel_axis];
|
||||
|
||||
// For the in we need to iterate the coordinate:
|
||||
//
|
||||
// I:
|
||||
//
|
||||
// over the range (noninclusive on the right):
|
||||
//
|
||||
// (N,0,s_1*i_1,s_2*i_2,...,s_n*i_n) ->
|
||||
//
|
||||
// (N+1,
|
||||
// chans_in_count,
|
||||
// s_1*i_1+ l_1*filter_dims_1,
|
||||
/// ...,
|
||||
/// s_n*i_n +l_n*filter_dims_n)
|
||||
//
|
||||
// with strides:
|
||||
//
|
||||
// (1,l_1,...,l_n).
|
||||
//
|
||||
// Note that we are iterating within the *padded* and *dilated* in batch, so
|
||||
// further down we must check the current coordinate is in the pad or dilation
|
||||
// gap.
|
||||
|
||||
size_t n_spatial_dimensions = in_shape.size() - 2;
|
||||
size_t n_in_channels = in_shape[in_channel_axis];
|
||||
|
||||
Coordinate in_transform_start(2 + n_spatial_dimensions);
|
||||
Coordinate in_transform_end(2 + n_spatial_dimensions);
|
||||
Strides in_transform_movement_strides(2 + n_spatial_dimensions, 1);
|
||||
CoordinateDiff in_transform_pad_below(2 + n_spatial_dimensions, 0);
|
||||
CoordinateDiff in_transform_pad_above(2 + n_spatial_dimensions, 0);
|
||||
Strides in_transform_dilation_strides(2 + n_spatial_dimensions, 1);
|
||||
|
||||
in_transform_start[in_batch_axis] = batch_index;
|
||||
in_transform_end[in_batch_axis] = batch_index + 1;
|
||||
in_transform_start[in_channel_axis] = 0;
|
||||
in_transform_end[in_channel_axis] = 1;
|
||||
|
||||
for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
|
||||
for (size_t i = 0; i < strides.size(); ++i)
|
||||
{
|
||||
size_t filter_dilation_stride = filter_dilation[i - 2];
|
||||
size_t filter_movement_stride = stride[i - 2];
|
||||
std::ptrdiff_t below_pad = in_pad_below[i - 2];
|
||||
std::ptrdiff_t above_pad = in_pad_above[i - 2];
|
||||
size_t in_dilation_stride = in_dilation[i - 2];
|
||||
|
||||
in_transform_start[i] = filter_movement_stride * out_coord[i];
|
||||
in_transform_end[i] = in_transform_start[i] +
|
||||
(filter_shape[i] - 1) * filter_dilation_stride + 1;
|
||||
in_transform_movement_strides[i] = filter_dilation_stride;
|
||||
in_transform_pad_below[i] = below_pad;
|
||||
in_transform_pad_above[i] = above_pad;
|
||||
in_transform_dilation_strides[i] = in_dilation_stride;
|
||||
output_shape[i + 2] =
|
||||
input_shape[i + 2] + (strides[i] - 1) * (input_shape[i + 2] - 1);
|
||||
input_3d[input_3d.size() - strides.size() + i] = input_shape[i + 2];
|
||||
strides_3d[strides_3d.size() - strides.size() + i] = strides[i];
|
||||
output_3d[output_3d.size() - strides.size() + i] = output_shape[i + 2];
|
||||
}
|
||||
|
||||
AxisVector in_transform_axis_order(2 + n_spatial_dimensions);
|
||||
for (size_t i = 0; i < in_transform_axis_order.size(); i++)
|
||||
const size_t input_size = shape_size(input_3d);
|
||||
if (input_size == 1)
|
||||
{
|
||||
in_transform_axis_order[i] = i;
|
||||
}
|
||||
CoordinateTransform in_transform(in_shape,
|
||||
in_transform_start,
|
||||
in_transform_end,
|
||||
in_transform_movement_strides,
|
||||
in_transform_axis_order,
|
||||
in_transform_pad_below,
|
||||
in_transform_pad_above,
|
||||
in_transform_dilation_strides);
|
||||
|
||||
// Simultaneously with iterating I, for the filter we need to iterate the
|
||||
// coordinate:
|
||||
//
|
||||
// F
|
||||
//
|
||||
// over the range (noninclusive on the right):
|
||||
//
|
||||
// (chan_out,0,0,...,0) ->
|
||||
// (chan_out+1,
|
||||
// chans_in_count,
|
||||
// filter_dims_1,
|
||||
// ...,
|
||||
// filter_dims_n)
|
||||
//
|
||||
// with unit stride.
|
||||
|
||||
Shape filter_transform_start(2 + n_spatial_dimensions);
|
||||
Shape filter_transform_end(2 + n_spatial_dimensions);
|
||||
|
||||
filter_transform_start[filter_out_channel_axis] = out_channel;
|
||||
filter_transform_end[filter_out_channel_axis] = out_channel + 1;
|
||||
filter_transform_start[filter_in_channel_axis] = 0;
|
||||
filter_transform_end[filter_in_channel_axis] = 1;
|
||||
|
||||
for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
|
||||
{
|
||||
filter_transform_start[i] = 0;
|
||||
filter_transform_end[i] = filter_shape[i];
|
||||
}
|
||||
|
||||
CoordinateTransform filter_transform(
|
||||
filter_shape, filter_transform_start, filter_transform_end);
|
||||
|
||||
// As we go, we sum up:
|
||||
//
|
||||
// out[O] += in[I] * filter[F].
|
||||
|
||||
ACCUMULATION result = 0;
|
||||
|
||||
CoordinateTransform::Iterator in_it = in_transform.begin();
|
||||
CoordinateTransform::Iterator filter_it = filter_transform.begin();
|
||||
CoordinateTransform::Iterator in_it_end = in_transform.end();
|
||||
CoordinateTransform::Iterator filter_it_end = filter_transform.end();
|
||||
|
||||
size_t in_channel_stride = row_major_strides(in_shape).at(in_channel_axis);
|
||||
size_t filter_in_channel_stride =
|
||||
row_major_strides(filter_shape).at(filter_in_channel_axis);
|
||||
|
||||
while (in_it != in_it_end && filter_it != filter_it_end)
|
||||
{
|
||||
const Coordinate& in_coord = *in_it;
|
||||
if (in_transform.has_source_coordinate(in_coord))
|
||||
for (size_t i = 0; i < shape_size(input_shape); ++i)
|
||||
{
|
||||
size_t in_idx = in_transform.index(in_coord);
|
||||
const Coordinate& filter_coord = *filter_it;
|
||||
size_t filter_idx = filter_transform.index(filter_coord);
|
||||
for (size_t in_channel = 0; in_channel < n_in_channels; ++in_channel)
|
||||
input_zeros.push_back(in[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t batch = 0; batch < input_shape[0]; ++batch)
|
||||
{
|
||||
const auto offset_batch = batch * input_size * input_shape[1];
|
||||
for (size_t channel = 0; channel < input_shape[1]; ++channel)
|
||||
{
|
||||
ACCUMULATION in_v = static_cast<ACCUMULATION>(in[in_idx]);
|
||||
ACCUMULATION f_v = static_cast<ACCUMULATION>(filter[filter_idx]);
|
||||
const auto offset_channel = offset_batch + channel * input_size;
|
||||
for (int i_z = 0; i_z < input_3d[0]; ++i_z)
|
||||
{
|
||||
const auto offset_i_z = i_z * input_3d[2] * input_3d[1];
|
||||
for (int i_y = 0; i_y < input_3d[1]; ++i_y)
|
||||
{
|
||||
const auto offset_i_y = i_y * input_3d[2];
|
||||
for (int i_x = 0; i_x < input_3d[2]; ++i_x)
|
||||
{
|
||||
input_zeros.push_back(
|
||||
in[offset_channel + i_x + offset_i_y + offset_i_z]);
|
||||
|
||||
result += in_v * f_v;
|
||||
in_idx += in_channel_stride;
|
||||
filter_idx += filter_in_channel_stride;
|
||||
if (i_x < input_3d[2] - 1)
|
||||
{
|
||||
for (int k = 0; k < strides_3d[2] - 1; k++)
|
||||
{
|
||||
input_zeros.push_back(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (i_y < input_3d[1] - 1)
|
||||
{
|
||||
const auto new_size =
|
||||
output_3d[2] * (strides_3d[1] - 1);
|
||||
input_zeros.insert(input_zeros.begin() +
|
||||
input_zeros.size(),
|
||||
new_size,
|
||||
0);
|
||||
}
|
||||
}
|
||||
|
||||
if (i_z < input_3d[0] - 1)
|
||||
{
|
||||
const auto new_size =
|
||||
output_3d[1] * output_3d[2] * (strides_3d[0] - 1);
|
||||
input_zeros.insert(
|
||||
input_zeros.begin() + input_zeros.size(), new_size, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
++in_it;
|
||||
++filter_it;
|
||||
}
|
||||
|
||||
out[out_transform.index(out_coord)] = result;
|
||||
}
|
||||
std::fesetround(old_mode);
|
||||
|
||||
void infer_forward_convbackprop_output_shape(const Shape& in_spatial_shape,
|
||||
const Shape& f_spatial_shape,
|
||||
const Shape& out_spatial_shape,
|
||||
Shape& infer_spatial_shape,
|
||||
const Strides& strides,
|
||||
const Strides& dilations,
|
||||
const CoordinateDiff& output_padding)
|
||||
{
|
||||
for (size_t idx = 0; idx < in_spatial_shape.size(); idx++)
|
||||
{
|
||||
int total_padding = strides[idx] * (in_spatial_shape[idx] - 1) +
|
||||
dilations[idx] * (f_spatial_shape[idx] - 1) + 1 -
|
||||
out_spatial_shape[idx] + output_padding[idx];
|
||||
size_t padded_dim = std::max<size_t>(total_padding, 0);
|
||||
size_t filter_dilated_dim = dilations[idx] * (f_spatial_shape[idx] - 1) + 1;
|
||||
size_t out_spatial_dim = (in_spatial_shape[idx] - 1) * strides[idx] +
|
||||
filter_dilated_dim - padded_dim +
|
||||
output_padding[idx];
|
||||
infer_spatial_shape.push_back(out_spatial_dim);
|
||||
}
|
||||
}
|
||||
|
||||
void validate_convolution_backprop_parameters(const Shape& in_shape,
|
||||
const Shape& f_shape,
|
||||
const Shape& out_shape,
|
||||
const Strides& strides,
|
||||
const Strides& dilations,
|
||||
const CoordinateDiff& pads_begin,
|
||||
const CoordinateDiff& pads_end,
|
||||
const CoordinateDiff& output_padding)
|
||||
{
|
||||
// this implementation supports 1D, 2D and 3D convolutions
|
||||
NGRAPH_CHECK(in_shape.size() >= 3 && in_shape.size() <= 5,
|
||||
"Unsupported input rank: ",
|
||||
in_shape);
|
||||
|
||||
NGRAPH_CHECK(in_shape.size() == f_shape.size(),
|
||||
"Incompatible input ranks: ",
|
||||
in_shape.size(),
|
||||
" and ",
|
||||
f_shape.size());
|
||||
|
||||
NGRAPH_CHECK(in_shape[in_channel_axis] == f_shape[filter_input_ch_axis],
|
||||
"Incompatible input channels in data batch and filters shapes: ",
|
||||
in_shape[in_channel_axis],
|
||||
" and ",
|
||||
f_shape[filter_input_ch_axis]);
|
||||
|
||||
NGRAPH_CHECK(in_shape.size() == out_shape.size(),
|
||||
"Incompatible input and output ranks: ",
|
||||
in_shape.size(),
|
||||
" and ",
|
||||
out_shape.size());
|
||||
|
||||
const auto spatial_dims = in_shape.size() - 2;
|
||||
NGRAPH_CHECK(strides.size() == spatial_dims,
|
||||
"Strides not definied for all and only spatial dimensions.");
|
||||
|
||||
NGRAPH_CHECK(dilations.size() == spatial_dims,
|
||||
"Dilations not defined for all and only spatial dimensions.");
|
||||
|
||||
NGRAPH_CHECK((pads_begin.size() == pads_end.size()) &&
|
||||
(pads_begin.size() == spatial_dims),
|
||||
"Pads not defined for all and only spatial dimensions.");
|
||||
|
||||
NGRAPH_CHECK(!output_padding.empty() && output_padding.size() == spatial_dims,
|
||||
"Output padding not defined for all and only spatial dimensions.");
|
||||
|
||||
Shape out_spatial_shape{std::next(out_shape.begin(), 2), std::end(out_shape)};
|
||||
Shape infered_out_spatial_shape{};
|
||||
infer_forward_convbackprop_output_shape(
|
||||
Shape{std::next(in_shape.begin(), 2), std::end(in_shape)},
|
||||
Shape{std::next(f_shape.begin(), 2), std::end(f_shape)},
|
||||
Shape{std::next(out_shape.begin(), 2), std::end(out_shape)},
|
||||
infered_out_spatial_shape,
|
||||
strides,
|
||||
dilations,
|
||||
output_padding);
|
||||
NGRAPH_CHECK(out_spatial_shape == infered_out_spatial_shape,
|
||||
"Incorrect output shape provided");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void convolution_backprop_impl(const T* in,
|
||||
const T* f,
|
||||
T* out,
|
||||
const Shape& in_shape,
|
||||
const Shape& f_shape,
|
||||
const Shape& out_shape,
|
||||
const Strides& strides,
|
||||
const Strides& dilation,
|
||||
const CoordinateDiff& pads_begin,
|
||||
const CoordinateDiff& pads_end,
|
||||
const CoordinateDiff& output_padding)
|
||||
|
||||
{
|
||||
// here we are converting all param types to int's to avoid arithmetic issues
|
||||
// (e.g signed + unsigned) in indexes calculation later
|
||||
ConvolutionParams params{strides, dilation, pads_begin, pads_end, output_padding};
|
||||
|
||||
// here we are extending spatial dimensions to 3D, because we are going to use 3D
|
||||
// convolution implementation to convolve also in 1D & 2D case
|
||||
Shape input_shape{in_shape};
|
||||
Shape filters_shape{f_shape};
|
||||
if (in_shape.size() < 5)
|
||||
{
|
||||
extend_to_3D(params, input_shape, filters_shape);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_shape.size() - 2; ++i)
|
||||
{
|
||||
if (input_shape[i + 2] > 1 || filters_shape[i + 2] > 1)
|
||||
{
|
||||
params.pads_begin[i] = filters_shape[i + 2] - params.pads_begin[i] - 1;
|
||||
params.pads_end[i] = filters_shape[i + 2] - params.pads_end[i] - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
params.pads_begin[i] = 0;
|
||||
params.pads_end[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// convert output shape to 3D, contains only dimensions
|
||||
Shape out_shape_3d{out_shape.begin() + 2, out_shape.end()};
|
||||
|
||||
int out_shape_rank = out_shape.size() - 2;
|
||||
if (out_shape_rank < 3)
|
||||
{
|
||||
int missing_dims = 3 - out_shape_rank;
|
||||
out_shape_3d.insert(
|
||||
std::prev(out_shape_3d.end(), out_shape_rank), missing_dims, 1);
|
||||
}
|
||||
|
||||
// modify params.pads_end when output_shape was provided in ctor in order to
|
||||
// calculate expected number of output elements
|
||||
for (size_t i = 0; i < out_shape_3d.size(); i++)
|
||||
{
|
||||
if (out_shape_3d[i] > 1)
|
||||
{
|
||||
// expected_dim = (in - 1)* strides + filter - 2*padding + out_padding
|
||||
// strides is already applied (through 0's extension in input)
|
||||
// padding = pads_begin + pads_end, formula below is using
|
||||
// params.pad_begin/params.pads_end:
|
||||
const size_t expected_dim =
|
||||
out_shape_3d[i] - ((input_shape[i + 2] - 1) - filters_shape[i + 2] +
|
||||
params.pads_begin[i] + params.pads_end[i] + 2 +
|
||||
params.output_padding[i]);
|
||||
params.pads_end[i] += expected_dim;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t filters_count = filters_shape[filter_out_ch_axis];
|
||||
const Shape filter_shape(++filters_shape.begin(), filters_shape.end());
|
||||
const size_t filter_size = shape_size(filter_shape);
|
||||
|
||||
const size_t batches_count = input_shape[in_batch_axis];
|
||||
Shape batch_shape(++input_shape.begin(), input_shape.end());
|
||||
const size_t batch_size = shape_size(batch_shape);
|
||||
|
||||
auto batch = in;
|
||||
|
||||
for (size_t batch_idx = 0; batch_idx < batches_count; ++batch_idx)
|
||||
{
|
||||
auto filter = f;
|
||||
for (size_t f_idx = 0; f_idx < filters_count; ++f_idx)
|
||||
{
|
||||
convolve_3D_channels(params, batch, batch_shape, filter, filter_shape, out);
|
||||
filter += filter_size;
|
||||
}
|
||||
batch += batch_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OUTPUT,
|
||||
typename FILTER,
|
||||
typename INPUT,
|
||||
typename ACCUMULATION = typename widen<INPUT>::type>
|
||||
void convolution_backprop_in(const OUTPUT* delta_out,
|
||||
const FILTER* filter,
|
||||
INPUT* delta_in,
|
||||
const Shape& out_shape,
|
||||
const Shape& filter_shape,
|
||||
template <typename T>
|
||||
void convolution_backprop_in(const T* delta_in,
|
||||
const T* filter,
|
||||
T* delta_out,
|
||||
const Shape& in_shape,
|
||||
const Shape& filter_shape,
|
||||
const Shape& out_shape,
|
||||
const Strides& in_dilation,
|
||||
const Strides& filter_dilation,
|
||||
const CoordinateDiff& forward_in_pad_bellow,
|
||||
const CoordinateDiff& forward_in_pad_above,
|
||||
const Strides& stride)
|
||||
const Strides& stride,
|
||||
const CoordinateDiff& output_padding)
|
||||
{
|
||||
std::vector<T> extended_input;
|
||||
std::vector<T> extended_filter;
|
||||
AxisSet reverse_axes;
|
||||
|
||||
Shape conv_input_shape = in_shape;
|
||||
Shape conv_filter_shape = filter_shape;
|
||||
Strides conv_stride = stride;
|
||||
Strides conv_filter_dilation = filter_dilation;
|
||||
auto conv_input_data = delta_in;
|
||||
|
||||
validate_convolution_backprop_parameters(in_shape,
|
||||
filter_shape,
|
||||
out_shape,
|
||||
stride,
|
||||
filter_dilation,
|
||||
forward_in_pad_bellow,
|
||||
forward_in_pad_above,
|
||||
output_padding);
|
||||
|
||||
// Note that we only reverse the spatial dimensions here (loop
|
||||
// starts at 2)
|
||||
std::vector<INPUT> reversed(shape_size(filter_shape));
|
||||
AxisSet reverse_axes;
|
||||
size_t reverse_axes_start = 2;
|
||||
for (size_t i = reverse_axes_start; i < filter_shape.size(); ++i)
|
||||
std::vector<T> reversed(shape_size(filter_shape));
|
||||
for (size_t i = 2; i < filter_shape.size(); ++i)
|
||||
{
|
||||
reverse_axes.insert(i);
|
||||
}
|
||||
@ -242,55 +323,109 @@ namespace ngraph
|
||||
filter_shape,
|
||||
filter_shape,
|
||||
reverse_axes,
|
||||
sizeof(FILTER));
|
||||
size_t filter_out_channel_axis = 1;
|
||||
size_t filter_in_channel_axis = 0;
|
||||
sizeof(T));
|
||||
|
||||
// Compute backward pad out pad bellow
|
||||
size_t spatial_dim_count = in_shape.size() - 2;
|
||||
auto conv_filter_data = &reversed[0];
|
||||
|
||||
CoordinateDiff backward_delta_out_pad_below;
|
||||
backward_delta_out_pad_below.resize(spatial_dim_count);
|
||||
|
||||
for (size_t i = 0; i < spatial_dim_count; i++)
|
||||
// if channel number for output is > 1 then reverse layout of filter coefficients as
|
||||
// it is required by convolve_3D_channels() function.
|
||||
// Current layout:
|
||||
// batch0_ch0|batch0_ch1|...|batch0_chN|...|batch1_ch0|batch1_ch1|...|batch1_chN|...
|
||||
// Expected layout:
|
||||
// batch0_ch0|batch1_ch0|...|batchN_ch0|...|batch0_ch1|batch1_ch1|...|batch1_chN|...
|
||||
if (filter_shape[1] > 1)
|
||||
{
|
||||
backward_delta_out_pad_below[i] =
|
||||
(static_cast<ptrdiff_t>(filter_shape[i + 2]) - 1) * filter_dilation[i] -
|
||||
forward_in_pad_bellow[i];
|
||||
}
|
||||
// Compute backward pad out pad above
|
||||
CoordinateDiff backward_delta_out_pad_above;
|
||||
backward_delta_out_pad_above.resize(spatial_dim_count);
|
||||
std::vector<T> temp_reversed(reversed);
|
||||
const Shape filter_dim_shape(filter_shape.begin() + 2, filter_shape.end());
|
||||
const size_t filter_size = shape_size(filter_dim_shape);
|
||||
|
||||
for (size_t i = 0; i < spatial_dim_count; i++)
|
||||
{
|
||||
backward_delta_out_pad_above[i] =
|
||||
(static_cast<ptrdiff_t>(filter_shape[i + 2]) - 1) * filter_dilation[i] +
|
||||
((forward_in_pad_bellow[i] + ((in_shape[i + 2]) - 1) * in_dilation[i] +
|
||||
forward_in_pad_above[i] -
|
||||
(static_cast<ptrdiff_t>(filter_shape[i + 2]) - 1) * filter_dilation[i]) %
|
||||
stride[i]) -
|
||||
forward_in_pad_above[i];
|
||||
for (size_t i = 0; i < filter_shape[1]; i++)
|
||||
{
|
||||
for (size_t j = 0; j < filter_shape[0]; j++)
|
||||
{
|
||||
const auto delta = temp_reversed.begin() +
|
||||
j * filter_shape[1] * filter_size + i * filter_size;
|
||||
const auto out = reversed.begin() + i * filter_shape[0] * filter_size +
|
||||
j * filter_size;
|
||||
std::copy(delta, delta + filter_size, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
convolution_backprop_impl<OUTPUT, FILTER, INPUT, ACCUMULATION>(
|
||||
delta_out,
|
||||
&reversed[0],
|
||||
delta_in,
|
||||
out_shape,
|
||||
filter_shape,
|
||||
in_shape,
|
||||
in_dilation,
|
||||
filter_dilation,
|
||||
backward_delta_out_pad_below,
|
||||
backward_delta_out_pad_above,
|
||||
stride,
|
||||
0,
|
||||
1,
|
||||
filter_out_channel_axis,
|
||||
filter_in_channel_axis,
|
||||
0,
|
||||
1);
|
||||
// swap filter batch and channels
|
||||
std::iter_swap(conv_filter_shape.begin(), conv_filter_shape.begin() + 1);
|
||||
|
||||
// extend stride and filter inputs with zero padding for stride and filter_dilation
|
||||
// > 1, after that set stride and filter params to 1.
|
||||
const size_t stride_dim =
|
||||
std::accumulate(stride.begin(), stride.end(), 1, std::multiplies<size_t>());
|
||||
if (stride_dim >= 2)
|
||||
{
|
||||
extend_with_zeros(stride, in_shape, delta_in, conv_input_shape, extended_input);
|
||||
std::fill(conv_stride.begin(), conv_stride.end(), 1);
|
||||
conv_input_data = &extended_input[0];
|
||||
}
|
||||
|
||||
const size_t dilation_dim = std::accumulate(
|
||||
filter_dilation.begin(), filter_dilation.end(), 1, std::multiplies<size_t>());
|
||||
if (dilation_dim >= 2)
|
||||
{
|
||||
extend_with_zeros<T>(filter_dilation,
|
||||
filter_shape,
|
||||
reinterpret_cast<const T*>(&reversed[0]),
|
||||
conv_filter_shape,
|
||||
extended_filter);
|
||||
std::fill(conv_filter_dilation.begin(), conv_filter_dilation.end(), 1);
|
||||
conv_filter_data = &extended_filter[0];
|
||||
}
|
||||
|
||||
convolution_backprop_impl(conv_input_data,
|
||||
conv_filter_data,
|
||||
delta_out,
|
||||
conv_input_shape,
|
||||
conv_filter_shape,
|
||||
out_shape,
|
||||
conv_stride,
|
||||
conv_filter_dilation,
|
||||
forward_in_pad_bellow,
|
||||
forward_in_pad_above,
|
||||
output_padding);
|
||||
}
|
||||
|
||||
// DEPRECATED, can't be removed currently due to arm-plugin dependency
|
||||
template <typename OUTPUT,
|
||||
typename FILTER,
|
||||
typename INPUT,
|
||||
typename ACCUMULATION = typename widen<INPUT>::type>
|
||||
NGRAPH_DEPRECATED(
|
||||
"convolution_backprop_in function with 4 template types is deprecated, use "
|
||||
"function with 1 template and output_padding parameter.")
|
||||
void convolution_backprop_in(const INPUT* delta_in,
|
||||
const FILTER* filter,
|
||||
OUTPUT* delta_out,
|
||||
const Shape& in_shape,
|
||||
const Shape& filter_shape,
|
||||
const Shape& out_shape,
|
||||
const Strides& in_dilation,
|
||||
const Strides& filter_dilation,
|
||||
const CoordinateDiff& forward_in_pad_bellow,
|
||||
const CoordinateDiff& forward_in_pad_above,
|
||||
const Strides& stride)
|
||||
{
|
||||
const ngraph::CoordinateDiff output_padding(in_shape.size() - 2, 0);
|
||||
|
||||
convolution_backprop_in(delta_in,
|
||||
filter,
|
||||
delta_out,
|
||||
in_shape,
|
||||
filter_shape,
|
||||
out_shape,
|
||||
in_dilation,
|
||||
filter_dilation,
|
||||
forward_in_pad_bellow,
|
||||
forward_in_pad_above,
|
||||
stride,
|
||||
output_padding);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
|
@ -178,23 +178,24 @@ namespace ngraph
|
||||
const size_t group_out_size = shape_size(group_out_shape);
|
||||
|
||||
Strides in_dilation(in_shape.size(), 1);
|
||||
const ngraph::CoordinateDiff output_padding(in_shape.size() - 2, 0);
|
||||
for (size_t batch_idx = 0; batch_idx < in_shape[in_batch_axis]; ++batch_idx)
|
||||
{
|
||||
group_filter = f;
|
||||
for (size_t group_idx = 0; group_idx < group_count; ++group_idx)
|
||||
{
|
||||
runtime::reference::convolution_backprop_in<INPUT, FILTER, OUTPUT, ACCU>(
|
||||
group_batch,
|
||||
group_filter,
|
||||
group_out,
|
||||
group_batch_shape,
|
||||
group_filter_shape,
|
||||
group_out_shape,
|
||||
in_dilation,
|
||||
dilation,
|
||||
pads_begin,
|
||||
pads_end,
|
||||
strides);
|
||||
runtime::reference::convolution_backprop_in(group_batch,
|
||||
group_filter,
|
||||
group_out,
|
||||
group_batch_shape,
|
||||
group_filter_shape,
|
||||
group_out_shape,
|
||||
in_dilation,
|
||||
dilation,
|
||||
pads_begin,
|
||||
pads_end,
|
||||
strides,
|
||||
output_padding);
|
||||
group_batch += group_batch_size;
|
||||
group_filter += group_filter_size;
|
||||
group_out += group_out_size;
|
||||
|
@ -356,6 +356,7 @@ set(MULTI_TEST_SRC
|
||||
backend/constant.in.cpp
|
||||
backend/convert.in.cpp
|
||||
backend/convert_like.in.cpp
|
||||
backend/convolution_backprop.in.cpp
|
||||
backend/convolution.in.cpp
|
||||
backend/binary_convolution.in.cpp
|
||||
backend/clamp.in.cpp
|
||||
|
1221
ngraph/test/backend/convolution_backprop.in.cpp
Normal file
1221
ngraph/test/backend/convolution_backprop.in.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@ -283,7 +283,8 @@ namespace
|
||||
op->get_dilations(),
|
||||
op->get_pads_begin(),
|
||||
op->get_pads_end(),
|
||||
op->get_strides());
|
||||
op->get_strides(),
|
||||
op->get_output_padding());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user