diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 7dff3913085..e73dc6b4bce 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -133,13 +133,7 @@ Shape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShape NODE_VALIDATION_CHECK(this, outputShapes.size() == m_body->get_results().size(), "number of results for snippet doesn't match passed to generate method: ", outputShapes.size(), " vs ", m_body->get_results().size(), "."); - // todo: does it allowed to have outputs with different layouts? I assume no, remove if invalid - const AxisVector outOrder = get<1>(outputShapes[0]); - for (size_t i = 1; i < outputShapes.size(); i++) { - const AxisVector order_i = get<1>(outputShapes[i]); - NODE_VALIDATION_CHECK(this, outOrder.size() == order_i.size() && equal(outOrder.begin(), outOrder.end(), order_i.begin()), - "Snippets output shapes must have the same layout"); - } + auto getMaxRankBlockedShape = [](const BlockedShapeVector& blockedShapes) -> const BlockedShape& { return *std::max_element(blockedShapes.begin(), blockedShapes.end(), [&](const BlockedShape& lhs, const BlockedShape& rhs) { @@ -187,18 +181,33 @@ Shape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShape } m_body->validate_nodes_and_infer_types(); + auto skipStartEndOnes = [](const Shape& shape) { + auto begin = shape.begin(); + auto end = shape.end(); + while (*begin == 1 && begin != end) + begin++; + while (begin != end && *(end-1) == 1) + end--; + Shape trimmedShape(end - begin, 1); + std::copy(begin, end, trimmedShape.begin()); + return trimmedShape; + }; // Check that output shapes are broadcastable => can be scheduled const auto& body_results = m_body->get_results(); PartialShape outPShape = body_results[0]->get_shape(); for (size_t i = 0; i < body_results.size(); i++) { auto shape_i = body_results[i]->get_shape(); - PartialShape pShape_i(shape_i); + auto outputShape_i = std::get<0>(outputShapes[i]); // Check that the produced output shape corresponds to the passed shape - bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i, std::get<0>(outputShapes[i]), + // Some produced shapes may have been changed to be broadcastable (e.g. blocked + planar outputs), + // so we need to remove leading and trailing "1" before the comparison + PartialShape pShape_i(skipStartEndOnes(shape_i)); + bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i, skipStartEndOnes(outputShape_i), ::ngraph::op::AutoBroadcastType::NUMPY); - NODE_VALIDATION_CHECK(this, compatibleWithPassedShape, "Inferred and passed results shapes are difference for snippet : ", - shape_i, " vs ", std::get<0>(outputShapes[i]), "."); + NODE_VALIDATION_CHECK(this, ov::shape_size(shape_i) == ov::shape_size(outputShape_i) && + compatibleWithPassedShape, "Inferred and passed results shapes are incompatible for snippet ", + get_friendly_name(), " : ", shape_i, " vs ", outputShape_i, "."); // Check that output shapes are broadcastable to each other => can be scheduled bool compatibleWithOtherOutputs = PartialShape::broadcast_merge_into(outPShape, shape_i, ::ngraph::op::AutoBroadcastType::NUMPY); diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index 1529b58a81d..0bd462f1ee5 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -372,9 +372,12 @@ TokenizeSnippets::TokenizeSnippets() { auto internal = input_body_parameters[i]; auto internal_consumers = internal->outputs(); - if (auto to_replace_with = ov::as_type_ptr(subgraph->get_input_node_shared_ptr(i))) { - for (auto output : internal_consumers) { + // todo: In principle, we can still attach the node to the subgraph if cyclic dependency is introduced during ternary merge. + // Need to support. + if (cyclicDependencyIsIntoduced(to_replace_with, currentTopoBounds)) + return abort_with_strategy("Attempt to perform recurrent merge for cyclic-dependent subgraphs. Aborting."); + for (const auto& output : internal_consumers) { for (auto consumer : output.get_target_inputs()) { auto other_body = clones[subgraph->get_input_node_shared_ptr(i)]; auto other_body_result = other_body->get_results()[consumer.get_source_output().get_index()];