Review logical ops for shape inference aspects (#14393)

* Review dims and labels propagation for logical not

* Review dims and labels propagation
for logical and, or, xor

* Remove duplicated tests

* Expand logical ops tests by numpy broadcast
and inputs order

* Review template shape infer of logical ops
 - add static shape inference test
 - add default ctor test

* Default ctor test for LogicalNot op
This commit is contained in:
Pawel Raasz 2022-12-16 11:37:47 +01:00 committed by GitHub
parent 72c39c3e32
commit b850f422ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 304 additions and 133 deletions

View File

@ -63,78 +63,6 @@ TEST(type_prop, add_bad_arguments) {
}); });
} }
//
// Tests for binary elementwise logical ops.
//
void test_binary_logical(std::string /* node_type */,
shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y)) {
// Check for bad arguments
auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
auto tv0_2_4_param_3 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
auto tv0_4_2_param = make_shared<op::Parameter>(element::boolean, Shape{4, 2});
auto test_binary_bad_arguments_view_shapes = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
try {
auto node = f(x, y);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
};
test_binary_bad_arguments_view_shapes(tv0_2_4_param_0, tv0_4_2_param);
auto test_binary_differ_arguments_view_element_types = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
try {
auto node = f(x, y);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not have the same element type"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
};
auto test_binary_non_bool_arguments_view_element_types = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
try {
auto node = f(x, y);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
} catch (const ngraph_error& error) {
EXPECT_HAS_SUBSTRING(error.what(), "must have boolean element type");
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
};
test_binary_differ_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
test_binary_differ_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_0);
test_binary_non_bool_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_3);
auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
auto node = f(x, y);
EXPECT_TRUE(node->has_same_type(node->input_values()[0].get_node_shared_ptr()));
};
test_binary_good_arguments(tv0_2_4_param_0, tv0_2_4_param_1);
}
TEST(type_prop, or_bad_arguments) {
test_binary_logical("Or", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::v1::LogicalOr>(x, y);
});
}
TEST(type_prop, xor_bad_arguments) {
test_binary_logical("Xor", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Xor>(x, y);
});
}
namespace { namespace {
template <typename T> template <typename T>
void test_binary_eltwise_numpy(const element::Type& et, const op::AutoBroadcastSpec& autob) { void test_binary_eltwise_numpy(const element::Type& et, const op::AutoBroadcastSpec& autob) {
@ -188,8 +116,6 @@ shared_ptr<op::v1::Reshape> createReshapeSubgraph(PartialShape param_shape,
TEST(type_prop, eltwise_auto_bcast) { TEST(type_prop, eltwise_auto_bcast) {
test_binary_eltwise_numpy<op::v1::Add>(element::f32, op::AutoBroadcastType::NUMPY); test_binary_eltwise_numpy<op::v1::Add>(element::f32, op::AutoBroadcastType::NUMPY);
test_binary_eltwise_numpy<op::v1::Maximum>(element::f32, op::AutoBroadcastType::NUMPY); test_binary_eltwise_numpy<op::v1::Maximum>(element::f32, op::AutoBroadcastType::NUMPY);
test_binary_eltwise_numpy<op::v1::LogicalOr>(element::boolean, op::AutoBroadcastType::NUMPY);
test_binary_eltwise_numpy<op::Xor>(element::boolean, op::AutoBroadcastType::NUMPY);
} }
// --- Binary elementwise comparision ops tests - start // --- Binary elementwise comparision ops tests - start
@ -427,8 +353,6 @@ TEST(type_prop, binary_arithmetic_bad_argument_element_types) {
TEST(type_prop, binary_arithmetic_bad_argument_shape_with_none_autobroadcast_attribute) { TEST(type_prop, binary_arithmetic_bad_argument_shape_with_none_autobroadcast_attribute) {
test_binary_eltwise_bad_argument_shape<op::v1::Add>(element::f32); test_binary_eltwise_bad_argument_shape<op::v1::Add>(element::f32);
test_binary_eltwise_bad_argument_shape<op::v1::Maximum>(element::f32); test_binary_eltwise_bad_argument_shape<op::v1::Maximum>(element::f32);
test_binary_eltwise_bad_argument_shape<op::v1::LogicalOr>(element::boolean);
test_binary_eltwise_bad_argument_shape<op::Xor>(element::boolean);
} }
TEST(type_prop, binary_elementwise_arithmetic_both_dynamic) { TEST(type_prop, binary_elementwise_arithmetic_both_dynamic) {
@ -592,11 +516,6 @@ TEST(type_prop, logic_arith_compare_partial_et) {
return std::make_shared<op::v1::Add>(param0, param1); return std::make_shared<op::v1::Add>(param0, param1);
}; };
auto test_logical_not = [](element::Type et) -> std::shared_ptr<Node> {
auto param = std::make_shared<op::Parameter>(et, Shape{1, 2, 3});
return std::make_shared<op::v1::LogicalNot>(param);
};
// Arith ops: // Arith ops:
// //
// int int -> int // int int -> int
@ -617,21 +536,6 @@ TEST(type_prop, logic_arith_compare_partial_et) {
ASSERT_EQ(test_arith(element::dynamic, element::i32)->get_element_type(), element::i32); ASSERT_EQ(test_arith(element::dynamic, element::i32)->get_element_type(), element::i32);
ASSERT_ANY_THROW({ test_arith(element::dynamic, element::boolean); }); ASSERT_ANY_THROW({ test_arith(element::dynamic, element::boolean); });
ASSERT_EQ(test_arith(element::dynamic, element::dynamic)->get_element_type(), element::dynamic); ASSERT_EQ(test_arith(element::dynamic, element::dynamic)->get_element_type(), element::dynamic);
// Logical negation op:
//
// Current behavior:
// int -> int
// boo -> boo
// dyn -> dyn
//
// TODO(amprocte): I believe the behavior should actually be:
// int -> !
// boo -> boo
// dyn -> boo
ASSERT_EQ(test_logical_not(element::i32)->get_element_type(), element::i32);
ASSERT_EQ(test_logical_not(element::boolean)->get_element_type(), element::boolean);
ASSERT_EQ(test_logical_not(element::dynamic)->get_element_type(), element::dynamic);
} }
TEST(type_prop, interval_value_propagation_add_rhs) { TEST(type_prop, interval_value_propagation_add_rhs) {

View File

@ -8,42 +8,69 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace testing;
using LogicalNotTestParam = std::tuple<element::Type, PartialShape>;
namespace { namespace {
void type_check(const ngraph::element::Type& type) { using namespace ngraph::element;
auto input = make_shared<op::Parameter>(type, Shape{1, 3, 6}); constexpr size_t exp_num_of_outputs = 1;
auto logical_not = make_shared<op::v1::LogicalNot>(input);
ASSERT_EQ(logical_not->get_element_type(), type); const auto types = Values(boolean, i16, i32, i64, u16, u32, u64, f32, f64);
} const auto static_shapes = Values(PartialShape{0}, PartialShape{1}, PartialShape{2, 3, 7, 8});
const auto dynamic_shapes = Values(PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()},
PartialShape{2, {-1, 5}, {4, -1}, -1, {3, 8}},
PartialShape::dynamic());
} // namespace } // namespace
TEST(type_prop, logical_not_i32) { class LogicalNotTest : public TypePropOpTest<op::v1::LogicalNot>, public WithParamInterface<LogicalNotTestParam> {
type_check(element::i32); protected:
void SetUp() override {
std::tie(exp_type, exp_shape) = GetParam();
}
element::Type exp_type;
PartialShape exp_shape;
};
INSTANTIATE_TEST_SUITE_P(type_prop_static_shape,
LogicalNotTest,
Combine(types, static_shapes),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(type_prop_dynamic_static_rank,
LogicalNotTest,
Combine(types, dynamic_shapes),
PrintToStringParamName());
TEST_P(LogicalNotTest, propagate_dimensions) {
const auto input = std::make_shared<op::Parameter>(exp_type, exp_shape);
const auto op = make_op(input);
EXPECT_EQ(op->get_element_type(), exp_type);
EXPECT_EQ(op->get_output_size(), exp_num_of_outputs);
EXPECT_EQ(op->get_output_partial_shape(0), exp_shape);
} }
TEST(type_prop, logical_not_i64) { TEST_P(LogicalNotTest, propagate_labels) {
type_check(element::i64); if (exp_shape.rank().is_static()) {
set_shape_labels(exp_shape, 10);
}
const auto exp_labels = get_shape_labels(exp_shape);
const auto input = std::make_shared<op::Parameter>(exp_type, exp_shape);
const auto op = make_op(input);
EXPECT_EQ(get_shape_labels(op->get_output_partial_shape(0)), exp_labels);
} }
TEST(type_prop, logical_not_u32) { TEST_P(LogicalNotTest, default_ctor) {
type_check(element::u32); const auto op = std::make_shared<op::v1::LogicalNot>();
} const auto input = std::make_shared<op::Parameter>(exp_type, exp_shape);
TEST(type_prop, logical_not_u64) { op->set_argument(0, input);
type_check(element::u64); op->validate_and_infer_types();
}
TEST(type_prop, logical_not_f16) { EXPECT_EQ(op->get_element_type(), exp_type);
type_check(element::f16); EXPECT_EQ(op->get_output_size(), exp_num_of_outputs);
} EXPECT_EQ(op->get_output_partial_shape(0), exp_shape);
TEST(type_prop, logical_not_f32) {
type_check(element::f32);
}
TEST(type_prop, logical_not_shape_inference) {
auto input = make_shared<op::Parameter>(element::boolean, Shape{1, 3, 6});
auto logical_not = make_shared<op::v1::LogicalNot>(input);
ASSERT_EQ(logical_not->get_shape(), (Shape{1, 3, 6}));
} }

View File

@ -4,7 +4,8 @@
#pragma once #pragma once
#include "gtest/gtest.h" #include "common_test_utils/test_assertions.hpp"
#include "dimension_tracker.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp" #include "util/type_prop.hpp"
@ -16,7 +17,10 @@ public:
}; };
template <typename T> template <typename T>
class LogicalOperatorTypeProp : public testing::Test {}; class LogicalOperatorTypeProp : public TypePropOpTest<typename T::op_type> {
protected:
size_t exp_logical_op_output_size{1};
};
class LogicalOperatorTypeName { class LogicalOperatorTypeName {
public: public:
@ -90,24 +94,111 @@ TYPED_TEST_P(LogicalOperatorTypeProp, incorrect_shape) {
ngraph::Shape{1, 2, 3}); ngraph::Shape{1, 2, 3});
} }
TYPED_TEST_P(LogicalOperatorTypeProp, broadcast) { TYPED_TEST_P(LogicalOperatorTypeProp, inputs_have_different_types) {
using namespace ngraph;
const auto a = std::make_shared<op::Parameter>(element::boolean, PartialShape{1, 1, 6});
const auto b = std::make_shared<op::Parameter>(element::f16, PartialShape{1, 3, 1});
OV_EXPECT_THROW(const auto logical_op = this->make_op(a, b),
NodeValidationFailure,
testing::HasSubstr("Arguments do not have the same element type"));
}
TYPED_TEST_P(LogicalOperatorTypeProp, inputs_have_inconsistent_shapes) {
using namespace ngraph;
const auto a = std::make_shared<op::Parameter>(element::boolean, PartialShape{1, 1, 6});
const auto b = std::make_shared<op::Parameter>(element::boolean, PartialShape{1, 3, 3});
OV_EXPECT_THROW(const auto logical_op = this->make_op(a, b),
NodeValidationFailure,
testing::HasSubstr("Argument shapes are inconsistent"));
}
TYPED_TEST_P(LogicalOperatorTypeProp, shape_broadcast) {
using namespace ngraph;
using OP_Type = typename TypeParam::op_type;
const auto exp_dtype = TypeParam::element_type;
const auto a = std::make_shared<op::Parameter>(element::boolean, Shape{1, 1, 6});
const auto b = std::make_shared<op::Parameter>(element::boolean, Shape{1, 3, 1});
const auto logical_op = this->make_op(a, b);
EXPECT_EQ(logical_op->get_element_type(), exp_dtype);
EXPECT_EQ(logical_op->get_output_size(), this->exp_logical_op_output_size);
EXPECT_EQ(logical_op->get_shape(), Shape({1, 3, 6}));
}
TYPED_TEST_P(LogicalOperatorTypeProp, partial_shape_no_broadcast) {
using namespace ngraph;
using namespace testing;
using OP_Type = typename TypeParam::op_type; using OP_Type = typename TypeParam::op_type;
auto input1 = std::make_shared<ngraph::op::Parameter>(ngraph::element::boolean, ngraph::Shape{1, 1, 6}); auto shape_a = PartialShape{1, {2, 4}, {2, 5}, 4, -1};
auto input2 = std::make_shared<ngraph::op::Parameter>(ngraph::element::boolean, ngraph::Shape{1, 3, 1}); auto shape_b = PartialShape{1, 3, {1, 6}, 4, {-1, 5}};
set_shape_labels(shape_a, std::vector<size_t>{ov::no_label, 11, 12, ov::no_label, 14});
set_shape_labels(shape_b, std::vector<size_t>{20, 21, ov::no_label, ov::no_label, ov::no_label});
const auto exp_shape = PartialShape{1, 3, {2, 5}, 4, {-1, 5}};
auto logical_and = std::make_shared<OP_Type>(input1, input2); const auto a = std::make_shared<op::Parameter>(element::boolean, shape_a);
const auto b = std::make_shared<op::Parameter>(element::boolean, shape_b);
ASSERT_EQ(logical_and->get_element_type(), ngraph::element::boolean); EXPECT_THAT(this->make_op(a, b, "NONE")->get_output_partial_shape(0),
ASSERT_EQ(logical_and->get_shape(), (ngraph::Shape{1, 3, 6})); AllOf(Eq(exp_shape), ResultOf(get_shape_labels, ElementsAre(20, 21, 12, ov::no_label, 14))));
EXPECT_THAT(this->make_op(b, a, "NONE")->get_output_partial_shape(0),
AllOf(Eq(exp_shape), ResultOf(get_shape_labels, ElementsAre(20, 11, 12, ov::no_label, 14))));
}
TYPED_TEST_P(LogicalOperatorTypeProp, partial_shape_numpy_broadcast) {
using namespace ngraph;
using namespace testing;
using OP_Type = typename TypeParam::op_type;
auto shape_a = PartialShape{1, {2, 4}, {2, 5}, 4, -1};
auto shape_b = PartialShape{1, 3, {1, 6}, 4};
set_shape_labels(shape_a, std::vector<size_t>{ov::no_label, 11, 12, 13, 14});
set_shape_labels(shape_b, std::vector<size_t>{20, 21, ov::no_label, 23});
const auto exp_shape = PartialShape{1, {2, 4}, 3, 4, 4};
const auto a = std::make_shared<op::Parameter>(element::boolean, shape_a);
const auto b = std::make_shared<op::Parameter>(element::boolean, shape_b);
EXPECT_THAT(this->make_op(a, b, "NUMPY")->get_output_partial_shape(0),
AllOf(Eq(exp_shape), ResultOf(get_shape_labels, ElementsAre(ov::no_label, 11, 21, 13, 23))));
EXPECT_THAT(this->make_op(b, a, "NUMPY")->get_output_partial_shape(0),
AllOf(Eq(exp_shape), ResultOf(get_shape_labels, ElementsAre(ov::no_label, 11, 12, 13, 23))));
}
TYPED_TEST_P(LogicalOperatorTypeProp, default_ctor) {
using namespace ngraph;
const auto op = this->make_op();
const auto a = std::make_shared<op::Parameter>(element::boolean, PartialShape{1, {2, 4}, {2, 5}, 4, -1});
const auto b = std::make_shared<op::Parameter>(element::boolean, PartialShape{1, 3, {1, 6}, 4});
op->set_arguments(NodeVector{a, b});
op->set_autob("NUMPY");
op->validate_and_infer_types();
EXPECT_EQ(op->get_autob(), op::AutoBroadcastSpec("NUMPY"));
EXPECT_EQ(op->get_element_type(), element::boolean);
EXPECT_EQ(op->get_output_size(), this->exp_logical_op_output_size);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, {2, 4}, 3, 4, 4}));
} }
REGISTER_TYPED_TEST_SUITE_P(LogicalOperatorTypeProp, REGISTER_TYPED_TEST_SUITE_P(LogicalOperatorTypeProp,
broadcast, shape_broadcast,
partial_shape_no_broadcast,
partial_shape_numpy_broadcast,
incorrect_type_f32, incorrect_type_f32,
incorrect_type_f64, incorrect_type_f64,
incorrect_type_i32, incorrect_type_i32,
incorrect_type_i64, incorrect_type_i64,
incorrect_type_u32, incorrect_type_u32,
incorrect_type_u64, incorrect_type_u64,
incorrect_shape); incorrect_shape,
inputs_have_different_types,
inputs_have_inconsistent_shapes,
default_ctor);

View File

@ -0,0 +1,107 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gmock/gmock.h"
#include "openvino/op/parameter.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
template <class TOp>
class BELStaticShapeInferenceTest : public OpStaticShapeInferenceTest<TOp> {
protected:
void SetUp() override {
this->output_shapes = ShapeVector(1);
}
element::Type dtype{element::boolean};
};
TYPED_TEST_SUITE_P(BELStaticShapeInferenceTest);
TYPED_TEST_P(BELStaticShapeInferenceTest, broadcast_none) {
const auto a = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto b = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto op = this->make_op(a, b, op::AutoBroadcastType::NONE);
this->input_shapes = {StaticShape{3, 4, 7, 5}, StaticShape{3, 4, 7, 5}};
shape_inference(op.get(), this->input_shapes, this->output_shapes);
ASSERT_EQ(this->output_shapes.front(), StaticShape({3, 4, 7, 5}));
}
TYPED_TEST_P(BELStaticShapeInferenceTest, broadcast_none_incompatible_shapes) {
const auto a = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto b = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto op = this->make_op(a, b, op::AutoBroadcastType::NONE);
this->input_shapes = {StaticShape{3, 4, 6, 5}, StaticShape{3, 1, 6, 1}};
OV_EXPECT_THROW(shape_inference(op.get(), this->input_shapes, this->output_shapes),
NodeValidationFailure,
HasSubstr("Argument shapes are inconsistent."))
}
TYPED_TEST_P(BELStaticShapeInferenceTest, broadcast_numpy_equal_rank) {
const auto a = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto b = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto op = this->make_op(a, b);
this->input_shapes = {StaticShape{3, 1, 1, 5}, StaticShape{3, 1, 6, 1}};
shape_inference(op.get(), this->input_shapes, this->output_shapes);
ASSERT_EQ(this->output_shapes.front(), StaticShape({3, 1, 6, 5}));
}
TYPED_TEST_P(BELStaticShapeInferenceTest, broadcast_numpy_a_rank_higher) {
const auto a = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto b = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1});
const auto op = this->make_op(a, b);
this->input_shapes = {StaticShape{6, 5, 1, 8}, StaticShape{5, 6, 1}},
shape_inference(op.get(), this->input_shapes, this->output_shapes);
ASSERT_EQ(this->output_shapes.front(), StaticShape({6, 5, 6, 8}));
}
TYPED_TEST_P(BELStaticShapeInferenceTest, broadcast_numpy_b_rank_higher) {
const auto a = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1});
const auto b = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto op = this->make_op(a, b);
this->input_shapes = {StaticShape{5, 6, 1}, StaticShape{6, 5, 1, 8}},
shape_inference(op.get(), this->input_shapes, this->output_shapes);
ASSERT_EQ(this->output_shapes.front(), StaticShape({6, 5, 6, 8}));
}
TYPED_TEST_P(BELStaticShapeInferenceTest, broadcast_numpy_incompatible_shapes) {
const auto a = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto b = std::make_shared<op::v0::Parameter>(this->dtype, PartialShape{-1, -1, -1, -1});
const auto op = this->make_op(a, b);
this->input_shapes = {StaticShape{3, 4, 6, 6}, StaticShape{2, 4, 6, 6}};
OV_EXPECT_THROW(shape_inference(op.get(), this->input_shapes, this->output_shapes),
NodeValidationFailure,
HasSubstr("Argument shapes are inconsistent."))
}
REGISTER_TYPED_TEST_SUITE_P(BELStaticShapeInferenceTest,
broadcast_none,
broadcast_none_incompatible_shapes,
broadcast_numpy_equal_rank,
broadcast_numpy_a_rank_higher,
broadcast_numpy_b_rank_higher,
broadcast_numpy_incompatible_shapes);
using BinaryLogicOpTypes = Types<op::v1::LogicalAnd, op::v1::LogicalOr, op::v1::LogicalXor>;
INSTANTIATE_TYPED_TEST_SUITE_P(shape_inference, BELStaticShapeInferenceTest, BinaryLogicOpTypes);

View File

@ -0,0 +1,42 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gmock/gmock.h"
#include "openvino/op/logical_not.hpp"
#include "openvino/op/parameter.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
class LogicalNotStaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v1::LogicalNot> {
protected:
void SetUp() override {
this->output_shapes = ShapeVector(1);
}
};
TEST_F(LogicalNotStaticShapeInferenceTest, static_rank) {
const auto a = std::make_shared<op::v0::Parameter>(element::boolean, PartialShape{-1, -1, -1, -1});
const auto op = this->make_op(a);
this->input_shapes = {StaticShape{3, 4, 7, 5}};
shape_inference(op.get(), this->input_shapes, this->output_shapes);
ASSERT_EQ(this->output_shapes.front(), StaticShape({3, 4, 7, 5}));
}
TEST_F(LogicalNotStaticShapeInferenceTest, dynamic_rank) {
const auto a = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto op = this->make_op(a);
this->input_shapes = {StaticShape{3, 1, 5, 2}};
shape_inference(op.get(), this->input_shapes, this->output_shapes);
ASSERT_EQ(this->output_shapes.front(), StaticShape({3, 1, 5, 2}));
}