Reference implementation for Tile op (#2641)
This commit is contained in:
parent
32b886a892
commit
85b06835aa
@ -31,7 +31,8 @@ namespace ngraph
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size);
|
||||
const size_t elem_size,
|
||||
const std::vector<int64_t>& repeats);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,34 +14,91 @@
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <stdio.h>
|
||||
#include <cstdio>
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/runtime/reference/tile.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
void runtime::reference::tile(
|
||||
const char* arg, char* out, const Shape& in_shape, const Shape& out_shape, size_t elem_size)
|
||||
namespace
|
||||
{
|
||||
/// \brief For each axis calculates the product of inner axes
|
||||
/// If dims has shape (2, 3, 4) then for 2 (first axis) the inner axes would be (3, 4)
|
||||
/// and for 3 (second axis) it would be (4)
|
||||
/// If dims has shape(2, 3, 4) then the output vector would be (3 * 4, 4, 1)
|
||||
/// The outermost axis is not used. For innermost axis it is always 1.
|
||||
/// \param[in] dims Shape of the output
|
||||
///
|
||||
/// \return Vector containing calculated values for each axis.
|
||||
std::vector<int64_t> create_pitches(const Shape& dims)
|
||||
{
|
||||
std::vector<int64_t> pitch;
|
||||
pitch.resize(dims.size() - 1);
|
||||
std::partial_sum(
|
||||
dims.rbegin(), dims.rend() - 1, pitch.rbegin(), std::multiplies<int64_t>());
|
||||
pitch.push_back(1);
|
||||
return pitch;
|
||||
}
|
||||
}
|
||||
|
||||
void runtime::reference::tile(const char* arg,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const Shape& out_shape,
|
||||
const size_t elem_size,
|
||||
const std::vector<int64_t>& repeats)
|
||||
{
|
||||
Shape in_shape_expanded(in_shape);
|
||||
in_shape_expanded.insert(in_shape_expanded.begin(), out_shape.size() - in_shape.size(), 1);
|
||||
CoordinateTransform input_transform(in_shape_expanded);
|
||||
CoordinateTransform output_transform(out_shape);
|
||||
size_t block_size = 0;
|
||||
int64_t num_repeats = 0;
|
||||
const int input_rank = in_shape_expanded.size();
|
||||
const int64_t last_dim = in_shape_expanded[input_rank - 1];
|
||||
const std::vector<int64_t> pitches = create_pitches(out_shape);
|
||||
const char* copy = nullptr;
|
||||
|
||||
for (const Coordinate& output_coord : output_transform)
|
||||
std::vector<int64_t> indices(in_shape_expanded.size() - 1, 0);
|
||||
size_t axis = indices.size();
|
||||
|
||||
// Copy and repeat data for innermost axis as many times as described in the repeats parameter
|
||||
while (axis <= indices.size())
|
||||
{
|
||||
std::vector<size_t> coord;
|
||||
for (auto i = 0; i < output_coord.size(); i++)
|
||||
{
|
||||
auto val = output_coord[i] % in_shape_expanded[i];
|
||||
coord.push_back(val);
|
||||
}
|
||||
Coordinate input_coord(coord);
|
||||
block_size = last_dim * elem_size;
|
||||
memcpy(out, arg, block_size);
|
||||
out += block_size;
|
||||
arg += block_size;
|
||||
|
||||
memcpy(out + output_transform.index(output_coord) * elem_size,
|
||||
arg + input_transform.index(input_coord) * elem_size,
|
||||
elem_size);
|
||||
copy = out - block_size;
|
||||
num_repeats = repeats[input_rank - 1] - 1;
|
||||
for (int64_t i = 0; i < num_repeats; ++i)
|
||||
{
|
||||
memcpy(out, copy, block_size);
|
||||
out += block_size;
|
||||
}
|
||||
|
||||
// Copy and repeat data for other axes as many times as described in the repeats parameter
|
||||
while (axis-- != 0)
|
||||
{
|
||||
if (++indices[axis] != in_shape_expanded[axis])
|
||||
{
|
||||
axis = indices.size();
|
||||
break;
|
||||
}
|
||||
indices[axis] = 0;
|
||||
|
||||
ptrdiff_t pitch = pitches[axis] * in_shape_expanded[axis];
|
||||
block_size = pitch * elem_size;
|
||||
copy = out - block_size;
|
||||
num_repeats = repeats[axis] - 1;
|
||||
for (int64_t i = 0; i < num_repeats; i++)
|
||||
{
|
||||
memcpy(out, copy, block_size);
|
||||
out += block_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -119,7 +119,8 @@ bool op::v0::Tile::evaluate(const HostTensorVector& outputs, const HostTensorVec
|
||||
output->get_data_ptr<char>(),
|
||||
data->get_shape(),
|
||||
output_shape,
|
||||
data->get_element_type().size());
|
||||
data->get_element_type().size(),
|
||||
repeats_val);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -1,5 +1,3 @@
|
||||
tile_3d_small_data_rank
|
||||
tile_3d_few_repeats
|
||||
fake_quantize_pdpd
|
||||
convert_float32_bf16
|
||||
convert_bf16_float32
|
||||
|
Loading…
Reference in New Issue
Block a user