BinaryElementwiseArithmetic - common shape_infer (#13421)

* Use eltwise_shape_infer form shape_inference in validate_and_infer_elementwise_args

* Align Unary ops, remove redundant validate_and_infer_elementwise_args usage

* Add test with default constructor for BinaryElementwiseArithmetic ops

* Style apply

* Fix expected error message

* Add common shape_infer tests for BinaryElementiwiseArithmetic ops

* Remove old Add test

* Update NGRAPH_CHECK to OV ASSERT

* Removal of redundant autob param to the validate function

* Tests update
This commit is contained in:
Katarzyna Mitrus 2022-10-19 10:57:16 +02:00 committed by GitHub
parent 2ffb915338
commit 3c0b5c7f9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 183 additions and 64 deletions

View File

@ -66,7 +66,7 @@ public:
private:
AutoBroadcastSpec m_autob;
void validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob);
void validate_and_infer_elementwise_arithmetic();
};
} // namespace util
} // namespace op

View File

@ -9,9 +9,7 @@
namespace ov {
namespace op {
namespace util {
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
Node* node,
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(Node* node);
}
} // namespace op
} // namespace ov

View File

@ -167,4 +167,4 @@ inline void check_divided_result<ov::Dimension>(const ov::Node* op,
"]",
" must be a multiple of divisor: ",
divisor);
}
}

View File

@ -26,11 +26,10 @@ bool ngraph::op::v1::LogicalNot::visit_attributes(AttributeVisitor& visitor) {
void op::v1::LogicalNot::validate_and_infer_types() {
OV_OP_SCOPE(v1_LogicalNot_validate_and_infer_types);
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
ov::PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, args_et, args_pshape);
const auto& element_type = get_input_element_type(0);
// No boolean element_type validation for backward compatibility
const auto& arg_pshape = get_input_partial_shape(0);
set_output_type(0, element_type, arg_pshape);
}
shared_ptr<Node> op::v1::LogicalNot::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -23,9 +23,8 @@ ov::op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const Out
: Op({arg0, arg1}),
m_autob(autob) {}
void ov::op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic(
const op::AutoBroadcastSpec& autob) {
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, autob);
void ov::op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic() {
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
@ -40,7 +39,7 @@ void ov::op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_a
void ov::op::util::BinaryElementwiseArithmetic::validate_and_infer_types() {
OV_OP_SCOPE(v0_util_BinaryElementwiseArithmetic_validate_and_infer_types);
validate_and_infer_elementwise_arithmetic(m_autob);
validate_and_infer_elementwise_arithmetic();
}
bool ov::op::util::BinaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor) {

View File

@ -23,7 +23,7 @@ ov::op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const Out
void ov::op::util::BinaryElementwiseComparison::validate_and_infer_types() {
OV_OP_SCOPE(v0_util_BinaryElementwiseComparison_validate_and_infer_types);
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, m_autob);
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, element::boolean, args_pshape);

View File

@ -23,7 +23,7 @@ ov::op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const Output<No
void ov::op::util::BinaryElementwiseLogical::validate_and_infer_types() {
OV_OP_SCOPE(v0_util_BinaryElementwiseLogical_validate_and_infer_types);
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, m_autob);
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);

View File

