BinaryElementwiseArithmetic - common shape_infer (#13421)
* Use eltwise_shape_infer form shape_inference in validate_and_infer_elementwise_args * Align Unary ops, remove redundant validate_and_infer_elementwise_args usage * Add test with default constructor for BinaryElementwiseArithmetic ops * Style apply * Fix expected error message * Add common shape_infer tests for BinaryElementiwiseArithmetic ops * Remove old Add test * Update NGRAPH_CHECK to OV ASSERT * Removal of redundant autob param to the validate function * Tests update
This commit is contained in:
parent
2ffb915338
commit
3c0b5c7f9b
@ -66,7 +66,7 @@ public:
|
||||
|
||||
private:
|
||||
AutoBroadcastSpec m_autob;
|
||||
void validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob);
|
||||
void validate_and_infer_elementwise_arithmetic();
|
||||
};
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
|
@ -9,9 +9,7 @@
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
|
||||
Node* node,
|
||||
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
|
||||
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(Node* node);
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
@ -167,4 +167,4 @@ inline void check_divided_result<ov::Dimension>(const ov::Node* op,
|
||||
"]",
|
||||
" must be a multiple of divisor: ",
|
||||
divisor);
|
||||
}
|
||||
}
|
||||
|
@ -26,11 +26,10 @@ bool ngraph::op::v1::LogicalNot::visit_attributes(AttributeVisitor& visitor) {
|
||||
|
||||
void op::v1::LogicalNot::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v1_LogicalNot_validate_and_infer_types);
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
ov::PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
set_output_type(0, args_et, args_pshape);
|
||||
const auto& element_type = get_input_element_type(0);
|
||||
// No boolean element_type validation for backward compatibility
|
||||
const auto& arg_pshape = get_input_partial_shape(0);
|
||||
set_output_type(0, element_type, arg_pshape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::LogicalNot::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
|
@ -23,9 +23,8 @@ ov::op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const Out
|
||||
: Op({arg0, arg1}),
|
||||
m_autob(autob) {}
|
||||
|
||||
void ov::op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic(
|
||||
const op::AutoBroadcastSpec& autob) {
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, autob);
|
||||
void ov::op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic() {
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
@ -40,7 +39,7 @@ void ov::op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_a
|
||||
|
||||
void ov::op::util::BinaryElementwiseArithmetic::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v0_util_BinaryElementwiseArithmetic_validate_and_infer_types);
|
||||
validate_and_infer_elementwise_arithmetic(m_autob);
|
||||
validate_and_infer_elementwise_arithmetic();
|
||||
}
|
||||
|
||||
bool ov::op::util::BinaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor) {
|
||||
|
@ -23,7 +23,7 @@ ov::op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const Out
|
||||
|
||||
void ov::op::util::BinaryElementwiseComparison::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v0_util_BinaryElementwiseComparison_validate_and_infer_types);
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, m_autob);
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
set_output_type(0, element::boolean, args_pshape);
|
||||
|
@ -23,7 +23,7 @@ ov::op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const Output<No
|
||||
void ov::op::util::BinaryElementwiseLogical::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v0_util_BinaryElementwiseLogical_validate_and_infer_types);
|
||||
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, m_autob);
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
|
@ -5,34 +5,31 @@
|
||||
#include "ngraph/op/util/elementwise_args.hpp"
|
||||
|
||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
std::tuple<ov::element::Type, ov::PartialShape> ov::op::util::validate_and_infer_elementwise_args(
|
||||
Node* node,
|
||||
const op::AutoBroadcastSpec& autob) {
|
||||
NGRAPH_CHECK(node != nullptr, "nGraph node is empty! Cannot validate eltwise arguments.");
|
||||
element::Type element_type = node->get_input_element_type(0);
|
||||
PartialShape pshape = node->get_input_partial_shape(0);
|
||||
std::tuple<ov::element::Type, ov::PartialShape> ov::op::util::validate_and_infer_elementwise_args(Node* node) {
|
||||
OPENVINO_ASSERT(node != nullptr, "Node is empty! Cannot validate eltwise arguments.");
|
||||
constexpr size_t valid_inputs_count = 2;
|
||||
NODE_VALIDATION_CHECK(node,
|
||||
node->get_input_size() == valid_inputs_count,
|
||||
"Incorrect number of inputs. Required: ",
|
||||
valid_inputs_count);
|
||||
|
||||
if (node->get_input_size() > 1) {
|
||||
for (size_t i = 1; i < node->get_input_size(); ++i) {
|
||||
NODE_VALIDATION_CHECK(node,
|
||||
element::Type::merge(element_type, element_type, node->get_input_element_type(i)),
|
||||
"Argument element types are inconsistent.");
|
||||
element::Type result_et;
|
||||
NODE_VALIDATION_CHECK(
|
||||
node,
|
||||
element::Type::merge(result_et, node->get_input_element_type(0), node->get_input_element_type(1)),
|
||||
"Arguments do not have the same element type (arg0 element type: ",
|
||||
node->get_input_element_type(0),
|
||||
", arg1 element type: ",
|
||||
node->get_input_element_type(1),
|
||||
").");
|
||||
|
||||
if (autob.m_type == op::AutoBroadcastType::NONE) {
|
||||
NODE_VALIDATION_CHECK(node,
|
||||
PartialShape::merge_into(pshape, node->get_input_partial_shape(i)),
|
||||
"Argument shapes are inconsistent.");
|
||||
} else if (autob.m_type == op::AutoBroadcastType::NUMPY || autob.m_type == op::AutoBroadcastType::PDPD) {
|
||||
NODE_VALIDATION_CHECK(
|
||||
node,
|
||||
PartialShape::broadcast_merge_into(pshape, node->get_input_partial_shape(i), autob),
|
||||
"Argument shapes are inconsistent.");
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(node, false, "Unsupported auto broadcast specification");
|
||||
}
|
||||
}
|
||||
}
|
||||
const auto& A_shape = node->get_input_partial_shape(0);
|
||||
const auto& B_shape = node->get_input_partial_shape(1);
|
||||
std::vector<ov::PartialShape> input_shapes = {A_shape, B_shape};
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||
eltwise_shape_infer(node, input_shapes, output_shapes);
|
||||
|
||||
return std::make_tuple(element_type, pshape);
|
||||
return std::make_tuple(result_et, output_shapes[0]);
|
||||
}
|
||||
|
@ -14,17 +14,15 @@ ov::op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic() : Op() {}
|
||||
ov::op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const Output<Node>& arg) : Op({arg}) {}
|
||||
|
||||
void ov::op::util::UnaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic() {
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
const auto& element_type = get_input_element_type(0);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
args_et.is_dynamic() || args_et != element::boolean,
|
||||
element_type.is_dynamic() || element_type != element::boolean,
|
||||
"Arguments cannot have boolean element type (argument element type: ",
|
||||
args_et,
|
||||
element_type,
|
||||
").");
|
||||
|
||||
set_output_type(0, args_et, args_pshape);
|
||||
const auto& arg_pshape = get_input_partial_shape(0);
|
||||
set_output_type(0, element_type, arg_pshape);
|
||||
}
|
||||
|
||||
void ov::op::util::UnaryElementwiseArithmetic::validate_and_infer_types() {
|
||||
|
@ -29,6 +29,30 @@ class ArithmeticOperator : public testing::Test {};
|
||||
|
||||
TYPED_TEST_SUITE_P(ArithmeticOperator);
|
||||
|
||||
TYPED_TEST_P(ArithmeticOperator, default_constructor) {
|
||||
auto A = std::make_shared<op::Parameter>(element::f32, PartialShape{-1, 4, 1, 6, Dimension(1, 6), Dimension(2, 6)});
|
||||
auto B = std::make_shared<op::Parameter>(element::f32, PartialShape{-1, 1, 5, 6, Dimension(5, 8), Dimension(5, 8)});
|
||||
|
||||
const auto op = std::make_shared<TypeParam>();
|
||||
|
||||
op->set_argument(0, A);
|
||||
op->set_argument(1, B);
|
||||
|
||||
auto autob = op::AutoBroadcastSpec(op::AutoBroadcastType::NONE);
|
||||
op->set_autob(autob);
|
||||
EXPECT_EQ(op->get_autob(), op::AutoBroadcastType::NONE);
|
||||
ASSERT_THROW(op->validate_and_infer_types(), NodeValidationFailure);
|
||||
|
||||
autob = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY);
|
||||
op->set_autob(autob);
|
||||
EXPECT_EQ(op->get_autob(), op::AutoBroadcastType::NUMPY);
|
||||
|
||||
op->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{-1, 4, 5, 6, Dimension(5, 8), Dimension(5, 6)}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(ArithmeticOperator, shape_inference_2D) {
|
||||
auto A = std::make_shared<op::Parameter>(element::f32, Shape{2, 2});
|
||||
auto B = std::make_shared<op::Parameter>(element::f32, Shape{2, 2});
|
||||
@ -786,6 +810,8 @@ TYPED_TEST_P(ArithmeticOperator, labels_equal_dynamic_shape_broadcast_none) {
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(ArithmeticOperator,
|
||||
default_constructor,
|
||||
|
||||
// Static shapes
|
||||
shape_inference_2D,
|
||||
shape_inference_4D,
|
||||
|
@ -39,7 +39,7 @@ void test_binary(std::string /* node_type */,
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incompatible view arguments not detected.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument element types are inconsistent"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not have the same element type"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
@ -91,7 +91,7 @@ void test_binary_logical(std::string /* node_type */,
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incompatible view arguments not detected.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument element types are inconsistent"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not have the same element type"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
|
@ -0,0 +1,115 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "utils.hpp"
|
||||
#include "openvino/op/ops.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "utils/shape_inference/shape_inference.hpp"
|
||||
#include "utils/shape_inference/static_shape.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
template <class T>
|
||||
class StaticShapeInferenceTest_BEA : public testing::Test {};
|
||||
|
||||
// StaticShapeInferenceTest for BinaryElementwiseArithmetis (BEA) operations
|
||||
TYPED_TEST_SUITE_P(StaticShapeInferenceTest_BEA);
|
||||
|
||||
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_autob_numpy_equal_rank) {
|
||||
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
auto node = std::make_shared<TypeParam>(A, B);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1, 1, 5}, StaticShape{3, 1, 6, 1}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(node.get(), static_input_shapes, static_output_shapes);
|
||||
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 1, 6, 5}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_autob_numpy_a_rank_higher) {
|
||||
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
|
||||
auto node = std::make_shared<TypeParam>(A, B);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 1, 5}, StaticShape{4, 6, 1}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(node.get(), static_input_shapes, static_output_shapes);
|
||||
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 4, 6, 5}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_autob_numpy_b_rank_higher) {
|
||||
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
auto node = std::make_shared<TypeParam>(A, B);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{4, 6, 1}, StaticShape{3, 4, 1, 5}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(node.get(), static_input_shapes, static_output_shapes);
|
||||
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 4, 6, 5}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_autob_numpy_incompatible_shapes) {
|
||||
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
auto node = std::make_shared<TypeParam>(A, B);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 6, 5}, StaticShape{2, 4, 6, 5}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
|
||||
ASSERT_THROW(shape_inference(node.get(), static_input_shapes, static_output_shapes), NodeValidationFailure);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_aubtob_none) {
|
||||
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
auto node = std::make_shared<TypeParam>(A, B, op::AutoBroadcastType::NONE);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 6, 5}, StaticShape{3, 4, 6, 5}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(node.get(), static_input_shapes, static_output_shapes);
|
||||
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 4, 6, 5}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(StaticShapeInferenceTest_BEA, shape_inference_aubtob_none_incompatible_shapes) {
|
||||
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
auto node = std::make_shared<TypeParam>(A, B, op::AutoBroadcastType::NONE);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 6, 5}, StaticShape{3, 1, 6, 1}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
|
||||
ASSERT_THROW(shape_inference(node.get(), static_input_shapes, static_output_shapes), NodeValidationFailure);
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(StaticShapeInferenceTest_BEA,
|
||||
shape_inference_autob_numpy_equal_rank,
|
||||
shape_inference_autob_numpy_a_rank_higher,
|
||||
shape_inference_autob_numpy_b_rank_higher,
|
||||
shape_inference_autob_numpy_incompatible_shapes,
|
||||
shape_inference_aubtob_none,
|
||||
shape_inference_aubtob_none_incompatible_shapes);
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_add, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Add>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_divide, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Divide>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_floor_mod, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::FloorMod>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_maximum, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Maximum>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_minimum, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Minimum>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_mod, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Mod>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_multiply, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Multiply>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_power, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Power>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_squared_difference, StaticShapeInferenceTest_BEA, ::testing::Types<op::v0::SquaredDifference>);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer_subtract, StaticShapeInferenceTest_BEA, ::testing::Types<op::v1::Subtract>);
|
@ -26,19 +26,6 @@ TEST(StaticShapeInferenceTest, UnaryEltwiseTest) {
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 6, 5, 5}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BinaryEltwiseTest) {
|
||||
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto data_1 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
|
||||
auto node = std::make_shared<op::v1::Add>(data, data_1);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 6, 1, 5}, StaticShape{1, 3, 5}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(node.get(), static_input_shapes, static_output_shapes);
|
||||
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 6, 3, 5}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, FakeQuantizeTest) {
|
||||
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto il = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
|
Loading…
Reference in New Issue
Block a user