[Snippets] Check for cyclic dependencies during ternary merge. (#10374)

This commit is contained in:
Ivan Novoselov 2022-02-22 13:30:15 +03:00 committed by GitHub
parent a3887f3328
commit 6500ec775d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 13 deletions

View File

@ -133,13 +133,7 @@ Shape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShape
NODE_VALIDATION_CHECK(this, outputShapes.size() == m_body->get_results().size(),
"number of results for snippet doesn't match passed to generate method: ", outputShapes.size(), " vs ", m_body->get_results().size(), ".");
// todo: does it allowed to have outputs with different layouts? I assume no, remove if invalid
const AxisVector outOrder = get<1>(outputShapes[0]);
for (size_t i = 1; i < outputShapes.size(); i++) {
const AxisVector order_i = get<1>(outputShapes[i]);
NODE_VALIDATION_CHECK(this, outOrder.size() == order_i.size() && equal(outOrder.begin(), outOrder.end(), order_i.begin()),
"Snippets output shapes must have the same layout");
}
auto getMaxRankBlockedShape = [](const BlockedShapeVector& blockedShapes) -> const BlockedShape& {
return *std::max_element(blockedShapes.begin(), blockedShapes.end(),
[&](const BlockedShape& lhs, const BlockedShape& rhs) {
@ -187,18 +181,33 @@ Shape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShape
}
m_body->validate_nodes_and_infer_types();
auto skipStartEndOnes = [](const Shape& shape) {
auto begin = shape.begin();
auto end = shape.end();
while (*begin == 1 && begin != end)
begin++;
while (begin != end && *(end-1) == 1)
end--;
Shape trimmedShape(end - begin, 1);
std::copy(begin, end, trimmedShape.begin());
return trimmedShape;
};
// Check that output shapes are broadcastable => can be scheduled
const auto& body_results = m_body->get_results();
PartialShape outPShape = body_results[0]->get_shape();
for (size_t i = 0; i < body_results.size(); i++) {
auto shape_i = body_results[i]->get_shape();
PartialShape pShape_i(shape_i);
auto outputShape_i = std::get<0>(outputShapes[i]);
// Check that the produced output shape corresponds to the passed shape
bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i, std::get<0>(outputShapes[i]),
// Some produced shapes may have been changed to be broadcastable (e.g. blocked + planar outputs),
// so we need to remove leading and trailing "1" before the comparison
PartialShape pShape_i(skipStartEndOnes(shape_i));
bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i, skipStartEndOnes(outputShape_i),
::ngraph::op::AutoBroadcastType::NUMPY);
NODE_VALIDATION_CHECK(this, compatibleWithPassedShape, "Inferred and passed results shapes are difference for snippet : ",
shape_i, " vs ", std::get<0>(outputShapes[i]), ".");
NODE_VALIDATION_CHECK(this, ov::shape_size(shape_i) == ov::shape_size(outputShape_i) &&
compatibleWithPassedShape, "Inferred and passed results shapes are incompatible for snippet ",
get_friendly_name(), " : ", shape_i, " vs ", outputShape_i, ".");
// Check that output shapes are broadcastable to each other => can be scheduled
bool compatibleWithOtherOutputs = PartialShape::broadcast_merge_into(outPShape, shape_i,
::ngraph::op::AutoBroadcastType::NUMPY);

View File

@ -372,9 +372,12 @@ TokenizeSnippets::TokenizeSnippets() {
auto internal = input_body_parameters[i];
auto internal_consumers = internal->outputs();
if (auto to_replace_with = ov::as_type_ptr<op::Subgraph>(subgraph->get_input_node_shared_ptr(i))) {
for (auto output : internal_consumers) {
// todo: In principle, we can still attach the node to the subgraph if cyclic dependency is introduced during ternary merge.
// Need to support.
if (cyclicDependencyIsIntoduced(to_replace_with, currentTopoBounds))
return abort_with_strategy("Attempt to perform recurrent merge for cyclic-dependent subgraphs. Aborting.");
for (const auto& output : internal_consumers) {
for (auto consumer : output.get_target_inputs()) {
auto other_body = clones[subgraph->get_input_node_shared_ptr(i)];
auto other_body_result = other_body->get_results()[consumer.get_source_output().get_index()];