Revise reference implementation for ReduceProd operation (#5774)

This commit is contained in:
Gabriele Galiero Casay 2021-06-14 11:29:33 +02:00 committed by GitHub
parent 458435ad75
commit dc415573d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 295 additions and 282 deletions

View File

@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#pragma once
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
@ -18,10 +20,5 @@ namespace ngraph
/// \return Normalized (positive only) axes as an AxisSet object. /// \return Normalized (positive only) axes as an AxisSet object.
AxisSet get_normalized_axes_from_tensor(const HostTensorPtr tensor, AxisSet get_normalized_axes_from_tensor(const HostTensorPtr tensor,
const ngraph::Rank& rank, const ngraph::Rank& rank,
const std::string& node_description) const std::string& node_description);
{
const auto axes_vector = host_tensor_2_vector<int64_t>(tensor);
const auto normalized_axes = ngraph::normalize_axes(node_description, axes_vector, rank);
return AxisSet{normalized_axes};
}
} // namespace ngraph } // namespace ngraph

View File

@ -5,6 +5,7 @@
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <numeric>
#include "ngraph/coordinate_transform.hpp" #include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp" #include "ngraph/shape_util.hpp"
@ -16,29 +17,27 @@ namespace ngraph
namespace reference namespace reference
{ {
template <typename T> template <typename T>
void product(const T* arg, void product(const T* arg, T* out, const Shape& in_shape, const AxisSet& reduction_axes)
T* out,
const Shape& in_shape,
const AxisSet& reduction_axes,
bool keep_dims)
{ {
auto out_shape = reduce(in_shape, reduction_axes, keep_dims); constexpr bool dont_keep_dims_in_output = false;
CoordinateTransform output_transform(out_shape); const auto out_shape = reduce(in_shape, reduction_axes, dont_keep_dims_in_output);
std::fill(out, out + shape_size(out_shape), 1);
for (const Coordinate& output_coord : output_transform) const auto in_strides = row_major_strides(in_shape);
{ const auto out_strides = row_major_strides(out_shape);
out[output_transform.index(output_coord)] = 1;
}
CoordinateTransform input_transform(in_shape);
CoordinateTransformBasic input_transform(in_shape);
for (const Coordinate& input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = reduce(input_coord, reduction_axes, keep_dims); const Coordinate output_coord =
reduce(input_coord, reduction_axes, dont_keep_dims_in_output);
size_t output_index = output_transform.index(output_coord); const size_t in_idx = std::inner_product(
input_coord.begin(), input_coord.end(), in_strides.begin(), 0);
const size_t out_idx = std::inner_product(
output_coord.begin(), output_coord.end(), out_strides.begin(), 0);
out[output_index] = out[output_index] * arg[input_transform.index(input_coord)]; out[out_idx] = out[out_idx] * arg[in_idx];
} }
} }
} // namespace reference } // namespace reference

View File

