[Transformations] Leftovers: FuseU4WeightsAndZeroPoint transformation (#20709)
* util::visit_constant_path * ConvertU4WeightsZeroPointToScalar: avoid unnecessary insert * Review comments applied * codestyle fix
This commit is contained in:
parent
c50bd2b55f
commit
598da6e5c0
@ -192,10 +192,30 @@ TRANSFORMATIONS_API std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<N
|
|||||||
TRANSFORMATIONS_API bool shapes_equal_except_dynamic_expected_batch(const PartialShape& expected,
|
TRANSFORMATIONS_API bool shapes_equal_except_dynamic_expected_batch(const PartialShape& expected,
|
||||||
const PartialShape& actual);
|
const PartialShape& actual);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Traverses a shapeOf subgraph starting from the node and not including the ShapeOf nodes,
|
||||||
|
* and calls "func" for each ov::Node.
|
||||||
|
*
|
||||||
|
* \param node The node from which constant path is started.
|
||||||
|
* \param visited Set of nodes which were visited.
|
||||||
|
* \param func The function which is called for each visited node.
|
||||||
|
*/
|
||||||
TRANSFORMATIONS_API void visit_shape_path(ov::Node* node,
|
TRANSFORMATIONS_API void visit_shape_path(ov::Node* node,
|
||||||
std::unordered_set<ov::Node*>& visited,
|
std::unordered_set<ov::Node*>& visited,
|
||||||
std::function<void(ov::Node*)> func);
|
std::function<void(ov::Node*)> func);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Traverses a constant path starting from "node", and calls "func" for each ov::Node.
|
||||||
|
* If the function was called for non-constant subgraph, exception is thrown.
|
||||||
|
*
|
||||||
|
* \param node The node from which constant path is started.
|
||||||
|
* \param visited Set of nodes which were visited.
|
||||||
|
* \param func The function which is called for each visited node.
|
||||||
|
*/
|
||||||
|
TRANSFORMATIONS_API void visit_constant_path(ov::Node* node,
|
||||||
|
std::unordered_set<ov::Node*>& visited,
|
||||||
|
std::function<void(ov::Node*)> func);
|
||||||
|
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
std::shared_ptr<Node> make_try_fold(Args&&... args) {
|
std::shared_ptr<Node> make_try_fold(Args&&... args) {
|
||||||
auto unary_output_node = std::make_shared<T>(std::forward<Args>(args)...);
|
auto unary_output_node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||||
|
@ -51,21 +51,15 @@ ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar()
|
|||||||
if (ov::shape_size(weights->get_shape()) < ov::shape_size(zero_point->get_shape()))
|
if (ov::shape_size(weights->get_shape()) < ov::shape_size(zero_point->get_shape()))
|
||||||
std::swap(zero_point, weights);
|
std::swap(zero_point, weights);
|
||||||
|
|
||||||
auto zero_point_shape = zero_point->get_shape();
|
const auto& zp_shape = zero_point->get_shape();
|
||||||
if (ov::shape_size(zero_point_shape) == 1)
|
if (ov::shape_size(zp_shape) == 1)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
const auto& weights_shape = weights->get_shape();
|
const auto& weights_shape = weights->get_shape();
|
||||||
const size_t weights_rank = weights_shape.size();
|
|
||||||
const size_t zero_point_rank = zero_point_shape.size();
|
|
||||||
// Zero point constant can be converted into scalar only if this does not affect Subtract output shape
|
// Zero point constant can be converted into scalar only if this does not affect Subtract output shape
|
||||||
if (weights_rank < zero_point_rank)
|
if (weights_shape.size() < zp_shape.size() ||
|
||||||
|
!std::equal(zp_shape.rbegin(), zp_shape.rend(), weights_shape.rbegin(), std::less_equal<size_t>())) {
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
zero_point_shape.insert(zero_point_shape.begin(), weights_rank - zero_point_rank, 1);
|
|
||||||
for (size_t i = 0; i < weights_rank; ++i) {
|
|
||||||
if (zero_point_shape[i] > weights_shape[i])
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float zp_value;
|
float zp_value;
|
||||||
|
@ -72,7 +72,7 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
std::unordered_set<Node*> visited;
|
std::unordered_set<Node*> visited;
|
||||||
ov::op::util::visit_shape_path(input.get_node(), visited, keep_const_precision);
|
ov::op::util::visit_constant_path(input.get_node(), visited, keep_const_precision);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (subtract_it != pattern_map.end()) {
|
if (subtract_it != pattern_map.end()) {
|
||||||
|
@ -21,6 +21,34 @@ namespace ov {
|
|||||||
namespace op {
|
namespace op {
|
||||||
namespace util {
|
namespace util {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void visit_path_impl(ov::Node* node,
|
||||||
|
std::unordered_set<ov::Node*>& visited,
|
||||||
|
std::function<void(ov::Node*)> func,
|
||||||
|
std::function<bool(ov::Node*)> skip_node_predicate) {
|
||||||
|
if (!node)
|
||||||
|
return;
|
||||||
|
visited.insert(node);
|
||||||
|
std::deque<ov::Node*> nodes{node};
|
||||||
|
while (!nodes.empty()) {
|
||||||
|
auto curr_node = nodes.front();
|
||||||
|
nodes.pop_front();
|
||||||
|
if (skip_node_predicate(curr_node))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
func(curr_node);
|
||||||
|
for (auto& input_value : curr_node->input_values()) {
|
||||||
|
// continue searching
|
||||||
|
const auto& input_node = input_value.get_node();
|
||||||
|
if (visited.count(input_node))
|
||||||
|
continue;
|
||||||
|
nodes.push_front(input_node);
|
||||||
|
visited.insert(input_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value, bool check_value_range) {
|
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value, bool check_value_range) {
|
||||||
switch (const_node->get_element_type()) {
|
switch (const_node->get_element_type()) {
|
||||||
case element::Type_t::f16:
|
case element::Type_t::f16:
|
||||||
@ -242,28 +270,18 @@ bool shapes_equal_except_dynamic_expected_batch(const ov::PartialShape& expected
|
|||||||
}
|
}
|
||||||
|
|
||||||
void visit_shape_path(Node* node, std::unordered_set<ov::Node*>& visited, std::function<void(ov::Node*)> func) {
|
void visit_shape_path(Node* node, std::unordered_set<ov::Node*>& visited, std::function<void(ov::Node*)> func) {
|
||||||
if (!node)
|
auto is_shapeof = [](ov::Node* node) {
|
||||||
return;
|
return ov::is_type<opset1::ShapeOf>(node) || ov::is_type<opset3::ShapeOf>(node);
|
||||||
visited.insert(node);
|
};
|
||||||
std::deque<ov::Node*> nodes{node};
|
visit_path_impl(node, visited, func, is_shapeof);
|
||||||
while (!nodes.empty()) {
|
}
|
||||||
auto curr_node = nodes.front();
|
|
||||||
nodes.pop_front();
|
|
||||||
// Do not check if already visited
|
|
||||||
if (ov::is_type<opset1::ShapeOf>(curr_node) || ov::is_type<opset3::ShapeOf>(curr_node)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
func(curr_node);
|
void visit_constant_path(ov::Node* node, std::unordered_set<ov::Node*>& visited, std::function<void(ov::Node*)> func) {
|
||||||
for (auto& input_value : curr_node->input_values()) {
|
auto check_parameter = [](ov::Node* node) {
|
||||||
// continue searching
|
OPENVINO_ASSERT(!ov::is_type<opset1::Parameter>(node), "visit_constant_path is called for non-constant path.");
|
||||||
const auto& input_node = input_value.get_node();
|
return false;
|
||||||
if (visited.count(input_node))
|
};
|
||||||
continue;
|
visit_path_impl(node, visited, func, check_parameter);
|
||||||
nodes.push_front(input_node);
|
|
||||||
visited.insert(input_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_dequantization_subgraph(const Output<Node>& node) {
|
bool is_dequantization_subgraph(const Output<Node>& node) {
|
||||||
|
@ -476,7 +476,7 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
|
|||||||
SetSnippetsNodeType(node->shared_from_this(), snippets::pass::SnippetsNodeType::SkippedByPlugin);
|
SetSnippetsNodeType(node->shared_from_this(), snippets::pass::SnippetsNodeType::SkippedByPlugin);
|
||||||
};
|
};
|
||||||
std::unordered_set<Node*> visited;
|
std::unordered_set<Node*> visited;
|
||||||
ov::op::util::visit_shape_path(node->get_input_node_ptr(1), visited, markup_func);
|
ov::op::util::visit_constant_path(node->get_input_node_ptr(1), visited, markup_func);
|
||||||
}
|
}
|
||||||
if (isSuitableConvolutionParent(node)) {
|
if (isSuitableConvolutionParent(node)) {
|
||||||
// Initiate fusing chain
|
// Initiate fusing chain
|
||||||
|
Loading…
Reference in New Issue
Block a user