[Transformations] ConvertBroadcast3 for boolean fix (#10001)
This commit is contained in:
parent
e1e467f23f
commit
6866ced978
@ -96,7 +96,7 @@ ngraph::pass::ConvertBroadcast3::ConvertBroadcast3() {
|
||||
auto constant_one = opset1::Constant::create(input_element_type, {1}, {1});
|
||||
auto broadcast_ones = std::make_shared<opset1::Broadcast>(constant_one, target_shape_input);
|
||||
if (input_element_type == element::boolean) {
|
||||
input = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
|
||||
input = std::make_shared<ngraph::opset1::LogicalAnd>(input, broadcast_ones);
|
||||
} else {
|
||||
input = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest: public CommonTestUtils::TestsCommon,
|
||||
class ConvertBroadcast3BIDIRECTBroadcastLogicalAndTest: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
|
||||
public:
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
@ -198,7 +198,7 @@ public:
|
||||
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
|
||||
auto constant_one = opset1::Constant::create(ngraph::element::boolean, {1}, {1});
|
||||
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
|
||||
auto mul = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast);
|
||||
auto mul = std::make_shared<ngraph::opset1::LogicalAnd>(input, broadcast);
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
|
||||
}
|
||||
};
|
||||
@ -219,7 +219,7 @@ TEST_P(ConvertBroadcast3BIDIRECTBroadcastMultiplyTest, CompareFunctions) {
|
||||
convert_broadcast3_test(f, f_ref);
|
||||
}
|
||||
|
||||
TEST_P(ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest, CompareFunctions) {
|
||||
TEST_P(ConvertBroadcast3BIDIRECTBroadcastLogicalAndTest, CompareFunctions) {
|
||||
convert_broadcast3_test(f, f_ref);
|
||||
}
|
||||
|
||||
@ -296,7 +296,7 @@ INSTANTIATE_TEST_SUITE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBro
|
||||
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{3}),
|
||||
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{2})));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest,
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastLogicalAndTest,
|
||||
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{4}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{3}),
|
||||
|
Loading…
Reference in New Issue
Block a user