Optimize and simplify ConvertCompressedOnlyToLegacy transformation (#9702)

This commit is contained in:
Maxim Vafin 2022-01-18 13:30:28 +03:00 committed by GitHub
parent 5fe228bc14
commit 96e7ee58e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 27 deletions

View File

@ -11,24 +11,12 @@
namespace ov {
namespace pass {
class TRANSFORMATIONS_API ConvertPrecisionCompressedOnly;
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
class TRANSFORMATIONS_API ConvertCompressedOnlyToLegacy;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief ConvertPrecisionCompressedOnly transformation runs ConvertPrecision transformation for CompressedOnly format.
*/
class ov::pass::ConvertPrecisionCompressedOnly : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ConvertPrecisionCompressedOnly", "0");
bool run_on_model(const std::shared_ptr<Model>& f) override;
};
/**
* @ingroup ie_transformation_common_api
* @brief Enables ConstantFolding for Convert operation in compressed function.

View File

@ -12,15 +12,6 @@
using namespace ov;
bool ov::pass::ConvertPrecisionCompressedOnly::run_on_model(const std::shared_ptr<ov::Model>& f) {
if (ngraph::op::util::has_decompression_converts(f)) {
const precisions_array convert_precision_list{{ov::element::f32, ov::element::f16}};
auto convert_precision = ngraph::pass::ConvertPrecision(convert_precision_list);
return convert_precision.run_on_model(f);
}
return false;
}
ov::pass::EnableDecompressionConvertConstantFolding::EnableDecompressionConvertConstantFolding() {
MATCHER_SCOPE(EnableDecompressionConvertConstantFolding);
auto convert = pattern::wrap_type<opset8::Convert>();
@ -38,13 +29,15 @@ ov::pass::EnableDecompressionConvertConstantFolding::EnableDecompressionConvertC
}
bool ov::pass::ConvertCompressedOnlyToLegacy::run_on_model(const std::shared_ptr<ov::Model>& f) {
Manager manager(get_pass_config());
if (ngraph::op::util::has_decompression_converts(f)) {
Manager manager(get_pass_config());
manager.register_pass<ov::pass::ConvertPrecisionCompressedOnly>();
manager.register_pass<ov::pass::EnableDecompressionConvertConstantFolding>();
manager.register_pass<ov::pass::ConstantFolding>();
manager.run_passes(f);
const precisions_array convert_precision_list{{ov::element::f32, ov::element::f16}};
manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list);
manager.register_pass<ov::pass::EnableDecompressionConvertConstantFolding>();
manager.register_pass<ov::pass::ConstantFolding>();
manager.run_passes(f);
}
return false;
}