@ -5,34 +5,31 @@
#include "ngraph/op/util/elementwise_args.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "utils.hpp"
std::tuple<ov::element::Type, ov::PartialShape> ov::op::util::validate_and_infer_elementwise_args(
Node* node,
const op::AutoBroadcastSpec& autob) {
NGRAPH_CHECK(node != nullptr, "nGraph node is empty! Cannot validate eltwise arguments.");
element::Type element_type = node->get_input_element_type(0);
PartialShape pshape = node->get_input_partial_shape(0);
std::tuple<ov::element::Type, ov::PartialShape> ov::op::util::validate_and_infer_elementwise_args(Node* node) {
OPENVINO_ASSERT(node != nullptr, "Node is empty! Cannot validate eltwise arguments.");
constexpr size_t valid_inputs_count = 2;
NODE_VALIDATION_CHECK(node,
node->get_input_size() == valid_inputs_count,
"Incorrect number of inputs. Required: ",
valid_inputs_count);
if (node->get_input_size() > 1) {
for (size_t i = 1; i < node->get_input_size(); ++i) {
NODE_VALIDATION_CHECK(node,
element::Type::merge(element_type, element_type, node->get_input_element_type(i)),
"Argument element types are inconsistent.");
element::Type result_et;
NODE_VALIDATION_CHECK(
node,
element::Type::merge(result_et, node->get_input_element_type(0), node->get_input_element_type(1)),
"Arguments do not have the same element type (arg0 element type: ",
node->get_input_element_type(0),
", arg1 element type: ",
node->get_input_element_type(1),
").");
if (autob.m_type == op::AutoBroadcastType::NONE) {
NODE_VALIDATION_CHECK(node,
PartialShape::merge_into(pshape, node->get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
} else if (autob.m_type == op::AutoBroadcastType::NUMPY || autob.m_type == op::AutoBroadcastType::PDPD) {
NODE_VALIDATION_CHECK(
node,
PartialShape::broadcast_merge_into(pshape, node->get_input_partial_shape(i), autob),
"Argument shapes are inconsistent.");
} else {
NODE_VALIDATION_CHECK(node, false, "Unsupported auto broadcast specification");
}
}
}
const auto& A_shape = node->get_input_partial_shape(0);
const auto& B_shape = node->get_input_partial_shape(1);
std::vector<ov::PartialShape> input_shapes = {A_shape, B_shape};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
eltwise_shape_infer(node, input_shapes, output_shapes);
return std::make_tuple(element_type, pshape);
return std::make_tuple(result_et, output_shapes[0]);
}

View File

@ -14,17 +14,15 @@ ov::op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic() : Op() {}
ov::op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const Output<Node>& arg) : Op({arg}) {}
void ov::op::util::UnaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic() {
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
const auto& element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
args_et.is_dynamic() || args_et != element::boolean,
element_type.is_dynamic() || element_type != element::boolean,
"Arguments cannot have boolean element type (argument element type: ",
args_et,
element_type,
").");
set_output_type(0, args_et, args_pshape);
const auto& arg_pshape = get_input_partial_shape(0);
set_output_type(0, element_type, arg_pshape);
}
void ov::op::util::UnaryElementwiseArithmetic::validate_and_infer_types() {

View File

@ -29,6 +29,30 @@ class ArithmeticOperator : public testing::Test {};
TYPED_TEST_SUITE_P(ArithmeticOperator);
TYPED_TEST_P(ArithmeticOperator, default_constructor) {
auto A = std::make_shared<op::Parameter>(element::f32, PartialShape{-1, 4, 1, 6, Dimension(1, 6), Dimension(2, 6)});
auto B = std::make_shared<op::Parameter>(element::f32, PartialShape{-1, 1, 5, 6, Dimension(5, 8), Dimension(5, 8)});
const auto op = std::make_shared<TypeParam>();
op->set_argument(0, A);
op->set_argument(1, B);
auto autob = op::AutoBroadcastSpec(op::AutoBroadcastType::NONE);
op->set_autob(autob);
EXPECT_EQ(op->get_autob(), op::AutoBroadcastType::NONE);
ASSERT_THROW(op->validate_and_infer_types(), NodeValidationFailure);
autob = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY);
op->set_autob(autob);
EXPECT_EQ(op->get_autob(), op::AutoBroadcastType::NUMPY);
op->validate_and_infer_types();
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{-1, 4, 5, 6, Dimension(5, 8), Dimension(5, 6)}));
}
TYPED_TEST_P(ArithmeticOperator, shape_inference_2D) {
auto A = std::make_shared<op::Parameter>(element::f32, Shape{2, 2});
auto B = std::make_shared<op::Parameter>(element::f32, Shape{2, 2});
@ -786,6 +810,8 @@ TYPED_TEST_P(ArithmeticOperator, labels_equal_dynamic_shape_broadcast_none) {
}
REGISTER_TYPED_TEST_SUITE_P(ArithmeticOperator,
default_constructor,
// Static shapes
shape_inference_2D,
shape_inference_4D,

View File

@ -39,7 +39,7 @@ void test_binary(std::string /* node_type */,
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument element types are inconsistent"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not have the same element type"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
@ -91,7 +91,7 @@ void test_binary_logical(std::string /* node_type */,
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument element types are inconsistent"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not have the same element type"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}

View File

@ -0,0 +1,115 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "utils.hpp"
#include "openvino/op/ops.hpp"
#include "openvino/op/parameter.hpp"
#include "utils/shape_inference/shape_inference.hpp"
#include "utils/shape_inference/static_shape.hpp"
using namespace ov;
using namespace ov::intel_cpu;
template <class T>
class StaticShapeInferenceTest_BEA : public testing::Test {};
// StaticShapeInferenceTest for BinaryElementwiseArithmetis (BEA) operations
TYPED_TEST_SUITE_P(StaticShapeInferenceTest_BEA);
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_autob_numpy_equal_rank) {
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto node = std::make_shared<TypeParam>(A, B);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1, 1, 5}, StaticShape{3, 1, 6, 1}},
static_output_shapes = {StaticShape{}};
shape_inference(node.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 1, 6, 5}));
}
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_autob_numpy_a_rank_higher) {
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto node = std::make_shared<TypeParam>(A, B);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 1, 5}, StaticShape{4, 6, 1}},
static_output_shapes = {StaticShape{}};
shape_inference(node.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 4, 6, 5}));
}
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_autob_numpy_b_rank_higher) {
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto node = std::make_shared<TypeParam>(A, B);
std::vector<StaticShape> static_input_shapes = {StaticShape{4, 6, 1}, StaticShape{3, 4, 1, 5}},
static_output_shapes = {StaticShape{}};
shape_inference(node.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 4, 6, 5}));
}
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_autob_numpy_incompatible_shapes) {
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto node = std::make_shared<TypeParam>(A, B);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 6, 5}, StaticShape{2, 4, 6, 5}},
static_output_shapes = {StaticShape{}};
ASSERT_THROW(shape_inference(node.get(), static_input_shapes, static_output_shapes), NodeValidationFailure);
}
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_aubtob_none) {
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto node = std::make_shared<TypeParam>(A, B, op::AutoBroadcastType::NONE);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 6, 5}, StaticShape{3, 4, 6, 5}},
static_output_shapes = {StaticShape{}};
shape_inference(node.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 4, 6, 5}));
}
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_aubtob_none_incompatible_shapes) {
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto node = std::make_shared<TypeParam>(A, B, op::AutoBroadcastType::NONE);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 6, 5}, StaticShape{3, 1, 6, 1}},
static_output_shapes = {StaticShape{}};
ASSERT_THROW(shape_inference(node.get(), static_input_shapes, static_output_shapes), NodeValidationFailure);
}
REGISTER_TYPED_TEST_SUITE_P(StaticShapeInferenceTest_BEA,
shape_inference_autob_numpy_equal_rank,
shape_inference_autob_numpy_a_rank_higher,
shape_inference_autob_numpy_b_rank_higher,
shape_inference_autob_numpy_incompatible_shapes,
shape_inference_aubtob_none,
shape_inference_aubtob_none_incompatible_shapes);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_add, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Add>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_divide, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Divide>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_floor_mod, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::FloorMod>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_maximum, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Maximum>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_minimum, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Minimum>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_mod, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Mod>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_multiply, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Multiply>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_power, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Power>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_squared_difference, StaticShapeInferenceTest_BEA, ::testing::Types<op::v0::SquaredDifference>);
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_subtract, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Subtract>);

View File

@ -26,19 +26,6 @@ TEST(StaticShapeInferenceTest, UnaryEltwiseTest) {
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 6, 5, 5}));
}
TEST(StaticShapeInferenceTest, BinaryEltwiseTest) {
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto data_1 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto node = std::make_shared<op::v1::Add>(data, data_1);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 6, 1, 5}, StaticShape{1, 3, 5}},
static_output_shapes = {StaticShape{}};
shape_inference(node.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 6, 3, 5}));
}
TEST(StaticShapeInferenceTest, FakeQuantizeTest) {
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto il = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});