[Snippets] Check for cyclic dependencies during ternary merge. (#10374)
This commit is contained in:
parent
a3887f3328
commit
6500ec775d
@ -133,13 +133,7 @@ Shape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShape
|
||||
|
||||
NODE_VALIDATION_CHECK(this, outputShapes.size() == m_body->get_results().size(),
|
||||
"number of results for snippet doesn't match passed to generate method: ", outputShapes.size(), " vs ", m_body->get_results().size(), ".");
|
||||
// todo: does it allowed to have outputs with different layouts? I assume no, remove if invalid
|
||||
const AxisVector outOrder = get<1>(outputShapes[0]);
|
||||
for (size_t i = 1; i < outputShapes.size(); i++) {
|
||||
const AxisVector order_i = get<1>(outputShapes[i]);
|
||||
NODE_VALIDATION_CHECK(this, outOrder.size() == order_i.size() && equal(outOrder.begin(), outOrder.end(), order_i.begin()),
|
||||
"Snippets output shapes must have the same layout");
|
||||
}
|
||||
|
||||
auto getMaxRankBlockedShape = [](const BlockedShapeVector& blockedShapes) -> const BlockedShape& {
|
||||
return *std::max_element(blockedShapes.begin(), blockedShapes.end(),
|
||||
[&](const BlockedShape& lhs, const BlockedShape& rhs) {
|
||||
@ -187,18 +181,33 @@ Shape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShape
|
||||
}
|
||||
|
||||
m_body->validate_nodes_and_infer_types();
|
||||
auto skipStartEndOnes = [](const Shape& shape) {
|
||||
auto begin = shape.begin();
|
||||
auto end = shape.end();
|
||||
while (*begin == 1 && begin != end)
|
||||
begin++;
|
||||
while (begin != end && *(end-1) == 1)
|
||||
end--;
|
||||
Shape trimmedShape(end - begin, 1);
|
||||
std::copy(begin, end, trimmedShape.begin());
|
||||
return trimmedShape;
|
||||
};
|
||||
|
||||
// Check that output shapes are broadcastable => can be scheduled
|
||||
const auto& body_results = m_body->get_results();
|
||||
PartialShape outPShape = body_results[0]->get_shape();
|
||||
for (size_t i = 0; i < body_results.size(); i++) {
|
||||
auto shape_i = body_results[i]->get_shape();
|
||||
PartialShape pShape_i(shape_i);
|
||||
auto outputShape_i = std::get<0>(outputShapes[i]);
|
||||
// Check that the produced output shape corresponds to the passed shape
|
||||
bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i, std::get<0>(outputShapes[i]),
|
||||
// Some produced shapes may have been changed to be broadcastable (e.g. blocked + planar outputs),
|
||||
// so we need to remove leading and trailing "1" before the comparison
|
||||
PartialShape pShape_i(skipStartEndOnes(shape_i));
|
||||
bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i, skipStartEndOnes(outputShape_i),
|
||||
::ngraph::op::AutoBroadcastType::NUMPY);
|
||||
NODE_VALIDATION_CHECK(this, compatibleWithPassedShape, "Inferred and passed results shapes are difference for snippet : ",
|
||||
shape_i, " vs ", std::get<0>(outputShapes[i]), ".");
|
||||
NODE_VALIDATION_CHECK(this, ov::shape_size(shape_i) == ov::shape_size(outputShape_i) &&
|
||||
compatibleWithPassedShape, "Inferred and passed results shapes are incompatible for snippet ",
|
||||
get_friendly_name(), " : ", shape_i, " vs ", outputShape_i, ".");
|
||||
// Check that output shapes are broadcastable to each other => can be scheduled
|
||||
bool compatibleWithOtherOutputs = PartialShape::broadcast_merge_into(outPShape, shape_i,
|
||||
::ngraph::op::AutoBroadcastType::NUMPY);
|
||||
|
@ -372,9 +372,12 @@ TokenizeSnippets::TokenizeSnippets() {
|
||||
|
||||
auto internal = input_body_parameters[i];
|
||||
auto internal_consumers = internal->outputs();
|
||||
|
||||
if (auto to_replace_with = ov::as_type_ptr<op::Subgraph>(subgraph->get_input_node_shared_ptr(i))) {
|
||||
for (auto output : internal_consumers) {
|
||||
// todo: In principle, we can still attach the node to the subgraph if cyclic dependency is introduced during ternary merge.
|
||||
// Need to support.
|
||||
if (cyclicDependencyIsIntoduced(to_replace_with, currentTopoBounds))
|
||||
return abort_with_strategy("Attempt to perform recurrent merge for cyclic-dependent subgraphs. Aborting.");
|
||||
for (const auto& output : internal_consumers) {
|
||||
for (auto consumer : output.get_target_inputs()) {
|
||||
auto other_body = clones[subgraph->get_input_node_shared_ptr(i)];
|
||||
auto other_body_result = other_body->get_results()[consumer.get_source_output().get_index()];
|
||||
|
Loading…
Reference in New Issue
Block a user