@ -45,7 +45,7 @@ namespace ngraph
void sum(const T* arg, T* out, const Shape& in_shape, const AxisSet& reduction_axes) void sum(const T* arg, T* out, const Shape& in_shape, const AxisSet& reduction_axes)
{ {
constexpr bool dont_keep_dims_in_output = false; constexpr bool dont_keep_dims_in_output = false;
auto out_shape = reduce(in_shape, reduction_axes, dont_keep_dims_in_output); const auto out_shape = reduce(in_shape, reduction_axes, dont_keep_dims_in_output);
std::vector<T> cs(shape_size(out_shape), 0); std::vector<T> cs(shape_size(out_shape), 0);
std::fill(out, out + shape_size(out_shape), 0); std::fill(out, out + shape_size(out_shape), 0);
@ -56,12 +56,12 @@ namespace ngraph
CoordinateTransformBasic input_transform(in_shape); CoordinateTransformBasic input_transform(in_shape);
for (const Coordinate& input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = const Coordinate output_coord =
reduce(input_coord, reduction_axes, dont_keep_dims_in_output); reduce(input_coord, reduction_axes, dont_keep_dims_in_output);
size_t in_idx = std::inner_product( const size_t in_idx = std::inner_product(
input_coord.begin(), input_coord.end(), in_strides.begin(), 0); input_coord.begin(), input_coord.end(), in_strides.begin(), 0);
size_t out_idx = std::inner_product( const size_t out_idx = std::inner_product(
output_coord.begin(), output_coord.end(), out_strides.begin(), 0); output_coord.begin(), output_coord.end(), out_strides.begin(), 0);
T x = arg[in_idx]; T x = arg[in_idx];

View File

@ -6,6 +6,7 @@
#include <ngraph/validation_util.hpp> #include <ngraph/validation_util.hpp>
#include "itt.hpp" #include "itt.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/util/evaluate_helpers.hpp"
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/product.hpp"
#include "ngraph/shape_util.hpp" #include "ngraph/shape_util.hpp"
@ -45,7 +46,7 @@ namespace reduce_prod
{ {
out->set_shape(reduce(arg->get_shape(), axes, keep_dims)); out->set_shape(reduce(arg->get_shape(), axes, keep_dims));
runtime::reference::product( runtime::reference::product(
arg->get_data_ptr<ET>(), out->get_data_ptr<ET>(), arg->get_shape(), axes, keep_dims); arg->get_data_ptr<ET>(), out->get_data_ptr<ET>(), arg->get_shape(), axes);
return true; return true;
} }
@ -75,8 +76,11 @@ bool op::v1::ReduceProd::evaluate(const HostTensorVector& outputs,
NGRAPH_OP_SCOPE(v1_ReduceProd_evaluate); NGRAPH_OP_SCOPE(v1_ReduceProd_evaluate);
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 2)); NGRAPH_CHECK(validate_host_tensor_vector(inputs, 2));
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1)); NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
return reduce_prod::evaluate_product(
inputs[0], outputs[0], get_reduction_axes(), get_keep_dims()); const auto reduction_axes = get_normalized_axes_from_tensor(
inputs[1], inputs[0]->get_partial_shape().rank(), get_friendly_name());
return reduce_prod::evaluate_product(inputs[0], outputs[0], reduction_axes, get_keep_dims());
} }
bool op::v1::ReduceProd::has_evaluate() const bool op::v1::ReduceProd::has_evaluate() const

View File

@ -7,11 +7,11 @@
#include "itt.hpp" #include "itt.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/util/evaluate_helpers.hpp"
#include "ngraph/op/util/op_types.hpp" #include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/sum.hpp" #include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/shape_util.hpp" #include "ngraph/shape_util.hpp"
#include "util/evaluate_helpers.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
@ -80,7 +80,7 @@ bool op::v1::ReduceSum::evaluate(const HostTensorVector& outputs,
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1)); NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
const auto reduction_axes = get_normalized_axes_from_tensor( const auto reduction_axes = get_normalized_axes_from_tensor(
inputs[1], get_input_partial_shape(0).rank(), get_friendly_name()); inputs[1], inputs[0]->get_partial_shape().rank(), get_friendly_name());
return reduce_sum::evaluate_sum(inputs[0], outputs[0], reduction_axes, get_keep_dims()); return reduce_sum::evaluate_sum(inputs[0], outputs[0], reduction_axes, get_keep_dims());
} }

View File

@ -0,0 +1,17 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/util/evaluate_helpers.hpp"
namespace ngraph
{
AxisSet get_normalized_axes_from_tensor(const HostTensorPtr tensor,
const ngraph::Rank& rank,
const std::string& node_description)
{
const auto axes_vector = host_tensor_2_vector<int64_t>(tensor);
const auto normalized_axes = ngraph::normalize_axes(node_description, axes_vector, rank);
return AxisSet{normalized_axes};
}
} // namespace ngraph

View File

