[Transformations] Leftovers: FuseU4WeightsAndZeroPoint transformation (#20709)

* util::visit_constant_path

* ConvertU4WeightsZeroPointToScalar: avoid unnecessary insert

* Review comments applied

* codestyle fix
This commit is contained in:
Vladislav Golubev 2023-11-27 10:38:05 +01:00 committed by GitHub
parent c50bd2b55f
commit 598da6e5c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 33 deletions

View File

@ -192,10 +192,30 @@ TRANSFORMATIONS_API std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<N
TRANSFORMATIONS_API bool shapes_equal_except_dynamic_expected_batch(const PartialShape& expected,
const PartialShape& actual);
/**
* \brief Traverses a shapeOf subgraph starting from the node and not including the ShapeOf nodes,
* and calls "func" for each ov::Node.
*
* \param node The node from which constant path is started.
* \param visited Set of nodes which were visited.
* \param func The function which is called for each visited node.
*/
TRANSFORMATIONS_API void visit_shape_path(ov::Node* node,
std::unordered_set<ov::Node*>& visited,
std::function<void(ov::Node*)> func);
/**
* \brief Traverses a constant path starting from "node", and calls "func" for each ov::Node.
* If the function was called for non-constant subgraph, exception is thrown.
*
* \param node The node from which constant path is started.
* \param visited Set of nodes which were visited.
* \param func The function which is called for each visited node.
*/
TRANSFORMATIONS_API void visit_constant_path(ov::Node* node,
std::unordered_set<ov::Node*>& visited,
std::function<void(ov::Node*)> func);
template <typename T, typename... Args>
std::shared_ptr<Node> make_try_fold(Args&&... args) {
auto unary_output_node = std::make_shared<T>(std::forward<Args>(args)...);

View File

@ -51,21 +51,15 @@ ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar()
if (ov::shape_size(weights->get_shape()) < ov::shape_size(zero_point->get_shape()))
std::swap(zero_point, weights);
auto zero_point_shape = zero_point->get_shape();
if (ov::shape_size(zero_point_shape) == 1)
const auto& zp_shape = zero_point->get_shape();
if (ov::shape_size(zp_shape) == 1)
return false;
const auto& weights_shape = weights->get_shape();
const size_t weights_rank = weights_shape.size();
const size_t zero_point_rank = zero_point_shape.size();
// Zero point constant can be converted into scalar only if this does not affect Subtract output shape
if (weights_rank < zero_point_rank)
if (weights_shape.size() < zp_shape.size() ||
!std::equal(zp_shape.rbegin(), zp_shape.rend(), weights_shape.rbegin(), std::less_equal<size_t>())) {
return false;
zero_point_shape.insert(zero_point_shape.begin(), weights_rank - zero_point_rank, 1);
for (size_t i = 0; i < weights_rank; ++i) {
if (zero_point_shape[i] > weights_shape[i])
return false;
}
float zp_value;

View File

@ -72,7 +72,7 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
}
};
std::unordered_set<Node*> visited;
ov::op::util::visit_shape_path(input.get_node(), visited, keep_const_precision);
ov::op::util::visit_constant_path(input.get_node(), visited, keep_const_precision);
}
if (subtract_it != pattern_map.end()) {

View File

@ -21,6 +21,34 @@ namespace ov {
namespace op {
namespace util {
namespace {
void visit_path_impl(ov::Node* node,
std::unordered_set<ov::Node*>& visited,
std::function<void(ov::Node*)> func,
std::function<bool(ov::Node*)> skip_node_predicate) {
if (!node)
return;
visited.insert(node);
std::deque<ov::Node*> nodes{node};
while (!nodes.empty()) {
auto curr_node = nodes.front();
nodes.pop_front();
if (skip_node_predicate(curr_node))
continue;
func(curr_node);
for (auto& input_value : curr_node->input_values()) {
// continue searching
const auto& input_node = input_value.get_node();
if (visited.count(input_node))
continue;
nodes.push_front(input_node);
visited.insert(input_node);
}
}
}
} // namespace
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value, bool check_value_range) {
switch (const_node->get_element_type()) {
case element::Type_t::f16:
@ -242,28 +270,18 @@ bool shapes_equal_except_dynamic_expected_batch(const ov::PartialShape& expected
}
void visit_shape_path(Node* node, std::unordered_set<ov::Node*>& visited, std::function<void(ov::Node*)> func) {
if (!node)
return;
visited.insert(node);
std::deque<ov::Node*> nodes{node};
while (!nodes.empty()) {
auto curr_node = nodes.front();
nodes.pop_front();
// Do not check if already visited
if (ov::is_type<opset1::ShapeOf>(curr_node) || ov::is_type<opset3::ShapeOf>(curr_node)) {
continue;
}
auto is_shapeof = [](ov::Node* node) {
return ov::is_type<opset1::ShapeOf>(node) || ov::is_type<opset3::ShapeOf>(node);
};
visit_path_impl(node, visited, func, is_shapeof);
}
func(curr_node);
for (auto& input_value : curr_node->input_values()) {
// continue searching
const auto& input_node = input_value.get_node();
if (visited.count(input_node))
continue;
nodes.push_front(input_node);
visited.insert(input_node);
}
}
void visit_constant_path(ov::Node* node, std::unordered_set<ov::Node*>& visited, std::function<void(ov::Node*)> func) {
auto check_parameter = [](ov::Node* node) {
OPENVINO_ASSERT(!ov::is_type<opset1::Parameter>(node), "visit_constant_path is called for non-constant path.");
return false;
};
visit_path_impl(node, visited, func, check_parameter);
}
bool is_dequantization_subgraph(const Output<Node>& node) {

View File

@ -476,7 +476,7 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
SetSnippetsNodeType(node->shared_from_this(), snippets::pass::SnippetsNodeType::SkippedByPlugin);
};
std::unordered_set<Node*> visited;
ov::op::util::visit_shape_path(node->get_input_node_ptr(1), visited, markup_func);
ov::op::util::visit_constant_path(node->get_input_node_ptr(1), visited, markup_func);
}
if (isSuitableConvolutionParent(node)) {
// Initiate fusing chain