[CPU] Fixed checks in snippets_mark_skipped for MatMul weights (#19141)

This commit is contained in:
Anton Voronov 2023-08-14 16:04:35 +04:00 committed by GitHub
parent 46f428eeac
commit 87f9b2bdf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -254,7 +254,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
ov::PartialShape matmul_shape;
for (const auto &parent_out : node->input_values()) {
const auto parent = parent_out.get_node_shared_ptr();
if (ov::is_type<ov::op::v0::Constant>(parent) || ov::is_type<ov::op::v0::Convert>(parent)) {
if (ov::op::util::is_on_constant_path(parent)) {
bias_shape = parent_out.get_shape();
num_non_const_inputs++;
} else {
@ -265,8 +265,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
// first check that weights are constant and both activations and weights have static shape
if (grandparents.size() == 2 &&
grandparents[1].get_partial_shape().is_static() &&
(ov::is_type<ov::op::v0::Constant>(grandparents[1].get_node_shared_ptr())
|| ov::is_type<ov::op::v0::Convert>(grandparents[1].get_node_shared_ptr()))) {
(ov::op::util::is_on_constant_path(grandparents[1].get_node_shared_ptr()))) {
auto rank_a = grandparents[0].get_partial_shape().rank().get_length();
auto rank_w = grandparents[1].get_partial_shape().rank().get_length();
if (rank_a != 1 && rank_w != 1 && rank_a <= 3 && rank_w <= 3)