diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/tile.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/tile.hpp index 53d25ee7f5a..40e5f5565f8 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/tile.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/tile.hpp @@ -31,7 +31,8 @@ namespace ngraph char* out, const Shape& in_shape, const Shape& out_shape, - size_t elem_size); + const size_t elem_size, + const std::vector& repeats); } } } diff --git a/ngraph/core/reference/src/runtime/reference/tile.cpp b/ngraph/core/reference/src/runtime/reference/tile.cpp index 335a5307ace..5d78dee2483 100644 --- a/ngraph/core/reference/src/runtime/reference/tile.cpp +++ b/ngraph/core/reference/src/runtime/reference/tile.cpp @@ -14,34 +14,91 @@ // limitations under the License. //***************************************************************************** +#include #include -#include +#include +#include #include "ngraph/check.hpp" #include "ngraph/runtime/reference/tile.hpp" using namespace ngraph; -void runtime::reference::tile( - const char* arg, char* out, const Shape& in_shape, const Shape& out_shape, size_t elem_size) +namespace +{ + /// \brief For each axis calculates the product of inner axes + /// If dims has shape (2, 3, 4) then for 2 (first axis) the inner axes would be (3, 4) + /// and for 3 (second axis) it would be (4) + /// If dims has shape(2, 3, 4) then the output vector would be (3 * 4, 4, 1) + /// The outermost axis is not used. For innermost axis it is always 1. + /// \param[in] dims Shape of the output + /// + /// \return Vector containing calculated values for each axis. + std::vector create_pitches(const Shape& dims) + { + std::vector pitch; + pitch.resize(dims.size() - 1); + std::partial_sum( + dims.rbegin(), dims.rend() - 1, pitch.rbegin(), std::multiplies()); + pitch.push_back(1); + return pitch; + } +} + +void runtime::reference::tile(const char* arg, + char* out, + const Shape& in_shape, + const Shape& out_shape, + const size_t elem_size, + const std::vector& repeats) { Shape in_shape_expanded(in_shape); in_shape_expanded.insert(in_shape_expanded.begin(), out_shape.size() - in_shape.size(), 1); - CoordinateTransform input_transform(in_shape_expanded); - CoordinateTransform output_transform(out_shape); + size_t block_size = 0; + int64_t num_repeats = 0; + const int input_rank = in_shape_expanded.size(); + const int64_t last_dim = in_shape_expanded[input_rank - 1]; + const std::vector pitches = create_pitches(out_shape); + const char* copy = nullptr; - for (const Coordinate& output_coord : output_transform) + std::vector indices(in_shape_expanded.size() - 1, 0); + size_t axis = indices.size(); + + // Copy and repeat data for innermost axis as many times as described in the repeats parameter + while (axis <= indices.size()) { - std::vector coord; - for (auto i = 0; i < output_coord.size(); i++) - { - auto val = output_coord[i] % in_shape_expanded[i]; - coord.push_back(val); - } - Coordinate input_coord(coord); + block_size = last_dim * elem_size; + memcpy(out, arg, block_size); + out += block_size; + arg += block_size; - memcpy(out + output_transform.index(output_coord) * elem_size, - arg + input_transform.index(input_coord) * elem_size, - elem_size); + copy = out - block_size; + num_repeats = repeats[input_rank - 1] - 1; + for (int64_t i = 0; i < num_repeats; ++i) + { + memcpy(out, copy, block_size); + out += block_size; + } + + // Copy and repeat data for other axes as many times as described in the repeats parameter + while (axis-- != 0) + { + if (++indices[axis] != in_shape_expanded[axis]) + { + axis = indices.size(); + break; + } + indices[axis] = 0; + + ptrdiff_t pitch = pitches[axis] * in_shape_expanded[axis]; + block_size = pitch * elem_size; + copy = out - block_size; + num_repeats = repeats[axis] - 1; + for (int64_t i = 0; i < num_repeats; i++) + { + memcpy(out, copy, block_size); + out += block_size; + } + } } } diff --git a/ngraph/core/src/op/tile.cpp b/ngraph/core/src/op/tile.cpp index 0e2ab2d573c..ee7f415914d 100644 --- a/ngraph/core/src/op/tile.cpp +++ b/ngraph/core/src/op/tile.cpp @@ -119,7 +119,8 @@ bool op::v0::Tile::evaluate(const HostTensorVector& outputs, const HostTensorVec output->get_data_ptr(), data->get_shape(), output_shape, - data->get_element_type().size()); + data->get_element_type().size(), + repeats_val); return true; } diff --git a/ngraph/test/runtime/interpreter/unit_test.manifest b/ngraph/test/runtime/interpreter/unit_test.manifest index 35b32188401..02fafaf8086 100644 --- a/ngraph/test/runtime/interpreter/unit_test.manifest +++ b/ngraph/test/runtime/interpreter/unit_test.manifest @@ -1,5 +1,3 @@ -tile_3d_small_data_rank -tile_3d_few_repeats fake_quantize_pdpd convert_float32_bf16 convert_bf16_float32