New Gather op reference implementation. (#3633)
* New Gather op reference implementation. * Unify span implementation for gather and gather_nd. Create span.hpp for common implementation of span. * Move span to utils directory. * Address review comments. * update span * Address PR comments. Co-authored-by: Patryk Elszkowski <patryk.elszkowki@intel.com>
This commit is contained in:
parent
96b032504e
commit
bd9bbe09c3
@ -185,7 +185,7 @@ namespace ngraph
|
||||
Strides(source_shape.size(), 1));
|
||||
}
|
||||
|
||||
/// \brief Class allows to iterate over Tensor with reverted axies part by part.
|
||||
/// \brief Class allows to iterate over Tensor with reverted axes part by part.
|
||||
///
|
||||
/// To create ReverseRange use _reverse_ function.
|
||||
///
|
||||
@ -213,8 +213,14 @@ namespace ngraph
|
||||
return ReverseRange(source_shape, reversed_axis);
|
||||
}
|
||||
|
||||
inline ReverseRange index(const Shape& source_shape)
|
||||
{
|
||||
return reverse(source_shape, {});
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
using impl::Direction;
|
||||
using impl::index;
|
||||
using impl::reverse;
|
||||
using impl::slice;
|
||||
} // namespace coordinates
|
||||
|
@ -18,8 +18,10 @@
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/coordinate_range.hpp"
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
#include "ngraph/runtime/reference/gather_nd.hpp"
|
||||
#include "utils/span.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -27,147 +29,105 @@ namespace ngraph
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
// Implement gather by calling gather_nd on sub-problems
|
||||
// # prepare constant shapes for tensors used for sub problems
|
||||
// indices'.shape = indices.shape[-1] + [1]
|
||||
// params'.shape = params.shape[axis:]
|
||||
// out'.shape = params'.shape
|
||||
// out'.shape[0] = indices.shape[-1]
|
||||
// # call sub-problems
|
||||
// foreach (params_index, out_index) in outer "axis" dimensions
|
||||
// # params_prime is shared by inner loop
|
||||
// params' = param[params_index] # rank(params') == rank(params) - axis
|
||||
// foreach indices_index in outer N-1 dimensions
|
||||
// indices' = indices[indices_index] # rank(indices') == 2
|
||||
// out_index = out_index + indices_index
|
||||
// out' = out[out_index] # rank(out') == rank(params')
|
||||
// gather_nd(params', indices'', out')
|
||||
namespace
|
||||
{
|
||||
template <typename Container>
|
||||
Shape to_shape(const Container& c)
|
||||
{
|
||||
return Shape(begin(c), end(c));
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
std::vector<size_t>
|
||||
join(const Container& c1, const Container& c2, const Container& c3)
|
||||
{
|
||||
using container_value_type =
|
||||
typename std::remove_cv<typename Container::value_type>::type;
|
||||
static_assert(std::is_same<container_value_type, size_t>::value,
|
||||
"Expect same type in container");
|
||||
std::vector<size_t> ret;
|
||||
ret.reserve(c1.size() + c2.size() + c3.size());
|
||||
std::copy(begin(c1), end(c1), std::back_inserter(ret));
|
||||
std::copy(begin(c2), end(c2), std::back_inserter(ret));
|
||||
std::copy(begin(c3), end(c3), std::back_inserter(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
const auto only_one = [] { return coordinates::index(Shape{1}); };
|
||||
} // namespace
|
||||
template <typename T, typename U>
|
||||
void gather(const T* params,
|
||||
const U* indices,
|
||||
T* out,
|
||||
void gather(const T* const params,
|
||||
const U* const indices,
|
||||
T* const out,
|
||||
const Shape& params_shape,
|
||||
const Shape& indices_shape,
|
||||
const Shape& out_shape,
|
||||
size_t axis)
|
||||
{
|
||||
// prepare shape of params_prime (remove first "axis" dimensions)
|
||||
const Shape params_prime_shape(params_shape.begin() + axis, params_shape.end());
|
||||
// prepare shape of indices_prime
|
||||
const size_t indices_ndim = indices_shape.size();
|
||||
Shape indices_prime_shape;
|
||||
// prepare shape of out_prime (same as params_prime except for first dim)
|
||||
Shape out_prime_shape(params_prime_shape);
|
||||
if (indices_ndim > 0)
|
||||
{
|
||||
out_prime_shape[0] = indices_shape[indices_ndim - 1];
|
||||
indices_prime_shape.emplace_back(indices_shape[indices_ndim - 1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
out_prime_shape[0] = 1;
|
||||
}
|
||||
indices_prime_shape.emplace_back(1);
|
||||
using std::next;
|
||||
assert(std::memset(out, 0, shape_size(out_shape) * sizeof(T)));
|
||||
|
||||
// Create a CoordinateTransform for "out" that visits the outer "axis" dimensions
|
||||
const size_t out_ndim = out_shape.size();
|
||||
const Coordinate out_outer_start_corner(out_ndim, 0);
|
||||
Coordinate out_outer_end_corner(out_shape);
|
||||
for (size_t i = axis; i < out_ndim; i++)
|
||||
{
|
||||
out_outer_end_corner[i] = 1;
|
||||
}
|
||||
Strides out_outer_strides(out_ndim, 1);
|
||||
AxisVector out_outer_axis_order(out_ndim);
|
||||
std::iota(out_outer_axis_order.begin(), out_outer_axis_order.end(), 0);
|
||||
CoordinateTransform out_outer_transform(out_shape,
|
||||
out_outer_start_corner,
|
||||
out_outer_end_corner,
|
||||
out_outer_strides,
|
||||
out_outer_axis_order);
|
||||
const auto params_axes_part = span(params_shape).subspan(0, axis);
|
||||
|
||||
// Create a CoordinateTransform for "params" that visits the outer "axis" dimensions
|
||||
const size_t params_ndim = params_shape.size();
|
||||
const Coordinate params_outer_start_corner(params_ndim, 0);
|
||||
Coordinate params_outer_end_corner(params_shape);
|
||||
for (size_t i = axis; i < params_ndim; i++)
|
||||
{
|
||||
params_outer_end_corner[i] = 1;
|
||||
}
|
||||
const Strides params_outer_strides(params_ndim, 1);
|
||||
AxisVector params_outer_axis_order(params_ndim);
|
||||
std::iota(params_outer_axis_order.begin(), params_outer_axis_order.end(), 0);
|
||||
const CoordinateTransform params_outer_transform(params_shape,
|
||||
params_outer_start_corner,
|
||||
params_outer_end_corner,
|
||||
params_outer_strides,
|
||||
params_outer_axis_order);
|
||||
NGRAPH_CHECK(params_shape.size() >= axis, "Not enough axes in param_shape.");
|
||||
|
||||
// Create a CoordinateTransform for "indices" that visits only the first element
|
||||
// along inner most axis
|
||||
const Coordinate indices_outer_start_corner(indices_ndim, 0);
|
||||
Coordinate indices_outer_end_corner(indices_shape);
|
||||
if (indices_ndim > 0)
|
||||
{
|
||||
indices_outer_end_corner[indices_ndim - 1] = 1;
|
||||
}
|
||||
const Strides indices_outer_strides(indices_ndim, 1);
|
||||
AxisVector indices_outer_axis_order(indices_ndim);
|
||||
std::iota(indices_outer_axis_order.begin(), indices_outer_axis_order.end(), 0);
|
||||
const CoordinateTransform indices_outer_transform(indices_shape,
|
||||
indices_outer_start_corner,
|
||||
indices_outer_end_corner,
|
||||
indices_outer_strides,
|
||||
indices_outer_axis_order);
|
||||
const auto remainder_part_shape = span(params_shape).subspan(axis + 1);
|
||||
|
||||
// Create an inner CoordinateTransfrom for "out"
|
||||
const size_t out_inner_ndim = out_ndim - axis;
|
||||
const Shape out_inner_shape(out_shape.begin() + axis, out_shape.end());
|
||||
const Coordinate out_inner_start_corner(out_inner_ndim, 0);
|
||||
Coordinate out_inner_end_corner(out_inner_shape);
|
||||
if (indices_ndim > 0)
|
||||
{
|
||||
out_inner_end_corner[indices_ndim - 1] = 1;
|
||||
}
|
||||
for (size_t i = indices_ndim; i < out_inner_ndim; i++)
|
||||
{
|
||||
out_inner_end_corner[i] = 1;
|
||||
}
|
||||
const Strides out_inner_strides(out_inner_ndim, 1);
|
||||
AxisVector out_inner_axis_order(out_inner_ndim);
|
||||
std::iota(out_inner_axis_order.begin(), out_inner_axis_order.end(), 0);
|
||||
const CoordinateTransform out_inner_transform(out_inner_shape,
|
||||
out_inner_start_corner,
|
||||
out_inner_end_corner,
|
||||
out_inner_strides,
|
||||
out_inner_axis_order);
|
||||
const auto found_out_shape =
|
||||
join(params_axes_part, span(indices_shape), remainder_part_shape);
|
||||
|
||||
auto out_outer_coord_iter = out_outer_transform.begin();
|
||||
for (const Coordinate& params_outer_coord : params_outer_transform)
|
||||
{
|
||||
if (out_outer_coord_iter == out_outer_transform.end())
|
||||
break;
|
||||
const T* params_prime =
|
||||
¶ms[params_outer_transform.index(params_outer_coord)];
|
||||
T* out_outer = &out[out_outer_transform.index(*out_outer_coord_iter)];
|
||||
NGRAPH_CHECK(found_out_shape == out_shape,
|
||||
"Output shape mismatch with calculations");
|
||||
|
||||
auto out_inner_coord_iter = out_inner_transform.begin();
|
||||
for (const Coordinate& indices_outer_coord : indices_outer_transform)
|
||||
const auto batch_shape = span(params_shape).subspan(axis);
|
||||
|
||||
const auto batch_size = shape_size(batch_shape);
|
||||
|
||||
const auto copy_size = shape_size(remainder_part_shape);
|
||||
|
||||
const size_t copy_round_in_batch =
|
||||
indices_shape.size() > 1
|
||||
? shape_size(span(indices_shape.data(), indices_shape.size() - 1))
|
||||
: 1;
|
||||
const size_t round_batch_offset = indices_shape.empty() ? 1 : indices_shape.back();
|
||||
|
||||
auto dst = out;
|
||||
|
||||
auto gather_range = params_axes_part.empty()
|
||||
? only_one()
|
||||
: coordinates::index(to_shape(params_axes_part));
|
||||
for (auto i : gather_range)
|
||||
{
|
||||
auto batch_index = i.begin_index;
|
||||
for (size_t batch = 0; batch != i.element_number;
|
||||
batch_index += i.step, ++batch)
|
||||
{
|
||||
if (out_inner_coord_iter == out_inner_transform.end())
|
||||
break;
|
||||
const U* indices_prime =
|
||||
&indices[indices_outer_transform.index(indices_outer_coord)];
|
||||
T* out_prime = &out_outer[out_inner_transform.index(*out_inner_coord_iter)];
|
||||
gather_nd<T, U>(params_prime,
|
||||
indices_prime,
|
||||
out_prime,
|
||||
params_prime_shape,
|
||||
indices_prime_shape,
|
||||
out_prime_shape);
|
||||
++out_inner_coord_iter;
|
||||
const auto batch_offset = batch_index * batch_size;
|
||||
assert(batch_offset < shape_size(params_shape));
|
||||
for (size_t round = 0; round != copy_round_in_batch; ++round)
|
||||
{
|
||||
const U* input_indices = indices + round * round_batch_offset;
|
||||
const auto indices_no =
|
||||
indices_shape.empty() ? 1 : indices_shape.back();
|
||||
|
||||
assert(!batch_shape.empty());
|
||||
for (size_t ii = 0; ii != indices_no; ++ii)
|
||||
{
|
||||
const auto positive_input_index =
|
||||
input_indices[ii] < 0 ? batch_shape.front() + input_indices[ii]
|
||||
: input_indices[ii];
|
||||
|
||||
const auto src_offset =
|
||||
batch_offset + copy_size * positive_input_index;
|
||||
|
||||
const auto src_begin = next(params, src_offset);
|
||||
const auto src_end = next(src_begin, copy_size);
|
||||
|
||||
std::copy(src_begin, src_end, dst);
|
||||
dst += copy_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
++out_outer_coord_iter;
|
||||
}
|
||||
}
|
||||
} // namespace reference
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
#include "utils/span.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -28,52 +29,8 @@ namespace ngraph
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
namespace
|
||||
namespace details
|
||||
{
|
||||
template <bool check>
|
||||
using Required = typename std::enable_if<check, bool>::type;
|
||||
|
||||
template <typename It>
|
||||
struct IsRandomAccessIt
|
||||
{
|
||||
static constexpr bool value =
|
||||
std::is_same<typename It::iterator_category,
|
||||
std::random_access_iterator_tag>::value;
|
||||
};
|
||||
|
||||
template <typename Iterator, Required<IsRandomAccessIt<Iterator>::value> = true>
|
||||
class Span
|
||||
{
|
||||
public:
|
||||
Span(Iterator begin, Iterator end)
|
||||
: m_begin{begin}
|
||||
, m_end{end}
|
||||
{
|
||||
}
|
||||
|
||||
Iterator begin() const { return m_begin; }
|
||||
Iterator end() const { return m_end; };
|
||||
typename Iterator::value_type operator[](size_t idx) const
|
||||
{
|
||||
return *next(m_begin, idx);
|
||||
}
|
||||
|
||||
typename Iterator::difference_type size() const
|
||||
{
|
||||
return std::distance(m_begin, m_end);
|
||||
}
|
||||
|
||||
private:
|
||||
Iterator m_begin;
|
||||
Iterator m_end;
|
||||
};
|
||||
|
||||
template <typename Iterator>
|
||||
Span<Iterator> span(Iterator begin, Iterator end)
|
||||
{
|
||||
return Span<Iterator>{begin, end};
|
||||
};
|
||||
|
||||
template <typename Iterator>
|
||||
std::vector<size_t> get_indices_offsets(const Iterator beg,
|
||||
const Iterator end,
|
||||
@ -90,7 +47,7 @@ namespace ngraph
|
||||
|
||||
return offsets;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace details
|
||||
|
||||
///
|
||||
/// Implementation find maximum length of *slice* of input *params* which might be
|
||||
@ -143,14 +100,14 @@ namespace ngraph
|
||||
"params_shape should have enough rank to be index by indices"};
|
||||
}
|
||||
|
||||
const auto slice_shape =
|
||||
span(next(begin(params_shape), first_slice_index_in_params), end(params_shape));
|
||||
const auto slice_shape = span(params_shape).subspan(first_slice_index_in_params);
|
||||
const auto slice_size = shape_size(slice_shape);
|
||||
|
||||
const auto dims_begin = next(rbegin(params_shape), slice_shape.size());
|
||||
const auto dims_end = next(dims_begin, indices_shape.back() - 1);
|
||||
|
||||
const auto indices_offsets = get_indices_offsets(dims_begin, dims_end, slice_size);
|
||||
const auto indices_offsets =
|
||||
details::get_indices_offsets(dims_begin, dims_end, slice_size);
|
||||
|
||||
const auto batch_offset = indices_offsets.front() * params_shape[batch_dims];
|
||||
|
||||
|
@ -0,0 +1,154 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
namespace details
|
||||
{
|
||||
template <bool check>
|
||||
using Required = typename std::enable_if<check, bool>::type;
|
||||
|
||||
template <typename It>
|
||||
struct IsRandomAccessIt
|
||||
{
|
||||
static constexpr bool value =
|
||||
std::is_same<typename It::iterator_category,
|
||||
std::random_access_iterator_tag>::value;
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
using void_t = void;
|
||||
|
||||
template <typename, typename = size_t>
|
||||
struct is_complete : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_complete<T, decltype(sizeof(T))> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename It>
|
||||
struct from_iterator
|
||||
{
|
||||
using stored_value = typename std::remove_pointer<
|
||||
typename std::iterator_traits<It>::pointer>::type;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
|
||||
/// @brief Span should mimic std::span
|
||||
template <typename Element>
|
||||
class Span
|
||||
{
|
||||
public:
|
||||
static_assert(std::is_object<Element>::value,
|
||||
"Element must be an object type (not a reference type or void)");
|
||||
static_assert(details::is_complete<Element>::value,
|
||||
"Element must be a complete type (not a forward declaration)");
|
||||
static_assert(!std::is_abstract<Element>::value,
|
||||
"Element cannot be an abstract class type");
|
||||
|
||||
constexpr Span() = default;
|
||||
|
||||
constexpr Span(Element* data, std::size_t size)
|
||||
: m_data{data}
|
||||
, m_size{size}
|
||||
{
|
||||
}
|
||||
|
||||
using value_type = Element;
|
||||
using size_type = std::size_t;
|
||||
|
||||
constexpr Element* begin() const noexcept { return m_data; }
|
||||
constexpr Element* end() const noexcept { return m_data + m_size; }
|
||||
friend constexpr Element* begin(const Span& s) noexcept { return s.begin(); }
|
||||
friend constexpr Element* end(const Span& s) noexcept { return s.end(); }
|
||||
constexpr std::size_t size() const noexcept { return m_size; }
|
||||
constexpr bool empty() const noexcept { return !m_size; }
|
||||
constexpr Element& front() const noexcept { return *m_data; }
|
||||
constexpr Element& back() const noexcept { return *(m_data + (m_size - 1)); }
|
||||
constexpr Element& operator[](std::size_t idx) const { return *(m_data + idx); }
|
||||
Element& at(std::size_t idx) const { return *(m_data + idx); }
|
||||
Span subspan(std::size_t offset,
|
||||
std::size_t size = std::numeric_limits<std::size_t>::max())
|
||||
{
|
||||
if (offset > m_size)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
return {m_data + offset, std::min(size, m_size - offset)};
|
||||
}
|
||||
|
||||
private:
|
||||
Element* m_data{nullptr};
|
||||
std::size_t m_size{0};
|
||||
};
|
||||
|
||||
template <typename Iterator,
|
||||
typename value = typename details::from_iterator<Iterator>::stored_value,
|
||||
details::Required<details::IsRandomAccessIt<Iterator>::value> = true>
|
||||
constexpr auto span(Iterator first, Iterator second) -> Span<value>
|
||||
{
|
||||
return Span<value>{
|
||||
std::addressof(*first),
|
||||
static_cast<typename Span<value>::size_type>(std::distance(first, second))};
|
||||
}
|
||||
|
||||
template <typename Container,
|
||||
// check if Container has contiguous range memory
|
||||
typename = details::void_t<decltype(std::declval<Container>().data()),
|
||||
decltype(std::declval<Container>().size())>>
|
||||
constexpr auto span(const Container& c) -> Span<const typename Container::value_type>
|
||||
{
|
||||
return {c.data(), c.size()};
|
||||
}
|
||||
|
||||
template <typename Container,
|
||||
// check if Container has contiguous range memory
|
||||
typename = details::void_t<decltype(std::declval<Container>().data()),
|
||||
decltype(std::declval<Container>().size())>>
|
||||
constexpr auto span(Container& c) -> Span<typename Container::value_type>
|
||||
{
|
||||
return {c.data(), c.size()};
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
constexpr auto span(const Element* data, std::size_t size) -> Span<const Element>
|
||||
{
|
||||
return {data, size};
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
constexpr auto span(Element* data, std::size_t size) -> Span<Element>
|
||||
{
|
||||
return {data, size};
|
||||
}
|
||||
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
} // namespace ngraph
|
@ -72,19 +72,95 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_4d_indices_axis_0_2d_input)
|
||||
auto f = make_shared<Function>(G, ParameterVector{P, I});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>({1.0f, 1.1f, 2.0f, 2.1f, 3.0f, 3.1f});
|
||||
test_case.add_input<int32_t>({0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2,
|
||||
0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2,
|
||||
0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2});
|
||||
|
||||
// clang-format off
|
||||
test_case.add_input<float>({1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f});
|
||||
|
||||
test_case.add_input<int32_t>({0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
|
||||
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2});
|
||||
test_case.add_expected_output<float>(
|
||||
out_shape,
|
||||
{1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f,
|
||||
3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f,
|
||||
2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f,
|
||||
2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f,
|
||||
1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f,
|
||||
3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f,
|
||||
2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f});
|
||||
{ 1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f});
|
||||
// clang-format on
|
||||
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
|
||||
}
|
||||
|
||||
@ -100,14 +176,50 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_3d_indices_axis_0_2d_input)
|
||||
auto f = make_shared<Function>(G, ParameterVector{P, I});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>({1.0f, 1.1f, 2.0f, 2.1f, 3.0f, 3.1f});
|
||||
// clang-format off
|
||||
test_case.add_input<float>({1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f});
|
||||
test_case.add_input<int32_t>(
|
||||
{0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2});
|
||||
{0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2,
|
||||
0, 1, 1, 2});
|
||||
test_case.add_expected_output<float>(
|
||||
out_shape, {1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f,
|
||||
2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f,
|
||||
1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f,
|
||||
2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f});
|
||||
out_shape, {1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f,
|
||||
|
||||
1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f});
|
||||
// clang-format on
|
||||
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
|
||||
}
|
||||
|
||||
@ -123,10 +235,20 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_2d_indices_axis_0_2d_input)
|
||||
auto f = make_shared<Function>(G, ParameterVector{P, I});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>({1.0f, 1.1f, 2.0f, 2.1f, 3.0f, 3.1f});
|
||||
// clang-format off
|
||||
test_case.add_input<float>({1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f});
|
||||
// clang-format on
|
||||
test_case.add_input<int32_t>({0, 1, 1, 2});
|
||||
// clang-format off
|
||||
test_case.add_expected_output<float>(out_shape,
|
||||
{1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f});
|
||||
{1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f});
|
||||
// clang-format on
|
||||
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
|
||||
}
|
||||
|
||||
@ -142,10 +264,24 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_2d_negative_and_positive_indices_axis_0_2d_i
|
||||
auto f = make_shared<Function>(G, ParameterVector{P, I});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>({1.0f, 1.1f, 2.0f, 2.1f, 3.0f, 3.1f});
|
||||
|
||||
// clang-format off
|
||||
test_case.add_input<float>({1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f});
|
||||
// clang-format on
|
||||
|
||||
test_case.add_input<int32_t>({0, -2, 1, 2});
|
||||
|
||||
// clang-format off
|
||||
test_case.add_expected_output<float>(out_shape,
|
||||
{1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f});
|
||||
{1.0f, 1.1f,
|
||||
2.0f, 2.1f,
|
||||
|
||||
2.0f, 2.1f,
|
||||
3.0f, 3.1f});
|
||||
// clang-format on
|
||||
|
||||
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
|
||||
}
|
||||
|
||||
@ -197,9 +333,19 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_2d_indices_axis_1_2d_input)
|
||||
auto f = make_shared<Function>(G, ParameterVector{P, I});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>({1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f});
|
||||
|
||||
// clang-format off
|
||||
test_case.add_input<float>({1.0f, 1.1f, 1.2f,
|
||||
2.0f, 2.1f, 2.2f,
|
||||
3.0f, 3.1f, 3.2f});
|
||||
// clang-format on
|
||||
test_case.add_input<int32_t>({0, 2});
|
||||
test_case.add_expected_output<float>(out_shape, {1.0f, 1.2f, 2.0f, 2.2f, 3.0f, 3.2f});
|
||||
|
||||
// clang-format off
|
||||
test_case.add_expected_output<float>(out_shape, {1.0f, 1.2f,
|
||||
2.0f, 2.2f,
|
||||
3.0f, 3.2f});
|
||||
// clang-format on
|
||||
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
|
||||
}
|
||||
|
||||
@ -215,14 +361,40 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_1d_indices_axis_2_4d_input)
|
||||
auto f = make_shared<Function>(G, ParameterVector{P, I});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>({1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
|
||||
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
|
||||
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
|
||||
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f});
|
||||
// clang-format off
|
||||
test_case.add_input<float>({ 1.0f, 1.1f, 1.2f,
|
||||
2.0f, 2.1f, 2.2f,
|
||||
3.0f, 3.1f, 3.2f,
|
||||
|
||||
11.0f, 11.1f, 11.2f,
|
||||
12.0f, 12.1f, 12.2f,
|
||||
13.0f, 13.1f, 13.2f,
|
||||
|
||||
|
||||
101.0f, 101.1f, 101.2f,
|
||||
102.0f, 102.1f, 102.2f,
|
||||
103.0f, 103.1f, 103.2f,
|
||||
|
||||
111.0f, 111.1f, 111.2f,
|
||||
112.0f, 112.1f, 112.2f,
|
||||
113.0f, 113.1f, 113.2f});
|
||||
// clang-format on
|
||||
test_case.add_input<int32_t>({0, 2});
|
||||
// clang-format off
|
||||
test_case.add_expected_output<float>(
|
||||
out_shape, {1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f, 1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f,
|
||||
1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f, 1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f});
|
||||
out_shape, { 1.0f, 1.1f, 1.2f,
|
||||
3.0f, 3.1f, 3.2f,
|
||||
|
||||
11.0f, 11.1f, 11.2f,
|
||||
13.0f, 13.1f, 13.2f,
|
||||
|
||||
|
||||
101.0f, 101.1f, 101.2f,
|
||||
103.0f, 103.1f, 103.2f,
|
||||
|
||||
111.0f, 111.1f, 111.2f,
|
||||
113.0f, 113.1f, 113.2f});
|
||||
// clang-format on
|
||||
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
|
||||
}
|
||||
|
||||
@ -404,4 +576,4 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_axis_0_bool)
|
||||
test_case.add_input<int64_t>({0, 1, 1, 2});
|
||||
test_case.add_expected_output<char>(out_shape, {1, 1, 1, 0, 1, 0, 0, 1});
|
||||
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user