Snippets pass manager (#18846)
This commit is contained in:
parent
38cad619af
commit
b0d917f0cb
@ -55,7 +55,7 @@ public:
|
||||
register_pass(pass);
|
||||
}
|
||||
|
||||
void run(lowered::LinearIR& linear_ir);
|
||||
void run(lowered::LinearIR& linear_ir) const;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Pass>> m_passes;
|
||||
|
@ -10,7 +10,7 @@
|
||||
#include <openvino/op/util/sub_graph_base.hpp>
|
||||
#include "openvino/op/op.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include "snippets/pass_manager.hpp"
|
||||
|
||||
#include "snippets/generator.hpp"
|
||||
|
||||
@ -26,8 +26,6 @@ namespace op {
|
||||
class Subgraph : public ov::op::util::SubGraphOp {
|
||||
public:
|
||||
OPENVINO_OP("Subgraph", "SnippetsOpset", ov::op::util::SubGraphOp);
|
||||
enum {DYNAMIC_DIMENSION = 0xffffffffffffffff};
|
||||
|
||||
// < 1, 42, 17, 15, 16> < 0, 1, 2, 3, 1>
|
||||
// should be:
|
||||
// A = < 1, 42, 17, 15> -> < 1, 3, 17, 15, 16> < 0, 1, 2, 3, 1>
|
||||
@ -74,9 +72,9 @@ public:
|
||||
|
||||
Subgraph() = default;
|
||||
|
||||
Subgraph(const OutputVector& args, std::shared_ptr<ov::Model> body);
|
||||
Subgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body);
|
||||
|
||||
Subgraph(const NodeVector& args, std::shared_ptr<ov::Model> body);
|
||||
Subgraph(const NodeVector& args, const std::shared_ptr<ov::Model>& body);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
@ -101,18 +99,14 @@ public:
|
||||
bool has_domain_sensitive_ops() const { return config.m_has_domain_sensitive_ops; }
|
||||
snippets::Schedule generate(const BlockedShapeVector& output_shapes,
|
||||
const BlockedShapeVector& input_shapes,
|
||||
ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision,
|
||||
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_lowered_pipeline,
|
||||
const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
|
||||
const lowered::pass::PassPipeline& control_flow_passes_pre_common,
|
||||
const lowered::pass::PassPipeline& control_flow_passes_post_common,
|
||||
const void* compile_params = nullptr);
|
||||
snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr);
|
||||
snippets::Schedule generate(ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision,
|
||||
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_lowered_pipeline,
|
||||
snippets::Schedule generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
|
||||
const lowered::pass::PassPipeline& control_flow_passes_pre_common,
|
||||
const lowered::pass::PassPipeline& control_flow_passes_post_common,
|
||||
const void* compile_params = nullptr);
|
||||
snippets::Schedule generate(const void* compile_params = nullptr);
|
||||
ov::PartialShape canonicalize(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes);
|
||||
@ -146,10 +140,10 @@ public:
|
||||
|
||||
private:
|
||||
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
|
||||
void data_flow_transformations(ov::pass::Manager& pre_common, ov::pass::Manager& post_common, ov::pass::Manager& post_precision);
|
||||
void data_flow_transformations(const std::vector<snippets::pass::Manager::PositionedPass>& backend_passes);
|
||||
void control_flow_transformations(lowered::LinearIR& linear_ir,
|
||||
lowered::pass::PassPipeline& target_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_pipeline);
|
||||
const lowered::pass::PassPipeline& backend_passes_pre_common,
|
||||
const lowered::pass::PassPipeline& backend_passes_post_common);
|
||||
void init_config();
|
||||
// Count of Subgraph virtual ports:
|
||||
// - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition)
|
||||
|
82
src/common/snippets/include/snippets/pass_manager.hpp
Normal file
82
src/common/snippets/include/snippets/pass_manager.hpp
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "openvino/pass/validate.hpp"
|
||||
#include <typeinfo>
|
||||
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
namespace pass {
|
||||
/**
|
||||
* @brief Manager is like ov::pass::Manager, but allows to insert new passes at arbitrary places in the pipeline
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class Manager : public ov::pass::Manager {
|
||||
public:
|
||||
~Manager() override = default;
|
||||
using PassBase = ov::pass::PassBase;
|
||||
using Validate = ov::pass::Validate;
|
||||
/**
|
||||
* @brief PassPosition describes a particular position in a transformation pipeline,
|
||||
* where a new transformation should be inserted.
|
||||
* @param pass_name name of the anchor pass, the new pass will be inserted before/after it.
|
||||
* Empty pass_name could mean either beginning or the end of the pipeline depending on the `after` flag.
|
||||
* No default value. Note that pass names namespaces are not supported, ov::PassName and snippets::PassName
|
||||
* are considered identical.
|
||||
* @param after `true` if the new pass should be inserted before the anchor pass, `false` otherwise (default).
|
||||
* If `pass_name` is empty, `true` means the end, and `false` - the beginning of the pipeline.
|
||||
* @param pass_instance the number of the pass with matching `pass_name` to be considered as the anchor pass.
|
||||
* 0 (default) means the first pass with `pass_name` will be considered as the anchor pass.
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class PassPosition {
|
||||
public:
|
||||
enum class Place {Before, After, PipelineStart, PipelineEnd};
|
||||
using PassListType = std::vector<std::shared_ptr<ov::pass::PassBase>>;
|
||||
explicit PassPosition(Place pass_place);
|
||||
explicit PassPosition(Place pass_place, std::string pass_name, size_t pass_instance = 0);
|
||||
PassListType::const_iterator get_insert_position(const PassListType& pass_list) const;
|
||||
private:
|
||||
const std::string m_pass_name;
|
||||
const size_t m_pass_instance{0};
|
||||
const Place m_place{Place::Before};
|
||||
};
|
||||
struct PositionedPass {
|
||||
PassPosition position;
|
||||
std::shared_ptr<PassBase> pass;
|
||||
PositionedPass(PassPosition arg_pos, std::shared_ptr<PassBase> arg_pass)
|
||||
: position(std::move(arg_pos)), pass(std::move(arg_pass)) {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> register_pass(Args&&... args) {
|
||||
return ov::pass::Manager::register_pass<T>(args...);
|
||||
}
|
||||
template <typename T, class Pos, class... Args, std::enable_if<std::is_same<PassPosition, Pos>::value, bool>() = true>
|
||||
std::shared_ptr<T> register_pass(const PassPosition& position, Args&&... args) {
|
||||
static_assert(std::is_base_of<PassBase, T>::value, "Attempt to insert pass that is not derived from PassBase");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto rc = insert_pass_instance(position, pass);
|
||||
rc->set_pass_config(m_pass_config);
|
||||
if (!m_pass_config->is_enabled<T>()) {
|
||||
m_pass_config->disable<T>();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
std::shared_ptr<PassBase> register_pass_instance(const PassPosition& pass_id, const std::shared_ptr<PassBase>& pass);
|
||||
void register_positioned_passes(const std::vector<PositionedPass>& pos_passes);
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Manager::PassBase> insert_pass_instance(const PassPosition& position, const std::shared_ptr<PassBase>& pass);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
@ -14,7 +14,7 @@ void PassPipeline::register_pass(const std::shared_ptr<Pass>& pass) {
|
||||
m_passes.push_back(pass);
|
||||
}
|
||||
|
||||
void PassPipeline::run(LinearIR& linear_ir) {
|
||||
void PassPipeline::run(LinearIR& linear_ir) const {
|
||||
for (const auto& pass : m_passes) {
|
||||
pass->run(linear_ir);
|
||||
}
|
||||
|
@ -43,7 +43,7 @@
|
||||
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include "snippets/pass_manager.hpp"
|
||||
#include "ngraph/pass/constant_folding.hpp"
|
||||
#include "ov_ops/type_relaxed.hpp"
|
||||
#include <openvino/pass/serialize.hpp>
|
||||
@ -57,16 +57,17 @@ using namespace ov::op::util;
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
namespace op {
|
||||
|
||||
void snippets::op::Subgraph::set_generator(std::shared_ptr<ov::snippets::Generator> generator) {
|
||||
void Subgraph::set_generator(std::shared_ptr<ov::snippets::Generator> generator) {
|
||||
m_generator = generator;
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::set_virtual_port_count(const size_t count) {
|
||||
void Subgraph::set_virtual_port_count(const size_t count) {
|
||||
m_virtual_port_count = count;
|
||||
}
|
||||
|
||||
auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool {
|
||||
auto Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool {
|
||||
return ov::is_type<ov::op::v1::Transpose>(op) ||
|
||||
ov::is_type<ov::op::v1::Softmax>(op) ||
|
||||
ov::is_type<ov::op::v8::Softmax>(op) ||
|
||||
@ -75,7 +76,7 @@ auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::No
|
||||
ov::is_type<ov::op::v3::Broadcast>(op); // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::init_config() {
|
||||
void Subgraph::init_config() {
|
||||
auto update = [](bool& flag, bool status) { flag = flag || status; };
|
||||
const auto ops = body_ptr()->get_ops();
|
||||
for (const auto& op : ops) {
|
||||
@ -84,7 +85,7 @@ void snippets::op::Subgraph::init_config() {
|
||||
}
|
||||
}
|
||||
|
||||
auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t {
|
||||
auto Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t {
|
||||
// The count of potential unique Buffers - it's hidden virtual ports as well
|
||||
// We should go through Subgraph and calculate potential non-inplace Buffers count.
|
||||
// These Buffers can be in 2 cases:
|
||||
@ -129,7 +130,9 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
|
||||
|
||||
const auto consumers = matmul->get_output_target_inputs(0);
|
||||
if (std::none_of(consumers.begin(), consumers.end(),
|
||||
[](const ov::Input<ov::Node>& in) { return ov::is_type<ov::op::v0::Result>(in.get_node()); })) {
|
||||
[](const ov::Input<ov::Node>& in) {
|
||||
return ov::is_type<ov::op::v0::Result>(in.get_node());
|
||||
})) {
|
||||
used_precision_size.push_back(matmul->get_element_type().size());
|
||||
}
|
||||
}
|
||||
@ -138,9 +141,9 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
|
||||
return used_precision_size.size();
|
||||
}
|
||||
|
||||
snippets::op::Subgraph::Subgraph(const OutputVector& args, std::shared_ptr<ov::Model> body)
|
||||
: SubGraphOp(args), m_generator(nullptr) {
|
||||
set_function(body);
|
||||
Subgraph::Subgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body)
|
||||
: SubGraphOp(args), m_generator(nullptr) {
|
||||
SubGraphOp::set_function(body);
|
||||
init_config();
|
||||
constructor_validate_and_infer_types();
|
||||
for (size_t i = 0; i < body->get_parameters().size(); ++i)
|
||||
@ -150,15 +153,15 @@ snippets::op::Subgraph::Subgraph(const OutputVector& args, std::shared_ptr<ov::M
|
||||
m_transformations_allowed = false;
|
||||
}
|
||||
|
||||
snippets::op::Subgraph::Subgraph(const NodeVector& args, std::shared_ptr<ov::Model> body)
|
||||
: Subgraph(as_output_vector(args), std::move(body)) {}
|
||||
Subgraph::Subgraph(const NodeVector& args, const std::shared_ptr<ov::Model>& body)
|
||||
: Subgraph(as_output_vector(args), body) {}
|
||||
|
||||
std::shared_ptr<Node> snippets::op::Subgraph::clone_with_new_inputs(const OutputVector& inputs) const {
|
||||
std::shared_ptr<Node> Subgraph::clone_with_new_inputs(const OutputVector& inputs) const {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
return make_shared<Subgraph>(inputs, body().clone());
|
||||
}
|
||||
|
||||
std::vector<PartialShape> snippets::op::Subgraph::reshape_body(const std::vector<PartialShape>& input_shapes) {
|
||||
std::vector<PartialShape> Subgraph::reshape_body(const std::vector<PartialShape>& input_shapes) {
|
||||
auto& params = body_ptr()->get_parameters();
|
||||
OPENVINO_ASSERT(params.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
@ -172,7 +175,7 @@ std::vector<PartialShape> snippets::op::Subgraph::reshape_body(const std::vector
|
||||
return output_shapes;
|
||||
}
|
||||
|
||||
std::vector<Shape> snippets::op::Subgraph::reshape_body(const std::vector<Shape>& input_shapes) {
|
||||
std::vector<Shape> Subgraph::reshape_body(const std::vector<Shape>& input_shapes) {
|
||||
auto& params = body_ptr()->get_parameters();
|
||||
OPENVINO_ASSERT(params.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
@ -188,7 +191,7 @@ std::vector<Shape> snippets::op::Subgraph::reshape_body(const std::vector<Shape>
|
||||
return output_shapes;
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::validate_and_infer_types() {
|
||||
void Subgraph::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::validate_and_infer_types")
|
||||
ov::ParameterVector old_parameters;
|
||||
@ -212,14 +215,14 @@ void snippets::op::Subgraph::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
bool snippets::op::Subgraph::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool Subgraph::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("body", body_ptr());
|
||||
visitor.on_attribute("input_descriptions", m_input_descriptions[0]);
|
||||
visitor.on_attribute("output_descriptions", m_output_descriptions[0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto snippets::op::Subgraph::wrap_node_as_subgraph(const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<op::Subgraph> {
|
||||
auto Subgraph::wrap_node_as_subgraph(const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<op::Subgraph> {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::wrap_node_as_subgraph")
|
||||
ov::ParameterVector body_parameters;
|
||||
@ -278,7 +281,7 @@ auto snippets::op::Subgraph::wrap_node_as_subgraph(const std::shared_ptr<ov::Nod
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::fill_empty_output_names(const Output<Node>& target_output_node, const Output<Node>& replacement_output_node) {
|
||||
void Subgraph::fill_empty_output_names(const Output<Node>& target_output_node, const Output<Node>& replacement_output_node) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
auto& out_tensor = target_output_node.get_tensor();
|
||||
const std::string new_name = ov::op::util::get_ie_output_name(replacement_output_node);
|
||||
@ -291,7 +294,7 @@ void snippets::op::Subgraph::fill_empty_output_names(const Output<Node>& target_
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
auto snippets::op::Subgraph::constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool {
|
||||
auto Subgraph::constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool {
|
||||
return ov::is_type<ov::op::v1::Transpose>(node) ||
|
||||
ov::is_type<ov::op::v1::Broadcast>(node) ||
|
||||
ov::is_type<ov::op::v3::Broadcast>(node) ||
|
||||
@ -311,16 +314,18 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector&
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::canonicalize")
|
||||
NODE_VALIDATION_CHECK(this, inputShapes.size() == body_ptr()->get_parameters().size(),
|
||||
"Number of parameters for snippet doesn't match passed to generate method: ", inputShapes.size(), " vs ", body_ptr()->get_parameters().size(), ".");
|
||||
"Number of parameters for snippet doesn't match passed to generate method: ",
|
||||
inputShapes.size(), " vs ", body_ptr()->get_parameters().size(), ".");
|
||||
|
||||
NODE_VALIDATION_CHECK(this, outputShapes.size() == body_ptr()->get_results().size(),
|
||||
"number of results for snippet doesn't match passed to generate method: ", outputShapes.size(), " vs ", body_ptr()->get_results().size(), ".");
|
||||
"number of results for snippet doesn't match passed to generate method: ",
|
||||
outputShapes.size(), " vs ", body_ptr()->get_results().size(), ".");
|
||||
|
||||
auto getMaxRankBlockedShape = [](const BlockedShapeVector& blockedShapes) -> const BlockedShape& {
|
||||
return *std::max_element(blockedShapes.begin(), blockedShapes.end(),
|
||||
[&](const BlockedShape& lhs, const BlockedShape& rhs) {
|
||||
return std::get<0>(lhs).size() < std::get<0>(rhs).size();
|
||||
});
|
||||
[&](const BlockedShape& lhs, const BlockedShape& rhs) {
|
||||
return std::get<0>(lhs).size() < std::get<0>(rhs).size();
|
||||
});
|
||||
};
|
||||
PartialShape baseShape;
|
||||
AxisVector baseOrder;
|
||||
@ -362,9 +367,9 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector&
|
||||
PartialShape::broadcast_merge_into(tmpPShape, inShape, ::ov::op::AutoBroadcastType::NUMPY),
|
||||
"Failed to create broadcastable shapes in snippets canonicalization");
|
||||
const auto paramShape = body_ptr()->get_parameters()[i]->get_partial_shape();
|
||||
const auto paramType = body_ptr()->get_parameters()[i]->get_element_type();
|
||||
const auto paramType = body_ptr()->get_parameters()[i]->get_element_type();
|
||||
if (paramShape.size() != inShape.size() || !equal(paramShape.begin(), paramShape.end(), inShape.begin()))
|
||||
body_ptr()->replace_parameter(i, std::make_shared<ov::op::v0::Parameter>(paramType, inShape));
|
||||
body_ptr()->replace_parameter(i, std::make_shared<ov::op::v0::Parameter>(paramType, inShape));
|
||||
}
|
||||
body_ptr()->validate_nodes_and_infer_types();
|
||||
|
||||
@ -373,10 +378,10 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector&
|
||||
auto end = shape.end();
|
||||
while (begin != end && *begin == 1)
|
||||
begin++;
|
||||
while (begin != end && *(end-1) == 1)
|
||||
while (begin != end && *(end - 1) == 1)
|
||||
end--;
|
||||
|
||||
PartialShape trimmedShape(std::vector<ov::Dimension> (end - begin, 1));
|
||||
PartialShape trimmedShape(std::vector<ov::Dimension>(end - begin, 1));
|
||||
std::copy(begin, end, trimmedShape.begin());
|
||||
return trimmedShape;
|
||||
};
|
||||
@ -456,7 +461,7 @@ ov::PartialShape snippets::op::Subgraph::canonicalized_body_shape_infer(const Bl
|
||||
return master_shape;
|
||||
}
|
||||
|
||||
bool snippets::op::Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept {
|
||||
bool Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept {
|
||||
const auto elementwise = std::dynamic_pointer_cast<const ov::op::util::BinaryElementwiseArithmetic>(node);
|
||||
return
|
||||
(elementwise == nullptr) ||
|
||||
@ -464,8 +469,8 @@ bool snippets::op::Subgraph::check_broadcast(const std::shared_ptr<const ov::Nod
|
||||
(elementwise->get_autob().m_type != ov::op::AutoBroadcastType::PDPD);
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outputShapes,
|
||||
const BlockedShapeVector& inputShapes) {
|
||||
void Subgraph::align_element_types(const BlockedShapeVector& outputShapes,
|
||||
const BlockedShapeVector& inputShapes) {
|
||||
// We should insert Convert before Results to set original output element type if needed
|
||||
const auto& body_results = body_ptr()->get_results();
|
||||
for (size_t i = 0; i < outputShapes.size(); i++) {
|
||||
@ -534,54 +539,45 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu
|
||||
}
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::data_flow_transformations(ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision) {
|
||||
void Subgraph::data_flow_transformations(const std::vector<snippets::pass::Manager::PositionedPass>& backend_passes) {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::data_flow_transformations")
|
||||
|
||||
const auto& params = body_ptr()->get_parameters();
|
||||
const auto& params = body_ptr()->get_parameters();
|
||||
bool inputs_has_dynamic_last_dims = std::any_of(params.begin(), params.end(),
|
||||
[](const shared_ptr<ov::op::v0::Parameter>& p) {
|
||||
return p->get_partial_shape().rbegin()->is_dynamic();
|
||||
});
|
||||
|
||||
pre_common.run_passes(body_ptr());
|
||||
|
||||
ov::pass::Manager common_manager;
|
||||
snippets::pass::Manager manager;
|
||||
if (config.m_has_domain_sensitive_ops) {
|
||||
common_manager.register_pass<snippets::pass::MatMulToBrgemm>();
|
||||
common_manager.register_pass<snippets::pass::FuseTransposeBrgemm>();
|
||||
common_manager.register_pass<snippets::pass::TransposeDecomposition>();
|
||||
common_manager.register_pass<snippets::pass::SetSoftmaxPorts>();
|
||||
manager.register_pass<snippets::pass::MatMulToBrgemm>();
|
||||
manager.register_pass<snippets::pass::FuseTransposeBrgemm>();
|
||||
manager.register_pass<snippets::pass::TransposeDecomposition>();
|
||||
manager.register_pass<snippets::pass::SetSoftmaxPorts>();
|
||||
}
|
||||
common_manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>();
|
||||
common_manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
|
||||
common_manager.register_pass<snippets::pass::ConvertPowerToPowerStatic>();
|
||||
manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>();
|
||||
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
|
||||
manager.register_pass<snippets::pass::ConvertPowerToPowerStatic>();
|
||||
// todo: presently dynamic pipeline is activated even if the last two dimension are static
|
||||
// In general, we can use static kernels in this case, but several parameters (src and dst memory pointers for example)
|
||||
// should be passed as run-time args, so it's a mixed mode: kernel is shape-aware, but some additional runtime args are required
|
||||
// Presently Broadcasting is organized in the following way:
|
||||
// * ALL last dims are static => broadcasting is handled via MoveBroadcast and pointer arithmetics (even for dynamic upper dims)
|
||||
if (!inputs_has_dynamic_last_dims) {
|
||||
common_manager.register_pass<snippets::pass::InsertMoveBroadcast>();
|
||||
manager.register_pass<snippets::pass::InsertMoveBroadcast>();
|
||||
}
|
||||
common_manager.run_passes(body_ptr());
|
||||
|
||||
post_common.run_passes(body_ptr());
|
||||
manager.register_pass<snippets::pass::PropagatePrecision>(m_generator->get_target_machine());
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
|
||||
|
||||
ov::pass::Manager precision_manager;
|
||||
precision_manager.register_pass<snippets::pass::PropagatePrecision>(m_generator->get_target_machine());
|
||||
precision_manager.register_pass<ov::pass::ConstantFolding>();
|
||||
precision_manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
|
||||
precision_manager.run_passes(body_ptr());
|
||||
|
||||
post_precision.run_passes(body_ptr());
|
||||
manager.register_positioned_passes(backend_passes);
|
||||
manager.run_passes(body_ptr());
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
|
||||
lowered::pass::PassPipeline& target_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_pipeline) {
|
||||
void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
|
||||
const lowered::pass::PassPipeline& backend_passes_pre_common,
|
||||
const lowered::pass::PassPipeline& backend_passes_post_common) {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::control_flow_transformations")
|
||||
|
||||
@ -590,7 +586,7 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin
|
||||
|
||||
// Ticket: 113666
|
||||
// TODO: Make pass pipeline with backend passes more flexible
|
||||
target_markup_pipeline.run(linear_ir);
|
||||
backend_passes_pre_common.run(linear_ir);
|
||||
|
||||
lowered::pass::PassPipeline common_pipeline;
|
||||
common_pipeline.register_pass<lowered::pass::MarkLoops>(vector_size);
|
||||
@ -607,7 +603,7 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin
|
||||
common_pipeline.register_pass<lowered::pass::InsertLoops>();
|
||||
common_pipeline.run(linear_ir);
|
||||
|
||||
target_pipeline.run(linear_ir);
|
||||
backend_passes_post_common.run(linear_ir);
|
||||
|
||||
const auto buffer_allocation_pass = std::make_shared<lowered::pass::AllocateBuffers>();
|
||||
lowered::pass::PassPipeline buffer_pipeline;
|
||||
@ -624,43 +620,36 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin
|
||||
m_buffer_scratchpad = buffer_allocation_pass->get_scratchpad_size();
|
||||
}
|
||||
|
||||
snippets::Schedule snippets::op::Subgraph::generate(const BlockedShapeVector& output_shapes,
|
||||
const BlockedShapeVector& input_shapes,
|
||||
const void* compile_params) {
|
||||
snippets::Schedule Subgraph::generate(const BlockedShapeVector& output_shapes,
|
||||
const BlockedShapeVector& input_shapes,
|
||||
const void* compile_params) {
|
||||
canonicalize(output_shapes, input_shapes);
|
||||
return generate(compile_params);
|
||||
}
|
||||
|
||||
snippets::Schedule snippets::op::Subgraph::generate(const BlockedShapeVector& output_shapes,
|
||||
const BlockedShapeVector& input_shapes,
|
||||
ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision,
|
||||
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_lowered_pipeline,
|
||||
const void* compile_params) {
|
||||
snippets::Schedule Subgraph::generate(const BlockedShapeVector& output_shapes,
|
||||
const BlockedShapeVector& input_shapes,
|
||||
const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
|
||||
const lowered::pass::PassPipeline& control_flow_passes_pre_common,
|
||||
const lowered::pass::PassPipeline& control_flow_passes_post_common,
|
||||
const void* compile_params) {
|
||||
canonicalize(output_shapes, input_shapes);
|
||||
return generate(pre_common, post_common, post_precision, target_lowered_markup_pipeline, target_lowered_pipeline, compile_params);
|
||||
return generate(data_flow_passes, control_flow_passes_pre_common, control_flow_passes_post_common, compile_params);
|
||||
}
|
||||
|
||||
snippets::Schedule snippets::op::Subgraph::generate(const void* compile_params) {
|
||||
auto mngr = ov::pass::Manager();
|
||||
auto lowered = lowered::pass::PassPipeline();
|
||||
return generate(mngr, mngr, mngr, lowered, lowered, compile_params);
|
||||
snippets::Schedule Subgraph::generate(const void* compile_params) {
|
||||
return generate({}, {}, {}, compile_params);
|
||||
}
|
||||
|
||||
snippets::Schedule snippets::op::Subgraph::generate(
|
||||
ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision,
|
||||
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_lowered_pipeline,
|
||||
const void* compile_params) {
|
||||
snippets::Schedule Subgraph::generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
|
||||
const lowered::pass::PassPipeline& control_flow_passes_pre_common,
|
||||
const lowered::pass::PassPipeline& control_flow_passes_post_common,
|
||||
const void* compile_params) {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::generate")
|
||||
NGRAPH_CHECK(m_generator != nullptr, "generate is called while generator is not set");
|
||||
|
||||
data_flow_transformations(pre_common, post_common, post_precision);
|
||||
data_flow_transformations(data_flow_passes);
|
||||
|
||||
lowered::Config lowering_config;
|
||||
lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops;
|
||||
@ -668,7 +657,7 @@ snippets::Schedule snippets::op::Subgraph::generate(
|
||||
lowering_config.m_loop_depth = tileRank;
|
||||
|
||||
lowered::LinearIR linear_ir = lowered::LinearIR(body_ptr(), lowering_config);
|
||||
control_flow_transformations(linear_ir, target_lowered_markup_pipeline, target_lowered_pipeline);
|
||||
control_flow_transformations(linear_ir, control_flow_passes_pre_common, control_flow_passes_post_common);
|
||||
|
||||
// actual code emission
|
||||
const auto& lowering_result = m_generator->generate(linear_ir, lowering_config, compile_params);
|
||||
@ -677,31 +666,32 @@ snippets::Schedule snippets::op::Subgraph::generate(
|
||||
return {master_shape, false /*canBeLinearized*/, ptr};
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::print() const {
|
||||
void Subgraph::print() const {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
remark(13) << "subgraph " << this->get_friendly_name() << " "
|
||||
<< this->get_type_name()
|
||||
<< " which contains " << body_ptr()->get_ops().size() << " nodes" << std::endl;
|
||||
<< this->get_type_name()
|
||||
<< " which contains " << body_ptr()->get_ops().size() << " nodes" << std::endl;
|
||||
|
||||
int qqq = 0;
|
||||
for (auto op : body_ptr()->get_ordered_ops()) {
|
||||
remark(13) << "op " << qqq++ << " " << op->get_friendly_name() << " (" << op->get_type_name() << ") " << op << std::endl;
|
||||
for (const auto& op : body_ptr()->get_ordered_ops()) {
|
||||
remark(13) << "op " << qqq++ << " " << op->get_friendly_name() << " (" << op->get_type_name() << ") " << op
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
for (auto& in : this->inputs()) {
|
||||
remark(13) << " -> " << in.get_source_output().get_node_shared_ptr()->get_friendly_name() << " "
|
||||
<< in.get_source_output().get_node_shared_ptr() << std::endl;
|
||||
<< in.get_source_output().get_node_shared_ptr() << std::endl;
|
||||
}
|
||||
|
||||
for (auto& out : this->outputs()) {
|
||||
for (auto& user : out.get_target_inputs()) {
|
||||
remark(13) << " <- " << user.get_node()->get_friendly_name() << " " << user.get_node() << std::endl;
|
||||
remark(13) << " <- " << user.get_node()->get_friendly_name() << " " << user.get_node() << std::endl;
|
||||
}
|
||||
remark(13) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::print_statistics(bool verbose) {
|
||||
void Subgraph::print_statistics(bool verbose) {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
auto getNodeInventory = [](std::shared_ptr<ov::Node> n) -> size_t {
|
||||
size_t total = 0;
|
||||
@ -725,22 +715,22 @@ void snippets::op::Subgraph::print_statistics(bool verbose) {
|
||||
return total;
|
||||
};
|
||||
|
||||
auto getModelInventory = [getNodeInventory](const ov::Model & f) -> size_t {
|
||||
auto getModelInventory = [getNodeInventory](const ov::Model& f) -> size_t {
|
||||
size_t total = 0;
|
||||
for (auto op : f.get_ordered_ops()) {
|
||||
// Results and parameters are artificially introduced,
|
||||
// while Constants are already considered if they are inputs of other operation
|
||||
// this should lead to 1:1 inventory for single node operations
|
||||
if (!ov::as_type_ptr<ov::opset1::Parameter>(op)
|
||||
&& !ov::as_type_ptr<ov::opset1::Result>(op)
|
||||
&& !ov::as_type_ptr<ov::opset1::Constant>(op)) {
|
||||
&& !ov::as_type_ptr<ov::opset1::Result>(op)
|
||||
&& !ov::as_type_ptr<ov::opset1::Constant>(op)) {
|
||||
total += getNodeInventory(op);
|
||||
}
|
||||
}
|
||||
return total;
|
||||
};
|
||||
|
||||
auto countConstants = [](const ov::Model & f) -> size_t {
|
||||
auto countConstants = [](const ov::Model& f) -> size_t {
|
||||
size_t count = 0;
|
||||
for (auto op : f.get_ordered_ops()) {
|
||||
count += !!ov::as_type_ptr<ov::opset1::Constant>(op) ? 1 : 0;
|
||||
@ -762,7 +752,7 @@ void snippets::op::Subgraph::print_statistics(bool verbose) {
|
||||
}
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::serialize() const {
|
||||
void Subgraph::serialize() const {
|
||||
std::stringstream xmlFile, binFile;
|
||||
ov::pass::Serialize serializer(xmlFile, xmlFile, ov::pass::Serialize::Version::IR_V10);
|
||||
serializer.run_on_model(body_ptr());
|
||||
@ -771,5 +761,6 @@ void snippets::op::Subgraph::serialize() const {
|
||||
std::cout << m_model << std::endl;
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
||||
|
81
src/common/snippets/src/pass_manager.cpp
Normal file
81
src/common/snippets/src/pass_manager.cpp
Normal file
@ -0,0 +1,81 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/pass_manager.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
namespace pass {
|
||||
Manager::PassPosition::PassPosition(Place pass_place) : m_place(pass_place) {
|
||||
OPENVINO_ASSERT(m_place == Place::PipelineStart || m_place == Place::PipelineEnd,
|
||||
"Invalid arg: pass_name and pass_instance args could be omitted only for Place::PipelineStart/Place::PipelineEnd");
|
||||
}
|
||||
Manager::PassPosition::PassPosition(Place pass_place, std::string pass_name, size_t pass_instance)
|
||||
: m_pass_name(std::move(pass_name)), m_pass_instance(pass_instance), m_place(pass_place) {
|
||||
OPENVINO_ASSERT((m_place == Place::Before || m_place == Place::After) && !m_pass_name.empty(),
|
||||
"Invalid args combination: pass_place must be Place::Before/Place::After and pass_name must be non-empty");
|
||||
}
|
||||
|
||||
Manager::PassPosition::PassListType::const_iterator
|
||||
Manager::PassPosition::get_insert_position(const PassListType& pass_list) const {
|
||||
size_t pass_count = 0;
|
||||
auto match = [this, &pass_count](const std::shared_ptr<PassBase>& p) {
|
||||
auto name = p->get_name();
|
||||
// Note that MatcherPass and ModelPass currently have different naming policies:
|
||||
// - MatcherPass have names without namespaces, e.g. ConvertToSwishCPU
|
||||
// - Similar ModelPass name includes namespaces: ov::snippets::pass::ConvertToSwishCPU
|
||||
// So we have to remove everything before the last ':', and ':' itself
|
||||
if (name.size() > m_pass_name.size()) {
|
||||
const auto pos = name.find_last_of(':');
|
||||
if (pos == std::string::npos)
|
||||
return false;
|
||||
name = name.substr(pos + 1);
|
||||
}
|
||||
if (name == m_pass_name) {
|
||||
if (m_pass_instance == pass_count)
|
||||
return true;
|
||||
pass_count++;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
switch (m_place) {
|
||||
case Place::PipelineStart: return pass_list.cbegin();
|
||||
case Place::PipelineEnd: return pass_list.cend();
|
||||
case Place::Before:
|
||||
case Place::After: {
|
||||
auto insert_it = std::find_if(pass_list.cbegin(), pass_list.cend(), match);
|
||||
OPENVINO_ASSERT(insert_it != pass_list.cend(), "snippets::pass::Manager failed to find pass ", m_pass_name);
|
||||
return m_place == Place::After ? std::next(insert_it) : insert_it;
|
||||
}
|
||||
default:
|
||||
OPENVINO_THROW("Unsupported Place type in PassPosition::get_insert_position");
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Manager::PassBase> Manager::register_pass_instance(const PassPosition& position,
|
||||
const std::shared_ptr<PassBase>& pass) {
|
||||
pass->set_pass_config(m_pass_config);
|
||||
return insert_pass_instance(position, pass);
|
||||
}
|
||||
|
||||
void Manager::register_positioned_passes(const std::vector<PositionedPass>& pos_passes) {
|
||||
for (const auto& pp : pos_passes)
|
||||
register_pass_instance(pp.position, pp.pass);
|
||||
}
|
||||
|
||||
std::shared_ptr<Manager::PassBase> Manager::insert_pass_instance(const PassPosition& position,
|
||||
const std::shared_ptr<PassBase>& pass) {
|
||||
auto insert_pos = position.get_insert_position(m_pass_list);
|
||||
insert_pos = m_pass_list.insert(insert_pos, pass);
|
||||
if (m_per_pass_validation) {
|
||||
// Note: insert_pos points to the newly inserted pass, so advance to validate the pass results
|
||||
std::advance(insert_pos, 1);
|
||||
m_pass_list.insert(insert_pos, std::make_shared<ov::pass::Validate>());
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace pass
|
||||
}// namespace snippets
|
||||
}// namespace ov
|
@ -50,13 +50,13 @@ public:
|
||||
|
||||
protected:
|
||||
static std::shared_ptr<ov::snippets::op::Subgraph> getSubgraph(const std::shared_ptr<Model>& f);
|
||||
static std::shared_ptr<ov::snippets::op::Subgraph> getLoweredSubgraph(const std::shared_ptr<Model>& f,
|
||||
const ov::PartialShape& master_shape,
|
||||
ov::pass::Manager pre_dialect = {},
|
||||
ov::pass::Manager post_dialect = {},
|
||||
ov::pass::Manager post_precision = {},
|
||||
ov::snippets::lowered::pass::PassPipeline lowered_pipeline = {},
|
||||
const std::shared_ptr<ov::snippets::Generator> generator = nullptr);
|
||||
static std::shared_ptr<ov::snippets::op::Subgraph>
|
||||
getLoweredSubgraph(const std::shared_ptr<Model>& f,
|
||||
const ov::PartialShape& master_shape,
|
||||
const std::vector<ov::snippets::pass::Manager::PositionedPass>& backend_passes = {},
|
||||
const ov::snippets::lowered::pass::PassPipeline& lowered_pre_common = {},
|
||||
const ov::snippets::lowered::pass::PassPipeline& lowered_post_common = {},
|
||||
const std::shared_ptr<ov::snippets::Generator>& generator = nullptr);
|
||||
static std::shared_ptr<ov::snippets::op::Subgraph> getTokenizedSubgraph(const std::shared_ptr<Model>& f);
|
||||
ov::PartialShape master_shape{};
|
||||
};
|
||||
|
@ -99,35 +99,19 @@ std::shared_ptr<ov::snippets::op::Subgraph> LoweringTests::getSubgraph(const std
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::snippets::op::Subgraph> LoweringTests::getLoweredSubgraph(const std::shared_ptr<Model> &f,
|
||||
const ov::PartialShape& master_shape,
|
||||
ov::pass::Manager pre_dialect,
|
||||
ov::pass::Manager post_dialect,
|
||||
ov::pass::Manager post_precision,
|
||||
ov::snippets::lowered::pass::PassPipeline lowered_pipeline,
|
||||
const std::shared_ptr<ov::snippets::Generator> generator) {
|
||||
std::shared_ptr<ov::snippets::op::Subgraph>
|
||||
LoweringTests::getLoweredSubgraph(const std::shared_ptr<Model> &f,
|
||||
const ov::PartialShape& master_shape,
|
||||
const std::vector<ov::snippets::pass::Manager::PositionedPass>& backend_passes,
|
||||
const ov::snippets::lowered::pass::PassPipeline& lowered_pre_common,
|
||||
const ov::snippets::lowered::pass::PassPipeline& lowered_post_common,
|
||||
const std::shared_ptr<ov::snippets::Generator>& generator) {
|
||||
auto subgraph = getTokenizedSubgraph(f);
|
||||
subgraph->set_generator(generator == nullptr ? std::make_shared<DummyGenerator>() : generator);
|
||||
subgraph->set_master_shape(master_shape);
|
||||
const auto& body = subgraph->body_ptr();
|
||||
auto& body_rt_info = body->get_rt_info();
|
||||
// todo: insertLoops pass requires body_rt_info["PluginShapesOverride"] and subgraph->set_tile_rank to work normally
|
||||
// consider revising snippets-plugin shape and scheduling communication
|
||||
std::vector<std::vector<size_t>> new_shapes;
|
||||
for (const auto& p : body->get_parameters()) {
|
||||
const auto pshape = p->get_output_partial_shape(0);
|
||||
OPENVINO_ASSERT(pshape.is_static(), "getLoweredSubgraph supports only static shapes");
|
||||
new_shapes.push_back(pshape.get_shape());
|
||||
}
|
||||
for (const auto& r : body->get_results()) {
|
||||
const auto pshape = r->get_input_partial_shape(0);
|
||||
OPENVINO_ASSERT(pshape.is_static(), "getLoweredSubgraph supports only static shapes");
|
||||
new_shapes.push_back(pshape.get_shape());
|
||||
}
|
||||
body_rt_info["PluginShapesOverride"] = new_shapes;
|
||||
subgraph->set_tile_rank(2);
|
||||
ov::snippets::lowered::pass::PassPipeline empty_pipeline;
|
||||
subgraph->generate(pre_dialect, post_precision, post_precision, empty_pipeline, lowered_pipeline);
|
||||
// Note: lowered_pipeline would have no effect on subgraph body, since it's applied on linear IR
|
||||
subgraph->generate(backend_passes, lowered_pre_common, lowered_post_common);
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,7 @@ namespace pass {
|
||||
class OPENVINO_API Manager {
|
||||
public:
|
||||
Manager();
|
||||
~Manager();
|
||||
virtual ~Manager();
|
||||
|
||||
//// \brief Construct Manager with shared PassConfig instance
|
||||
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
||||
|
@ -694,23 +694,30 @@ bool Snippet::SnippetJitExecutor::optimizeExecDomain(std::vector<VectorDims>& in
|
||||
}
|
||||
|
||||
void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp) {
|
||||
ov::pass::Manager pre_dialect;
|
||||
pre_dialect.register_pass<ConvertToSwishCPU>();
|
||||
using Manager = snippets::pass::Manager;
|
||||
using PassPosition = snippets::pass::Manager::PassPosition;
|
||||
using Place = snippets::pass::Manager::PassPosition::Place;
|
||||
std::vector<Manager::PositionedPass> backend_passes;
|
||||
|
||||
#define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...) \
|
||||
backend_passes.emplace_back(PASS_POS, std::make_shared<PASS>(__VA_ARGS__))
|
||||
|
||||
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ConvertToSwishCPU);
|
||||
if (enforceBF16 && snippet_for_generation->has_domain_sensitive_ops()) {
|
||||
// enforce BF16 precisions to supported operations
|
||||
// MatMul has to be decomposed to Brgemm operations before enforcement
|
||||
// Note, MatMul decomposition will be ran later again for case if BF16 enforcement is not happened
|
||||
CPU_REGISTER_PASS_X64(pre_dialect, ov::snippets::pass::MatMulToBrgemm);
|
||||
CPU_REGISTER_PASS_X64(pre_dialect, pass::EnforcePrecision, element::f32, element::bf16);
|
||||
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
|
||||
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ov::snippets::pass::MatMulToBrgemm);
|
||||
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), pass::EnforcePrecision, element::f32, element::bf16);
|
||||
}
|
||||
|
||||
ov::pass::Manager post_dialect;
|
||||
CPU_REGISTER_PASS_X64(post_dialect, ov::intel_cpu::pass::BrgemmToBrgemmCPU);
|
||||
CPU_REGISTER_PASS_X64(post_dialect, ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
|
||||
SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, "PropagatePrecision"), ov::intel_cpu::pass::BrgemmToBrgemmCPU);
|
||||
SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, "PropagatePrecision"), ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
|
||||
|
||||
ov::pass::Manager post_precision;
|
||||
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::RemoveConverts);
|
||||
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::MulAddToFMA);
|
||||
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::RemoveConverts);
|
||||
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::MulAddToFMA);
|
||||
|
||||
#undef SNIPPETS_REGISTER_PASS
|
||||
|
||||
ov::snippets::lowered::pass::PassPipeline control_flow_markup_pipeline;
|
||||
CPU_REGISTER_PASS_X64(control_flow_markup_pipeline, ov::intel_cpu::pass::BrgemmBlocking);
|
||||
@ -718,13 +725,10 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp)
|
||||
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
|
||||
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert);
|
||||
|
||||
schedule = snippet_for_generation->generate(
|
||||
pre_dialect,
|
||||
post_dialect,
|
||||
post_precision,
|
||||
control_flow_markup_pipeline,
|
||||
control_flow_pipeline,
|
||||
reinterpret_cast<const void*>(jcp));
|
||||
schedule = snippet_for_generation->generate(backend_passes,
|
||||
control_flow_markup_pipeline,
|
||||
control_flow_pipeline,
|
||||
reinterpret_cast<const void*>(jcp));
|
||||
}
|
||||
|
||||
bool Snippet::SnippetJitExecutor::schedule_created() {
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <transformations/snippets/x64/op/fused_mul_add.hpp>
|
||||
#include "snippets/op/scalar.hpp"
|
||||
#include "lowering_utils.hpp"
|
||||
#include "snippets/pass_manager.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@ -117,6 +118,7 @@ public:
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
using PassPosition = ov::snippets::pass::Manager::PassPosition;
|
||||
LoweringTests::SetUp();
|
||||
std::vector<PartialShape> inputShapes(3);
|
||||
size_t add_input_idx;
|
||||
@ -124,7 +126,9 @@ protected:
|
||||
const bool scalar_input = ov::shape_size(inputShapes[2].to_shape()) == 1;
|
||||
snippets_model = std::make_shared<EltwiseWithMulAddFunction>(inputShapes, add_input_idx, scalar_input);
|
||||
|
||||
cpu_manager.register_pass<ov::intel_cpu::pass::MulAddToFMA>();
|
||||
// Note: this inserts MulAddToFMA at the end of the pipeline
|
||||
backend_passes.emplace_back(PassPosition(PassPosition::Place::PipelineEnd),
|
||||
std::make_shared<ov::intel_cpu::pass::MulAddToFMA>());
|
||||
|
||||
std::vector<ov::Node::type_info_t> custom_opset{ov::intel_cpu::FusedMulAdd::get_type_info_static()};
|
||||
auto target_machine = std::make_shared<DummyTargetMachine>(custom_opset);
|
||||
@ -133,11 +137,16 @@ protected:
|
||||
|
||||
std::shared_ptr<SnippetsFunctionBase> snippets_model;
|
||||
std::shared_ptr<ov::snippets::Generator> generator;
|
||||
ov::pass::Manager cpu_manager;
|
||||
std::vector<ov::snippets::pass::Manager::PositionedPass> backend_passes;
|
||||
};
|
||||
|
||||
TEST_P(MulAddToFMATests, MulAddToFMATests) {
|
||||
auto subgraph = getLoweredSubgraph(snippets_model->getOriginal(), master_shape, {}, {}, cpu_manager, {}, generator);
|
||||
auto subgraph = getLoweredSubgraph(snippets_model->getOriginal(),
|
||||
master_shape,
|
||||
backend_passes,
|
||||
{},
|
||||
{},
|
||||
generator);
|
||||
model = subgraph->body_ptr();
|
||||
model_ref = snippets_model->getLowered();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user