ConvertMinimum adjusted to handle scalar input correctly (#6017)
This commit is contained in:
parent
9de75d9cf1
commit
5c716d2afc
@ -30,14 +30,14 @@ ngraph::pass::ConvertMinimum::ConvertMinimum() {
|
||||
*/
|
||||
|
||||
auto neg_0 = std::make_shared<ngraph::opset1::Multiply>(minimum->input(0).get_source_output(),
|
||||
opset1::Constant::create(minimum->get_input_element_type(0), Shape{1}, {-1}));
|
||||
opset1::Constant::create(minimum->get_input_element_type(0), Shape{}, {-1}));
|
||||
|
||||
auto neg_1 = std::make_shared<ngraph::opset1::Multiply>(minimum->input(1).get_source_output(),
|
||||
opset1::Constant::create(minimum->get_input_element_type(1), Shape{1}, {-1}));
|
||||
opset1::Constant::create(minimum->get_input_element_type(1), Shape{}, {-1}));
|
||||
|
||||
auto max = std::make_shared<ngraph::opset1::Maximum>(neg_0, neg_1);
|
||||
|
||||
auto neg_2 = std::make_shared<ngraph::opset1::Multiply>(max, opset1::Constant::create(max->get_element_type(), Shape{1}, {-1}));
|
||||
auto neg_2 = std::make_shared<ngraph::opset1::Multiply>(max, opset1::Constant::create(max->get_element_type(), Shape{}, {-1}));
|
||||
|
||||
neg_2->set_friendly_name(minimum->get_friendly_name());
|
||||
ngraph::copy_runtime_info(minimum, {neg_0, neg_1, max, neg_2});
|
||||
|
@ -130,7 +130,6 @@ TEST_F(NGraphReaderTests, ReadHSigmoidNetwork) {
|
||||
<layer name="Multiply_744" type="Const" precision="FP32" id="4">
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</output>
|
||||
<blobs>
|
||||
@ -147,7 +146,6 @@ TEST_F(NGraphReaderTests, ReadHSigmoidNetwork) {
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
|
Loading…
Reference in New Issue
Block a user