[Snippets] Created common static method is_supported_fq (#19775)
This commit is contained in:
committed by
GitHub
parent
3ce48fc3d6
commit
69c237f340
@@ -6,8 +6,6 @@
|
||||
|
||||
#include "openvino/op/fake_quantize.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "snippets/pass/transform_convert.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
@@ -82,6 +80,8 @@ public:
|
||||
class CommonFakeQuantizeDecomposition: public ov::pass::ModelPass {
|
||||
public:
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
|
||||
static bool is_supported_fq(const std::shared_ptr<const ov::op::v0::FakeQuantize>& fq);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
||||
41
src/common/snippets/include/snippets/pass/validate.hpp
Normal file
41
src/common/snippets/include/snippets/pass/validate.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @interface Validate
|
||||
* @brief The pass validates OV model on correctness after all common optimizations
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class Validate: public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("Validate", "0");
|
||||
Validate(const std::shared_ptr<ov::pass::PassConfig>& pass_config) : m_pass_config(pass_config) {}
|
||||
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
|
||||
private:
|
||||
bool is_supported_constant(const std::shared_ptr<const ov::Node>& op);
|
||||
bool is_supported_convert(const std::shared_ptr<const ov::Node>& op);
|
||||
bool is_supported_matmul(const std::shared_ptr<const ov::Node>& op);
|
||||
bool is_supported_softmax(const std::shared_ptr<const ov::Node>& op);
|
||||
bool is_supported_fq(const std::shared_ptr<const ov::Node>& node);
|
||||
bool is_supported_transpose(const std::shared_ptr<const ov::Node>& node);
|
||||
bool is_supported_op(const std::shared_ptr<const ov::Node>& node);
|
||||
|
||||
// Pass config of CommonOptimizations that contains information: which of common passes are disabled
|
||||
std::shared_ptr<ov::pass::PassConfig> m_pass_config;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
||||
@@ -51,6 +51,10 @@ constexpr bool everyone_is(T val, P item, Args... item_others) {
|
||||
return val == item && everyone_is(val, item_others...);
|
||||
}
|
||||
|
||||
constexpr inline bool implication(bool cause, bool cond) {
|
||||
return !cause || !!cond;
|
||||
}
|
||||
|
||||
VectorDims get_planar_vdims(const VectorDims& shape, const std::vector<size_t>& layout);
|
||||
VectorDims get_planar_vdims(const snippets::lowered::PortDescriptorPtr& port_desc);
|
||||
VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port);
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/pass/transpose_decomposition.hpp"
|
||||
#include "snippets/pass/fuse_transpose_brgemm.hpp"
|
||||
#include "snippets/pass/fq_decomposition.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
|
||||
@@ -86,15 +87,7 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
|
||||
};
|
||||
|
||||
auto is_supported_fq_op = [](const std::shared_ptr<const Node>& n) -> bool {
|
||||
// TODO [92179]: Add support of FakeQuantize with non-constants inputs and with binarization algorithm.
|
||||
const auto fq = ov::as_type_ptr<const opset1::FakeQuantize>(n);
|
||||
return fq && fq->get_levels() != 2 &&
|
||||
is_type<ov::op::v0::Constant>(n->get_input_node_shared_ptr(1)) &&
|
||||
is_type<ov::op::v0::Constant>(n->get_input_node_shared_ptr(2)) &&
|
||||
is_type<ov::op::v0::Constant>(n->get_input_node_shared_ptr(3)) &&
|
||||
is_type<ov::op::v0::Constant>(n->get_input_node_shared_ptr(4)) &&
|
||||
(fq->get_auto_broadcast() == ov::op::AutoBroadcastType::NUMPY ||
|
||||
fq->get_auto_broadcast() == ov::op::AutoBroadcastType::NONE);
|
||||
return CommonFakeQuantizeDecomposition::is_supported_fq(ov::as_type_ptr<const opset1::FakeQuantize>(n));
|
||||
};
|
||||
|
||||
auto is_supported_ternary_eltwise_op = [](const std::shared_ptr<const Node> &n) -> bool {
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
|
||||
#include "snippets/pass/transpose_decomposition.hpp"
|
||||
#include "snippets/pass/fuse_transpose_brgemm.hpp"
|
||||
#include "snippets/pass/transform_convert.hpp"
|
||||
#include "snippets/pass/validate.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
@@ -372,7 +374,7 @@ CommonOptimizations::CommonOptimizations(const SnippetsTokenization::Config& con
|
||||
|
||||
// Firstly, we should transform all original Converts inside body to ConvertTruncation to save original behavior.
|
||||
// Then if Subgraph contains FakeQuantize we enable specific transformation for quantized subgraphs.
|
||||
ov::pass::Manager manager;
|
||||
ov::pass::Manager manager(get_pass_config());
|
||||
manager.register_pass<ov::snippets::pass::TransformConvertToConvertTruncation>();
|
||||
manager.register_pass<ov::snippets::pass::ExplicitTransposeMatMulInputs>();
|
||||
if (is_quantized) {
|
||||
@@ -392,6 +394,10 @@ CommonOptimizations::CommonOptimizations(const SnippetsTokenization::Config& con
|
||||
if (config.split_m_dimension)
|
||||
SplitDimensionM(subgraph, config.concurrency);
|
||||
}
|
||||
|
||||
// Validate the body after all common optimizations
|
||||
ov::snippets::pass::Validate(get_pass_config()).run_on_model(body);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
|
||||
@@ -12,35 +12,19 @@
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/pass/validate.hpp"
|
||||
|
||||
#include "openvino/reference/autobroadcast_binop.hpp"
|
||||
#include "openvino/reference/broadcast.hpp"
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "snippets/op/convert_saturation.hpp"
|
||||
|
||||
namespace {
|
||||
bool isValidRangesInputs(const std::shared_ptr<ov::opset1::FakeQuantize>& fq) {
|
||||
auto il = fq->input_value(1);
|
||||
auto ih = fq->input_value(2);
|
||||
auto greater_equal = std::make_shared<ov::opset1::Greater>(il, ih);
|
||||
|
||||
ov::OutputVector result(1);
|
||||
if (!greater_equal->constant_fold(result, greater_equal->input_values()))
|
||||
return false;
|
||||
|
||||
auto res_node = std::dynamic_pointer_cast<const ov::opset1::Constant>(result[0].get_node_shared_ptr());
|
||||
|
||||
const std::vector<bool> comp_result = res_node->cast_vector<bool>();
|
||||
|
||||
return !std::any_of(comp_result.begin(), comp_result.end(), [](const bool value) {
|
||||
return value;
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ov::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
|
||||
MATCHER_SCOPE(FakeQuantizeDecomposition);
|
||||
|
||||
auto fake_quantize = ov::pass::pattern::wrap_type<ov::opset1::FakeQuantize>(
|
||||
auto fake_quantize = ov::pass::pattern::wrap_type<ov::op::v0::FakeQuantize>(
|
||||
OutputVector{ov::pass::pattern::any_input(),
|
||||
ov::pass::pattern::wrap_type<ov::op::v0::Constant>(),
|
||||
ov::pass::pattern::wrap_type<ov::op::v0::Constant>(),
|
||||
@@ -50,14 +34,16 @@ ov::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
|
||||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::FakeQuantizeDecomposition")
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
const auto fake_quantize_node = std::dynamic_pointer_cast<ov::opset1::FakeQuantize>(
|
||||
const auto fake_quantize_node = std::dynamic_pointer_cast<ov::op::v0::FakeQuantize>(
|
||||
pattern_to_output.at(fake_quantize).get_node_shared_ptr());
|
||||
|
||||
if (!fake_quantize_node || transformation_callback(fake_quantize_node) ||
|
||||
!isValidRangesInputs(fake_quantize_node)) {
|
||||
if (!fake_quantize_node || transformation_callback(fake_quantize_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
OPENVINO_ASSERT(CommonFakeQuantizeDecomposition::is_supported_fq(fake_quantize_node),
|
||||
"FQ Decomposition got invalid FakeQuantize node with the name " + fake_quantize_node->get_friendly_name());
|
||||
|
||||
Output<Node> data{fake_quantize_node->input_value(0)};
|
||||
const Output<Node> input_low{fake_quantize_node->input_value(1)};
|
||||
const Output<Node> input_high{fake_quantize_node->input_value(2)};
|
||||
@@ -94,8 +80,8 @@ ov::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
|
||||
|
||||
// if we set input_low or input_high in formula we got output = output_low and output = output_high
|
||||
// respectively so we just clamp x
|
||||
const auto max = std::make_shared<ov::opset1::Maximum>(data, input_low);
|
||||
const auto min = std::make_shared<ov::opset1::Minimum>(max, input_high);
|
||||
const auto max = std::make_shared<ov::op::v1::Maximum>(data, input_low);
|
||||
const auto min = std::make_shared<ov::op::v1::Minimum>(max, input_high);
|
||||
decomp_ops.push_back(max);
|
||||
decomp_ops.push_back(min);
|
||||
|
||||
@@ -106,30 +92,30 @@ ov::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
|
||||
input_high.get_partial_shape(),
|
||||
broadcast_type);
|
||||
const auto scales =
|
||||
std::make_shared<ov::opset1::Constant>(ov::element::f32, scale_shape.get_shape(), out_scales);
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::f32, scale_shape.get_shape(), out_scales);
|
||||
decomp_ops.push_back(scales);
|
||||
|
||||
result = std::make_shared<ov::opset1::Multiply>(min, scales);
|
||||
result = std::make_shared<ov::op::v1::Multiply>(min, scales);
|
||||
decomp_ops.push_back(result);
|
||||
} else {
|
||||
// (levels-1)
|
||||
const auto levels_minus_one =
|
||||
std::make_shared<ov::opset1::Constant>(input_type, Shape{}, fake_quantize_node->get_levels() - 1);
|
||||
std::make_shared<ov::op::v0::Constant>(input_type, Shape{}, fake_quantize_node->get_levels() - 1);
|
||||
decomp_ops.push_back(levels_minus_one);
|
||||
// (input_high - input_low)
|
||||
const auto subInHighLow = std::make_shared<ov::opset1::Subtract>(input_high, input_low);
|
||||
const auto subInHighLow = std::make_shared<ov::op::v1::Subtract>(input_high, input_low);
|
||||
// (levels-1) / (input_high - input_low)
|
||||
const auto isc = std::make_shared<ov::opset1::Divide>(levels_minus_one, subInHighLow);
|
||||
const auto isc = std::make_shared<ov::op::v1::Divide>(levels_minus_one, subInHighLow);
|
||||
// input_low * (levels-1) / (input_high - input_low)
|
||||
const auto ish = std::make_shared<ov::opset1::Multiply>(input_low, isc);
|
||||
const auto ish = std::make_shared<ov::op::v1::Multiply>(input_low, isc);
|
||||
decomp_ops.push_back(subInHighLow);
|
||||
decomp_ops.push_back(isc);
|
||||
decomp_ops.push_back(ish);
|
||||
|
||||
// x * (levels-1) / (input_high - input_low)
|
||||
const auto after_isc_apply = std::make_shared<ov::opset1::Multiply>(min, isc);
|
||||
const auto after_isc_apply = std::make_shared<ov::op::v1::Multiply>(min, isc);
|
||||
// x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low)
|
||||
result = std::make_shared<ov::opset1::Subtract>(after_isc_apply, ish);
|
||||
result = std::make_shared<ov::op::v1::Subtract>(after_isc_apply, ish);
|
||||
decomp_ops.push_back(after_isc_apply);
|
||||
decomp_ops.push_back(result);
|
||||
}
|
||||
@@ -143,20 +129,20 @@ ov::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
|
||||
if (do_dequantize) {
|
||||
// (levels-1)
|
||||
const auto levels_minus_one =
|
||||
std::make_shared<ov::opset1::Constant>(input_type, Shape{}, fake_quantize_node->get_levels() - 1);
|
||||
std::make_shared<ov::op::v0::Constant>(input_type, Shape{}, fake_quantize_node->get_levels() - 1);
|
||||
// (output_high - output_low)
|
||||
const auto sub_out_high_low = std::make_shared<ov::opset1::Subtract>(output_high, output_low);
|
||||
const auto sub_out_high_low = std::make_shared<ov::op::v1::Subtract>(output_high, output_low);
|
||||
// (output_high - output_low) / (levels-1)
|
||||
const auto osc = std::make_shared<ov::opset1::Divide>(sub_out_high_low, levels_minus_one);
|
||||
const auto osc = std::make_shared<ov::op::v1::Divide>(sub_out_high_low, levels_minus_one);
|
||||
decomp_ops.push_back(sub_out_high_low);
|
||||
decomp_ops.push_back(osc);
|
||||
|
||||
// round(x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low)) *
|
||||
// (output_high - output_low) / (levels-1)
|
||||
const auto after_osc_apply = std::make_shared<ov::opset1::Multiply>(result, osc);
|
||||
const auto after_osc_apply = std::make_shared<ov::op::v1::Multiply>(result, osc);
|
||||
// round(x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low)) *
|
||||
// (output_high - output_low) / (levels-1) + output_low
|
||||
result = std::make_shared<ov::opset1::Add>(after_osc_apply, output_low);
|
||||
result = std::make_shared<ov::op::v1::Add>(after_osc_apply, output_low);
|
||||
decomp_ops.push_back(after_osc_apply);
|
||||
decomp_ops.push_back(result);
|
||||
}
|
||||
@@ -177,21 +163,17 @@ ov::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
|
||||
}
|
||||
|
||||
bool ov::snippets::pass::FakeQuantizeDecomposition::getScalesAndShifts(
|
||||
const std::shared_ptr<const ov::opset1::FakeQuantize>& fq_node,
|
||||
const std::shared_ptr<const ov::op::v0::FakeQuantize>& fq_node,
|
||||
std::vector<float>& cl,
|
||||
std::vector<float>& ch,
|
||||
std::vector<float>& isc,
|
||||
std::vector<float>& ish,
|
||||
std::vector<float>& osc,
|
||||
std::vector<float>& osh) {
|
||||
auto input_low_constant =
|
||||
std::dynamic_pointer_cast<ov::opset1::Constant>(fq_node->get_input_node_shared_ptr(1));
|
||||
auto input_high_constant =
|
||||
std::dynamic_pointer_cast<ov::opset1::Constant>(fq_node->get_input_node_shared_ptr(2));
|
||||
auto output_low_constant =
|
||||
std::dynamic_pointer_cast<ov::opset1::Constant>(fq_node->get_input_node_shared_ptr(3));
|
||||
auto output_high_constant =
|
||||
std::dynamic_pointer_cast<ov::opset1::Constant>(fq_node->get_input_node_shared_ptr(4));
|
||||
auto input_low_constant = ov::as_type_ptr<ov::op::v0::Constant>(fq_node->get_input_node_shared_ptr(1));
|
||||
auto input_high_constant = ov::as_type_ptr<ov::op::v0::Constant>(fq_node->get_input_node_shared_ptr(2));
|
||||
auto output_low_constant = ov::as_type_ptr<ov::op::v0::Constant>(fq_node->get_input_node_shared_ptr(3));
|
||||
auto output_high_constant = ov::as_type_ptr<ov::op::v0::Constant>(fq_node->get_input_node_shared_ptr(4));
|
||||
if (!input_low_constant || !input_high_constant || !output_low_constant || !output_high_constant)
|
||||
return false;
|
||||
|
||||
@@ -305,12 +287,12 @@ bool ov::snippets::pass::FakeQuantizeDecomposition::getScalesAndShifts(
|
||||
}
|
||||
|
||||
std::vector<float> ov::snippets::pass::FakeQuantizeDecomposition::calculateScales(const ov::element::Type& out_type,
|
||||
const std::vector<float>& cl,
|
||||
const std::vector<float>& ch,
|
||||
const std::vector<float>& isc,
|
||||
const std::vector<float>& ish,
|
||||
const std::vector<float>& osc,
|
||||
const std::vector<float>& osh) {
|
||||
const std::vector<float>& cl,
|
||||
const std::vector<float>& ch,
|
||||
const std::vector<float>& isc,
|
||||
const std::vector<float>& ish,
|
||||
const std::vector<float>& osc,
|
||||
const std::vector<float>& osh) {
|
||||
std::vector<float> out_scales;
|
||||
if (out_type == ov::element::u8 &&
|
||||
std::all_of(cl.cbegin(),
|
||||
@@ -360,6 +342,32 @@ std::vector<float> ov::snippets::pass::FakeQuantizeDecomposition::calculateScale
|
||||
return out_scales;
|
||||
}
|
||||
|
||||
bool ov::snippets::pass::CommonFakeQuantizeDecomposition::is_supported_fq(const std::shared_ptr<const ov::op::v0::FakeQuantize>& fq) {
|
||||
// TODO [92179]: Add support of FakeQuantize with non-constants inputs and with binarization algorithm.
|
||||
auto is_valid_range_values = [](const std::shared_ptr<const ov::op::v0::FakeQuantize>& fq) {
|
||||
const auto il = fq->input_value(1);
|
||||
const auto ih = fq->input_value(2);
|
||||
const auto greater_equal = std::make_shared<ov::op::v1::Greater>(il, ih);
|
||||
|
||||
ov::OutputVector result(1);
|
||||
if (!greater_equal->constant_fold(result, greater_equal->input_values()))
|
||||
return false;
|
||||
|
||||
const auto res_node = std::dynamic_pointer_cast<const ov::op::v0::Constant>(result[0].get_node_shared_ptr());
|
||||
const auto comp_result = res_node->cast_vector<bool>();
|
||||
return !std::any_of(comp_result.begin(), comp_result.end(), [](const bool value) {
|
||||
return value;
|
||||
});
|
||||
};
|
||||
return fq && fq->get_levels() != 2 &&
|
||||
ov::is_type<ov::op::v0::Constant>(fq->get_input_node_shared_ptr(1)) &&
|
||||
ov::is_type<ov::op::v0::Constant>(fq->get_input_node_shared_ptr(2)) &&
|
||||
ov::is_type<ov::op::v0::Constant>(fq->get_input_node_shared_ptr(3)) &&
|
||||
ov::is_type<ov::op::v0::Constant>(fq->get_input_node_shared_ptr(4)) &&
|
||||
utils::one_of(fq->get_auto_broadcast(), ov::op::AutoBroadcastType::NUMPY, ov::op::AutoBroadcastType::NONE) &&
|
||||
is_valid_range_values(fq);
|
||||
}
|
||||
|
||||
bool ov::snippets::pass::CommonFakeQuantizeDecomposition::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(CommonFakeQuantizeDecomposition);
|
||||
ov::pass::Manager manager;
|
||||
|
||||
111
src/common/snippets/src/pass/validate.cpp
Normal file
111
src/common/snippets/src/pass/validate.cpp
Normal file
@@ -0,0 +1,111 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/pass/validate.hpp"
|
||||
|
||||
#include "snippets/op/convert_saturation.hpp"
|
||||
#include "snippets/op/convert_truncation.hpp"
|
||||
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
|
||||
#include "snippets/pass/fq_decomposition.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
#include "openvino/op/fake_quantize.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/softmax.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
namespace pass {
|
||||
|
||||
namespace {
|
||||
#define VALIDATE(op, op_type, validator) \
|
||||
if (ov::is_type<op_type>(op)) \
|
||||
OPENVINO_ASSERT(validator(op), "Snippets validation of OV body has been failed: " + \
|
||||
std::string(op->get_type_name()) + " op " + op->get_friendly_name() + " is not supported"); \
|
||||
else
|
||||
|
||||
} // namespace
|
||||
|
||||
bool Validate::is_supported_constant(const std::shared_ptr<const ov::Node>& op) {
|
||||
const auto constant = ov::as_type_ptr<const ov::op::v0::Constant>(op);
|
||||
const auto consumers = op->get_output_target_inputs(0);
|
||||
return constant &&
|
||||
(ov::shape_size(constant->get_output_shape(0)) == 1 ||
|
||||
std::all_of(consumers.cbegin(), consumers.cend(),
|
||||
[](const ov::Input<ov::Node>& in) {
|
||||
return ov::is_type<const ov::op::v1::Transpose>(in.get_node()) ||
|
||||
ov::is_type<const ov::op::v1::Broadcast>(in.get_node()) ||
|
||||
ov::is_type<const ov::op::v3::Broadcast>(in.get_node());
|
||||
}));
|
||||
}
|
||||
|
||||
bool Validate::is_supported_convert(const std::shared_ptr<const ov::Node>& op) {
|
||||
return ov::is_type<const op::ConvertTruncation>(op) || ov::is_type<const op::ConvertSaturation>(op);
|
||||
}
|
||||
|
||||
bool Validate::is_supported_matmul(const std::shared_ptr<const ov::Node>& op) {
|
||||
// If ExplicitTransposeMatMulInputs pass is enabled, MatMul should have not transposed inputs
|
||||
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(op);
|
||||
return matmul && utils::implication(m_pass_config->is_enabled<ov::snippets::pass::ExplicitTransposeMatMulInputs>(),
|
||||
!matmul->get_transpose_a() && !matmul->get_transpose_b());
|
||||
}
|
||||
|
||||
bool Validate::is_supported_softmax(const std::shared_ptr<const ov::Node>& op) {
|
||||
// Softmax is supported only with axis by last dim
|
||||
const auto softmax_rank = op->get_input_partial_shape(0).rank();
|
||||
int64_t axis = 0;
|
||||
if (const auto softmax_v8 = ov::as_type_ptr<const ov::op::v8::Softmax>(op)) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
axis = ov::normalize_axis(softmax_v8->get_friendly_name(), softmax_v8->get_axis(), softmax_rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
} else if (const auto softmax_v1 = ov::as_type_ptr<const ov::op::v1::Softmax>(op)) {
|
||||
axis = softmax_v1->get_axis();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return axis == softmax_rank.get_length() - 1;
|
||||
}
|
||||
|
||||
bool Validate::is_supported_fq(const std::shared_ptr<const ov::Node>& node) {
|
||||
// FQ is decomposed into ops in CommonFakeQuantizeDecomposition pass
|
||||
return m_pass_config->is_disabled<ov::snippets::pass::CommonFakeQuantizeDecomposition>();
|
||||
}
|
||||
|
||||
bool Validate::is_supported_transpose(const std::shared_ptr<const ov::Node>& node) {
|
||||
// Transpose is supported only on Inputs or Outputs of body
|
||||
const auto consumers = node->get_output_target_inputs(0);
|
||||
return (ov::is_type<ov::op::v0::Parameter>(node->get_input_node_shared_ptr(0))) ||
|
||||
(consumers.size() == 1 && ov::is_type<ov::op::v0::Result>(consumers.cbegin()->get_node()));
|
||||
}
|
||||
|
||||
bool Validate::is_supported_op(const std::shared_ptr<const ov::Node>& node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Validate::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||
RUN_ON_MODEL_SCOPE(Validate);
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::Validate")
|
||||
|
||||
for (const auto& op : m->get_ordered_ops()) {
|
||||
VALIDATE(op, ov::op::v0::Constant, is_supported_constant)
|
||||
VALIDATE(op, ov::op::v0::Convert, is_supported_convert)
|
||||
VALIDATE(op, ov::op::v0::MatMul, is_supported_matmul)
|
||||
VALIDATE(op, ov::op::v1::Softmax, is_supported_softmax)
|
||||
VALIDATE(op, ov::op::v8::Softmax, is_supported_softmax)
|
||||
VALIDATE(op, ov::op::v0::FakeQuantize, is_supported_fq)
|
||||
VALIDATE(op, ov::op::v1::Transpose, is_supported_transpose)
|
||||
VALIDATE(op, ov::op::v1::Reshape, is_supported_op);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace pass
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
||||
Reference in New Issue
Block a user