Snippets pass manager (#18846)

This commit is contained in:
Ivan Novoselov 2023-09-01 11:31:42 +03:00 committed by GitHub
parent 38cad619af
commit b0d917f0cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 319 additions and 174 deletions

View File

@ -55,7 +55,7 @@ public:
register_pass(pass); register_pass(pass);
} }
void run(lowered::LinearIR& linear_ir); void run(lowered::LinearIR& linear_ir) const;
private: private:
std::vector<std::shared_ptr<Pass>> m_passes; std::vector<std::shared_ptr<Pass>> m_passes;

View File

@ -10,7 +10,7 @@
#include <openvino/op/util/sub_graph_base.hpp> #include <openvino/op/util/sub_graph_base.hpp>
#include "openvino/op/op.hpp" #include "openvino/op/op.hpp"
#include "openvino/core/rt_info.hpp" #include "openvino/core/rt_info.hpp"
#include <ngraph/pass/manager.hpp> #include "snippets/pass_manager.hpp"
#include "snippets/generator.hpp" #include "snippets/generator.hpp"
@ -26,8 +26,6 @@ namespace op {
class Subgraph : public ov::op::util::SubGraphOp { class Subgraph : public ov::op::util::SubGraphOp {
public: public:
OPENVINO_OP("Subgraph", "SnippetsOpset", ov::op::util::SubGraphOp); OPENVINO_OP("Subgraph", "SnippetsOpset", ov::op::util::SubGraphOp);
enum {DYNAMIC_DIMENSION = 0xffffffffffffffff};
// < 1, 42, 17, 15, 16> < 0, 1, 2, 3, 1> // < 1, 42, 17, 15, 16> < 0, 1, 2, 3, 1>
// should be: // should be:
// A = < 1, 42, 17, 15> -> < 1, 3, 17, 15, 16> < 0, 1, 2, 3, 1> // A = < 1, 42, 17, 15> -> < 1, 3, 17, 15, 16> < 0, 1, 2, 3, 1>
@ -74,9 +72,9 @@ public:
Subgraph() = default; Subgraph() = default;
Subgraph(const OutputVector& args, std::shared_ptr<ov::Model> body); Subgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body);
Subgraph(const NodeVector& args, std::shared_ptr<ov::Model> body); Subgraph(const NodeVector& args, const std::shared_ptr<ov::Model>& body);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
@ -101,18 +99,14 @@ public:
bool has_domain_sensitive_ops() const { return config.m_has_domain_sensitive_ops; } bool has_domain_sensitive_ops() const { return config.m_has_domain_sensitive_ops; }
snippets::Schedule generate(const BlockedShapeVector& output_shapes, snippets::Schedule generate(const BlockedShapeVector& output_shapes,
const BlockedShapeVector& input_shapes, const BlockedShapeVector& input_shapes,
ov::pass::Manager& pre_common, const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
ov::pass::Manager& post_common, const lowered::pass::PassPipeline& control_flow_passes_pre_common,
ov::pass::Manager& post_precision, const lowered::pass::PassPipeline& control_flow_passes_post_common,
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
lowered::pass::PassPipeline& target_lowered_pipeline,
const void* compile_params = nullptr); const void* compile_params = nullptr);
snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr); snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr);
snippets::Schedule generate(ov::pass::Manager& pre_common, snippets::Schedule generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
ov::pass::Manager& post_common, const lowered::pass::PassPipeline& control_flow_passes_pre_common,
ov::pass::Manager& post_precision, const lowered::pass::PassPipeline& control_flow_passes_post_common,
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
lowered::pass::PassPipeline& target_lowered_pipeline,
const void* compile_params = nullptr); const void* compile_params = nullptr);
snippets::Schedule generate(const void* compile_params = nullptr); snippets::Schedule generate(const void* compile_params = nullptr);
ov::PartialShape canonicalize(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes); ov::PartialShape canonicalize(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes);
@ -146,10 +140,10 @@ public:
private: private:
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes); void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
void data_flow_transformations(ov::pass::Manager& pre_common, ov::pass::Manager& post_common, ov::pass::Manager& post_precision); void data_flow_transformations(const std::vector<snippets::pass::Manager::PositionedPass>& backend_passes);
void control_flow_transformations(lowered::LinearIR& linear_ir, void control_flow_transformations(lowered::LinearIR& linear_ir,
lowered::pass::PassPipeline& target_markup_pipeline, const lowered::pass::PassPipeline& backend_passes_pre_common,
lowered::pass::PassPipeline& target_pipeline); const lowered::pass::PassPipeline& backend_passes_post_common);
void init_config(); void init_config();
// Count of Subgraph virtual ports: // Count of Subgraph virtual ports:
// - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition) // - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition)

View File

@ -0,0 +1,82 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pass.hpp"
#include "openvino/pass/validate.hpp"
#include <typeinfo>
namespace ov {
namespace snippets {
namespace pass {
/**
* @brief Manager is like ov::pass::Manager, but allows to insert new passes at arbitrary places in the pipeline
* @ingroup snippets
*/
class Manager : public ov::pass::Manager {
public:
~Manager() override = default;
using PassBase = ov::pass::PassBase;
using Validate = ov::pass::Validate;
/**
* @brief PassPosition describes a particular position in a transformation pipeline,
* where a new transformation should be inserted.
* @param pass_name name of the anchor pass, the new pass will be inserted before/after it.
* Empty pass_name could mean either beginning or the end of the pipeline depending on the `after` flag.
* No default value. Note that pass names namespaces are not supported, ov::PassName and snippets::PassName
* are considered identical.
* @param after `true` if the new pass should be inserted before the anchor pass, `false` otherwise (default).
* If `pass_name` is empty, `true` means the end, and `false` - the beginning of the pipeline.
* @param pass_instance the number of the pass with matching `pass_name` to be considered as the anchor pass.
* 0 (default) means the first pass with `pass_name` will be considered as the anchor pass.
* @ingroup snippets
*/
class PassPosition {
public:
enum class Place {Before, After, PipelineStart, PipelineEnd};
using PassListType = std::vector<std::shared_ptr<ov::pass::PassBase>>;
explicit PassPosition(Place pass_place);
explicit PassPosition(Place pass_place, std::string pass_name, size_t pass_instance = 0);
PassListType::const_iterator get_insert_position(const PassListType& pass_list) const;
private:
const std::string m_pass_name;
const size_t m_pass_instance{0};
const Place m_place{Place::Before};
};
struct PositionedPass {
PassPosition position;
std::shared_ptr<PassBase> pass;
PositionedPass(PassPosition arg_pos, std::shared_ptr<PassBase> arg_pass)
: position(std::move(arg_pos)), pass(std::move(arg_pass)) {
}
};
template <typename T, class... Args>
std::shared_ptr<T> register_pass(Args&&... args) {
return ov::pass::Manager::register_pass<T>(args...);
}
template <typename T, class Pos, class... Args, std::enable_if<std::is_same<PassPosition, Pos>::value, bool>() = true>
std::shared_ptr<T> register_pass(const PassPosition& position, Args&&... args) {
static_assert(std::is_base_of<PassBase, T>::value, "Attempt to insert pass that is not derived from PassBase");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto rc = insert_pass_instance(position, pass);
rc->set_pass_config(m_pass_config);
if (!m_pass_config->is_enabled<T>()) {
m_pass_config->disable<T>();
}
return rc;
}
std::shared_ptr<PassBase> register_pass_instance(const PassPosition& pass_id, const std::shared_ptr<PassBase>& pass);
void register_positioned_passes(const std::vector<PositionedPass>& pos_passes);
protected:
std::shared_ptr<Manager::PassBase> insert_pass_instance(const PassPosition& position, const std::shared_ptr<PassBase>& pass);
};
} // namespace pass
} // namespace snippets
} // namespace ov

