Files
openvino/src/common/snippets/src/pass/validate.cpp

112 lines
4.6 KiB
C++

// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/pass/validate.hpp"
#include "snippets/op/convert_saturation.hpp"
#include "snippets/op/convert_truncation.hpp"
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
#include "snippets/pass/fq_decomposition.hpp"
#include "snippets/utils.hpp"
#include "snippets/itt.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/softmax.hpp"
#include "openvino/core/validation_util.hpp"
namespace ov {
namespace snippets {
namespace pass {
namespace {
#define VALIDATE(op, op_type, validator) \
if (ov::is_type<op_type>(op)) \
OPENVINO_ASSERT(validator(op), "Snippets validation of OV body has been failed: " + \
std::string(op->get_type_name()) + " op " + op->get_friendly_name() + " is not supported"); \
else
} // namespace
bool Validate::is_supported_constant(const std::shared_ptr<const ov::Node>& op) {
const auto constant = ov::as_type_ptr<const ov::op::v0::Constant>(op);
const auto consumers = op->get_output_target_inputs(0);
return constant &&
(ov::shape_size(constant->get_output_shape(0)) == 1 ||
std::all_of(consumers.cbegin(), consumers.cend(),
[](const ov::Input<ov::Node>& in) {
return ov::is_type<const ov::op::v1::Transpose>(in.get_node()) ||
ov::is_type<const ov::op::v1::Broadcast>(in.get_node()) ||
ov::is_type<const ov::op::v3::Broadcast>(in.get_node());
}));
}
bool Validate::is_supported_convert(const std::shared_ptr<const ov::Node>& op) {
return ov::is_type<const op::ConvertTruncation>(op) || ov::is_type<const op::ConvertSaturation>(op);
}
bool Validate::is_supported_matmul(const std::shared_ptr<const ov::Node>& op) {
// If ExplicitTransposeMatMulInputs pass is enabled, MatMul should have not transposed inputs
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(op);
return matmul && utils::implication(m_pass_config->is_enabled<ov::snippets::pass::ExplicitTransposeMatMulInputs>(),
!matmul->get_transpose_a() && !matmul->get_transpose_b());
}
bool Validate::is_supported_softmax(const std::shared_ptr<const ov::Node>& op) {
// Softmax is supported only with axis by last dim
const auto softmax_rank = op->get_input_partial_shape(0).rank();
int64_t axis = 0;
if (const auto softmax_v8 = ov::as_type_ptr<const ov::op::v8::Softmax>(op)) {
OPENVINO_SUPPRESS_DEPRECATED_START
axis = ov::normalize_axis(softmax_v8->get_friendly_name(), softmax_v8->get_axis(), softmax_rank);
OPENVINO_SUPPRESS_DEPRECATED_END
} else if (const auto softmax_v1 = ov::as_type_ptr<const ov::op::v1::Softmax>(op)) {
axis = softmax_v1->get_axis();
} else {
return false;
}
return axis == softmax_rank.get_length() - 1;
}
bool Validate::is_supported_fq(const std::shared_ptr<const ov::Node>& node) {
// FQ is decomposed into ops in CommonFakeQuantizeDecomposition pass
return m_pass_config->is_disabled<ov::snippets::pass::CommonFakeQuantizeDecomposition>();
}
bool Validate::is_supported_transpose(const std::shared_ptr<const ov::Node>& node) {
// Transpose is supported only on Inputs or Outputs of body
const auto consumers = node->get_output_target_inputs(0);
return (ov::is_type<ov::op::v0::Parameter>(node->get_input_node_shared_ptr(0))) ||
(consumers.size() == 1 && ov::is_type<ov::op::v0::Result>(consumers.cbegin()->get_node()));
}
bool Validate::is_supported_op(const std::shared_ptr<const ov::Node>& node) {
return false;
}
bool Validate::run_on_model(const std::shared_ptr<ov::Model>& m) {
RUN_ON_MODEL_SCOPE(Validate);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::Validate")
for (const auto& op : m->get_ordered_ops()) {
VALIDATE(op, ov::op::v0::Constant, is_supported_constant)
VALIDATE(op, ov::op::v0::Convert, is_supported_convert)
VALIDATE(op, ov::op::v0::MatMul, is_supported_matmul)
VALIDATE(op, ov::op::v1::Softmax, is_supported_softmax)
VALIDATE(op, ov::op::v8::Softmax, is_supported_softmax)
VALIDATE(op, ov::op::v0::FakeQuantize, is_supported_fq)
VALIDATE(op, ov::op::v1::Transpose, is_supported_transpose)
VALIDATE(op, ov::op::v1::Reshape, is_supported_op);
}
return true;
}
} // namespace pass
} // namespace snippets
} // namespace ov