[Transformations] ConvertBroadcast3 for boolean fix (#10001)

This commit is contained in:
Maxim Andronov 2022-02-01 12:53:05 +03:00 committed by GitHub
parent e1e467f23f
commit 6866ced978
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 5 deletions

View File

@ -96,7 +96,7 @@ ngraph::pass::ConvertBroadcast3::ConvertBroadcast3() {
auto constant_one = opset1::Constant::create(input_element_type, {1}, {1});
auto broadcast_ones = std::make_shared<opset1::Broadcast>(constant_one, target_shape_input);
if (input_element_type == element::boolean) {
input = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
input = std::make_shared<ngraph::opset1::LogicalAnd>(input, broadcast_ones);
} else {
input = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
}

View File

@ -170,7 +170,7 @@ public:
}
};
class ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest: public CommonTestUtils::TestsCommon,
class ConvertBroadcast3BIDIRECTBroadcastLogicalAndTest: public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
public:
std::shared_ptr<Function> f, f_ref;
@ -198,7 +198,7 @@ public:
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
auto constant_one = opset1::Constant::create(ngraph::element::boolean, {1}, {1});
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
auto mul = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast);
auto mul = std::make_shared<ngraph::opset1::LogicalAnd>(input, broadcast);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
}
};
@ -219,7 +219,7 @@ TEST_P(ConvertBroadcast3BIDIRECTBroadcastMultiplyTest, CompareFunctions) {
convert_broadcast3_test(f, f_ref);
}
TEST_P(ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest, CompareFunctions) {
TEST_P(ConvertBroadcast3BIDIRECTBroadcastLogicalAndTest, CompareFunctions) {
convert_broadcast3_test(f, f_ref);
}
@ -296,7 +296,7 @@ INSTANTIATE_TEST_SUITE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBro
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{3}),
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{2})));
INSTANTIATE_TEST_SUITE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest,
INSTANTIATE_TEST_SUITE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastLogicalAndTest,
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{4}),
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{3}),