[GNA] Fixed output convert transformation (#9808)
* Fixed output convert transformation * Update src/plugins/intel_gna/transformations/remove_converts.cpp * Reverted wa
This commit is contained in:
parent
73e9eb4c61
commit
bcdf7b0cad
@ -13,6 +13,7 @@
|
|||||||
#include <ngraph/opsets/opset8.hpp>
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
#include <ngraph/pass/manager.hpp>
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
|
||||||
namespace GNAPluginNS {
|
namespace GNAPluginNS {
|
||||||
NGRAPH_RTTI_DEFINITION(RemoveInputConvert, "RemoveInputConvert", 0);
|
NGRAPH_RTTI_DEFINITION(RemoveInputConvert, "RemoveInputConvert", 0);
|
||||||
@ -72,7 +73,7 @@ namespace GNAPluginNS {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the result presicion will be changed automaically
|
// the result precision will be changed automatically
|
||||||
ngraph::replace_output_update_name(convert_node->output(0), convert_node->input_value(0));
|
ngraph::replace_output_update_name(convert_node->output(0), convert_node->input_value(0));
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -254,6 +254,45 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class RemoveOutputConvertConnectedToLayerTest: public RemoveOutputConvertTest {
|
||||||
|
public:
|
||||||
|
void SetUp() override {
|
||||||
|
std::tie(net_precision_, target_precision_) = this->GetParam();
|
||||||
|
const ngraph::Shape input_shape{1, 10};
|
||||||
|
// test function
|
||||||
|
{
|
||||||
|
auto input = ngraph::builder::makeParams(net_precision_, {input_shape, input_shape, input_shape, input_shape});
|
||||||
|
auto mul1 = ngraph::builder::makeEltwise(input[0], input[1], ngraph::helpers::EltwiseTypes::ADD);
|
||||||
|
auto mul2 = ngraph::builder::makeEltwise(input[2], input[3], ngraph::helpers::EltwiseTypes::ADD);
|
||||||
|
auto mul3 = ngraph::builder::makeEltwise(mul1, mul2, ngraph::helpers::EltwiseTypes::ADD);
|
||||||
|
auto convert1 = ngraph::builder::makeConversion(mul1, target_precision_, ngraph::helpers::ConversionTypes::CONVERT);
|
||||||
|
auto convert2 = ngraph::builder::makeConversion(mul2, target_precision_, ngraph::helpers::ConversionTypes::CONVERT);
|
||||||
|
auto convert3 = ngraph::builder::makeConversion(mul3, target_precision_, ngraph::helpers::ConversionTypes::CONVERT);
|
||||||
|
auto result1 = std::make_shared<ngraph::opset8::Result>(convert1);
|
||||||
|
auto result2 = std::make_shared<ngraph::opset8::Result>(convert2);
|
||||||
|
auto result3 = std::make_shared<ngraph::opset8::Result>(convert3);
|
||||||
|
|
||||||
|
func_ = std::make_shared<ngraph::Function>(ngraph::ResultVector{result1, result2, result3}, input, "multiple_output");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref function
|
||||||
|
{
|
||||||
|
auto input = ngraph::builder::makeParams(net_precision_, {input_shape, input_shape, input_shape, input_shape});
|
||||||
|
auto mul1 = ngraph::builder::makeEltwise(input[0], input[1], ngraph::helpers::EltwiseTypes::ADD);
|
||||||
|
auto mul2 = ngraph::builder::makeEltwise(input[2], input[3], ngraph::helpers::EltwiseTypes::ADD);
|
||||||
|
auto mul3 = ngraph::builder::makeEltwise(mul1, mul2, ngraph::helpers::EltwiseTypes::ADD);
|
||||||
|
auto result1 = std::make_shared<ngraph::opset8::Result>(mul1);
|
||||||
|
auto result2 = std::make_shared<ngraph::opset8::Result>(mul2);
|
||||||
|
auto result3 = std::make_shared<ngraph::opset8::Result>(mul3);
|
||||||
|
|
||||||
|
ref_func_no_convert_ = std::make_shared<ngraph::Function>(ngraph::ResultVector{result1, result2, result3}, input, "multiple_output");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref function convert should not be removed
|
||||||
|
ref_func_convert_ = ngraph::clone_function(*func_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
ov::element::TypeVector netTypes = {
|
ov::element::TypeVector netTypes = {
|
||||||
ov::element::f16,
|
ov::element::f16,
|
||||||
ov::element::f32,
|
ov::element::f32,
|
||||||
@ -294,6 +333,10 @@ TEST_P(RemoveMultiOutputsConvertTest, CompareWithRefs) {
|
|||||||
Run();
|
Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(RemoveOutputConvertConnectedToLayerTest, CompareWithRefs) {
|
||||||
|
Run();
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TransformationTests, RemoveInputConvertTest,
|
INSTANTIATE_TEST_SUITE_P(TransformationTests, RemoveInputConvertTest,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
::testing::ValuesIn(netTypes),
|
::testing::ValuesIn(netTypes),
|
||||||
@ -323,4 +366,10 @@ INSTANTIATE_TEST_SUITE_P(TransformationTests, RemoveMultiOutputsConvertTest,
|
|||||||
::testing::ValuesIn(netTypes),
|
::testing::ValuesIn(netTypes),
|
||||||
::testing::ValuesIn(targetTypes)),
|
::testing::ValuesIn(targetTypes)),
|
||||||
RemoveMultiOutputsConvertTest::getTestCaseName);
|
RemoveMultiOutputsConvertTest::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TransformationTests, RemoveOutputConvertConnectedToLayerTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(netTypes),
|
||||||
|
::testing::ValuesIn(targetTypes)),
|
||||||
|
RemoveOutputConvertConnectedToLayerTest::getTestCaseName);
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
|
Loading…
Reference in New Issue
Block a user