@ -73,6 +73,7 @@ set(SRC
op_eval/non_zero.cpp op_eval/non_zero.cpp
op_eval/reduce_l1.cpp op_eval/reduce_l1.cpp
op_eval/reduce_l2.cpp op_eval/reduce_l2.cpp
op_eval/reduce_prod.cpp
op_eval/reduce_sum.cpp op_eval/reduce_sum.cpp
op_eval/roi_align.cpp op_eval/roi_align.cpp
op_eval/roi_pooling.cpp op_eval/roi_pooling.cpp

View File

@ -80,95 +80,6 @@ NGRAPH_TEST(${BACKEND_NAME}, reduce_product_matrix_rows)
EXPECT_TRUE(test::all_close_f((vector<float>{2, 12, 30}), read_vector<float>(result))); EXPECT_TRUE(test::all_close_f((vector<float>{2, 12, 30}), read_vector<float>(result)));
} }
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_matrix_rows_zero)
{
Shape shape_a{3, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{3};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 1);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3, 3, 3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1, 1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_matrix_cols_zero)
{
// Now the reduction (g(x:float32[2,2],y:float32[]) = reduce(x,y,f,axes={})).
Shape shape_a{0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{2};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 0);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3, 3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_vector_zero)
{
Shape shape_a{0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 0);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_matrix_to_scalar_zero_by_zero)
{
Shape shape_a{0, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{};
auto axes = make_shared<op::Constant>(element::i32, Shape{2}, vector<int32_t>{0, 1});
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_3d_to_matrix_most_sig) NGRAPH_TEST(${BACKEND_NAME}, reduce_product_3d_to_matrix_most_sig)
{ {
Shape shape_a{3, 3, 3}; Shape shape_a{3, 3, 3};
@ -283,31 +194,6 @@ NGRAPH_TEST(${BACKEND_NAME}, reduce_product_3d_to_scalar)
read_vector<float>(result))); read_vector<float>(result)));
} }
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_3d_eliminate_zero_dim)
{
Shape shape_a{3, 0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{3, 2};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 1);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
// Overwrite the initial result vector to make sure we're not just coincidentally getting the
// right value.
copy_data(result, vector<float>{2112, 2112, 2112, 2112, 2112, 2112});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1, 1, 1, 1, 1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_2d_to_scalar_int32) NGRAPH_TEST(${BACKEND_NAME}, reduce_product_2d_to_scalar_int32)
{ {
Shape shape_a{3, 3}; Shape shape_a{3, 3};
@ -433,95 +319,6 @@ NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_matrix_rows)
EXPECT_TRUE(test::all_close_f((vector<float>{2, 12, 30}), read_vector<float>(result))); EXPECT_TRUE(test::all_close_f((vector<float>{2, 12, 30}), read_vector<float>(result)));
} }
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_matrix_rows_zero)
{
Shape shape_a{3, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{3, 1};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 1);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3, 3, 3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1, 1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_matrix_cols_zero)
{
// Now the reduction (g(x:float32[2,2],y:float32[]) = reduce(x,y,f,axes={})).
Shape shape_a{0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{1, 2};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 0);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3, 3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_vector_zero)
{
Shape shape_a{0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{1};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 0);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_matrix_to_scalar_zero_by_zero)
{
Shape shape_a{0, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{1, 1};
auto axes = make_shared<op::Constant>(element::i32, Shape{2}, vector<int32_t>{0, 1});
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_3d_to_matrix_most_sig) NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_3d_to_matrix_most_sig)
{ {
Shape shape_a{3, 3, 3}; Shape shape_a{3, 3, 3};
@ -636,31 +433,6 @@ NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_3d_to_scalar)
read_vector<float>(result))); read_vector<float>(result)));
} }
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_3d_eliminate_zero_dim)
{
Shape shape_a{3, 0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{3, 1, 2};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 1);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
// Overwrite the initial result vector to make sure we're not just coincidentally getting the
// right value.
copy_data(result, vector<float>{2112, 2112, 2112, 2112, 2112, 2112});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1, 1, 1, 1, 1}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_2d_to_scalar_int32) NGRAPH_TEST(${BACKEND_NAME}, reduce_product_keep_2d_to_scalar_int32)
{ {
Shape shape_a{3, 3}; Shape shape_a{3, 3};

View File

@ -0,0 +1,244 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/test_control.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
TEST(op_eval, reduce_product_matrix_rows_zero)
{
Shape shape_a{3, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{3};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 1);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3, 3, 3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1, 1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_matrix_cols_zero)
{
// Now the reduction (g(x:float32[2,2],y:float32[]) = reduce(x,y,f,axes={})).
Shape shape_a{0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{2};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 0);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3, 3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_vector_zero)
{
Shape shape_a{0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 0);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_matrix_to_scalar_zero_by_zero)
{
Shape shape_a{0, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{};
auto axes = make_shared<op::Constant>(element::i32, Shape{2}, vector<int32_t>{0, 1});
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_3d_eliminate_zero_dim)
{
Shape shape_a{3, 0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{3, 2};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 1);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, false), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
// Overwrite the initial result vector to make sure we're not just coincidentally getting the
// right value.
copy_data(result, vector<float>{2112, 2112, 2112, 2112, 2112, 2112});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1, 1, 1, 1, 1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_keep_matrix_rows_zero)
{
Shape shape_a{3, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{3, 1};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 1);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3, 3, 3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1, 1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_keep_matrix_cols_zero)
{
// Now the reduction (g(x:float32[2,2],y:float32[]) = reduce(x,y,f,axes={})).
Shape shape_a{0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{1, 2};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 0);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3, 3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_keep_vector_zero)
{
Shape shape_a{0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{1};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 0);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_keep_matrix_to_scalar_zero_by_zero)
{
Shape shape_a{0, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{1, 1};
auto axes = make_shared<op::Constant>(element::i32, Shape{2}, vector<int32_t>{0, 1});
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
copy_data(result, vector<float>({3}));
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1}), read_vector<float>(result)));
}
TEST(op_eval, reduce_product_keep_3d_eliminate_zero_dim)
{
Shape shape_a{3, 0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{3, 1, 2};
auto axes = make_shared<op::Constant>(element::i32, Shape{}, 1);
auto f =
make_shared<Function>(make_shared<op::v1::ReduceProd>(A, axes, true), ParameterVector{A});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_rt);
// Overwrite the initial result vector to make sure we're not just coincidentally getting the
// right value.
copy_data(result, vector<float>{2112, 2112, 2112, 2112, 2112, 2112});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f((vector<float>{1, 1, 1, 1, 1, 1}), read_vector<float>(result)));
}

