[Transformations] Leftovers: FuseU4WeightsAndZeroPoint transformation (#20709)
* util::visit_constant_path * ConvertU4WeightsZeroPointToScalar: avoid unnecessary insert * Review comments applied * codestyle fix
This commit is contained in:
parent
c50bd2b55f
commit
598da6e5c0
@ -192,10 +192,30 @@ TRANSFORMATIONS_API std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<N
|
||||
TRANSFORMATIONS_API bool shapes_equal_except_dynamic_expected_batch(const PartialShape& expected,
|
||||
const PartialShape& actual);
|
||||
|
||||
/**
|
||||
* \brief Traverses a shapeOf subgraph starting from the node and not including the ShapeOf nodes,
|
||||
* and calls "func" for each ov::Node.
|
||||
*
|
||||
* \param node The node from which constant path is started.
|
||||
* \param visited Set of nodes which were visited.
|
||||
* \param func The function which is called for each visited node.
|
||||
*/
|
||||
TRANSFORMATIONS_API void visit_shape_path(ov::Node* node,
|
||||
std::unordered_set<ov::Node*>& visited,
|
||||
std::function<void(ov::Node*)> func);
|
||||
|
||||
/**
|
||||
* \brief Traverses a constant path starting from "node", and calls "func" for each ov::Node.
|
||||
* If the function was called for non-constant subgraph, exception is thrown.
|
||||
*
|
||||
* \param node The node from which constant path is started.
|
||||
* \param visited Set of nodes which were visited.
|
||||
* \param func The function which is called for each visited node.
|
||||
*/
|
||||
TRANSFORMATIONS_API void visit_constant_path(ov::Node* node,
|
||||
std::unordered_set<ov::Node*>& visited,
|
||||
std::function<void(ov::Node*)> func);
|
||||
|
||||
template <typename T, typename... Args>
|
||||
std::shared_ptr<Node> make_try_fold(Args&&... args) {
|
||||
auto unary_output_node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
|
@ -51,21 +51,15 @@ ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar()
|
||||
if (ov::shape_size(weights->get_shape()) < ov::shape_size(zero_point->get_shape()))
|
||||
std::swap(zero_point, weights);
|
||||
|
||||
auto zero_point_shape = zero_point->get_shape();
|
||||
if (ov::shape_size(zero_point_shape) == 1)
|
||||
const auto& zp_shape = zero_point->get_shape();
|
||||
if (ov::shape_size(zp_shape) == 1)
|
||||
return false;
|
||||
|
||||
const auto& weights_shape = weights->get_shape();
|
||||
const size_t weights_rank = weights_shape.size();
|
||||
const size_t zero_point_rank = zero_point_shape.size();
|
||||
// Zero point constant can be converted into scalar only if this does not affect Subtract output shape
|
||||
if (weights_rank < zero_point_rank)
|
||||
if (weights_shape.size() < zp_shape.size() ||
|
||||
!std::equal(zp_shape.rbegin(), zp_shape.rend(), weights_shape.rbegin(), std::less_equal<size_t>())) {
|
||||
return false;
|
||||
|
||||
zero_point_shape.insert(zero_point_shape.begin(), weights_rank - zero_point_rank, 1);
|
||||
for (size_t i = 0; i < weights_rank; ++i) {
|
||||
if (zero_point_shape[i] > weights_shape[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
float zp_value;
|
||||
|
@ -72,7 +72,7 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
|
||||
}
|
||||
};
|
||||
std::unordered_set<Node*> visited;
|
||||
ov::op::util::visit_shape_path(input.get_node(), visited, keep_const_precision);
|
||||
ov::op::util::visit_constant_path(input.get_node(), visited, keep_const_precision);
|
||||
}
|
||||
|
||||
if (subtract_it != pattern_map.end()) {
|
||||
|
@ -21,6 +21,34 @@ namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
|
||||
namespace {
|
||||
void visit_path_impl(ov::Node* node,
|
||||
std::unordered_set<ov::Node*>& visited,
|
||||
std::function<void(ov::Node*)> func,
|
||||
std::function<bool(ov::Node*)> skip_node_predicate) {
|
||||
if (!node)
|
||||
return;
|
||||
visited.insert(node);
|
||||
std::deque<ov::Node*> nodes{node};
|
||||
while (!nodes.empty()) {
|
||||
auto curr_node = nodes.front();
|
||||
nodes.pop_front();
|
||||
if (skip_node_predicate(curr_node))
|
||||
continue;
|
||||
|
||||
func(curr_node);
|
||||
for (auto& input_value : curr_node->input_values()) {
|
||||
// continue searching
|
||||
const auto& input_node = input_value.get_node();
|
||||
if (visited.count(input_node))
|
||||
continue;
|
||||
nodes.push_front(input_node);
|
||||
visited.insert(input_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value, bool check_value_range) {
|
||||
switch (const_node->get_element_type()) {
|
||||
case element::Type_t::f16:
|
||||
@ -242,28 +270,18 @@ bool shapes_equal_except_dynamic_expected_batch(const ov::PartialShape& expected
|
||||
}
|
||||
|
||||
void visit_shape_path(Node* node, std::unordered_set<ov::Node*>& visited, std::function<void(ov::Node*)> func) {
|
||||
if (!node)
|
||||
return;
|
||||
visited.insert(node);
|
||||
std::deque<ov::Node*> nodes{node};
|
||||
while (!nodes.empty()) {
|
||||
auto curr_node = nodes.front();
|
||||
nodes.pop_front();
|
||||
// Do not check if already visited
|
||||
if (ov::is_type<opset1::ShapeOf>(curr_node) || ov::is_type<opset3::ShapeOf>(curr_node)) {
|
||||
continue;
|
||||
}
|
||||
auto is_shapeof = [](ov::Node* node) {
|
||||
return ov::is_type<opset1::ShapeOf>(node) || ov::is_type<opset3::ShapeOf>(node);
|
||||
};
|
||||
visit_path_impl(node, visited, func, is_shapeof);
|
||||
}
|
||||
|
||||
func(curr_node);
|
||||
for (auto& input_value : curr_node->input_values()) {
|
||||
// continue searching
|
||||
const auto& input_node = input_value.get_node();
|
||||
if (visited.count(input_node))
|
||||
continue;
|
||||
nodes.push_front(input_node);
|
||||
visited.insert(input_node);
|
||||
}
|
||||
}
|
||||
void visit_constant_path(ov::Node* node, std::unordered_set<ov::Node*>& visited, std::function<void(ov::Node*)> func) {
|
||||
auto check_parameter = [](ov::Node* node) {
|
||||
OPENVINO_ASSERT(!ov::is_type<opset1::Parameter>(node), "visit_constant_path is called for non-constant path.");
|
||||
return false;
|
||||
};
|
||||
visit_path_impl(node, visited, func, check_parameter);
|
||||
}
|
||||
|
||||
bool is_dequantization_subgraph(const Output<Node>& node) {
|
||||
|
@ -476,7 +476,7 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
|
||||
SetSnippetsNodeType(node->shared_from_this(), snippets::pass::SnippetsNodeType::SkippedByPlugin);
|
||||
};
|
||||
std::unordered_set<Node*> visited;
|
||||
ov::op::util::visit_shape_path(node->get_input_node_ptr(1), visited, markup_func);
|
||||
ov::op::util::visit_constant_path(node->get_input_node_ptr(1), visited, markup_func);
|
||||
}
|
||||
if (isSuitableConvolutionParent(node)) {
|
||||
// Initiate fusing chain
|
||||
|
Loading…
Reference in New Issue
Block a user