Optimize and simplify ConvertCompressedOnlyToLegacy transformation (#9702)
This commit is contained in:
parent
5fe228bc14
commit
96e7ee58e1
@ -11,24 +11,12 @@
|
|||||||
namespace ov {
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
class TRANSFORMATIONS_API ConvertPrecisionCompressedOnly;
|
|
||||||
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
|
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
|
||||||
class TRANSFORMATIONS_API ConvertCompressedOnlyToLegacy;
|
class TRANSFORMATIONS_API ConvertCompressedOnlyToLegacy;
|
||||||
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
|
||||||
/**
|
|
||||||
* @ingroup ie_transformation_common_api
|
|
||||||
* @brief ConvertPrecisionCompressedOnly transformation runs ConvertPrecision transformation for CompressedOnly format.
|
|
||||||
*/
|
|
||||||
|
|
||||||
class ov::pass::ConvertPrecisionCompressedOnly : public ov::pass::ModelPass {
|
|
||||||
public:
|
|
||||||
OPENVINO_RTTI("ConvertPrecisionCompressedOnly", "0");
|
|
||||||
bool run_on_model(const std::shared_ptr<Model>& f) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @ingroup ie_transformation_common_api
|
* @ingroup ie_transformation_common_api
|
||||||
* @brief Enables ConstantFolding for Convert operation in compressed function.
|
* @brief Enables ConstantFolding for Convert operation in compressed function.
|
||||||
|
@ -12,15 +12,6 @@
|
|||||||
|
|
||||||
using namespace ov;
|
using namespace ov;
|
||||||
|
|
||||||
bool ov::pass::ConvertPrecisionCompressedOnly::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
|
||||||
if (ngraph::op::util::has_decompression_converts(f)) {
|
|
||||||
const precisions_array convert_precision_list{{ov::element::f32, ov::element::f16}};
|
|
||||||
auto convert_precision = ngraph::pass::ConvertPrecision(convert_precision_list);
|
|
||||||
return convert_precision.run_on_model(f);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
ov::pass::EnableDecompressionConvertConstantFolding::EnableDecompressionConvertConstantFolding() {
|
ov::pass::EnableDecompressionConvertConstantFolding::EnableDecompressionConvertConstantFolding() {
|
||||||
MATCHER_SCOPE(EnableDecompressionConvertConstantFolding);
|
MATCHER_SCOPE(EnableDecompressionConvertConstantFolding);
|
||||||
auto convert = pattern::wrap_type<opset8::Convert>();
|
auto convert = pattern::wrap_type<opset8::Convert>();
|
||||||
@ -38,13 +29,15 @@ ov::pass::EnableDecompressionConvertConstantFolding::EnableDecompressionConvertC
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ov::pass::ConvertCompressedOnlyToLegacy::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
bool ov::pass::ConvertCompressedOnlyToLegacy::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||||
|
if (ngraph::op::util::has_decompression_converts(f)) {
|
||||||
Manager manager(get_pass_config());
|
Manager manager(get_pass_config());
|
||||||
|
|
||||||
manager.register_pass<ov::pass::ConvertPrecisionCompressedOnly>();
|
const precisions_array convert_precision_list{{ov::element::f32, ov::element::f16}};
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list);
|
||||||
manager.register_pass<ov::pass::EnableDecompressionConvertConstantFolding>();
|
manager.register_pass<ov::pass::EnableDecompressionConvertConstantFolding>();
|
||||||
manager.register_pass<ov::pass::ConstantFolding>();
|
manager.register_pass<ov::pass::ConstantFolding>();
|
||||||
|
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user