View File

@ -373,6 +373,8 @@ all_dynamic
# disabled reference implementation # disabled reference implementation
reduce_sum_keep_2d_to_scalar_int8 reduce_sum_keep_2d_to_scalar_int8
reduce_sum_2d_to_scalar_int8 reduce_sum_2d_to_scalar_int8
reduce_product_to_scalar_int8
reduce_product_keep_to_scalar_int8
# accuracy # accuracy
reduce_sum_keep_stable_acc reduce_sum_keep_stable_acc
reduce_sum_keep_3d_to_scalar_int32 reduce_sum_keep_3d_to_scalar_int32
@ -458,17 +460,6 @@ onnx_dyn_shapes_model_tile_static
gather_4d_indices_axis_0_uint8 gather_4d_indices_axis_0_uint8
tensor_constant_with_op tensor_constant_with_op
constant_equality_bool constant_equality_bool
reduce_product_matrix_rows
reduce_product_3d_to_matrix_most_sig
reduce_product_3d_to_matrix_least_sig
reduce_product_keep_matrix_columns
reduce_product_keep_matrix_rows
reduce_product_keep_3d_to_matrix_most_sig
reduce_product_keep_3d_to_matrix_least_sig
reduce_product_matrix_columns_dynamic
reduce_product_matrix_rows_dynamic
reduce_product_keep_matrix_columns_dynamic
reduce_product_keep_matrix_rows_dynamic
reduce_min_matrix_columns reduce_min_matrix_columns
reduce_min_matrix_rows reduce_min_matrix_rows
reduce_min_matrix_rows_int32 reduce_min_matrix_rows_int32
@ -485,18 +476,6 @@ reduce_min_keep_matrix_columns_dynamic
reduce_min_keep_matrix_rows_dynamic reduce_min_keep_matrix_rows_dynamic
# zero dimension / result mismatch # zero dimension / result mismatch
reduce_product_matrix_rows_zero
reduce_product_matrix_cols_zero
reduce_product_vector_zero
reduce_product_matrix_to_scalar_zero_by_zero
reduce_product_3d_eliminate_zero_dim
reduce_product_to_scalar_int8
reduce_product_keep_matrix_rows_zero
reduce_product_keep_matrix_cols_zero
reduce_product_keep_vector_zero
reduce_product_keep_matrix_to_scalar_zero_by_zero
reduce_product_keep_3d_eliminate_zero_dim
reduce_product_keep_to_scalar_int8
reduce_min_to_scalar_int8 reduce_min_to_scalar_int8
reduce_min_matrix_rows_zero reduce_min_matrix_rows_zero
reduce_min_matrix_cols_zero reduce_min_matrix_cols_zero