[GPU] Apply is_non_decompression_multiply() callback only for compressed models (#21719)
This commit is contained in:
parent
98e8caad79
commit
032ac898e2
|
@ -188,9 +188,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
|||
bool enableInt8;
|
||||
bool unroll_loop = config.get_property(ov::intel_gpu::enable_loop_unrolling);
|
||||
{
|
||||
ov::pass::Manager manager;
|
||||
auto pass_config = manager.get_pass_config();
|
||||
manager.set_per_pass_validation(false);
|
||||
ov::pass::Manager initial_transformations_manager;
|
||||
initial_transformations_manager.set_per_pass_validation(false);
|
||||
|
||||
// Temporary solution, global rt info cleanup is needed
|
||||
for (auto& node : func->get_ops()) {
|
||||
|
@ -199,13 +198,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
|||
}
|
||||
|
||||
enableInt8 = config.get_property(ov::intel_gpu::enable_lp_transformations) && ov::pass::low_precision::LowPrecision::isFunctionQuantized(func);
|
||||
if (enableInt8) {
|
||||
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
|
||||
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
manager.register_pass<EinsumDecomposition>();
|
||||
initial_transformations_manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
initial_transformations_manager.register_pass<EinsumDecomposition>();
|
||||
|
||||
precisions_map fp_convert_precision_map = {
|
||||
{ov::element::f64, ov::element::f32}
|
||||
|
@ -254,19 +248,19 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
|||
}
|
||||
|
||||
type_to_fuse_map empty_fuse_map = {};
|
||||
manager.register_pass<ov::pass::Validate>();
|
||||
initial_transformations_manager.register_pass<ov::pass::Validate>();
|
||||
|
||||
// fuse softmax, MVN patterns, so that they will not be marked as precision sensitive in ConvertPrecision
|
||||
manager.register_pass<ov::pass::SoftmaxFusion>();
|
||||
manager.register_pass<ov::pass::MVNFusion>();
|
||||
initial_transformations_manager.register_pass<ov::pass::SoftmaxFusion>();
|
||||
initial_transformations_manager.register_pass<ov::pass::MVNFusion>();
|
||||
// decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision
|
||||
manager.register_pass<ov::pass::MVN6Decomposition>();
|
||||
initial_transformations_manager.register_pass<ov::pass::MVN6Decomposition>();
|
||||
// Run these broadcast optimizations earlier to ensure that those are executed before NopElimination/ConstantFolding
|
||||
manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
initial_transformations_manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
|
||||
initial_transformations_manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
|
||||
manager.register_pass<ov::pass::KeepConstantsPrecisionAndAddConverts>();
|
||||
pass_config->set_callback<ov::pass::KeepConstantsPrecisionAndAddConverts>(
|
||||
initial_transformations_manager.register_pass<ov::pass::KeepConstantsPrecisionAndAddConverts>();
|
||||
initial_transformations_manager.get_pass_config()->set_callback<ov::pass::KeepConstantsPrecisionAndAddConverts>(
|
||||
[](const_node_ptr& node) -> bool {
|
||||
auto next_node = node->get_output_target_inputs(0).begin()->get_node();
|
||||
if (is_type<ov::op::v0::Convert>(next_node)) {
|
||||
|
@ -275,9 +269,22 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
|||
return !is_type<ov::op::v0::MatMul>(next_node);
|
||||
});
|
||||
|
||||
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
|
||||
initial_transformations_manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8,
|
||||
ov::element::u4,
|
||||
ov::element::i4}, true);
|
||||
|
||||
// Ignore nodes that are not related to FullyConnected and allow ConstantFolding to be applied to them
|
||||
pass_config->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_decompression_multiply);
|
||||
initial_transformations_manager.get_pass_config()->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_decompression_multiply);
|
||||
initial_transformations_manager.run_passes(func);
|
||||
|
||||
ov::pass::Manager manager;
|
||||
auto pass_config = manager.get_pass_config();
|
||||
|
||||
// Need to check if transfomrations work correctly for mixed models with both compression and quantization at the same time.
|
||||
if (enableInt8) {
|
||||
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
|
||||
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
|
||||
}
|
||||
|
||||
manager.register_pass<ov::intel_gpu::MoveConvertAfterGather>();
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user