[GNA] Fixed output convert transformation (#9808)

* Fixed output convert transformation

* Update src/plugins/intel_gna/transformations/remove_converts.cpp

* Reverted wa
This commit is contained in:
Mikhail Ryzhov 2022-01-31 23:08:24 +03:00 committed by GitHub
parent 73e9eb4c61
commit bcdf7b0cad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 1 deletions

View File

@ -13,6 +13,7 @@
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/rt_info.hpp>
namespace GNAPluginNS {
NGRAPH_RTTI_DEFINITION(RemoveInputConvert, "RemoveInputConvert", 0);
@ -72,7 +73,7 @@ namespace GNAPluginNS {
return false;
}
// the result presicion will be changed automaically
// the result precision will be changed automatically
ngraph::replace_output_update_name(convert_node->output(0), convert_node->input_value(0));
return true;
};

View File

@ -254,6 +254,45 @@ public:
}
};
class RemoveOutputConvertConnectedToLayerTest: public RemoveOutputConvertTest {
public:
void SetUp() override {
std::tie(net_precision_, target_precision_) = this->GetParam();
const ngraph::Shape input_shape{1, 10};
// test function
{
auto input = ngraph::builder::makeParams(net_precision_, {input_shape, input_shape, input_shape, input_shape});
auto mul1 = ngraph::builder::makeEltwise(input[0], input[1], ngraph::helpers::EltwiseTypes::ADD);
auto mul2 = ngraph::builder::makeEltwise(input[2], input[3], ngraph::helpers::EltwiseTypes::ADD);
auto mul3 = ngraph::builder::makeEltwise(mul1, mul2, ngraph::helpers::EltwiseTypes::ADD);
auto convert1 = ngraph::builder::makeConversion(mul1, target_precision_, ngraph::helpers::ConversionTypes::CONVERT);
auto convert2 = ngraph::builder::makeConversion(mul2, target_precision_, ngraph::helpers::ConversionTypes::CONVERT);
auto convert3 = ngraph::builder::makeConversion(mul3, target_precision_, ngraph::helpers::ConversionTypes::CONVERT);
auto result1 = std::make_shared<ngraph::opset8::Result>(convert1);
auto result2 = std::make_shared<ngraph::opset8::Result>(convert2);
auto result3 = std::make_shared<ngraph::opset8::Result>(convert3);
func_ = std::make_shared<ngraph::Function>(ngraph::ResultVector{result1, result2, result3}, input, "multiple_output");
}
// ref function
{
auto input = ngraph::builder::makeParams(net_precision_, {input_shape, input_shape, input_shape, input_shape});
auto mul1 = ngraph::builder::makeEltwise(input[0], input[1], ngraph::helpers::EltwiseTypes::ADD);
auto mul2 = ngraph::builder::makeEltwise(input[2], input[3], ngraph::helpers::EltwiseTypes::ADD);
auto mul3 = ngraph::builder::makeEltwise(mul1, mul2, ngraph::helpers::EltwiseTypes::ADD);
auto result1 = std::make_shared<ngraph::opset8::Result>(mul1);
auto result2 = std::make_shared<ngraph::opset8::Result>(mul2);
auto result3 = std::make_shared<ngraph::opset8::Result>(mul3);
ref_func_no_convert_ = std::make_shared<ngraph::Function>(ngraph::ResultVector{result1, result2, result3}, input, "multiple_output");
}
// ref function convert should not be removed
ref_func_convert_ = ngraph::clone_function(*func_);
}
};
ov::element::TypeVector netTypes = {
ov::element::f16,
ov::element::f32,
@ -294,6 +333,10 @@ TEST_P(RemoveMultiOutputsConvertTest, CompareWithRefs) {
Run();
}
TEST_P(RemoveOutputConvertConnectedToLayerTest, CompareWithRefs) {
Run();
}
INSTANTIATE_TEST_SUITE_P(TransformationTests, RemoveInputConvertTest,
::testing::Combine(
::testing::ValuesIn(netTypes),
@ -323,4 +366,10 @@ INSTANTIATE_TEST_SUITE_P(TransformationTests, RemoveMultiOutputsConvertTest,
::testing::ValuesIn(netTypes),
::testing::ValuesIn(targetTypes)),
RemoveMultiOutputsConvertTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(TransformationTests, RemoveOutputConvertConnectedToLayerTest,
::testing::Combine(
::testing::ValuesIn(netTypes),
::testing::ValuesIn(targetTypes)),
RemoveOutputConvertConnectedToLayerTest::getTestCaseName);
} // namespace testing