Optimisations for binary operations broadcast (#1058)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
@@ -29,6 +30,55 @@
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace
|
||||
{
|
||||
Strides default_strides(size_t n_axes) { return Strides(n_axes, 1); }
|
||||
CoordinateDiff default_padding(size_t n_axes) { return CoordinateDiff(n_axes, 0); }
|
||||
AxisVector default_axis_order(size_t n_axes)
|
||||
{
|
||||
AxisVector result(n_axes);
|
||||
std::iota(result.begin(), result.end(), 0);
|
||||
return result;
|
||||
}
|
||||
|
||||
Coordinate default_source_start_corner(size_t n_axes) { return Coordinate(n_axes, 0); }
|
||||
Coordinate default_source_end_corner(const Shape& source_shape) { return source_shape; }
|
||||
}
|
||||
|
||||
CoordinateTransformBasic::CoordinateTransformBasic(const Shape& source_shape)
|
||||
: m_source_shape(source_shape)
|
||||
{
|
||||
}
|
||||
|
||||
// Compute the index of a source-space coordinate in the buffer.
|
||||
size_t CoordinateTransformBasic::index(const Coordinate& c) const noexcept
|
||||
{
|
||||
size_t index = 0;
|
||||
size_t stride = 1;
|
||||
size_t const padding = c.size() - m_source_shape.size();
|
||||
|
||||
for (size_t axis = m_source_shape.size(); axis-- > 0;)
|
||||
{
|
||||
if (m_source_shape[axis] > 1)
|
||||
{
|
||||
index += c[axis + padding] * stride;
|
||||
stride *= m_source_shape[axis];
|
||||
}
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
CoordinateIterator CoordinateTransformBasic::begin() const noexcept
|
||||
{
|
||||
return CoordinateIterator(m_source_shape);
|
||||
}
|
||||
|
||||
const CoordinateIterator& CoordinateTransformBasic::end() const noexcept
|
||||
{
|
||||
return CoordinateIterator::end();
|
||||
}
|
||||
|
||||
CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
const Coordinate& source_start_corner,
|
||||
const Coordinate& source_end_corner,
|
||||
@@ -37,7 +87,7 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
const CoordinateDiff& target_padding_below,
|
||||
const CoordinateDiff& target_padding_above,
|
||||
const Strides& target_dilation_strides)
|
||||
: m_source_shape(source_shape)
|
||||
: CoordinateTransformBasic(source_shape)
|
||||
, m_source_start_corner(source_start_corner)
|
||||
, m_source_end_corner(source_end_corner)
|
||||
, m_source_strides(source_strides)
|
||||
@@ -45,7 +95,6 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
, m_target_padding_below(target_padding_below)
|
||||
, m_target_padding_above(target_padding_above)
|
||||
, m_target_dilation_strides(target_dilation_strides)
|
||||
, m_end_iterator(Shape(), true)
|
||||
{
|
||||
m_n_axes = source_shape.size();
|
||||
|
||||
@@ -170,11 +219,6 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
}
|
||||
}
|
||||
|
||||
Strides CoordinateTransform::default_strides(size_t n_axes)
|
||||
{
|
||||
return Strides(n_axes, 1);
|
||||
}
|
||||
|
||||
CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
const Coordinate& source_start_corner,
|
||||
const Coordinate& source_end_corner,
|
||||
@@ -193,11 +237,6 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
{
|
||||
}
|
||||
|
||||
CoordinateDiff CoordinateTransform::default_padding(size_t n_axes)
|
||||
{
|
||||
return CoordinateDiff(n_axes, 0);
|
||||
}
|
||||
|
||||
CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
const Coordinate& source_start_corner,
|
||||
const Coordinate& source_end_corner,
|
||||
@@ -214,15 +253,6 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
{
|
||||
}
|
||||
|
||||
AxisVector CoordinateTransform::default_axis_order(size_t n_axes)
|
||||
{
|
||||
AxisVector result(n_axes);
|
||||
size_t n = 0;
|
||||
std::generate(result.begin(), result.end(), [&n]() -> size_t { return n++; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
const Coordinate& source_start_corner,
|
||||
const Coordinate& source_end_corner,
|
||||
@@ -252,16 +282,6 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
|
||||
{
|
||||
}
|
||||
|
||||
Coordinate CoordinateTransform::default_source_start_corner(size_t n_axes)
|
||||
{
|
||||
return Coordinate(n_axes, 0);
|
||||
}
|
||||
|
||||
Coordinate CoordinateTransform::default_source_end_corner(const Shape& source_shape)
|
||||
{
|
||||
return source_shape;
|
||||
}
|
||||
|
||||
CoordinateTransform::CoordinateTransform(const Shape& source_shape)
|
||||
: CoordinateTransform(source_shape,
|
||||
default_source_start_corner(source_shape.size()),
|
||||
@@ -274,25 +294,10 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape)
|
||||
{
|
||||
}
|
||||
|
||||
// Compute the index of a source-space coordinate in the buffer.
|
||||
size_t CoordinateTransform::index_source(const Coordinate& c) const
|
||||
{
|
||||
size_t index = 0;
|
||||
size_t stride = 1;
|
||||
|
||||
for (size_t axis = m_n_axes; axis-- > 0;)
|
||||
{
|
||||
index += c[axis] * stride;
|
||||
stride *= m_source_shape[axis];
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
// Compute the index of a target-space coordinate in thebuffer.
|
||||
size_t CoordinateTransform::index(const Coordinate& c) const
|
||||
{
|
||||
return index_source(to_source_coordinate(c));
|
||||
return CoordinateTransformBasic::index(to_source_coordinate(c));
|
||||
}
|
||||
|
||||
// Convert a target-space coordinate to a source-space coordinate.
|
||||
@@ -373,18 +378,56 @@ bool CoordinateTransform::has_source_coordinate(const Coordinate& c_target) cons
|
||||
return true;
|
||||
}
|
||||
|
||||
const Shape& CoordinateTransform::get_target_shape() const
|
||||
const Shape& CoordinateTransform::get_source_shape() const noexcept
|
||||
{
|
||||
return m_source_shape;
|
||||
}
|
||||
|
||||
const Shape& CoordinateTransform::get_target_shape() const noexcept
|
||||
{
|
||||
return m_target_shape;
|
||||
}
|
||||
|
||||
// The "is_end" parameter is true if we want the "end()" iterator.
|
||||
CoordinateTransform::Iterator::Iterator(const Shape& target_shape, bool is_end)
|
||||
: m_target_shape(target_shape)
|
||||
const Coordinate& CoordinateTransform::get_source_start_corner() const noexcept
|
||||
{
|
||||
// Initial coordinate is (0,...,0) in the target space.
|
||||
m_coordinate = Coordinate(target_shape.size(), 0);
|
||||
return m_source_start_corner;
|
||||
}
|
||||
|
||||
const Coordinate& CoordinateTransform::get_source_end_corner() const noexcept
|
||||
{
|
||||
return m_source_end_corner;
|
||||
}
|
||||
|
||||
const Strides& CoordinateTransform::get_source_strides() const noexcept
|
||||
{
|
||||
return m_source_strides;
|
||||
}
|
||||
|
||||
const AxisVector& CoordinateTransform::get_source_axis_order() const noexcept
|
||||
{
|
||||
return m_source_axis_order;
|
||||
}
|
||||
|
||||
const Strides& CoordinateTransform::get_target_dilation_strides() const noexcept
|
||||
{
|
||||
return m_target_dilation_strides;
|
||||
}
|
||||
|
||||
CoordinateIterator CoordinateTransform::begin() const noexcept
|
||||
{
|
||||
return CoordinateIterator(m_target_shape);
|
||||
}
|
||||
|
||||
const CoordinateIterator& CoordinateTransform::end() const noexcept
|
||||
{
|
||||
return CoordinateIterator::end();
|
||||
}
|
||||
|
||||
// The "is_end" parameter is true if we want the "end()" iterator.
|
||||
CoordinateIterator::CoordinateIterator(const Shape& target_shape, bool is_end)
|
||||
: m_target_shape(target_shape)
|
||||
, m_coordinate(target_shape.size(), 0)
|
||||
{
|
||||
// The case where we have a zero-length axis is a bit special, in that
|
||||
// the iterator always starts out of bounds.
|
||||
m_empty = false;
|
||||
@@ -401,19 +444,20 @@ CoordinateTransform::Iterator::Iterator(const Shape& target_shape, bool is_end)
|
||||
m_oob = is_end || m_empty;
|
||||
}
|
||||
|
||||
void CoordinateTransform::Iterator::operator++()
|
||||
void CoordinateIterator::operator++()
|
||||
{
|
||||
// If we are out of bounds, start over at (0,...0). (TODO: not sure if that's what we want. It
|
||||
// might be best to stay put?)
|
||||
advance(m_target_shape.size() - 1);
|
||||
}
|
||||
|
||||
void CoordinateIterator::advance(size_t axis) noexcept
|
||||
{
|
||||
m_oob |= m_target_shape.empty();
|
||||
|
||||
if (m_oob)
|
||||
{
|
||||
std::fill(m_coordinate.begin(), m_coordinate.end(), 0);
|
||||
m_oob = m_empty;
|
||||
return;
|
||||
}
|
||||
|
||||
// Increment the target coordinate.
|
||||
for (size_t axis = m_target_shape.size(); axis-- > 0;)
|
||||
do
|
||||
{
|
||||
m_coordinate[axis]++;
|
||||
|
||||
@@ -426,21 +470,21 @@ void CoordinateTransform::Iterator::operator++()
|
||||
{
|
||||
m_coordinate[axis] = 0;
|
||||
}
|
||||
}
|
||||
} while (axis-- > 0);
|
||||
|
||||
// If we are still here there was carry-out from the most significant axis. We are now out of
|
||||
// bounds.
|
||||
m_oob = true;
|
||||
}
|
||||
|
||||
CoordinateTransform::Iterator CoordinateTransform::Iterator::operator++(int)
|
||||
CoordinateIterator CoordinateIterator::operator++(int)
|
||||
{
|
||||
CoordinateTransform::Iterator temp = *this;
|
||||
CoordinateIterator temp = *this;
|
||||
++(*this);
|
||||
return temp;
|
||||
}
|
||||
|
||||
void CoordinateTransform::Iterator::operator+=(size_t n)
|
||||
void CoordinateIterator::operator+=(size_t n)
|
||||
{
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
@@ -448,17 +492,17 @@ void CoordinateTransform::Iterator::operator+=(size_t n)
|
||||
}
|
||||
}
|
||||
|
||||
const Coordinate& CoordinateTransform::Iterator::operator*() const
|
||||
const Coordinate& CoordinateIterator::operator*() const noexcept
|
||||
{
|
||||
return m_coordinate;
|
||||
}
|
||||
|
||||
bool CoordinateTransform::Iterator::operator!=(const Iterator& it)
|
||||
bool CoordinateIterator::operator!=(const CoordinateIterator& it) const noexcept
|
||||
{
|
||||
return !(*this == it);
|
||||
}
|
||||
|
||||
bool CoordinateTransform::Iterator::operator==(const Iterator& it)
|
||||
bool CoordinateIterator::operator==(const CoordinateIterator& it) const noexcept
|
||||
{
|
||||
if (it.m_oob)
|
||||
{
|
||||
@@ -490,3 +534,9 @@ bool CoordinateTransform::Iterator::operator==(const Iterator& it)
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const CoordinateIterator& CoordinateIterator::end()
|
||||
{
|
||||
static const CoordinateIterator it(Shape(), true);
|
||||
return it;
|
||||
}
|
||||
|
||||
@@ -24,9 +24,87 @@
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
class NGRAPH_API CoordinateTransform
|
||||
/// \brief A useful class that allows to iterate over the tensor coordinates.
|
||||
/// For example, for tensor with dimensions {2, 3} this iterator
|
||||
/// produces the following coordinates:
|
||||
/// {0,0}, {0,1}, {0,2},
|
||||
/// {1,0}, {1,1}, {2,2}
|
||||
class NGRAPH_API CoordinateIterator
|
||||
{
|
||||
public:
|
||||
/// \brief Coordinates iterator constructor
|
||||
/// \param target_shape The target shape for coordinates iteration
|
||||
/// \param is_end The flag indicates that the coordinate iterator is the last.
|
||||
CoordinateIterator(const Shape& target_shape, bool is_end = false);
|
||||
|
||||
/// \brief The postfix operation increment the iterator by one.
|
||||
void operator++();
|
||||
|
||||
/// \brief The prefix operation increment the iterator by one.
|
||||
CoordinateIterator operator++(int);
|
||||
|
||||
/// \brief Increments iterator n times.
|
||||
/// \param n number of elements it should be advanced
|
||||
void operator+=(size_t n);
|
||||
|
||||
/// \brief Iterator dereferencing operator returns reference to current pointed coordinate.
|
||||
const Coordinate& operator*() const noexcept;
|
||||
|
||||
/// \brief Checks for iterator inequality.
|
||||
/// \param it second iterator to compare
|
||||
bool operator!=(const CoordinateIterator& it) const noexcept;
|
||||
|
||||
/// \brief Checks for iterator equality.
|
||||
/// \param it second iterator to compare
|
||||
bool operator==(const CoordinateIterator& it) const noexcept;
|
||||
|
||||
/// \brief Increments iterator using specified axis of the shape n times.
|
||||
/// \param axis index used for iteration
|
||||
void advance(size_t axis) noexcept;
|
||||
|
||||
/// \brief Useful function to build the last iterator.
|
||||
/// Returns a singleton that points to the last iterator.
|
||||
static const CoordinateIterator& end();
|
||||
|
||||
private:
|
||||
Shape m_target_shape;
|
||||
Coordinate m_coordinate;
|
||||
bool m_oob;
|
||||
bool m_empty;
|
||||
};
|
||||
|
||||
/// \brief Class which allows to calculate item index with given coordinates in tensor
|
||||
/// and helps to iterate over all coordinates.
|
||||
/// Tensor items should be placed in memory in row-major order.
|
||||
class NGRAPH_API CoordinateTransformBasic
|
||||
{
|
||||
public:
|
||||
using Iterator = CoordinateIterator;
|
||||
|
||||
CoordinateTransformBasic(const Shape& source_shape);
|
||||
|
||||
/// \brief The tensor element index calculation by given coordinate.
|
||||
/// \param c tensor element coordinate
|
||||
size_t index(const Coordinate& c) const noexcept;
|
||||
|
||||
/// \brief Returns an iterator to the first coordinate of the tensor.
|
||||
CoordinateIterator begin() const noexcept;
|
||||
|
||||
/// \brief Returns an iterator to the coordinate following the last element of the tensor.
|
||||
const CoordinateIterator& end() const noexcept;
|
||||
|
||||
protected:
|
||||
Shape m_source_shape;
|
||||
};
|
||||
|
||||
/// \brief Class which allows to calculate item index with given coordinates in tensor
|
||||
/// and helps to iterate over the subset of coordinates.
|
||||
/// Tensor items should be placed in memory in row-major order.
|
||||
class NGRAPH_API CoordinateTransform : protected CoordinateTransformBasic
|
||||
{
|
||||
public:
|
||||
using Iterator = CoordinateIterator;
|
||||
|
||||
CoordinateTransform(const Shape& source_shape,
|
||||
const Coordinate& source_start_corner,
|
||||
const Coordinate& source_end_corner,
|
||||
@@ -61,47 +139,33 @@ namespace ngraph
|
||||
|
||||
CoordinateTransform(const Shape& source_shape);
|
||||
|
||||
/// \brief The tensor element index calculation by given coordinate.
|
||||
/// \param c tensor element coordinate
|
||||
size_t index(const Coordinate& c) const;
|
||||
|
||||
/// \brief Checks that coordinate belongs to given coordinates subset.
|
||||
/// \param c tensor element coordinate
|
||||
bool has_source_coordinate(const Coordinate& c) const;
|
||||
|
||||
/// \brief Convert a target-space coordinate to a source-space coordinate.
|
||||
/// \param c tensor element coordinate
|
||||
Coordinate to_source_coordinate(const Coordinate& c) const;
|
||||
const Shape& get_target_shape() const;
|
||||
|
||||
const Shape& get_source_shape() const { return m_source_shape; }
|
||||
const Coordinate& get_source_start_corner() const { return m_source_start_corner; }
|
||||
const Coordinate& get_source_end_corner() const { return m_source_end_corner; }
|
||||
const Strides& get_source_strides() const { return m_source_strides; }
|
||||
const AxisVector& get_source_axis_order() const { return m_source_axis_order; }
|
||||
const Strides& get_target_dilation_strides() const { return m_target_dilation_strides; }
|
||||
class NGRAPH_API Iterator
|
||||
{
|
||||
public:
|
||||
Iterator(const Shape& target_shape, bool is_end = false);
|
||||
const Shape& get_source_shape() const noexcept;
|
||||
const Shape& get_target_shape() const noexcept;
|
||||
const Coordinate& get_source_start_corner() const noexcept;
|
||||
const Coordinate& get_source_end_corner() const noexcept;
|
||||
const Strides& get_source_strides() const noexcept;
|
||||
const AxisVector& get_source_axis_order() const noexcept;
|
||||
const Strides& get_target_dilation_strides() const noexcept;
|
||||
|
||||
void operator++();
|
||||
Iterator operator++(int);
|
||||
void operator+=(size_t n);
|
||||
const Coordinate& operator*() const;
|
||||
bool operator!=(const Iterator& it);
|
||||
bool operator==(const Iterator& it);
|
||||
/// \brief Returns an iterator to the first coordinate of the tensor.
|
||||
CoordinateIterator begin() const noexcept;
|
||||
|
||||
private:
|
||||
Shape m_target_shape;
|
||||
Shape m_axis_walk_order;
|
||||
Coordinate m_coordinate;
|
||||
bool m_oob;
|
||||
bool m_empty;
|
||||
};
|
||||
/// \brief Returns an iterator to the coordinate following the last element of the tensor.
|
||||
const CoordinateIterator& end() const noexcept;
|
||||
|
||||
Iterator begin() noexcept { return Iterator(m_target_shape); }
|
||||
Iterator end() noexcept { return m_end_iterator; }
|
||||
size_t index_source(const Coordinate& c) const;
|
||||
static Strides default_strides(size_t n_axes);
|
||||
static CoordinateDiff default_padding(size_t n_axes);
|
||||
static AxisVector default_axis_order(size_t n_axes);
|
||||
static Coordinate default_source_start_corner(size_t n_axes);
|
||||
static Coordinate default_source_end_corner(const Shape& source_shape);
|
||||
|
||||
Shape m_source_shape;
|
||||
private:
|
||||
Coordinate m_source_start_corner;
|
||||
Coordinate m_source_end_corner;
|
||||
Strides m_source_strides;
|
||||
@@ -112,6 +176,5 @@ namespace ngraph
|
||||
|
||||
Shape m_target_shape;
|
||||
size_t m_n_axes;
|
||||
Iterator m_end_iterator;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -28,6 +28,55 @@ namespace ngraph
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
namespace internal
|
||||
{
|
||||
template <int A0, int A1, typename T, typename U, typename Functor>
|
||||
inline void numpy_autobroadcast_binop(const T* arg0,
|
||||
const T* arg1,
|
||||
U* out,
|
||||
const Shape& arg0_shape,
|
||||
const Shape& arg1_shape,
|
||||
const Shape& output_shape,
|
||||
const size_t stride,
|
||||
const size_t axis,
|
||||
Functor elementwise_functor)
|
||||
{
|
||||
CoordinateTransformBasic arg0_transform(arg0_shape);
|
||||
CoordinateTransformBasic arg1_transform(arg1_shape);
|
||||
|
||||
for (CoordinateIterator it(output_shape), ite = CoordinateIterator::end();
|
||||
it != ite;
|
||||
it.advance(axis))
|
||||
{
|
||||
const Coordinate& output_coord = *it;
|
||||
size_t const idx0 = arg0_transform.index(output_coord);
|
||||
size_t const idx1 = arg1_transform.index(output_coord);
|
||||
for (size_t i = 0; i < stride; ++i)
|
||||
*out++ = elementwise_functor(arg0[idx0 + i * A0], arg1[idx1 + i * A1]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void calculate_fixed_idx_and_stride(size_t& arg0_p, // Fixed idx
|
||||
size_t arg1_p,
|
||||
size_t& stride,
|
||||
size_t arg0_shape_padding,
|
||||
size_t arg1_shape_padding,
|
||||
const Shape& arg0_shape,
|
||||
const Shape& arg1_shape)
|
||||
{
|
||||
while ((arg0_p < arg0_shape_padding ||
|
||||
arg0_shape[arg0_p - arg0_shape_padding] == 1) &&
|
||||
(arg0_p >= arg1_shape_padding &&
|
||||
arg1_shape[arg0_p - arg1_shape_padding] != 1) &&
|
||||
--arg0_p > arg1_p)
|
||||
;
|
||||
|
||||
stride = arg0_p < arg1_shape_padding
|
||||
? shape_size(arg1_shape)
|
||||
: row_major_stride(arg1_shape, arg0_p - arg1_shape_padding);
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Helper function to implement autobroadcasting elementwise binop references.
|
||||
///
|
||||
/// \tparam T Element type of the input tensors.
|
||||
@@ -85,62 +134,132 @@ namespace ngraph
|
||||
// ------------
|
||||
// [ 3, 2, 6]
|
||||
{
|
||||
Shape arg0_padded_shape = arg0_shape;
|
||||
Shape arg1_padded_shape = arg1_shape;
|
||||
size_t const shape_rank =
|
||||
std::max(arg0_shape.size(), arg1_shape.size()) + 1;
|
||||
size_t const arg0_shape_padding = shape_rank - arg0_shape.size();
|
||||
size_t const arg1_shape_padding = shape_rank - arg1_shape.size();
|
||||
|
||||
while (arg0_padded_shape.size() < arg1_padded_shape.size())
|
||||
Shape output_shape(shape_rank, 0);
|
||||
|
||||
size_t arg0_p = 0, arg1_p = 0;
|
||||
|
||||
for (size_t i = 0; i < shape_rank; i++)
|
||||
{
|
||||
arg0_padded_shape.insert(arg0_padded_shape.begin(), 1);
|
||||
}
|
||||
Shape::value_type arg0_dim =
|
||||
i < arg0_shape_padding ? 1 : arg0_shape[i - arg0_shape_padding];
|
||||
Shape::value_type arg1_dim =
|
||||
i < arg1_shape_padding ? 1 : arg1_shape[i - arg1_shape_padding];
|
||||
|
||||
while (arg1_padded_shape.size() < arg0_padded_shape.size())
|
||||
{
|
||||
arg1_padded_shape.insert(arg1_padded_shape.begin(), 1);
|
||||
}
|
||||
output_shape[i] = arg0_dim == 1 ? arg1_dim : arg0_dim;
|
||||
|
||||
Shape arg0_squeezed_shape;
|
||||
Shape arg1_squeezed_shape;
|
||||
AxisSet arg0_squeezed_axes;
|
||||
AxisSet arg1_squeezed_axes;
|
||||
Shape output_shape;
|
||||
|
||||
for (size_t i = 0; i < arg0_padded_shape.size(); i++)
|
||||
{
|
||||
if (arg0_padded_shape[i] == 1)
|
||||
if (arg0_dim != arg1_dim)
|
||||
{
|
||||
arg0_squeezed_axes.insert(i);
|
||||
if (arg0_dim == 1)
|
||||
arg0_p = std::max(arg0_p, i);
|
||||
|
||||
if (arg1_dim == 1)
|
||||
arg1_p = std::max(arg1_p, i);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
// Universal function without optimisations
|
||||
CoordinateTransformBasic arg0_transform(arg0_shape);
|
||||
CoordinateTransformBasic arg1_transform(arg1_shape);
|
||||
U *dst = out;
|
||||
|
||||
for(CoordinateIterator it(output_shape),
|
||||
ite = CoordinateIterator::end();
|
||||
it != ite;
|
||||
++it)
|
||||
{
|
||||
const Coordinate& output_coord = *it;
|
||||
size_t const idx0 = arg0_transform.index(output_coord);
|
||||
size_t const idx1 = arg1_transform.index(output_coord);
|
||||
*dst++ = elementwise_functor(arg0[idx0], arg1[idx1]);
|
||||
}
|
||||
#else
|
||||
using internal::numpy_autobroadcast_binop;
|
||||
using internal::calculate_fixed_idx_and_stride;
|
||||
|
||||
if (arg0_p < arg1_p)
|
||||
{
|
||||
size_t stride =
|
||||
row_major_stride(arg0_shape, arg1_p - arg0_shape_padding);
|
||||
|
||||
if (stride > 1)
|
||||
numpy_autobroadcast_binop<1, 1>(arg0,
|
||||
arg1,
|
||||
out,
|
||||
arg0_shape,
|
||||
arg1_shape,
|
||||
output_shape,
|
||||
stride,
|
||||
arg1_p,
|
||||
elementwise_functor);
|
||||
else
|
||||
{
|
||||
arg0_squeezed_shape.push_back(arg0_padded_shape[i]);
|
||||
}
|
||||
calculate_fixed_idx_and_stride(arg1_p,
|
||||
arg0_p,
|
||||
stride,
|
||||
arg1_shape_padding,
|
||||
arg0_shape_padding,
|
||||
arg1_shape,
|
||||
arg0_shape);
|
||||
|
||||
if (arg1_padded_shape[i] == 1)
|
||||
{
|
||||
arg1_squeezed_axes.insert(i);
|
||||
numpy_autobroadcast_binop<1, 0>(arg0,
|
||||
arg1,
|
||||
out,
|
||||
arg0_shape,
|
||||
arg1_shape,
|
||||
output_shape,
|
||||
stride,
|
||||
arg1_p,
|
||||
elementwise_functor);
|
||||
}
|
||||
}
|
||||
else if (arg0_p > arg1_p)
|
||||
{
|
||||
size_t stride =
|
||||
row_major_stride(arg1_shape, arg0_p - arg1_shape_padding);
|
||||
|
||||
if (stride > 1)
|
||||
numpy_autobroadcast_binop<1, 1>(arg0,
|
||||
arg1,
|
||||
out,
|
||||
arg0_shape,
|
||||
arg1_shape,
|
||||
output_shape,
|
||||
stride,
|
||||
arg0_p,
|
||||
elementwise_functor);
|
||||
else
|
||||
{
|
||||
arg1_squeezed_shape.push_back(arg1_padded_shape[i]);
|
||||
calculate_fixed_idx_and_stride(arg0_p,
|
||||
arg1_p,
|
||||
stride,
|
||||
arg0_shape_padding,
|
||||
arg1_shape_padding,
|
||||
arg0_shape,
|
||||
arg1_shape);
|
||||
|
||||
numpy_autobroadcast_binop<0, 1>(arg0,
|
||||
arg1,
|
||||
out,
|
||||
arg0_shape,
|
||||
arg1_shape,
|
||||
output_shape,
|
||||
stride,
|
||||
arg0_p,
|
||||
elementwise_functor);
|
||||
}
|
||||
|
||||
output_shape.push_back(arg0_padded_shape[i] == 1
|
||||
? arg1_padded_shape[i]
|
||||
: arg0_padded_shape[i]);
|
||||
}
|
||||
|
||||
CoordinateTransform arg0_transform(arg0_squeezed_shape);
|
||||
CoordinateTransform arg1_transform(arg1_squeezed_shape);
|
||||
CoordinateTransform output_transform(output_shape);
|
||||
|
||||
for (const Coordinate& output_coord : output_transform)
|
||||
else
|
||||
{
|
||||
Coordinate arg0_coord = reduce(output_coord, arg0_squeezed_axes);
|
||||
Coordinate arg1_coord = reduce(output_coord, arg1_squeezed_axes);
|
||||
out[output_transform.index(output_coord)] =
|
||||
elementwise_functor(arg0[arg0_transform.index(arg0_coord)],
|
||||
arg1[arg1_transform.index(arg1_coord)]);
|
||||
for (size_t i = 0, end = shape_size(output_shape); i < end; ++i)
|
||||
out[i] = elementwise_functor(arg0[i], arg1[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
break;
|
||||
case op::AutoBroadcastType::PDPD:
|
||||
|
||||
@@ -93,6 +93,17 @@ namespace ngraph
|
||||
return strides;
|
||||
}
|
||||
|
||||
template <typename SHAPE_TYPE>
|
||||
size_t row_major_stride(const SHAPE_TYPE& shape, size_t axis)
|
||||
{
|
||||
size_t s = 1;
|
||||
for (size_t i = shape.size(); i-- > axis + 1;)
|
||||
{
|
||||
s *= shape[i];
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
template <typename SHAPE_TYPE>
|
||||
inline bool is_scalar(const SHAPE_TYPE& shape)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user