ConvertMinimum adjusted to handle scalar input correctly (#6017)

This commit is contained in:
Evgenya Stepyreva 2021-06-03 22:22:40 +03:00 committed by GitHub
parent 9de75d9cf1
commit 5c716d2afc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 5 deletions

View File

@ -30,14 +30,14 @@ ngraph::pass::ConvertMinimum::ConvertMinimum() {
*/
auto neg_0 = std::make_shared<ngraph::opset1::Multiply>(minimum->input(0).get_source_output(),
opset1::Constant::create(minimum->get_input_element_type(0), Shape{1}, {-1}));
opset1::Constant::create(minimum->get_input_element_type(0), Shape{}, {-1}));
auto neg_1 = std::make_shared<ngraph::opset1::Multiply>(minimum->input(1).get_source_output(),
opset1::Constant::create(minimum->get_input_element_type(1), Shape{1}, {-1}));
opset1::Constant::create(minimum->get_input_element_type(1), Shape{}, {-1}));
auto max = std::make_shared<ngraph::opset1::Maximum>(neg_0, neg_1);
auto neg_2 = std::make_shared<ngraph::opset1::Multiply>(max, opset1::Constant::create(max->get_element_type(), Shape{1}, {-1}));
auto neg_2 = std::make_shared<ngraph::opset1::Multiply>(max, opset1::Constant::create(max->get_element_type(), Shape{}, {-1}));
neg_2->set_friendly_name(minimum->get_friendly_name());
ngraph::copy_runtime_info(minimum, {neg_0, neg_1, max, neg_2});

View File

@ -130,7 +130,6 @@ TEST_F(NGraphReaderTests, ReadHSigmoidNetwork) {
<layer name="Multiply_744" type="Const" precision="FP32" id="4">
<output>
<port id="0" precision="FP32">
<dim>1</dim>
</port>
</output>
<blobs>
@ -147,7 +146,6 @@ TEST_F(NGraphReaderTests, ReadHSigmoidNetwork) {
<dim>22</dim>
</port>
<port id="1">
<dim>1</dim>
</port>
</input>
<output>