From 598da6e5c01b75ef9b268372371f2e1390d6cd57 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Mon, 27 Nov 2023 10:38:05 +0100 Subject: [PATCH] [Transformations] Leftovers: FuseU4WeightsAndZeroPoint transformation (#20709) * util::visit_constant_path * ConvertU4WeightsZeroPointToScalar: avoid unnecessary insert * Review comments applied * codestyle fix --- .../include/transformations/utils/utils.hpp | 20 +++++++ ...onvert_u4_weights_zero_point_to_scalar.cpp | 14 ++--- .../mark_dequantization_subgraph.cpp | 2 +- .../src/transformations/utils/utils.cpp | 60 ++++++++++++------- .../x64/pass/snippets_mark_skipped.cpp | 2 +- 5 files changed, 65 insertions(+), 33 deletions(-) diff --git a/src/common/transformations/include/transformations/utils/utils.hpp b/src/common/transformations/include/transformations/utils/utils.hpp index 1961c35ef16..b6f4c853c5c 100644 --- a/src/common/transformations/include/transformations/utils/utils.hpp +++ b/src/common/transformations/include/transformations/utils/utils.hpp @@ -192,10 +192,30 @@ TRANSFORMATIONS_API std::shared_ptr clone_try_fold(const std::shared_ptr& visited, std::function func); +/** + * \brief Traverses a constant path starting from "node", and calls "func" for each ov::Node. + * If the function was called for non-constant subgraph, exception is thrown. + * + * \param node The node from which constant path is started. + * \param visited Set of nodes which were visited. + * \param func The function which is called for each visited node. + */ +TRANSFORMATIONS_API void visit_constant_path(ov::Node* node, + std::unordered_set& visited, + std::function func); + template std::shared_ptr make_try_fold(Args&&... args) { auto unary_output_node = std::make_shared(std::forward(args)...); diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp index 6313db127ac..79be25747ae 100644 --- a/src/common/transformations/src/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp @@ -51,21 +51,15 @@ ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar() if (ov::shape_size(weights->get_shape()) < ov::shape_size(zero_point->get_shape())) std::swap(zero_point, weights); - auto zero_point_shape = zero_point->get_shape(); - if (ov::shape_size(zero_point_shape) == 1) + const auto& zp_shape = zero_point->get_shape(); + if (ov::shape_size(zp_shape) == 1) return false; const auto& weights_shape = weights->get_shape(); - const size_t weights_rank = weights_shape.size(); - const size_t zero_point_rank = zero_point_shape.size(); // Zero point constant can be converted into scalar only if this does not affect Subtract output shape - if (weights_rank < zero_point_rank) + if (weights_shape.size() < zp_shape.size() || + !std::equal(zp_shape.rbegin(), zp_shape.rend(), weights_shape.rbegin(), std::less_equal())) { return false; - - zero_point_shape.insert(zero_point_shape.begin(), weights_rank - zero_point_rank, 1); - for (size_t i = 0; i < weights_rank; ++i) { - if (zero_point_shape[i] > weights_shape[i]) - return false; } float zp_value; diff --git a/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp b/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp index c2662dc77e2..2c95d0a36a3 100644 --- a/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp +++ b/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp @@ -72,7 +72,7 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element:: } }; std::unordered_set visited; - ov::op::util::visit_shape_path(input.get_node(), visited, keep_const_precision); + ov::op::util::visit_constant_path(input.get_node(), visited, keep_const_precision); } if (subtract_it != pattern_map.end()) { diff --git a/src/common/transformations/src/transformations/utils/utils.cpp b/src/common/transformations/src/transformations/utils/utils.cpp index 9e8a6fad92e..60a91f0227b 100644 --- a/src/common/transformations/src/transformations/utils/utils.cpp +++ b/src/common/transformations/src/transformations/utils/utils.cpp @@ -21,6 +21,34 @@ namespace ov { namespace op { namespace util { +namespace { +void visit_path_impl(ov::Node* node, + std::unordered_set& visited, + std::function func, + std::function skip_node_predicate) { + if (!node) + return; + visited.insert(node); + std::deque nodes{node}; + while (!nodes.empty()) { + auto curr_node = nodes.front(); + nodes.pop_front(); + if (skip_node_predicate(curr_node)) + continue; + + func(curr_node); + for (auto& input_value : curr_node->input_values()) { + // continue searching + const auto& input_node = input_value.get_node(); + if (visited.count(input_node)) + continue; + nodes.push_front(input_node); + visited.insert(input_node); + } + } +} +} // namespace + bool get_single_value(const std::shared_ptr& const_node, float& value, bool check_value_range) { switch (const_node->get_element_type()) { case element::Type_t::f16: @@ -242,28 +270,18 @@ bool shapes_equal_except_dynamic_expected_batch(const ov::PartialShape& expected } void visit_shape_path(Node* node, std::unordered_set& visited, std::function func) { - if (!node) - return; - visited.insert(node); - std::deque nodes{node}; - while (!nodes.empty()) { - auto curr_node = nodes.front(); - nodes.pop_front(); - // Do not check if already visited - if (ov::is_type(curr_node) || ov::is_type(curr_node)) { - continue; - } + auto is_shapeof = [](ov::Node* node) { + return ov::is_type(node) || ov::is_type(node); + }; + visit_path_impl(node, visited, func, is_shapeof); +} - func(curr_node); - for (auto& input_value : curr_node->input_values()) { - // continue searching - const auto& input_node = input_value.get_node(); - if (visited.count(input_node)) - continue; - nodes.push_front(input_node); - visited.insert(input_node); - } - } +void visit_constant_path(ov::Node* node, std::unordered_set& visited, std::function func) { + auto check_parameter = [](ov::Node* node) { + OPENVINO_ASSERT(!ov::is_type(node), "visit_constant_path is called for non-constant path."); + return false; + }; + visit_path_impl(node, visited, func, check_parameter); } bool is_dequantization_subgraph(const Output& node) { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp index eb943e0658b..d21d6b96e54 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp @@ -476,7 +476,7 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { SetSnippetsNodeType(node->shared_from_this(), snippets::pass::SnippetsNodeType::SkippedByPlugin); }; std::unordered_set visited; - ov::op::util::visit_shape_path(node->get_input_node_ptr(1), visited, markup_func); + ov::op::util::visit_constant_path(node->get_input_node_ptr(1), visited, markup_func); } if (isSuitableConvolutionParent(node)) { // Initiate fusing chain