View File

@ -14,7 +14,7 @@ void PassPipeline::register_pass(const std::shared_ptr<Pass>& pass) {
m_passes.push_back(pass); m_passes.push_back(pass);
} }
void PassPipeline::run(LinearIR& linear_ir) { void PassPipeline::run(LinearIR& linear_ir) const {
for (const auto& pass : m_passes) { for (const auto& pass : m_passes) {
pass->run(linear_ir); pass->run(linear_ir);
} }

View File

@ -43,7 +43,7 @@
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
#include <ngraph/pass/manager.hpp> #include "snippets/pass_manager.hpp"
#include "ngraph/pass/constant_folding.hpp" #include "ngraph/pass/constant_folding.hpp"
#include "ov_ops/type_relaxed.hpp" #include "ov_ops/type_relaxed.hpp"
#include <openvino/pass/serialize.hpp> #include <openvino/pass/serialize.hpp>
@ -57,16 +57,17 @@ using namespace ov::op::util;
namespace ov { namespace ov {
namespace snippets { namespace snippets {
namespace op {
void snippets::op::Subgraph::set_generator(std::shared_ptr<ov::snippets::Generator> generator) { void Subgraph::set_generator(std::shared_ptr<ov::snippets::Generator> generator) {
m_generator = generator; m_generator = generator;
} }
void snippets::op::Subgraph::set_virtual_port_count(const size_t count) { void Subgraph::set_virtual_port_count(const size_t count) {
m_virtual_port_count = count; m_virtual_port_count = count;
} }
auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool { auto Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool {
return ov::is_type<ov::op::v1::Transpose>(op) || return ov::is_type<ov::op::v1::Transpose>(op) ||
ov::is_type<ov::op::v1::Softmax>(op) || ov::is_type<ov::op::v1::Softmax>(op) ||
ov::is_type<ov::op::v8::Softmax>(op) || ov::is_type<ov::op::v8::Softmax>(op) ||
@ -75,7 +76,7 @@ auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::No
ov::is_type<ov::op::v3::Broadcast>(op); // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern ov::is_type<ov::op::v3::Broadcast>(op); // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern
} }
void snippets::op::Subgraph::init_config() { void Subgraph::init_config() {
auto update = [](bool& flag, bool status) { flag = flag || status; }; auto update = [](bool& flag, bool status) { flag = flag || status; };
const auto ops = body_ptr()->get_ops(); const auto ops = body_ptr()->get_ops();
for (const auto& op : ops) { for (const auto& op : ops) {
@ -84,7 +85,7 @@ void snippets::op::Subgraph::init_config() {
} }
} }
auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t { auto Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t {
// The count of potential unique Buffers - it's hidden virtual ports as well // The count of potential unique Buffers - it's hidden virtual ports as well
// We should go through Subgraph and calculate potential non-inplace Buffers count. // We should go through Subgraph and calculate potential non-inplace Buffers count.
// These Buffers can be in 2 cases: // These Buffers can be in 2 cases:
@ -129,7 +130,9 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
const auto consumers = matmul->get_output_target_inputs(0); const auto consumers = matmul->get_output_target_inputs(0);
if (std::none_of(consumers.begin(), consumers.end(), if (std::none_of(consumers.begin(), consumers.end(),
[](const ov::Input<ov::Node>& in) { return ov::is_type<ov::op::v0::Result>(in.get_node()); })) { [](const ov::Input<ov::Node>& in) {
return ov::is_type<ov::op::v0::Result>(in.get_node());
})) {
used_precision_size.push_back(matmul->get_element_type().size()); used_precision_size.push_back(matmul->get_element_type().size());
} }
} }
@ -138,9 +141,9 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
return used_precision_size.size(); return used_precision_size.size();
} }
snippets::op::Subgraph::Subgraph(const OutputVector& args, std::shared_ptr<ov::Model> body) Subgraph::Subgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body)
: SubGraphOp(args), m_generator(nullptr) { : SubGraphOp(args), m_generator(nullptr) {
set_function(body); SubGraphOp::set_function(body);
init_config(); init_config();
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
for (size_t i = 0; i < body->get_parameters().size(); ++i) for (size_t i = 0; i < body->get_parameters().size(); ++i)
@ -150,15 +153,15 @@ snippets::op::Subgraph::Subgraph(const OutputVector& args, std::shared_ptr<ov::M
m_transformations_allowed = false; m_transformations_allowed = false;
} }
snippets::op::Subgraph::Subgraph(const NodeVector& args, std::shared_ptr<ov::Model> body) Subgraph::Subgraph(const NodeVector& args, const std::shared_ptr<ov::Model>& body)
: Subgraph(as_output_vector(args), std::move(body)) {} : Subgraph(as_output_vector(args), body) {}
std::shared_ptr<Node> snippets::op::Subgraph::clone_with_new_inputs(const OutputVector& inputs) const { std::shared_ptr<Node> Subgraph::clone_with_new_inputs(const OutputVector& inputs) const {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
return make_shared<Subgraph>(inputs, body().clone()); return make_shared<Subgraph>(inputs, body().clone());
} }
std::vector<PartialShape> snippets::op::Subgraph::reshape_body(const std::vector<PartialShape>& input_shapes) { std::vector<PartialShape> Subgraph::reshape_body(const std::vector<PartialShape>& input_shapes) {
auto& params = body_ptr()->get_parameters(); auto& params = body_ptr()->get_parameters();
OPENVINO_ASSERT(params.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body"); OPENVINO_ASSERT(params.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
for (size_t i = 0; i < params.size(); ++i) { for (size_t i = 0; i < params.size(); ++i) {
@ -172,7 +175,7 @@ std::vector<PartialShape> snippets::op::Subgraph::reshape_body(const std::vector
return output_shapes; return output_shapes;
} }
std::vector<Shape> snippets::op::Subgraph::reshape_body(const std::vector<Shape>& input_shapes) { std::vector<Shape> Subgraph::reshape_body(const std::vector<Shape>& input_shapes) {
auto& params = body_ptr()->get_parameters(); auto& params = body_ptr()->get_parameters();
OPENVINO_ASSERT(params.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body"); OPENVINO_ASSERT(params.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
for (size_t i = 0; i < params.size(); ++i) { for (size_t i = 0; i < params.size(); ++i) {
@ -188,7 +191,7 @@ std::vector<Shape> snippets::op::Subgraph::reshape_body(const std::vector<Shape>
return output_shapes; return output_shapes;
} }
void snippets::op::Subgraph::validate_and_infer_types() { void Subgraph::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::validate_and_infer_types") OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::validate_and_infer_types")
ov::ParameterVector old_parameters; ov::ParameterVector old_parameters;
@ -212,14 +215,14 @@ void snippets::op::Subgraph::validate_and_infer_types() {
} }
} }
bool snippets::op::Subgraph::visit_attributes(AttributeVisitor& visitor) { bool Subgraph::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("body", body_ptr()); visitor.on_attribute("body", body_ptr());
visitor.on_attribute("input_descriptions", m_input_descriptions[0]); visitor.on_attribute("input_descriptions", m_input_descriptions[0]);
visitor.on_attribute("output_descriptions", m_output_descriptions[0]); visitor.on_attribute("output_descriptions", m_output_descriptions[0]);
return true; return true;
} }
auto snippets::op::Subgraph::wrap_node_as_subgraph(const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<op::Subgraph> { auto Subgraph::wrap_node_as_subgraph(const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<op::Subgraph> {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::wrap_node_as_subgraph") OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::wrap_node_as_subgraph")
ov::ParameterVector body_parameters; ov::ParameterVector body_parameters;
@ -278,7 +281,7 @@ auto snippets::op::Subgraph::wrap_node_as_subgraph(const std::shared_ptr<ov::Nod
return subgraph; return subgraph;
} }
void snippets::op::Subgraph::fill_empty_output_names(const Output<Node>& target_output_node, const Output<Node>& replacement_output_node) { void Subgraph::fill_empty_output_names(const Output<Node>& target_output_node, const Output<Node>& replacement_output_node) {
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
auto& out_tensor = target_output_node.get_tensor(); auto& out_tensor = target_output_node.get_tensor();
const std::string new_name = ov::op::util::get_ie_output_name(replacement_output_node); const std::string new_name = ov::op::util::get_ie_output_name(replacement_output_node);
@ -291,7 +294,7 @@ void snippets::op::Subgraph::fill_empty_output_names(const Output<Node>& target_
NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_SUPPRESS_DEPRECATED_END
} }
auto snippets::op::Subgraph::constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool { auto Subgraph::constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool {
return ov::is_type<ov::op::v1::Transpose>(node) || return ov::is_type<ov::op::v1::Transpose>(node) ||
ov::is_type<ov::op::v1::Broadcast>(node) || ov::is_type<ov::op::v1::Broadcast>(node) ||
ov::is_type<ov::op::v3::Broadcast>(node) || ov::is_type<ov::op::v3::Broadcast>(node) ||
@ -311,16 +314,18 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector&
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::canonicalize") OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::canonicalize")
NODE_VALIDATION_CHECK(this, inputShapes.size() == body_ptr()->get_parameters().size(), NODE_VALIDATION_CHECK(this, inputShapes.size() == body_ptr()->get_parameters().size(),
"Number of parameters for snippet doesn't match passed to generate method: ", inputShapes.size(), " vs ", body_ptr()->get_parameters().size(), "."); "Number of parameters for snippet doesn't match passed to generate method: ",
inputShapes.size(), " vs ", body_ptr()->get_parameters().size(), ".");
NODE_VALIDATION_CHECK(this, outputShapes.size() == body_ptr()->get_results().size(), NODE_VALIDATION_CHECK(this, outputShapes.size() == body_ptr()->get_results().size(),
"number of results for snippet doesn't match passed to generate method: ", outputShapes.size(), " vs ", body_ptr()->get_results().size(), "."); "number of results for snippet doesn't match passed to generate method: ",
outputShapes.size(), " vs ", body_ptr()->get_results().size(), ".");
auto getMaxRankBlockedShape = [](const BlockedShapeVector& blockedShapes) -> const BlockedShape& { auto getMaxRankBlockedShape = [](const BlockedShapeVector& blockedShapes) -> const BlockedShape& {
return *std::max_element(blockedShapes.begin(), blockedShapes.end(), return *std::max_element(blockedShapes.begin(), blockedShapes.end(),
[&](const BlockedShape& lhs, const BlockedShape& rhs) { [&](const BlockedShape& lhs, const BlockedShape& rhs) {
return std::get<0>(lhs).size() < std::get<0>(rhs).size(); return std::get<0>(lhs).size() < std::get<0>(rhs).size();
}); });
}; };
PartialShape baseShape; PartialShape baseShape;
AxisVector baseOrder; AxisVector baseOrder;
@ -362,9 +367,9 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector&
PartialShape::broadcast_merge_into(tmpPShape, inShape, ::ov::op::AutoBroadcastType::NUMPY), PartialShape::broadcast_merge_into(tmpPShape, inShape, ::ov::op::AutoBroadcastType::NUMPY),
"Failed to create broadcastable shapes in snippets canonicalization"); "Failed to create broadcastable shapes in snippets canonicalization");
const auto paramShape = body_ptr()->get_parameters()[i]->get_partial_shape(); const auto paramShape = body_ptr()->get_parameters()[i]->get_partial_shape();
const auto paramType = body_ptr()->get_parameters()[i]->get_element_type(); const auto paramType = body_ptr()->get_parameters()[i]->get_element_type();
if (paramShape.size() != inShape.size() || !equal(paramShape.begin(), paramShape.end(), inShape.begin())) if (paramShape.size() != inShape.size() || !equal(paramShape.begin(), paramShape.end(), inShape.begin()))
body_ptr()->replace_parameter(i, std::make_shared<ov::op::v0::Parameter>(paramType, inShape)); body_ptr()->replace_parameter(i, std::make_shared<ov::op::v0::Parameter>(paramType, inShape));
} }
body_ptr()->validate_nodes_and_infer_types(); body_ptr()->validate_nodes_and_infer_types();
@ -373,10 +378,10 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector&
auto end = shape.end(); auto end = shape.end();
while (begin != end && *begin == 1) while (begin != end && *begin == 1)
begin++; begin++;
while (begin != end && *(end-1) == 1) while (begin != end && *(end - 1) == 1)
end--; end--;
PartialShape trimmedShape(std::vector<ov::Dimension> (end - begin, 1)); PartialShape trimmedShape(std::vector<ov::Dimension>(end - begin, 1));
std::copy(begin, end, trimmedShape.begin()); std::copy(begin, end, trimmedShape.begin());
return trimmedShape; return trimmedShape;
}; };
@ -456,7 +461,7 @@ ov::PartialShape snippets::op::Subgraph::canonicalized_body_shape_infer(const Bl
return master_shape; return master_shape;
} }
bool snippets::op::Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept { bool Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept {
const auto elementwise = std::dynamic_pointer_cast<const ov::op::util::BinaryElementwiseArithmetic>(node); const auto elementwise = std::dynamic_pointer_cast<const ov::op::util::BinaryElementwiseArithmetic>(node);
return return
(elementwise == nullptr) || (elementwise == nullptr) ||
@ -464,8 +469,8 @@ bool snippets::op::Subgraph::check_broadcast(const std::shared_ptr<const ov::Nod
(elementwise->get_autob().m_type != ov::op::AutoBroadcastType::PDPD); (elementwise->get_autob().m_type != ov::op::AutoBroadcastType::PDPD);
} }
void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outputShapes, void Subgraph::align_element_types(const BlockedShapeVector& outputShapes,
const BlockedShapeVector& inputShapes) { const BlockedShapeVector& inputShapes) {
// We should insert Convert before Results to set original output element type if needed // We should insert Convert before Results to set original output element type if needed
const auto& body_results = body_ptr()->get_results(); const auto& body_results = body_ptr()->get_results();
for (size_t i = 0; i < outputShapes.size(); i++) { for (size_t i = 0; i < outputShapes.size(); i++) {
@ -534,54 +539,45 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu
} }
} }
void snippets::op::Subgraph::data_flow_transformations(ov::pass::Manager& pre_common, void Subgraph::data_flow_transformations(const std::vector<snippets::pass::Manager::PositionedPass>& backend_passes) {
ov::pass::Manager& post_common,
ov::pass::Manager& post_precision) {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::data_flow_transformations") OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::data_flow_transformations")
const auto& params = body_ptr()->get_parameters(); const auto& params = body_ptr()->get_parameters();
bool inputs_has_dynamic_last_dims = std::any_of(params.begin(), params.end(), bool inputs_has_dynamic_last_dims = std::any_of(params.begin(), params.end(),
[](const shared_ptr<ov::op::v0::Parameter>& p) { [](const shared_ptr<ov::op::v0::Parameter>& p) {
return p->get_partial_shape().rbegin()->is_dynamic(); return p->get_partial_shape().rbegin()->is_dynamic();
}); });
snippets::pass::Manager manager;
pre_common.run_passes(body_ptr());
ov::pass::Manager common_manager;
if (config.m_has_domain_sensitive_ops) { if (config.m_has_domain_sensitive_ops) {
common_manager.register_pass<snippets::pass::MatMulToBrgemm>(); manager.register_pass<snippets::pass::MatMulToBrgemm>();
common_manager.register_pass<snippets::pass::FuseTransposeBrgemm>(); manager.register_pass<snippets::pass::FuseTransposeBrgemm>();
common_manager.register_pass<snippets::pass::TransposeDecomposition>(); manager.register_pass<snippets::pass::TransposeDecomposition>();
common_manager.register_pass<snippets::pass::SetSoftmaxPorts>(); manager.register_pass<snippets::pass::SetSoftmaxPorts>();
} }
common_manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>(); manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>();
common_manager.register_pass<snippets::pass::ConvertConstantsToScalars>(); manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
common_manager.register_pass<snippets::pass::ConvertPowerToPowerStatic>(); manager.register_pass<snippets::pass::ConvertPowerToPowerStatic>();
// todo: presently dynamic pipeline is activated even if the last two dimension are static // todo: presently dynamic pipeline is activated even if the last two dimension are static
// In general, we can use static kernels in this case, but several parameters (src and dst memory pointers for example) // In general, we can use static kernels in this case, but several parameters (src and dst memory pointers for example)
// should be passed as run-time args, so it's a mixed mode: kernel is shape-aware, but some additional runtime args are required // should be passed as run-time args, so it's a mixed mode: kernel is shape-aware, but some additional runtime args are required
// Presently Broadcasting is organized in the following way: // Presently Broadcasting is organized in the following way:
// * ALL last dims are static => broadcasting is handled via MoveBroadcast and pointer arithmetics (even for dynamic upper dims) // * ALL last dims are static => broadcasting is handled via MoveBroadcast and pointer arithmetics (even for dynamic upper dims)
if (!inputs_has_dynamic_last_dims) { if (!inputs_has_dynamic_last_dims) {
common_manager.register_pass<snippets::pass::InsertMoveBroadcast>(); manager.register_pass<snippets::pass::InsertMoveBroadcast>();
} }
common_manager.run_passes(body_ptr());
post_common.run_passes(body_ptr()); manager.register_pass<snippets::pass::PropagatePrecision>(m_generator->get_target_machine());
manager.register_pass<ov::pass::ConstantFolding>();
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
ov::pass::Manager precision_manager; manager.register_positioned_passes(backend_passes);
precision_manager.register_pass<snippets::pass::PropagatePrecision>(m_generator->get_target_machine()); manager.run_passes(body_ptr());
precision_manager.register_pass<ov::pass::ConstantFolding>();
precision_manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
precision_manager.run_passes(body_ptr());
post_precision.run_passes(body_ptr());
} }
void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir, void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
lowered::pass::PassPipeline& target_markup_pipeline, const lowered::pass::PassPipeline& backend_passes_pre_common,
lowered::pass::PassPipeline& target_pipeline) { const lowered::pass::PassPipeline& backend_passes_post_common) {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::control_flow_transformations") OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::control_flow_transformations")
@ -590,7 +586,7 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin
// Ticket: 113666 // Ticket: 113666
// TODO: Make pass pipeline with backend passes more flexible // TODO: Make pass pipeline with backend passes more flexible
target_markup_pipeline.run(linear_ir); backend_passes_pre_common.run(linear_ir);
lowered::pass::PassPipeline common_pipeline; lowered::pass::PassPipeline common_pipeline;
common_pipeline.register_pass<lowered::pass::MarkLoops>(vector_size); common_pipeline.register_pass<lowered::pass::MarkLoops>(vector_size);
@ -607,7 +603,7 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin
common_pipeline.register_pass<lowered::pass::InsertLoops>(); common_pipeline.register_pass<lowered::pass::InsertLoops>();
common_pipeline.run(linear_ir); common_pipeline.run(linear_ir);
target_pipeline.run(linear_ir); backend_passes_post_common.run(linear_ir);
const auto buffer_allocation_pass = std::make_shared<lowered::pass::AllocateBuffers>(); const auto buffer_allocation_pass = std::make_shared<lowered::pass::AllocateBuffers>();
lowered::pass::PassPipeline buffer_pipeline; lowered::pass::PassPipeline buffer_pipeline;
@ -624,43 +620,36 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin
m_buffer_scratchpad = buffer_allocation_pass->get_scratchpad_size(); m_buffer_scratchpad = buffer_allocation_pass->get_scratchpad_size();
} }
snippets::Schedule snippets::op::Subgraph::generate(const BlockedShapeVector& output_shapes, snippets::Schedule Subgraph::generate(const BlockedShapeVector& output_shapes,
const BlockedShapeVector& input_shapes, const BlockedShapeVector& input_shapes,
const void* compile_params) { const void* compile_params) {
canonicalize(output_shapes, input_shapes); canonicalize(output_shapes, input_shapes);
return generate(compile_params); return generate(compile_params);
} }
snippets::Schedule snippets::op::Subgraph::generate(const BlockedShapeVector& output_shapes, snippets::Schedule Subgraph::generate(const BlockedShapeVector& output_shapes,
const BlockedShapeVector& input_shapes, const BlockedShapeVector& input_shapes,
ov::pass::Manager& pre_common, const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
ov::pass::Manager& post_common, const lowered::pass::PassPipeline& control_flow_passes_pre_common,
ov::pass::Manager& post_precision, const lowered::pass::PassPipeline& control_flow_passes_post_common,
lowered::pass::PassPipeline& target_lowered_markup_pipeline, const void* compile_params) {
lowered::pass::PassPipeline& target_lowered_pipeline,
const void* compile_params) {
canonicalize(output_shapes, input_shapes); canonicalize(output_shapes, input_shapes);
return generate(pre_common, post_common, post_precision, target_lowered_markup_pipeline, target_lowered_pipeline, compile_params); return generate(data_flow_passes, control_flow_passes_pre_common, control_flow_passes_post_common, compile_params);
} }
snippets::Schedule snippets::op::Subgraph::generate(const void* compile_params) { snippets::Schedule Subgraph::generate(const void* compile_params) {
auto mngr = ov::pass::Manager(); return generate({}, {}, {}, compile_params);
auto lowered = lowered::pass::PassPipeline();
return generate(mngr, mngr, mngr, lowered, lowered, compile_params);
} }
snippets::Schedule snippets::op::Subgraph::generate( snippets::Schedule Subgraph::generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
ov::pass::Manager& pre_common, const lowered::pass::PassPipeline& control_flow_passes_pre_common,
ov::pass::Manager& post_common, const lowered::pass::PassPipeline& control_flow_passes_post_common,
ov::pass::Manager& post_precision, const void* compile_params) {
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
lowered::pass::PassPipeline& target_lowered_pipeline,
const void* compile_params) {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::generate") OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::generate")
NGRAPH_CHECK(m_generator != nullptr, "generate is called while generator is not set"); NGRAPH_CHECK(m_generator != nullptr, "generate is called while generator is not set");
data_flow_transformations(pre_common, post_common, post_precision); data_flow_transformations(data_flow_passes);
lowered::Config lowering_config; lowered::Config lowering_config;
lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops; lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops;
@ -668,7 +657,7 @@ snippets::Schedule snippets::op::Subgraph::generate(
lowering_config.m_loop_depth = tileRank; lowering_config.m_loop_depth = tileRank;
lowered::LinearIR linear_ir = lowered::LinearIR(body_ptr(), lowering_config); lowered::LinearIR linear_ir = lowered::LinearIR(body_ptr(), lowering_config);
control_flow_transformations(linear_ir, target_lowered_markup_pipeline, target_lowered_pipeline); control_flow_transformations(linear_ir, control_flow_passes_pre_common, control_flow_passes_post_common);
// actual code emission // actual code emission
const auto& lowering_result = m_generator->generate(linear_ir, lowering_config, compile_params); const auto& lowering_result = m_generator->generate(linear_ir, lowering_config, compile_params);
@ -677,31 +666,32 @@ snippets::Schedule snippets::op::Subgraph::generate(
return {master_shape, false /*canBeLinearized*/, ptr}; return {master_shape, false /*canBeLinearized*/, ptr};
} }
void snippets::op::Subgraph::print() const { void Subgraph::print() const {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
remark(13) << "subgraph " << this->get_friendly_name() << " " remark(13) << "subgraph " << this->get_friendly_name() << " "
<< this->get_type_name() << this->get_type_name()
<< " which contains " << body_ptr()->get_ops().size() << " nodes" << std::endl; << " which contains " << body_ptr()->get_ops().size() << " nodes" << std::endl;
int qqq = 0; int qqq = 0;
for (auto op : body_ptr()->get_ordered_ops()) { for (const auto& op : body_ptr()->get_ordered_ops()) {
remark(13) << "op " << qqq++ << " " << op->get_friendly_name() << " (" << op->get_type_name() << ") " << op << std::endl; remark(13) << "op " << qqq++ << " " << op->get_friendly_name() << " (" << op->get_type_name() << ") " << op
<< std::endl;
} }
for (auto& in : this->inputs()) { for (auto& in : this->inputs()) {
remark(13) << " -> " << in.get_source_output().get_node_shared_ptr()->get_friendly_name() << " " remark(13) << " -> " << in.get_source_output().get_node_shared_ptr()->get_friendly_name() << " "
<< in.get_source_output().get_node_shared_ptr() << std::endl; << in.get_source_output().get_node_shared_ptr() << std::endl;
} }
for (auto& out : this->outputs()) { for (auto& out : this->outputs()) {
for (auto& user : out.get_target_inputs()) { for (auto& user : out.get_target_inputs()) {
remark(13) << " <- " << user.get_node()->get_friendly_name() << " " << user.get_node() << std::endl; remark(13) << " <- " << user.get_node()->get_friendly_name() << " " << user.get_node() << std::endl;
} }
remark(13) << std::endl; remark(13) << std::endl;
} }
} }
void snippets::op::Subgraph::print_statistics(bool verbose) { void Subgraph::print_statistics(bool verbose) {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
auto getNodeInventory = [](std::shared_ptr<ov::Node> n) -> size_t { auto getNodeInventory = [](std::shared_ptr<ov::Node> n) -> size_t {
size_t total = 0; size_t total = 0;
@ -725,22 +715,22 @@ void snippets::op::Subgraph::print_statistics(bool verbose) {
return total; return total;
}; };
auto getModelInventory = [getNodeInventory](const ov::Model & f) -> size_t { auto getModelInventory = [getNodeInventory](const ov::Model& f) -> size_t {
size_t total = 0; size_t total = 0;
for (auto op : f.get_ordered_ops()) { for (auto op : f.get_ordered_ops()) {
// Results and parameters are artificially introduced, // Results and parameters are artificially introduced,
// while Constants are already considered if they are inputs of other operation // while Constants are already considered if they are inputs of other operation
// this should lead to 1:1 inventory for single node operations // this should lead to 1:1 inventory for single node operations
if (!ov::as_type_ptr<ov::opset1::Parameter>(op) if (!ov::as_type_ptr<ov::opset1::Parameter>(op)
&& !ov::as_type_ptr<ov::opset1::Result>(op) && !ov::as_type_ptr<ov::opset1::Result>(op)
&& !ov::as_type_ptr<ov::opset1::Constant>(op)) { && !ov::as_type_ptr<ov::opset1::Constant>(op)) {
total += getNodeInventory(op); total += getNodeInventory(op);
} }
} }
return total; return total;
}; };
auto countConstants = [](const ov::Model & f) -> size_t { auto countConstants = [](const ov::Model& f) -> size_t {
size_t count = 0; size_t count = 0;
for (auto op : f.get_ordered_ops()) { for (auto op : f.get_ordered_ops()) {
count += !!ov::as_type_ptr<ov::opset1::Constant>(op) ? 1 : 0; count += !!ov::as_type_ptr<ov::opset1::Constant>(op) ? 1 : 0;
@ -762,7 +752,7 @@ void snippets::op::Subgraph::print_statistics(bool verbose) {
} }
} }
void snippets::op::Subgraph::serialize() const { void Subgraph::serialize() const {
std::stringstream xmlFile, binFile; std::stringstream xmlFile, binFile;
ov::pass::Serialize serializer(xmlFile, xmlFile, ov::pass::Serialize::Version::IR_V10); ov::pass::Serialize serializer(xmlFile, xmlFile, ov::pass::Serialize::Version::IR_V10);
serializer.run_on_model(body_ptr()); serializer.run_on_model(body_ptr());
@ -771,5 +761,6 @@ void snippets::op::Subgraph::serialize() const {
std::cout << m_model << std::endl; std::cout << m_model << std::endl;
} }
} // namespace op
} // namespace snippets } // namespace snippets
} // namespace ov } // namespace ov

View File

@ -0,0 +1,81 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/pass_manager.hpp"
namespace ov {
namespace snippets {
namespace pass {
Manager::PassPosition::PassPosition(Place pass_place) : m_place(pass_place) {
OPENVINO_ASSERT(m_place == Place::PipelineStart || m_place == Place::PipelineEnd,
"Invalid arg: pass_name and pass_instance args could be omitted only for Place::PipelineStart/Place::PipelineEnd");
}
Manager::PassPosition::PassPosition(Place pass_place, std::string pass_name, size_t pass_instance)
: m_pass_name(std::move(pass_name)), m_pass_instance(pass_instance), m_place(pass_place) {
OPENVINO_ASSERT((m_place == Place::Before || m_place == Place::After) && !m_pass_name.empty(),
"Invalid args combination: pass_place must be Place::Before/Place::After and pass_name must be non-empty");
}
Manager::PassPosition::PassListType::const_iterator
Manager::PassPosition::get_insert_position(const PassListType& pass_list) const {
size_t pass_count = 0;
auto match = [this, &pass_count](const std::shared_ptr<PassBase>& p) {
auto name = p->get_name();
// Note that MatcherPass and ModelPass currently have different naming policies:
// - MatcherPass have names without namespaces, e.g. ConvertToSwishCPU
// - Similar ModelPass name includes namespaces: ov::snippets::pass::ConvertToSwishCPU
// So we have to remove everything before the last ':', and ':' itself
if (name.size() > m_pass_name.size()) {
const auto pos = name.find_last_of(':');
if (pos == std::string::npos)
return false;
name = name.substr(pos + 1);
}
if (name == m_pass_name) {
if (m_pass_instance == pass_count)
return true;
pass_count++;
}
return false;
};
switch (m_place) {
case Place::PipelineStart: return pass_list.cbegin();
case Place::PipelineEnd: return pass_list.cend();
case Place::Before:
case Place::After: {
auto insert_it = std::find_if(pass_list.cbegin(), pass_list.cend(), match);
OPENVINO_ASSERT(insert_it != pass_list.cend(), "snippets::pass::Manager failed to find pass ", m_pass_name);
return m_place == Place::After ? std::next(insert_it) : insert_it;
}
default:
OPENVINO_THROW("Unsupported Place type in PassPosition::get_insert_position");
}
}
std::shared_ptr<Manager::PassBase> Manager::register_pass_instance(const PassPosition& position,
const std::shared_ptr<PassBase>& pass) {
pass->set_pass_config(m_pass_config);
return insert_pass_instance(position, pass);
}
void Manager::register_positioned_passes(const std::vector<PositionedPass>& pos_passes) {
for (const auto& pp : pos_passes)
register_pass_instance(pp.position, pp.pass);
}
std::shared_ptr<Manager::PassBase> Manager::insert_pass_instance(const PassPosition& position,
const std::shared_ptr<PassBase>& pass) {
auto insert_pos = position.get_insert_position(m_pass_list);
insert_pos = m_pass_list.insert(insert_pos, pass);
if (m_per_pass_validation) {
// Note: insert_pos points to the newly inserted pass, so advance to validate the pass results
std::advance(insert_pos, 1);
m_pass_list.insert(insert_pos, std::make_shared<ov::pass::Validate>());
}
return pass;
}
} // namespace pass
}// namespace snippets
}// namespace ov

View File

@ -50,13 +50,13 @@ public:
protected: protected:
static std::shared_ptr<ov::snippets::op::Subgraph> getSubgraph(const std::shared_ptr<Model>& f); static std::shared_ptr<ov::snippets::op::Subgraph> getSubgraph(const std::shared_ptr<Model>& f);
static std::shared_ptr<ov::snippets::op::Subgraph> getLoweredSubgraph(const std::shared_ptr<Model>& f, static std::shared_ptr<ov::snippets::op::Subgraph>
const ov::PartialShape& master_shape, getLoweredSubgraph(const std::shared_ptr<Model>& f,
ov::pass::Manager pre_dialect = {}, const ov::PartialShape& master_shape,
ov::pass::Manager post_dialect = {}, const std::vector<ov::snippets::pass::Manager::PositionedPass>& backend_passes = {},
ov::pass::Manager post_precision = {}, const ov::snippets::lowered::pass::PassPipeline& lowered_pre_common = {},
ov::snippets::lowered::pass::PassPipeline lowered_pipeline = {}, const ov::snippets::lowered::pass::PassPipeline& lowered_post_common = {},
const std::shared_ptr<ov::snippets::Generator> generator = nullptr); const std::shared_ptr<ov::snippets::Generator>& generator = nullptr);
static std::shared_ptr<ov::snippets::op::Subgraph> getTokenizedSubgraph(const std::shared_ptr<Model>& f); static std::shared_ptr<ov::snippets::op::Subgraph> getTokenizedSubgraph(const std::shared_ptr<Model>& f);
ov::PartialShape master_shape{}; ov::PartialShape master_shape{};
}; };

View File

@ -99,35 +99,19 @@ std::shared_ptr<ov::snippets::op::Subgraph> LoweringTests::getSubgraph(const std
return subgraph; return subgraph;
} }
std::shared_ptr<ov::snippets::op::Subgraph> LoweringTests::getLoweredSubgraph(const std::shared_ptr<Model> &f, std::shared_ptr<ov::snippets::op::Subgraph>
const ov::PartialShape& master_shape, LoweringTests::getLoweredSubgraph(const std::shared_ptr<Model> &f,
ov::pass::Manager pre_dialect, const ov::PartialShape& master_shape,
ov::pass::Manager post_dialect, const std::vector<ov::snippets::pass::Manager::PositionedPass>& backend_passes,
ov::pass::Manager post_precision, const ov::snippets::lowered::pass::PassPipeline& lowered_pre_common,
ov::snippets::lowered::pass::PassPipeline lowered_pipeline, const ov::snippets::lowered::pass::PassPipeline& lowered_post_common,
const std::shared_ptr<ov::snippets::Generator> generator) { const std::shared_ptr<ov::snippets::Generator>& generator) {
auto subgraph = getTokenizedSubgraph(f); auto subgraph = getTokenizedSubgraph(f);
subgraph->set_generator(generator == nullptr ? std::make_shared<DummyGenerator>() : generator); subgraph->set_generator(generator == nullptr ? std::make_shared<DummyGenerator>() : generator);
subgraph->set_master_shape(master_shape); subgraph->set_master_shape(master_shape);
const auto& body = subgraph->body_ptr();
auto& body_rt_info = body->get_rt_info();
// todo: insertLoops pass requires body_rt_info["PluginShapesOverride"] and subgraph->set_tile_rank to work normally
// consider revising snippets-plugin shape and scheduling communication
std::vector<std::vector<size_t>> new_shapes;
for (const auto& p : body->get_parameters()) {
const auto pshape = p->get_output_partial_shape(0);
OPENVINO_ASSERT(pshape.is_static(), "getLoweredSubgraph supports only static shapes");
new_shapes.push_back(pshape.get_shape());
}
for (const auto& r : body->get_results()) {
const auto pshape = r->get_input_partial_shape(0);
OPENVINO_ASSERT(pshape.is_static(), "getLoweredSubgraph supports only static shapes");
new_shapes.push_back(pshape.get_shape());
}
body_rt_info["PluginShapesOverride"] = new_shapes;
subgraph->set_tile_rank(2); subgraph->set_tile_rank(2);
ov::snippets::lowered::pass::PassPipeline empty_pipeline; // Note: lowered_pipeline would have no effect on subgraph body, since it's applied on linear IR
subgraph->generate(pre_dialect, post_precision, post_precision, empty_pipeline, lowered_pipeline); subgraph->generate(backend_passes, lowered_pre_common, lowered_post_common);
return subgraph; return subgraph;
} }

View File

@ -21,7 +21,7 @@ namespace pass {
class OPENVINO_API Manager { class OPENVINO_API Manager {
public: public:
Manager(); Manager();
~Manager(); virtual ~Manager();
//// \brief Construct Manager with shared PassConfig instance //// \brief Construct Manager with shared PassConfig instance
explicit Manager(std::shared_ptr<PassConfig> pass_config); explicit Manager(std::shared_ptr<PassConfig> pass_config);

View File

@ -694,23 +694,30 @@ bool Snippet::SnippetJitExecutor::optimizeExecDomain(std::vector<VectorDims>& in
} }
void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp) { void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp) {
ov::pass::Manager pre_dialect; using Manager = snippets::pass::Manager;
pre_dialect.register_pass<ConvertToSwishCPU>(); using PassPosition = snippets::pass::Manager::PassPosition;
using Place = snippets::pass::Manager::PassPosition::Place;
std::vector<Manager::PositionedPass> backend_passes;
#define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...) \
backend_passes.emplace_back(PASS_POS, std::make_shared<PASS>(__VA_ARGS__))
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ConvertToSwishCPU);
if (enforceBF16 && snippet_for_generation->has_domain_sensitive_ops()) { if (enforceBF16 && snippet_for_generation->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations // enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement // MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be ran later again for case if BF16 enforcement is not happened // Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
CPU_REGISTER_PASS_X64(pre_dialect, ov::snippets::pass::MatMulToBrgemm); SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ov::snippets::pass::MatMulToBrgemm);
CPU_REGISTER_PASS_X64(pre_dialect, pass::EnforcePrecision, element::f32, element::bf16); SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), pass::EnforcePrecision, element::f32, element::bf16);
} }
ov::pass::Manager post_dialect; SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, "PropagatePrecision"), ov::intel_cpu::pass::BrgemmToBrgemmCPU);
CPU_REGISTER_PASS_X64(post_dialect, ov::intel_cpu::pass::BrgemmToBrgemmCPU); SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, "PropagatePrecision"), ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
CPU_REGISTER_PASS_X64(post_dialect, ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
ov::pass::Manager post_precision; SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::RemoveConverts);
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::RemoveConverts); SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::MulAddToFMA);
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::MulAddToFMA);
#undef SNIPPETS_REGISTER_PASS
ov::snippets::lowered::pass::PassPipeline control_flow_markup_pipeline; ov::snippets::lowered::pass::PassPipeline control_flow_markup_pipeline;
CPU_REGISTER_PASS_X64(control_flow_markup_pipeline, ov::intel_cpu::pass::BrgemmBlocking); CPU_REGISTER_PASS_X64(control_flow_markup_pipeline, ov::intel_cpu::pass::BrgemmBlocking);
@ -718,13 +725,10 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp)
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline; ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert); CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert);
schedule = snippet_for_generation->generate( schedule = snippet_for_generation->generate(backend_passes,
pre_dialect, control_flow_markup_pipeline,
post_dialect, control_flow_pipeline,
post_precision, reinterpret_cast<const void*>(jcp));
control_flow_markup_pipeline,
control_flow_pipeline,
reinterpret_cast<const void*>(jcp));
} }
bool Snippet::SnippetJitExecutor::schedule_created() { bool Snippet::SnippetJitExecutor::schedule_created() {

View File

@ -8,6 +8,7 @@
#include <transformations/snippets/x64/op/fused_mul_add.hpp> #include <transformations/snippets/x64/op/fused_mul_add.hpp>
#include "snippets/op/scalar.hpp" #include "snippets/op/scalar.hpp"
#include "lowering_utils.hpp" #include "lowering_utils.hpp"
#include "snippets/pass_manager.hpp"
namespace ov { namespace ov {
namespace test { namespace test {
@ -117,6 +118,7 @@ public:
protected: protected:
void SetUp() override { void SetUp() override {
using PassPosition = ov::snippets::pass::Manager::PassPosition;
LoweringTests::SetUp(); LoweringTests::SetUp();
std::vector<PartialShape> inputShapes(3); std::vector<PartialShape> inputShapes(3);
size_t add_input_idx; size_t add_input_idx;
@ -124,7 +126,9 @@ protected:
const bool scalar_input = ov::shape_size(inputShapes[2].to_shape()) == 1; const bool scalar_input = ov::shape_size(inputShapes[2].to_shape()) == 1;
snippets_model = std::make_shared<EltwiseWithMulAddFunction>(inputShapes, add_input_idx, scalar_input); snippets_model = std::make_shared<EltwiseWithMulAddFunction>(inputShapes, add_input_idx, scalar_input);
cpu_manager.register_pass<ov::intel_cpu::pass::MulAddToFMA>(); // Note: this inserts MulAddToFMA at the end of the pipeline
backend_passes.emplace_back(PassPosition(PassPosition::Place::PipelineEnd),
std::make_shared<ov::intel_cpu::pass::MulAddToFMA>());
std::vector<ov::Node::type_info_t> custom_opset{ov::intel_cpu::FusedMulAdd::get_type_info_static()}; std::vector<ov::Node::type_info_t> custom_opset{ov::intel_cpu::FusedMulAdd::get_type_info_static()};
auto target_machine = std::make_shared<DummyTargetMachine>(custom_opset); auto target_machine = std::make_shared<DummyTargetMachine>(custom_opset);
@ -133,11 +137,16 @@ protected:
std::shared_ptr<SnippetsFunctionBase> snippets_model; std::shared_ptr<SnippetsFunctionBase> snippets_model;
std::shared_ptr<ov::snippets::Generator> generator; std::shared_ptr<ov::snippets::Generator> generator;
ov::pass::Manager cpu_manager; std::vector<ov::snippets::pass::Manager::PositionedPass> backend_passes;
}; };
TEST_P(MulAddToFMATests, MulAddToFMATests) { TEST_P(MulAddToFMATests, MulAddToFMATests) {
auto subgraph = getLoweredSubgraph(snippets_model->getOriginal(), master_shape, {}, {}, cpu_manager, {}, generator); auto subgraph = getLoweredSubgraph(snippets_model->getOriginal(),
master_shape,
backend_passes,
{},
{},
generator);
model = subgraph->body_ptr(); model = subgraph->body_ptr();
model_ref = snippets_model->getLowered(); model_ref = snippets_model->getLowered();
} }