[Snippets] Dynamic pipeline reorganization (#18563)

This commit is contained in:
Ivan Novoselov 2023-10-24 06:23:10 +01:00 committed by GitHub
parent 59fe0a05a4
commit bc82ba4419
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 1316 additions and 825 deletions

View File

@ -12,7 +12,6 @@
namespace ov {
namespace snippets {
using code = const uint8_t *;
using RegInfo = std::pair<std::vector<size_t>, std::vector<size_t>>;
/**

View File

@ -11,12 +11,32 @@
#include "snippets_isa.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/shape_types.hpp"
#include "target_machine.hpp"
namespace ov {
namespace snippets {
class Generator;
/**
* @interface LoweringResult
* @brief Holds all relevant information produced during lowering
* @param compiled_snippet pointer to interface class that encapsulates compiled binary code
* @param buffer_scratchpad_size the amount of additional memory required by the binary code to execute.
* Must be allocated and freed by the backend.
*/
class LoweringResult {
friend class Generator;
// Some emitters rely on other precompiled kernels.
// We need to keep the pointers to such emitters alive, so the kernels would still be accessible at runtime.
std::vector<std::shared_ptr<Emitter>> m_saved_emitters{};
public:
std::shared_ptr<CompiledSnippet> compiled_snippet = nullptr;
size_t buffer_scratchpad_size = 0;
};
/**
* @interface Schedule
* @brief Return scheduling information and pointer to generated kernel code
@ -26,20 +46,21 @@ class Schedule {
public:
Schedule() = default;
/**
* @brief Default to create schedule out of specific parameters
* @param wd work domain for kernel execution
* @param p pointer to generated code
* @brief Create schedule out of specific parameters
* @param domain work domain for kernel execution
* @param lr lowering result produced during code generation
*/
Schedule(const VectorDims& wd, code p) : parallel_exec_domain(wd), ptr(p) {}
Schedule(std::vector<size_t>&& domain, LoweringResult&& lr) : parallel_exec_domain(domain), lowering_result(lr) {}
Schedule(std::vector<size_t> domain, LoweringResult&& lr) : parallel_exec_domain(std::move(domain)), lowering_result(lr) {}
/**
* @brief Returns callable instanse of code pointer
*/
template<typename K> K get_callable() const {
return reinterpret_cast<K>(const_cast<unsigned char*>(ptr));
return reinterpret_cast<K>(const_cast<unsigned char*>(lowering_result.compiled_snippet->get_code()));
}
VectorDims parallel_exec_domain {};
code ptr {nullptr};
LoweringResult lowering_result {};
};
/**
@ -52,7 +73,7 @@ public:
/**
* @brief Default constructor
*/
Generator(const std::shared_ptr<TargetMachine>& t) : target(t), lowered_saved{} {}
Generator(const std::shared_ptr<TargetMachine>& t) : target(t) {}
/**
* @brief Default destructor
*/
@ -62,17 +83,13 @@ public:
* @brief Allows to tweak the lowering process.
*/
/**
* @brief virtual method any specific implementation should implement
* @param m model in canonical for for table-based code generation
* @param config config with transformation and optimization parameters
* @param compile_params parameters for generated code
* @return pointer to generated code
* @brief generates executable code
* @param linear_ir lowered IR for code generation
* @param result variable to hande the result, only compiled_snippet and m_saved_emitters field will be modified
* @param compile_params compile-time parameters used for code generation
* @return void
*/
struct LoweringResult {
LoweringResult(code c) : binary_code(c) {}
code binary_code = nullptr;
};
LoweringResult generate(lowered::LinearIR& linear_ir, const lowered::Config& config, const void* compile_params = nullptr);
void generate(lowered::LinearIR& linear_ir, LoweringResult& result, const void* compile_params = nullptr) const;
/**
* @brief gets target machine
@ -96,17 +113,21 @@ public:
*/
opRegType get_op_reg_type(const std::shared_ptr<Node>& op) const;
virtual std::shared_ptr<Generator> clone() const = 0;
protected:
/**
* @brief gets register type by specific plugin op type
* @return register type
*/
virtual opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const;
/**
* @brief returns true if an emitter can use precompiled kernel.
* @return bool
*/
virtual bool uses_precompiled_kernel(const std::shared_ptr<Emitter>& emitter) const { return false; }
std::shared_ptr<TargetMachine> target;
// todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then).
// This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method.
lowered::LinearIR lowered_saved;
};
} // namespace snippets

View File

@ -74,7 +74,6 @@ protected:
std::vector<size_t> m_loop_ids{};
std::shared_ptr<IShapeInferSnippets> m_shapeInference{nullptr};
};
using ExpressionPtr = std::shared_ptr<Expression>;
class IOExpression : public Expression {
friend class LinearIR;

View File

@ -27,6 +27,13 @@ public:
}
return create(n, params...);
}
template<class ExprType, typename std::enable_if<std::is_base_of<Expression, ExprType>::value, bool>::type = true>
static ExpressionPtr shallow_copy(const std::shared_ptr<ExprType>& expr) {
if (const auto& io_expr = std::dynamic_pointer_cast<IOExpression>(expr))
return std::make_shared<IOExpression>(*io_expr);
else
return std::make_shared<ExprType>(*expr);
}
private:
/* -- Default Builders - initialize input port connectors from parents and create new output port connectors themselves */

View File

@ -116,6 +116,7 @@ public:
IShapeInferSnippets::Result shape_infer(const std::vector<VectorDimsRef>& input_shapes);
const std::shared_ptr<ShapeInferSnippetsNode>& get_shape_infer_instance() const {return m_shape_infer; }
VectorDims get_master_shape() const;
LinearIR deep_copy() const;
private:
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;

View File

@ -0,0 +1,28 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "pass.hpp"
namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @interface InsertMovebroadcast
* @brief Injects explicit Movebroadcast operations when the most varying dim is broadcasted
* @ingroup snippets
*/
class InsertBroadcastMove : public Pass {
public:
OPENVINO_RTTI("InsertBroadcastMove", "Pass")
bool run(LinearIR& linear_ir) override;
};
} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

View File

@ -18,8 +18,8 @@ namespace pass {
*/
class SoftmaxDecomposition : public Pass {
public:
explicit SoftmaxDecomposition(size_t vector_size);
OPENVINO_RTTI("SoftmaxDecomposition", "Pass")
explicit SoftmaxDecomposition(size_t vector_size);
bool run(LinearIR& linear_ir) override;
private:

View File

@ -0,0 +1,31 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "pass.hpp"
#include "snippets/lowered/loop_manager.hpp"
namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @interface ValidateShapes
* @brief The pass checks that there are no dynamic shapes in the IR
* @ingroup snippets
*/
class ValidateShapes : public Pass {
public:
OPENVINO_RTTI("ValidateShapes", "Pass")
ValidateShapes() = default;
bool run(LinearIR& linear_ir) override;
};
} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

View File

@ -0,0 +1,54 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/op.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
namespace ov {
namespace snippets {
namespace op {
/**
* @interface RankNormalization
* @brief Generated by Canonicalization for rank normalization purposes. It can prepend input shapes with seve1s only first or last dimensions.
* @arg num_prepend - num `1`s that will be inserted at the beginning of the input shape. Any value is allowed.
* @arg num_append - num `1`s that will be inserted at the end of the input shape. Could be either 0 (default) or 1;
* @ingroup snippets
*/
// Note that technically the same goal could be achieved using op::Unsqueeze operation,
// but RankNormalization has a much narrower semantics, and hence allows for an easier control and a more efficient shape infer.
//
class RankNormalization : public ov::op::Op {
public:
OPENVINO_OP("RankNormalization", "SnippetsOpset");
RankNormalization() = default;
RankNormalization(const Output<Node>& data, size_t num_prepend, size_t num_append);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
size_t get_num_append() const { return m_num_append; }
size_t get_num_prepend() const { return m_num_prepend; }
class ShapeInfer : public IShapeInferSnippets {
size_t m_num_prepend = 0;
size_t m_num_append = 0;
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
IShapeInferSnippets::Result
infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
private:
size_t m_num_prepend = 0;
size_t m_num_append = 0;
};
} // namespace op
} // namespace snippets
} // namespace ov

View File

@ -12,6 +12,7 @@
#include "openvino/core/rt_info.hpp"
#include "snippets/pass_manager.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/generator.hpp"
@ -68,7 +69,8 @@ public:
//
// D = < 1, 3, 17, 15, 32> < 0, 1, 2, 3, 4>
// E = < 1, 3, 17, 1, 32> < 0, 1, 2, 3, 4>
using BlockedShape = std::tuple<ov::PartialShape, ov::AxisVector, ov::element::Type>;
using Layout = std::vector<size_t>;
using BlockedShape = std::pair<VectorDims, Layout>;
using BlockedShapeVector = std::vector<BlockedShape>;
Subgraph() = default;
@ -94,43 +96,36 @@ public:
const std::shared_ptr<ov::snippets::Generator>& get_generator() const { return m_generator; }
std::shared_ptr<ov::snippets::Generator>& get_generator() { return m_generator; }
size_t get_buffer_scratchpad_size() const { return m_buffer_scratchpad; }
size_t get_virtual_port_count() const { return m_virtual_port_count; }
bool is_quantized() const { return config.m_is_quantized; }
bool has_domain_sensitive_ops() const { return config.m_has_domain_sensitive_ops; }
snippets::Schedule generate(const BlockedShapeVector& output_shapes,
const BlockedShapeVector& input_shapes,
const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
const lowered::pass::PassPipeline& control_flow_passes_pre_common,
const lowered::pass::PassPipeline& control_flow_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory = nullptr,
const void* compile_params = nullptr);
snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr);
snippets::Schedule generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
const lowered::pass::PassPipeline& control_flow_passes_pre_common,
const lowered::pass::PassPipeline& control_flow_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory = nullptr,
const void* compile_params = nullptr);
snippets::Schedule generate(const void* compile_params = nullptr);
ov::PartialShape canonicalize(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes);
ov::PartialShape canonicalized_body_shape_infer(const BlockedShapeVector& input_shapes);
std::vector<PartialShape> reshape_body(const std::vector<PartialShape>& input_shapes);
std::vector<Shape> reshape_body(const std::vector<Shape>& input_shapes);
snippets::Schedule generate(const BlockedShapeVector& blocked_input_shapes = {},
const std::vector<ov::element::Type>& input_precisions = {},
const std::vector<ov::element::Type>& output_precisions = {},
const std::vector<pass::Manager::PositionedPass>& data_flow_passes = {},
const lowered::pass::PassPipeline& control_flow_passes_pre_common = {},
const lowered::pass::PassPipeline& control_flow_passes_post_common = {},
const std::shared_ptr<IShapeInferSnippetsFactory>& factory = nullptr,
const void* compile_params = nullptr);
snippets::Schedule generate_from_linear_ir(const lowered::pass::PassPipeline& backend_passes_pre_common = {},
const lowered::pass::PassPipeline& backend_passes_post_common = {},
const void* compile_params = nullptr) const;
IShapeInferSnippets::Result shape_infer(const std::vector<VectorDimsRef>& input_shapes);
// plugin sets generator for a snippet to some specific generator.
// it's going to be replaced with Jitters table later
void set_generator(std::shared_ptr<ov::snippets::Generator> generator);
void set_tile_rank(size_t newRank) {tileRank = newRank;}
void set_virtual_port_count(const size_t count);
void set_min_jit_work_amount(const size_t jit_work_amount);
void set_min_parallel_work_amount(const size_t parallel_work_amount);
void set_virtual_port_count(size_t count);
void set_min_jit_work_amount(size_t jit_work_amount);
void set_min_parallel_work_amount(size_t parallel_work_amount);
void print() const;
void serialize() const;
void set_master_shape(ov::PartialShape new_shape) {master_shape = std::move(new_shape);}
VectorDims infer_master_shape();
static auto wrap_node_as_subgraph(const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<Subgraph>;
static void fill_empty_output_names(const Output<Node>& target_output_node, const Output<Node>& replacement_output_node);
@ -143,28 +138,30 @@ public:
// Return estimated unique buffer count (upper bound). It's needed for tokenization
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
static auto is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool;
void data_flow_transformations(const BlockedShapeVector& blocked_input_shapes = {},
const std::vector<ov::element::Type>& input_precisions = {},
const std::vector<ov::element::Type>& output_precisions = {},
const std::vector<snippets::pass::Manager::PositionedPass>& = {});
std::shared_ptr<lowered::LinearIR>
convert_body_to_linear_ir(const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory = std::make_shared<IShapeInferSnippetsFactory>()) const;
convert_body_to_linear_ir(const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory = std::make_shared<IShapeInferSnippetsFactory>());
std::shared_ptr<Subgraph> clone() const;
private:
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
void data_flow_transformations(const std::vector<snippets::pass::Manager::PositionedPass>& backend_passes);
void control_flow_transformations(lowered::LinearIR& linear_ir,
LoweringResult& lowering_result,
const lowered::pass::PassPipeline& backend_passes_pre_common,
const lowered::pass::PassPipeline& backend_passes_post_common);
const lowered::pass::PassPipeline& backend_passes_post_common) const;
void init_config();
// Count of Subgraph virtual ports:
// - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition)
// NOTE: To avoid overheads in each calculation of this count (for example, in validate_and_type_infer()),
// we should MANUALLY calculate it where it needed.
size_t m_virtual_port_count = 0;
size_t m_buffer_scratchpad = 0lu;
Shape exec_domain = {};
std::shared_ptr<ov::snippets::Generator> m_generator = nullptr;
ov::PartialShape master_shape;
size_t tileRank = 0; // set by plugin to specify the number of dimensions processed in a single kernel call
size_t maxInputRank = 0;
std::vector<size_t> appendOnesForCanonical;
std::shared_ptr<lowered::LinearIR> m_linear_ir = nullptr;

View File

@ -0,0 +1,34 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
#include "snippets/op/subgraph.hpp"
namespace ov {
namespace snippets {
namespace pass {
/**
* @interface AlignElementTypes
* @brief Align body precision with expected input/output precision. Insert op::ConvertSaturation if necessary.
* @ingroup snippets
*/
class AlignElementTypes: public ov::pass::ModelPass {
public:
OPENVINO_RTTI("AlignElementTypes");
AlignElementTypes(std::vector<ov::element::Type> input_precisions,
std::vector<ov::element::Type> output_precisions);
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
private:
std::vector<ov::element::Type> m_input_precisions;
std::vector<ov::element::Type> m_output_precisions;
};
} // namespace pass
} // namespace snippets
} // namespace ov

View File

@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/shape_types.hpp"
namespace ov {
namespace snippets {
namespace pass {
/**
* @interface Canonicalization
* @brief Canonicalization inserts RankNormalization (ov::op::Unsqueeze analogue) operations to account for:
* - input ranks mismatch, then inputs with smaller ranks are prepeneded with 1
* - layouts mismatch (only planar + blocked is supported), planar shapes are postpended with 1
* @ingroup snippets
*/
class Canonicalization: public ov::pass::ModelPass {
public:
OPENVINO_RTTI("Canonicalization");
using BlockedShapeVector = op::Subgraph::BlockedShapeVector;
using Layout = std::vector<size_t>;
explicit Canonicalization(const BlockedShapeVector& blocked_input_shapes);
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
private:
std::vector<VectorDims> m_in_shapes;
std::vector<Layout> m_in_layouts;
bool m_has_dynamic_inputs = false;
};
} // namespace pass
} // namespace snippets
} // namespace ov

View File

@ -24,6 +24,7 @@
#include "op/loop.hpp"
#include "op/brgemm.hpp"
#include "op/vector_buffer.hpp"
#include "op/rank_normalization.hpp"
namespace ov {
namespace snippets {

View File

@ -22,6 +22,7 @@ OV_OP(Store, ov::snippets::op)
OV_OP(BroadcastMove, ov::snippets::op)
OV_OP(Scalar, ov::snippets::op)
OV_OP(Nop, ov::snippets::op)
OV_OP(RankNormalization, ov::snippets::op)
// Layout-oblivious from opset1

View File

@ -13,6 +13,15 @@
namespace ov {
namespace snippets {
struct CompiledSnippet {
virtual const uint8_t* get_code() const = 0;
virtual size_t get_code_size() const = 0;
virtual bool empty() const = 0;
virtual ~CompiledSnippet() = default;
};
using CompiledSnippetPtr = std::shared_ptr<CompiledSnippet>;
typedef std::pair<std::function<std::shared_ptr<Emitter>(const lowered::ExpressionPtr&)>,
std::function<std::set<ov::element::TypeVector>(const std::shared_ptr<ov::Node>&)>> jitters_value;
@ -33,7 +42,7 @@ public:
* @brief finalizes code generation
* @return generated kernel binary
*/
virtual code get_snippet() const = 0;
virtual CompiledSnippetPtr get_snippet() = 0;
/**
* @brief gets number of lanes supported by target's vector ISA

View File

@ -58,6 +58,7 @@ constexpr inline bool implication(bool cause, bool cond) {
VectorDims get_planar_vdims(const VectorDims& shape, const std::vector<size_t>& layout);
VectorDims get_planar_vdims(const snippets::lowered::PortDescriptorPtr& port_desc);
VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port);
bool is_dynamic_vdims(const VectorDims& shape);
} // namespace utils
} // namespace snippets

View File

@ -15,7 +15,7 @@
namespace ov {
namespace snippets {
Generator::LoweringResult Generator::generate(lowered::LinearIR& linear_ir, const lowered::Config& config, const void* compile_params) {
void Generator::generate(lowered::LinearIR& linear_ir, LoweringResult& result, const void* compile_params) const {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::Generator::generate")
OV_ITT_TASK_CHAIN(GENERATE, ov::pass::itt::domains::SnippetsTransform, "Snippets::Generator", "::Transformations")
if (!target->is_supported())
@ -28,7 +28,6 @@ Generator::LoweringResult Generator::generate(lowered::LinearIR& linear_ir, cons
lowered_pipeline.register_pass<lowered::pass::AssignRegisters>(reg_type_mapper);
lowered_pipeline.register_pass<lowered::pass::InsertTailLoop>();
lowered_pipeline.run(linear_ir);
linear_ir.init_emitters(target);
OV_ITT_TASK_NEXT(GENERATE, "::EmitCode")
@ -45,12 +44,15 @@ Generator::LoweringResult Generator::generate(lowered::LinearIR& linear_ir, cons
}
OV_ITT_TASK_NEXT(GENERATE, "::GetSnippet")
// todo: we save lowered to access compiled brgemm kernels on execution time (normally lowered is destructed by then)
// remove this when kernel caching is implemented. Don't forget to make generate const method.
if (config.m_save_expressions)
lowered_saved = linear_ir;
return { target->get_snippet() };
// Note: some emitters use precompiled kernels. They need to be saved, so the kernels are accessible at runtime.
if (linear_ir.get_config().m_save_expressions) {
for (const auto& expr : linear_ir) {
const auto& emitter = expr->get_emitter();
if (uses_precompiled_kernel(emitter))
result.m_saved_emitters.emplace_back(emitter);
}
}
result.compiled_snippet = target->get_snippet();
}
std::shared_ptr<const TargetMachine> Generator::get_target_machine() const {
@ -63,7 +65,8 @@ Generator::opRegType Generator::get_op_reg_type(const std::shared_ptr<Node>& op)
std::dynamic_pointer_cast<op::LoopBegin>(op) ||
std::dynamic_pointer_cast<op::LoopEnd>(op) ||
std::dynamic_pointer_cast<op::Brgemm>(op) ||
std::dynamic_pointer_cast<op::Buffer>(op))
std::dynamic_pointer_cast<op::Buffer>(op) ||
std::dynamic_pointer_cast<op::RankNormalization>(op))
return gpr2gpr;
else if (std::dynamic_pointer_cast<snippets::op::Load>(op) ||
std::dynamic_pointer_cast<snippets::op::BroadcastLoad>(op))

View File

@ -122,6 +122,59 @@ LinearIR::container LinearIR::deep_copy_range(LinearIR::container::const_iterato
return result;
}
LinearIR LinearIR::deep_copy() const {
// todo: implement the same functionality using standard copy constructor
auto clone_ports_descriptors = [](std::vector<PortDescriptorPtr>& ports) {
std::for_each(ports.begin(), ports.end(), [](PortDescriptorPtr& pd) { pd = pd->clone(); });
};
const auto& original_lir = *this;
LinearIR new_lir;
new_lir.m_config = original_lir.m_config;
new_lir.m_shape_infer = original_lir.m_shape_infer;
NodeVector original_nodes;
original_nodes.reserve(original_lir.m_expressions.size());
std::unordered_map<PortConnectorPtr, PortConnectorPtr> connectors_map;
for (const auto& orig_expr : original_lir) {
original_nodes.push_back(orig_expr->get_node());
const auto& copy_expr = ExpressionFactory::shallow_copy(orig_expr);
clone_ports_descriptors(copy_expr->m_input_port_descriptors);
clone_ports_descriptors(copy_expr->m_output_port_descriptors);
for (auto& orig_con : copy_expr->m_output_port_connectors) {
const auto& copy_source = copy_expr->get_output_port(orig_con->get_source().get_index());
const auto& copy_con = std::make_shared<PortConnector>(copy_source);
connectors_map[orig_con] = copy_con;
orig_con = copy_con;
}
for (size_t i = 0; i < copy_expr->get_input_count(); i++) {
const auto& copy_connector = connectors_map[copy_expr->get_input_port_connector(i)];
const auto& copy_consumer = copy_expr->get_input_port(i);
copy_connector->add_consumer(copy_consumer);
copy_expr->replace_input(i, copy_connector);
}
if (auto io_expr = std::dynamic_pointer_cast<IOExpression>(copy_expr))
new_lir.m_io_expressions.push_back(io_expr);
new_lir.m_expressions.push_back(copy_expr);
}
// node_map and expr_map map original node pointer (expression) to a new pointer (expression)
ngraph::NodeMap node_map;
OPENVINO_SUPPRESS_DEPRECATED_START
ngraph::clone_nodes(original_nodes, node_map);
OPENVINO_SUPPRESS_DEPRECATED_END
new_lir.m_node2expression_map.clear();
for (const auto& copy_expr : new_lir.m_expressions) {
copy_expr->m_source_node = node_map[copy_expr->m_source_node.get()];
new_lir.m_node2expression_map[copy_expr->m_source_node] = copy_expr;
}
new_lir.m_loop_manager = std::make_shared<LoopManager>();
// It's Ok to share shapeInfer factory, since LIR doesn't change it
new_lir.m_shape_infer_factory = m_shape_infer_factory;
// Note: shapeInfer stores expression pointers. we re-create it, so shape inference is performed on cloned exprs.
new_lir.m_shape_infer = std::make_shared<LIRShapeInfer>(new_lir.m_expressions, new_lir.m_io_expressions);
return new_lir;
}
void LinearIR::debug_print(bool tds_as_pointers) const {
auto print_rinfo = [](const RegInfo& rinfo) {
std::cerr << " : {";
@ -320,7 +373,7 @@ VectorDims LinearIR::get_master_shape() const {
for (const auto& oe : out_exprs) {
const auto& port_desc = oe->get_input_port_descriptor(0);
OPENVINO_ASSERT(ov::snippets::broadcast_merge_into(master_shape, port_desc->get_shape()),
"Failed to merge input shapes in OptimizeDomain pass");
"Failed to merge input shapes in infer_master_shape");
}
}
return master_shape;
@ -339,6 +392,19 @@ LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_e
OPENVINO_THROW("Invalid io expression type detected");
}
}
// Note that if all output shapes are static, as in the case when the first shape infer was performed on nGraph,
// we can treat them as the last result
std::vector<VectorDims> outputDims;
outputDims.reserve(m_output_exprs.size());
for (const auto& expr : m_output_exprs) {
const auto &shape = expr->get_input_port_descriptor(0)->get_shape();
if (utils::is_dynamic_vdims(shape)) {
outputDims.clear();
break;
}
outputDims.push_back(shape);
}
m_last_result = {outputDims, ShapeInferStatus::success};
}
IShapeInferSnippets::Result LinearIR::LIRShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {

View File

@ -46,12 +46,21 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
for (const auto& expr : expressions) {
auto op = expr->get_node();
if (const auto io_expr = std::dynamic_pointer_cast<IOExpression>(expr)) {
if (io_expr->get_type() == IOExpression::io_type::INPUT)
manually_assigned_gprs[expr->get_output_port_connector(0)] = io_expr->get_index();
else if (io_expr->get_type() == IOExpression::io_type::OUTPUT)
if (io_expr->get_type() == IOExpression::io_type::INPUT) {
const auto& out_connector = expr->get_output_port_connector(0);
manually_assigned_gprs[out_connector] = io_expr->get_index();
const auto& consumer_inputs = out_connector->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
// TODO [96434]: Support RankNormalization (Reshape) in arbitrary place in pipeline, not just after inputs
if (ov::is_type<op::RankNormalization>(first_consumer->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization is supposed to be the only consumer");
manually_assigned_gprs[first_consumer->get_output_port_connector(0)] = io_expr->get_index();
}
} else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) {
manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
else
} else {
OPENVINO_THROW("Unsupported io_type detected");
}
} else if (const auto& buffer = ov::as_type_ptr<op::Buffer>(op)) {
const auto buffer_id = buffer->get_id();
// All buffers have one common data pointer

View File

@ -0,0 +1,90 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/lowered/pass/insert_broadcastmove.hpp"
#include "snippets/utils.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/itt.hpp"
namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
bool InsertBroadcastMove::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::InsertBroadcastMove")
bool modified = false;
const auto& loop_manager = linear_ir.get_loop_manager();
auto supports_broadcasting = [](const std::shared_ptr<ov::Node>& n) {
return ov::op::util::supports_auto_broadcast(n) ||
n->get_autob().m_type == ov::op::AutoBroadcastType::NUMPY ||
is_type<ov::op::v0::PRelu>(n);
};
auto dont_need_broadcasting = [](const ov::Output<ov::Node>& v){
// We don't need to insert BroadcastMove after the following operations:
// - Scalar has emitter with explicit broadcasting
// - VectorBuffer has scalar output shape to avoid broadcast conflicts and manually shape insertion.
// - Fill can be inserted only after VectorBuffer, and should be ignored as well.
return utils::is_scalar_constant(v.get_node_shared_ptr()) ||
ov::is_type<ov::snippets::op::VectorBuffer>(v.get_node_shared_ptr()) ||
ov::is_type<ov::snippets::op::Fill>(v.get_node_shared_ptr());
};
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
const auto& expr = *expr_it;
const auto& node = expr->get_node();
const auto& descriptors = expr->get_input_port_descriptors();
if (!supports_broadcasting(node) || descriptors.size() < 2)
continue;
const auto& connectors = expr->get_input_port_connectors();
OPENVINO_ASSERT(connectors.size() == descriptors.size(),
"Invalid expression configuration: connectors and descriptors size mismatch");
std::vector<size_t> last_dims(descriptors.size());
std::transform(descriptors.begin(), descriptors.end(), last_dims.begin(),
[](const std::shared_ptr<PortDescriptor>& d){
return d->get_shape().back();
});
const auto broadcasted_dim = *std::max_element(last_dims.begin(), last_dims.end());
for (size_t i = 0; i < last_dims.size(); i++) {
const auto& parent_port = connectors[i]->get_source();
if (last_dims[i] != broadcasted_dim &&
!dont_need_broadcasting(parent_port.get_expr()->get_node())) {
OPENVINO_ASSERT(last_dims[i] == 1,
"Attempt to broadcast non-1 dimension. Target dim: ", broadcasted_dim,
" This dim: ", last_dims[i]);
auto input_shape = descriptors[i]->get_shape();
// Note that input_shape could be empty (aka ngraph scalar), so we can't just replace the last dim
if (input_shape.empty())
input_shape.resize(1);
input_shape.back() = last_dims[i];
const auto broadcast = std::make_shared<op::BroadcastMove>(node->get_input_source_output(i), utils::vdims_to_pshape(input_shape));
PortDescriptorUtils::set_port_descriptor_ptr(broadcast->output(0), connectors[i]->get_source().get_descriptor_ptr()->clone());
const auto broadcast_expr = linear_ir.create_expression(broadcast, {connectors[i]});
linear_ir.insert(expr_it, broadcast_expr);
linear_ir.replace_input(expr->get_input_port(i), broadcast_expr->get_output_port_connector(0));
// Note that BroadcastMove modified the next expr input shape, so we need to set update
// expr's input port descriptor to reflect the changes
expr->get_input_port_descriptor(i)->set_shape(broadcast_expr->get_output_port_descriptor(0)->get_shape());
// Copy Loop identifies
const auto& loop_ids = expr->get_loop_ids();
broadcast_expr->set_loop_ids(loop_ids);
loop_manager->update_loops_port(loop_ids, expr->get_input_port(0), {broadcast_expr->get_input_port(0)}, true);
modified = true;
}
}
}
return modified;
}
} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

View File

@ -35,10 +35,9 @@ std::vector<size_t> get_buffer_loop_ids(const std::vector<size_t>& lhs, const st
ov::Shape compute_allocation_shape(const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<size_t>& buffer_loop_ids,
const std::vector<size_t>& parent_loop_ids,
const ov::Output<ov::Node>& parent_output,
const ExpressionPort& expr_port,
const int allocation_rank) {
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(parent_output);
const auto planar_shape = utils::get_planar_vdims(port);
const auto& planar_shape = utils::get_planar_vdims(expr_port);
const size_t rank = allocation_rank >= 0 ? std::min(static_cast<size_t>(allocation_rank), planar_shape.size()) : planar_shape.size();
ov::Shape allocation_shape(rank);
@ -123,9 +122,9 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt&
for (const auto& entry_point : loop_entries) {
const auto& entry_port = entry_point.expr_port;
const auto& expr = entry_port->get_expr();
const auto port = entry_port->get_index();
const auto port_idx = entry_port->get_index();
const auto node = expr->get_node();
const auto& input_connector = expr->get_input_port_connector(port);
const auto& input_connector = expr->get_input_port_connector(port_idx);
const auto& parent_expr_output = input_connector->get_source();
const auto& parent_expr = parent_expr_output.get_expr();
const auto parent_port = parent_expr_output.get_index();
@ -140,7 +139,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt&
const auto parent_ma = ov::as_type_ptr<op::MemoryAccess>(parent);
const auto node_ma = ov::as_type_ptr<op::MemoryAccess>(node);
bool is_buffer_needed = (parent_ma && parent_ma->is_memory_access_output_port(parent_port)) ||
(node_ma && node_ma->is_memory_access_input_port(port));
(node_ma && node_ma->is_memory_access_input_port(port_idx));
const auto current_loops = expr->get_loop_ids();
const auto parent_loops = parent_expr->get_loop_ids();
const auto buffer_loop_ids = get_buffer_loop_ids(current_loops, parent_loops, is_buffer_needed);
@ -154,7 +153,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt&
const auto allocation_shape = compute_allocation_shape(loop_manager,
buffer_loop_ids,
parent_loops,
parent->output(parent_port),
parent_expr_output,
m_buffer_allocation_rank);
const auto buffer = std::make_shared<op::Buffer>(parent->output(parent_port), allocation_shape);
PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), parent_expr_output.get_descriptor_ptr()->clone());
@ -169,7 +168,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt&
for (const auto& exit_point : loop_exits) {
const auto& exit_port = exit_point.expr_port;
const auto& expr = exit_port->get_expr();
const auto port = exit_port->get_index();
const auto port_idx = exit_port->get_index();
const auto node = expr->get_node();
const auto output_connector = exit_port->get_port_connector_ptr();
const auto child_exprs_inputs = output_connector->get_consumers();
@ -200,7 +199,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt&
const auto child_ma = ov::as_type_ptr<op::MemoryAccess>(child);
const auto node_ma = ov::as_type_ptr<op::MemoryAccess>(node);
bool is_buffer_needed = (child_ma && child_ma->is_memory_access_input_port(child_port)) ||
(node_ma && node_ma->is_memory_access_output_port(port));
(node_ma && node_ma->is_memory_access_output_port(port_idx));
const auto local_buffer_loop_ids = get_buffer_loop_ids(current_loops, child_expr->get_loop_ids(), is_buffer_needed);
if (is_buffer_needed) {
@ -247,9 +246,9 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt&
const auto allocation_shape = compute_allocation_shape(loop_manager,
buffer_loop_ids,
current_loops,
node->output(port),
*exit_port,
m_buffer_allocation_rank);
auto buffer = std::make_shared<op::Buffer>(node->output(port), allocation_shape);
auto buffer = std::make_shared<op::Buffer>(node->output(port_idx), allocation_shape);
PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), exit_port->get_descriptor_ptr()->clone());
// We cannot insert Node output connector on Buffer output because not all consumers of Node needs Buffer
// Example:

View File

@ -3,7 +3,7 @@
//
#include "snippets/lowered/pass/insert_load_store.hpp"
#include "snippets/op/rank_normalization.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/snippets_isa.hpp"
@ -30,14 +30,18 @@ size_t InsertLoadStore::get_count(const PortDescriptorPtr& port_desc) const {
}
bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
std::shared_ptr<Expression> data_expr = *data_expr_it;
auto consumer_inputs = data_expr->get_output_port_connector(0)->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
if (is_type<op::RankNormalization>(first_consumer->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization is supposed to be the only consumer");
data_expr = first_consumer;
}
const auto& loop_manager = linear_ir.get_loop_manager();
const auto& data_expr = *data_expr_it;
const auto& data_node = data_expr->get_node();
const auto& data_ngraph_output = data_expr->get_node()->output(0);
const auto& output_connector = data_expr->get_output_port_connector(0);
const auto consumer_inputs = output_connector->get_consumers();
bool was_inserted = false;
for (const auto& consumer_input : consumer_inputs) {
for (const auto& consumer_input : output_connector->get_consumers()) {
const auto& consumer_expr = consumer_input.get_expr();
const auto port = consumer_input.get_index();
const auto& consumer = consumer_expr->get_node();
@ -46,7 +50,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr
return false;
const auto loop_ids = consumer_expr->get_loop_ids();
const auto load = std::make_shared<op::Load>(data_node->output(0), get_count(data_expr->get_output_port_descriptor(0)));
const auto load = std::make_shared<op::Load>(data_ngraph_output, get_count(data_expr->get_output_port_descriptor(0)));
PortDescriptorUtils::set_port_descriptor_ptr(load->output(0), consumer_input.get_descriptor_ptr()->clone());
const auto load_expr = linear_ir.create_expression(load, {output_connector});
linear_ir.insert(linear_ir.find_after(data_expr_it, consumer_expr), load_expr);
@ -55,7 +59,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr
load_expr->set_loop_ids(loop_ids);
// Need to update all the corresponding Loops with the same Entry Point
const auto prev_entry_point = consumer_input;
const auto& prev_entry_point = consumer_input;
const auto new_entry_point = load_expr->get_input_port(0);
loop_manager->update_loops_port(loop_ids, prev_entry_point, {new_entry_point}, true);
was_inserted = true;
@ -116,20 +120,14 @@ bool InsertLoadStore::run(LinearIR& linear_ir) {
const auto& node = expr->get_node();
if (ov::is_type<ov::op::v0::Parameter>(node)) {
modified |= insert_load(linear_ir, expr_it);
continue;
}
if (ov::is_type<ov::op::v0::Result>(node)) {
} else if (ov::is_type<ov::op::v0::Result>(node)) {
modified |= insert_store(linear_ir, expr_it);
continue;
}
if (auto buffer = ov::as_type_ptr<op::Buffer>(node)) {
} else if (auto buffer = ov::as_type_ptr<op::Buffer>(node)) {
modified |= insert_load(linear_ir, expr_it);
if (buffer->is_intermediate_memory())
modified |= insert_store(linear_ir, expr_it);
continue;
}
}
return modified;
}

View File

@ -29,7 +29,8 @@ bool MarkLoops::run(LinearIR& linear_ir) {
auto is_not_start_point = [](const std::shared_ptr<ov::Node>& node) {
return ov::is_type<ov::op::v0::Result>(node) ||
ov::is_type<ov::op::v0::Constant>(node) ||
ov::is_type<ov::op::v0::Parameter>(node);
ov::is_type<ov::op::v0::Parameter>(node) ||
ov::is_type<op::RankNormalization>(node);
};
auto are_conflicted = [](const ExpressionPort& lhs, const ExpressionPort& rhs) {

View File

@ -8,6 +8,7 @@
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
#include "snippets/utils.hpp"
namespace ov {
@ -79,18 +80,28 @@ bool OptimizeDomain::run(snippets::lowered::LinearIR& linear_ir) {
return false;
}
OPENVINO_ASSERT(config.m_min_parallel_work_amount != 0, "OptimizeDomain: Min parallel work amount can't equal to zero");
std::vector<std::shared_ptr<snippets::lowered::IOExpression>> input_exprs;
std::vector<VectorDims> input_shapes;
VectorDims master_shape = linear_ir.get_master_shape();
for (const auto& expr : linear_ir.get_IO_ops()) {
if (expr->get_type() == snippets::lowered::IOExpression::io_type::INPUT) {
input_exprs.push_back(expr);
const auto& shape = expr->get_output_port_descriptor(0)->get_shape();
bool blocked_input_shapes = false;
for (const auto& io_expr : linear_ir.get_IO_ops()) {
if (io_expr->get_type() == snippets::lowered::IOExpression::io_type::INPUT) {
auto consumer_inputs = io_expr->get_output_port_connector(0)->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
if (auto rank_norm = as_type_ptr<op::RankNormalization>(first_consumer->get_node())) {
// If RankNormalization appends dims, then the appended dims will be broadcasted
// so collapsing is not allowed. We may increment tile rank though.
if (rank_norm->get_num_append() != 0)
blocked_input_shapes = true;
// If RankNormalization prepends dims, then the dims should be ignored during domain optimization
// to avoid passing already incremented shapes to linear_ir.shape_infer()
}
const ExpressionPtr& shape_producing_expr = blocked_input_shapes ?
first_consumer :
io_expr;
const auto& shape = utils::get_planar_vdims(shape_producing_expr->get_output_port_descriptor(0));
OPENVINO_ASSERT(std::none_of(shape.begin(), shape.end(),
[](size_t d) {return d == snippets::IShapeInferSnippets::DYNAMIC_DIMENSION; }),
"OptimizeDomain pass does not support dynamic shapes");
OPENVINO_ASSERT(ov::snippets::broadcast_merge_into(master_shape, shape),
"Failed to merge input shapes in OptimizeDomain pass");
input_shapes.emplace_back(shape);
}
}
@ -98,7 +109,9 @@ bool OptimizeDomain::run(snippets::lowered::LinearIR& linear_ir) {
master_shape.end(),
(size_t)1,
std::multiplies<size_t>());
const auto num_dims_collapsed = optimize(input_shapes,
const auto num_dims_collapsed = blocked_input_shapes ?
0 :
optimize(input_shapes,
master_shape,
total_work_amount,
config.m_min_parallel_work_amount,

View File

@ -19,23 +19,25 @@ bool PropagateLayout::run(LinearIR& linear_ir) {
if (linear_ir.empty())
return false;
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
const auto& expr = *expr_it;
for (const auto& expr : linear_ir) {
const auto io_expr = std::dynamic_pointer_cast<IOExpression>(expr);
if (!io_expr)
continue;
const bool is_input = io_expr->get_type() == IOExpression::io_type::INPUT;
const auto& connectors = is_input ? expr->get_output_port_connectors() : expr->get_input_port_connectors();
if (connectors.size() != 1)
OPENVINO_THROW("Parameter/Results should have exactly one output/input");
OPENVINO_ASSERT(connectors.size() == 1, "Parameter/Results should have exactly one output/input");
// If input - we should be looking downstream, if output - upstream
const auto& target_connector = connectors.front();
if (is_input) {
const auto consumer_inputs = target_connector->get_consumers();
// Note that here we consider only the first child (which is usually load),
// but often there is another child - LoopEnd
auto consumer_inputs = target_connector->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
// If there is a RankNormalization op after a parameter - we should skip it
if (is_type<op::RankNormalization>(first_consumer->get_node()))
consumer_inputs = first_consumer->get_output_port_connector(0)->get_consumers();
std::set<std::vector<size_t>> child_layouts;
for (const auto& child_input : consumer_inputs) {
const auto& child = child_input.get_expr();

View File

@ -44,13 +44,15 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
// Float constant values in byte representation
const auto float_min_constant = uint32_t(0xff7fffff);
const auto zero_constant = uint32_t(0x00000000);
const bool is_dynamic = softmax->is_dynamic();
// We need an iterator to the inserted element
auto push_node = [&linear_ir, &expr_it](const std::shared_ptr<Node>& n) {
auto push_node = [&linear_ir, &expr_it, is_dynamic](const std::shared_ptr<Node>& n) {
const auto expr = linear_ir.insert(expr_it, n);
if (is_dynamic)
expr->get()->updateShapes();
return std::make_pair(expr, n);
};
const ov::PartialShape broadcasted_shape(softmax_expr->get_input_port_descriptor(0)->get_shape());
// Note: VectorBuffer is a special case, since it should go before the initial Load. So we handle it separately
const auto& vector_buffer_max = push_node(std::make_shared<op::VectorBuffer>());
// Init value of vector buffer for ReduceMax is -FLOAT_MIN.
@ -65,9 +67,8 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
std::vector<ExpressionPort>{(*max.first)->get_input_port(0),
(*max.first)->get_input_port(1)},
std::vector<ExpressionPort>{(*max.first)->get_output_port(0)});
const auto broadcast_horizon_max = push_node(
std::make_shared<op::BroadcastMove>(horizon_max.second, horizon_max.second->get_input_partial_shape(0)));
std::make_shared<op::BroadcastMove>(horizon_max.second, broadcasted_shape));
const auto vector_buffer_sum = push_node(std::make_shared<op::VectorBuffer>());
// Init value of vector buffer for ReduceSum is zero.
const auto fill_sum = push_node(std::make_shared<op::Fill>(vector_buffer_sum.second, 0, zero_constant));
@ -89,7 +90,7 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
// Divide is expensive operation, so we decompose it into 1 / x * y, where 1 / x is executed outside loop
const auto pow = push_node(std::make_shared<op::PowerStatic>(horizon_sum.second, -1.f));
const auto broadcast_pow = push_node(std::make_shared<op::BroadcastMove>(pow.second, horizon_sum.second->get_input_partial_shape(0)));
const auto broadcast_pow = push_node(std::make_shared<op::BroadcastMove>(pow.second, broadcasted_shape));
// Mul (pseudo-Divide loop)
const auto mul = push_node(std::make_shared<ov::op::v1::Multiply>(exp.second, broadcast_pow.second));

View File

@ -0,0 +1,48 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/lowered/pass/validate_shapes.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
#include "snippets/itt.hpp"
namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
bool ValidateShapes::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ValidateShapes")
for (const auto& expr : linear_ir) {
const auto num_inputs = expr->get_input_count();
const auto& port_connectors = expr->get_input_port_connectors();
const auto& port_descriptors = expr->get_input_port_descriptors();
OPENVINO_ASSERT(port_connectors.size() == num_inputs, "Invalid number of port connectors detected");
OPENVINO_ASSERT(port_descriptors.size() == num_inputs, "Invalid number of port descriptors detected");
for (size_t i = 0; i < num_inputs; i++) {
const auto& descr = port_descriptors[i];
const auto& layout = descr->get_layout();
const auto& shape = descr->get_shape();
const auto& n = expr->get_node();
OPENVINO_ASSERT(std::none_of(shape.begin(), shape.end(),
[](size_t d) {return d == IShapeInferSnippets::DYNAMIC_DIMENSION;}),
"Dynamic dimensions are not allowed at this point of pipeline. ",
"Check the expr for node ", n->get_friendly_name());
OPENVINO_ASSERT(layout.size() == shape.size(), "Layout and shape sizes must match. ",
"Check the expr for node ", n->get_friendly_name());
const auto& parent_desc = port_connectors[i]->get_source().get_descriptor_ptr();
const auto& parent_shape = parent_desc->get_shape();
OPENVINO_ASSERT(parent_shape == shape, "Parent shape must be equal to the expression shape. ",
"Check the expr for node ", n->get_friendly_name());
}
}
return false;
}
} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

View File

@ -3,6 +3,7 @@
//
#include "snippets/lowered/port_descriptor.hpp"
#include <snippets/utils.hpp>
namespace ov {
namespace snippets {
@ -12,13 +13,15 @@ size_t PortDescriptor::ServiceDimensions::FULL_DIM = SIZE_MAX;
PortDescriptor::PortDescriptor(const ov::Input<ov::Node>& in, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(ov::Input<const Node>(in.get_node(), in.get_index()), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(const ov::Input<const ov::Node>& in, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(in.get_shape(), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(const ov::Input<const ov::Node>& in, std::vector<size_t> subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(utils::pshape_to_vdims(in.get_partial_shape()), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(const ov::Output<ov::Node>& out, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(ov::Output<const Node>(out.get_node(), out.get_index()), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(const ov::Output<const ov::Node>& out, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(out.get_shape(), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(const ov::Output<const ov::Node>& out, std::vector<size_t> subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(utils::pshape_to_vdims(out.get_partial_shape()), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(VectorDims shape, VectorDims subtensor_shape, std::vector<size_t> layout)
: m_tensor_shape(std::move(shape)), m_layout(std::move(layout)), m_subtensor_shape(std::move(subtensor_shape)) {
@ -30,13 +33,12 @@ void PortDescriptor::validate_arguments() {
m_layout.resize(m_tensor_shape.size());
// NCHW layout by default
std::iota(m_layout.begin(), m_layout.end(), 0);
} else if (m_layout.size() != m_tensor_shape.size()) {
OPENVINO_THROW("Snippets tensor descriptor: Layout size must be equal to the shape size");
}
OPENVINO_ASSERT(m_layout.size() == m_tensor_shape.size(), "Snippets tensor descriptor: Layout size must be equal to the shape size");
}
PortDescriptorPtr PortDescriptor::clone() const {
const auto desc = std::make_shared<PortDescriptor>(m_tensor_shape, m_subtensor_shape, m_layout);
auto desc = std::make_shared<PortDescriptor>(m_tensor_shape, m_subtensor_shape, m_layout);
desc->set_reg(m_reg);
return desc;
}

View File

@ -46,15 +46,13 @@ bool Buffer::visit_attributes(AttributeVisitor& visitor) {
void Buffer::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
ov::Shape output_shape;
ov::PartialShape output_shape;
if (m_type == Type::NewMemory) {
OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!");
output_shape = m_shape;
} else if (m_type == Type::IntermediateMemory) {
const auto& input_shape = get_input_partial_shape(0);
OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape");
m_element_type = get_input_element_type(0);
output_shape = input_shape.get_shape();
output_shape = get_input_partial_shape(0);
} else {
OPENVINO_THROW("Buffer supports only the following types: NewMemory and IntermediateMemory");
}

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/op/rank_normalization.hpp"
#include "snippets/utils.hpp"
namespace ov {
namespace snippets {
namespace op {
RankNormalization::RankNormalization(const Output<Node>& data, size_t num_prepend, size_t num_append) :
Op({data}), m_num_prepend(num_prepend), m_num_append(num_append) {
constructor_validate_and_infer_types();
}
std::shared_ptr<ov::Node> RankNormalization::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<RankNormalization>(new_args[0], m_num_prepend, m_num_append);
}
void RankNormalization::validate_and_infer_types() {
auto new_shape = get_input_partial_shape(0);
// Note: other values are not allowed, only planar + blocked layout combination can be normalized.
NODE_VALIDATION_CHECK(this, utils::one_of(m_num_append, 0lu, 1lu),
"num_append could be only 0 or 1, other values are not allowed.");
new_shape.insert(new_shape.begin(), m_num_prepend, Dimension(1));
new_shape.insert(new_shape.end(), m_num_append, Dimension(1));
set_output_type(0, get_input_element_type(0), new_shape);
}
bool RankNormalization::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("num_prepend", m_num_prepend);
visitor.on_attribute("num_append", m_num_append);
return true;
}
RankNormalization::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) {
const auto& rank_norm = as_type_ptr<RankNormalization>(n);
OPENVINO_ASSERT(rank_norm, "Invalid operation passed to RankNormalization::ShapeInfer: ", n->get_type_info().name);
m_num_append = rank_norm->m_num_append;
m_num_prepend = rank_norm->m_num_prepend;
}
IShapeInferSnippets::Result
RankNormalization::ShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 1, "Invalid number of input shapes passed to RankNormalization::ShapeInfer::infer");
VectorDims out_shape = input_shapes[0].get();
out_shape.insert(out_shape.begin(), m_num_prepend, 1);
out_shape.insert(out_shape.end(), m_num_append, 1);
return {{out_shape}, ShapeInferStatus::success};
}
} // namespace op
} // namespace snippets
} // namespace ov

View File

@ -4,14 +4,17 @@
#include "snippets/op/scalar.hpp"
namespace ov {
namespace snippets {
namespace op {
std::shared_ptr<ov::Node> ov::snippets::op::Scalar::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<ov::Node> Scalar::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<Scalar>(*this);
}
// Scalar currently supports only one-element constants, this could be changed in the future
void ov::snippets::op::Scalar::validate_and_infer_types() {
void Scalar::validate_and_infer_types() {
Constant::validate_and_infer_types();
auto out_pshape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, out_pshape.is_static(), "Scalar supports only static input shapes");
@ -20,7 +23,7 @@ void ov::snippets::op::Scalar::validate_and_infer_types() {
" shape");
}
bool ov::snippets::op::Scalar::visit_attributes(AttributeVisitor& visitor) {
bool Scalar::visit_attributes(AttributeVisitor& visitor) {
auto shape = get_output_shape(0);
auto type = get_output_element_type(0);
auto value = cast_vector<float>();
@ -29,3 +32,7 @@ bool ov::snippets::op::Scalar::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("value", value);
return true;
}
} // namespace op
} // namespace snippets
} // namespace ov

View File

@ -6,9 +6,7 @@
#include "snippets/remarks.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/op/convert_saturation.hpp"
#include "snippets/pass/insert_movebroadcast.hpp"
#include "snippets/pass/broadcast_to_movebroadcast.hpp"
#include "snippets/pass/propagate_precision.hpp"
#include "snippets/pass/convert_constants.hpp"
@ -17,6 +15,9 @@
#include "snippets/pass/matmul_to_brgemm.hpp"
#include "snippets/pass/fuse_transpose_brgemm.hpp"
#include "snippets/pass/set_softmax_ports.hpp"
#include "snippets/pass/canonicalization.hpp"
#include "snippets/pass/align_element_types.hpp"
#include "snippets/lowered/pass/validate_shapes.hpp"
#include "snippets/utils.hpp"
@ -29,6 +30,7 @@
#include "snippets/lowered/pass/init_loops.hpp"
#include "snippets/lowered/pass/insert_buffers.hpp"
#include "snippets/lowered/pass/insert_load_store.hpp"
#include "snippets/lowered/pass/insert_broadcastmove.hpp"
#include "snippets/lowered/pass/load_movebroadcast_to_broadcastload.hpp"
#include "snippets/lowered/pass/allocate_buffers.hpp"
#include "snippets/lowered/pass/propagate_layout.hpp"
@ -61,7 +63,7 @@ namespace snippets {
namespace op {
void Subgraph::set_generator(std::shared_ptr<ov::snippets::Generator> generator) {
m_generator = generator;
m_generator = std::move(generator);
}
void Subgraph::set_virtual_port_count(const size_t count) {
@ -171,36 +173,6 @@ std::shared_ptr<Node> Subgraph::clone_with_new_inputs(const OutputVector& inputs
return make_shared<Subgraph>(inputs, body().clone());
}
std::vector<PartialShape> Subgraph::reshape_body(const std::vector<PartialShape>& input_shapes) {
auto& params = body_ptr()->get_parameters();
OPENVINO_ASSERT(params.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
for (size_t i = 0; i < params.size(); ++i) {
params[i]->set_partial_shape(input_shapes[i]);
}
body_ptr()->validate_nodes_and_infer_types();
std::vector<PartialShape> output_shapes;
for (const auto& res : body_ptr()->get_results()) {
output_shapes.emplace_back(res->get_input_partial_shape(0));
}
return output_shapes;
}
std::vector<Shape> Subgraph::reshape_body(const std::vector<Shape>& input_shapes) {
auto& params = body_ptr()->get_parameters();
OPENVINO_ASSERT(params.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
for (size_t i = 0; i < params.size(); ++i) {
params[i]->set_partial_shape(input_shapes[i]);
}
body_ptr()->validate_nodes_and_infer_types();
std::vector<Shape> output_shapes;
for (const auto& res : body_ptr()->get_results()) {
auto pshape = res->get_input_partial_shape(0);
OPENVINO_ASSERT(pshape.is_static(), "Subgraph inferred dynamic output shape during reshape with static inputs");
output_shapes.emplace_back(res->get_input_partial_shape(0).get_shape());
}
return output_shapes;
}
void Subgraph::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::validate_and_infer_types")
@ -311,166 +283,6 @@ auto Subgraph::constant_input_should_be_inside_body(const std::shared_ptr<ov::No
ov::is_type<ov::op::v1::Reshape>(node);
}
///
/// \brief Canonization transforms original subgraph and to canonical form suitable for code generation. In particular,
/// it handles supported layout conversions, broadcasts inputs and outputs to a single rank and layout. Canonicalization
/// returns master-shape (max rank + max dimensions over all outputs) that can be used for scheduling.
/// Canonicalization currently supports only the following layout conversions:
/// * None: all inputs have the same layout
/// * Planar + blocked: some inputs have blocked, and some have planar layouts, e.g. <N, C, H, W, c> + <N, C, H, W>
/// Also there is precision aligning inside body of subgraph during canonicalization
ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShapes,
const BlockedShapeVector& inputShapes) {
INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::canonicalize")
NODE_VALIDATION_CHECK(this, inputShapes.size() == body_ptr()->get_parameters().size(),
"Number of parameters for snippet doesn't match passed to generate method: ",
inputShapes.size(), " vs ", body_ptr()->get_parameters().size(), ".");
NODE_VALIDATION_CHECK(this, outputShapes.size() == body_ptr()->get_results().size(),
"number of results for snippet doesn't match passed to generate method: ",
outputShapes.size(), " vs ", body_ptr()->get_results().size(), ".");
auto getMaxRankBlockedShape = [](const BlockedShapeVector& blockedShapes) -> const BlockedShape& {
return *std::max_element(blockedShapes.begin(), blockedShapes.end(),
[&](const BlockedShape& lhs, const BlockedShape& rhs) {
return std::get<0>(lhs).size() < std::get<0>(rhs).size();
});
};
PartialShape baseShape;
AxisVector baseOrder;
std::tie(baseShape, baseOrder, std::ignore) = getMaxRankBlockedShape(inputShapes);
maxInputRank = baseShape.size();
appendOnesForCanonical.resize(inputShapes.size(), 0);
const bool baseIsBlocked = baseOrder.size() != std::set<size_t>(baseOrder.begin(), baseOrder.end()).size();
for (size_t i = 0; i < inputShapes.size(); i++) {
const auto& blockedShape = inputShapes[i];
PartialShape inShape;
AxisVector inOrder;
element::Type inType;
std::tie(inShape, inOrder, inType) = blockedShape;
const auto inRank = inShape.size();
NODE_VALIDATION_CHECK(this, inRank <= maxInputRank, "Input rank can't be larger than output rank in snippets.");
if (inRank < maxInputRank) {
appendOnesForCanonical[i] = maxInputRank - inRank;
PartialShape newShape(ov::Shape(maxInputRank, 1));
// todo: more complicated logics is needed if we want to merge smth else than blocked and planar
if (baseIsBlocked) {
const bool inIsNotBlocked = inOrder.size() == std::set<size_t>(inOrder.begin(), inOrder.end()).size();
NODE_VALIDATION_CHECK(this, inIsNotBlocked, "Snippets don't support conversion between blocked layouts of different ranks");
inShape.insert(inShape.end(), ov::Dimension(1));
appendOnesForCanonical[i]--;
}
NODE_VALIDATION_CHECK(this, PartialShape::broadcast_merge_into(newShape, inShape, ov::op::AutoBroadcastType::NUMPY),
"Failed to broadcast_merge inputs in snippets canonicalization");
inShape = std::move(newShape);
} else {
// todo: 4d blocked + 5d planar layouts are not supported: <N, C, H, W, c> + <N, C, D, H, W>
NODE_VALIDATION_CHECK(this,
equal(baseOrder.begin(), baseOrder.end(), inOrder.begin()),
"Snippets canonicalization got input shapes of equal ranks but different layouts, which is not supported");
}
ov::PartialShape tmpPShape(baseShape);
// todo: we need to generalize canonicalization for domain-sensitive ops. E.g. MatMul inputs can't be broadcasted one to another
if (!config.m_has_domain_sensitive_ops)
NODE_VALIDATION_CHECK(this,
PartialShape::broadcast_merge_into(tmpPShape, inShape, ::ov::op::AutoBroadcastType::NUMPY),
"Failed to create broadcastable shapes in snippets canonicalization");
const auto paramShape = body_ptr()->get_parameters()[i]->get_partial_shape();
const auto paramType = body_ptr()->get_parameters()[i]->get_element_type();
if (paramShape.size() != inShape.size() || !equal(paramShape.begin(), paramShape.end(), inShape.begin()))
body_ptr()->replace_parameter(i, std::make_shared<ov::op::v0::Parameter>(paramType, inShape));
}
body_ptr()->validate_nodes_and_infer_types();
auto skipStartEndOnes = [](const PartialShape& shape) {
auto begin = shape.begin();
auto end = shape.end();
while (begin != end && *begin == 1)
begin++;
while (begin != end && *(end - 1) == 1)
end--;
PartialShape trimmedShape(std::vector<ov::Dimension>(end - begin, 1));
std::copy(begin, end, trimmedShape.begin());
return trimmedShape;
};
// Check that output shapes are broadcastable => can be scheduled
const auto& body_results = body_ptr()->get_results();
PartialShape outPShape = body_results[0]->get_input_partial_shape(0);
// todo: we need a slightly more general approach for backward ROI propagation
const auto& result_parent = body_results[0]->get_input_node_shared_ptr(0);
if (body_results.size() == 1 &&
ov::is_type<ov::op::v1::Transpose>(result_parent) &&
ov::is_type<ov::op::v0::MatMul>(result_parent->get_input_node_shared_ptr(0))) {
outPShape = result_parent->get_input_partial_shape(0);
} else {
for (size_t i = 0; i < body_results.size(); i++) {
auto shape_i = body_results[i]->get_input_partial_shape(0);
auto outputShape_i = std::get<0>(outputShapes[i]);
// Check that the produced output shape corresponds to the passed shape
// Some produced shapes may have been changed to be broadcastable (e.g. blocked + planar outputs),
// so we need to remove leading and trailing "1" before the comparison
PartialShape pShape_i(skipStartEndOnes(shape_i));
bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i,
skipStartEndOnes(outputShape_i),
::ov::op::AutoBroadcastType::NUMPY);
NODE_VALIDATION_CHECK(this, compatibleWithPassedShape,
"Inferred and passed results shapes are incompatible for snippet ");
// Check that output shapes are broadcastable to each other => can be scheduled
bool compatibleWithOtherOutputs = PartialShape::broadcast_merge_into(outPShape, shape_i,
::ov::op::AutoBroadcastType::NUMPY);
NODE_VALIDATION_CHECK(this, compatibleWithOtherOutputs,
"Snippets output shapes must be numpy broadcastable");
}
}
// We should insert Converts after Parameters and Constant and before Results
// to align precision inside Subgraph body that is supported by Plugin
align_element_types(outputShapes, inputShapes);
master_shape = outPShape;
return master_shape;
}
ov::PartialShape snippets::op::Subgraph::canonicalized_body_shape_infer(const BlockedShapeVector& inputShapes) {
std::vector<Shape> normInputShapes;
for (size_t i = 0; i < inputShapes.size(); i++) {
PartialShape inShape = std::get<0>(inputShapes[i]);
const auto inRank = inShape.size();
if (inRank < maxInputRank) {
PartialShape newShape(ov::Shape(maxInputRank, 1));
for (size_t ir = 0; ir < inRank; ir++) {
newShape[appendOnesForCanonical[i] + ir] = inShape[ir];
}
normInputShapes.push_back(newShape.get_shape());
} else {
normInputShapes.push_back(inShape.get_shape());
}
}
reshape_body(normInputShapes);
const auto& body_results = body_ptr()->get_results();
PartialShape outPShape = body_results[0]->get_input_partial_shape(0);
const auto& result_parent = body_results[0]->get_input_node_shared_ptr(0);
if (body_results.size() == 1 &&
ov::is_type<ov::op::v1::Transpose>(result_parent) &&
ov::is_type<ov::op::v0::MatMul>(result_parent->get_input_node_shared_ptr(0))) {
outPShape = result_parent->get_input_partial_shape(0);
} else {
for (size_t i = 0; i < body_results.size(); i++) {
auto shape_i = body_results[i]->get_input_partial_shape(0);
bool compatibleWithOtherOutputs = PartialShape::broadcast_merge_into(outPShape, shape_i,
::ov::op::AutoBroadcastType::NUMPY);
NODE_VALIDATION_CHECK(this, compatibleWithOtherOutputs,
"Snippets output shapes must be numpy broadcastable");
}
}
master_shape = outPShape;
return master_shape;
}
bool Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept {
const auto elementwise = std::dynamic_pointer_cast<const ov::op::util::BinaryElementwiseArithmetic>(node);
return
@ -503,8 +315,40 @@ IShapeInferSnippets::Result Subgraph::OVShapeInfer::infer(const std::vector<Vect
return m_last_result;
}
VectorDims Subgraph::infer_master_shape() {
std::vector<VectorDims> output_dims;
if (is_dynamic()) {
// Note that in case of dynamic implementation shapeInfer() is called before PrepareParams,
// so there must be last_result available
// In principle, we can instantiate shape_infer here, but it's not an intended pipeline behavior.
OPENVINO_ASSERT(m_shape_infer, "Can't calculate master_shape when shapeInfer is not initialized");
output_dims = m_shape_infer->get_last_result().dims;
OPENVINO_ASSERT(!output_dims.empty(), "Can't calculate master_shape before the first shape inference");
} else {
for (const auto& res : body_ptr()->get_results()) {
const auto& res_input = res->input(0);
OPENVINO_ASSERT(res_input.get_partial_shape().is_static(), "Result have dynamic shape in static pipeline");
// We need to account to the shape's layout stored in Output<Node> rt_info
const auto& planar_shape = utils::get_planar_pshape(res_input.get_source_output());
output_dims.emplace_back(planar_shape.get_shape());
}
}
if (output_dims.size() == 1)
return output_dims.front();
const auto& default_broadcasting = std::make_shared<NumpyBroadcastShapeInfer>();
// Note: we have to convert vector<VectorDims> to vector<reference_wrapper<const VectorDims>>
// because of shape inference interface
std::vector<std::reference_wrapper<const VectorDims>> inputs;
inputs.reserve(output_dims.size());
for (const auto& d : output_dims)
inputs.emplace_back(d);
return default_broadcasting->infer(inputs).dims.front();
}
std::shared_ptr<lowered::LinearIR>
Subgraph::convert_body_to_linear_ir(const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory) const {
Subgraph::convert_body_to_linear_ir(const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory) {
lowered::Config lowering_config;
lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops;
lowering_config.m_need_fill_tail_register = config.m_has_domain_sensitive_ops;
@ -513,89 +357,44 @@ Subgraph::convert_body_to_linear_ir(const std::shared_ptr<IShapeInferSnippetsFac
lowering_config.m_min_parallel_work_amount = config.m_min_parallel_work_amount;
lowering_config.m_min_kernel_work_amount = config.m_min_jit_work_amount;
return std::make_shared<lowered::LinearIR>(body_ptr(), shape_infer_factory, lowering_config);
m_linear_ir = std::make_shared<lowered::LinearIR>(body_ptr(), shape_infer_factory, lowering_config);
m_shape_infer = m_linear_ir->get_shape_infer_instance();
return m_linear_ir;
}
void Subgraph::align_element_types(const BlockedShapeVector& outputShapes,
const BlockedShapeVector& inputShapes) {
// We should insert Convert before Results to set original output element type if needed
const auto& body_results = body_ptr()->get_results();
for (size_t i = 0; i < outputShapes.size(); i++) {
const auto needed_out_type = std::get<2>(outputShapes[i]);
if (body_results[i]->get_input_element_type(0) != needed_out_type) {
auto parent_output = body_results[i]->get_input_source_output(0);
std::shared_ptr<ov::Node> consumer = body_results[i];
// Snippets supports Transpose only after Parameter or before Result nodes
// So we have to insert Convert before Transpose (if there is) on Subgraph outputs
const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent_output.get_node_shared_ptr());
if (transpose) {
OPENVINO_ASSERT(parent_output.get_target_inputs().size() == 1,
"If Result has Transpose on input, this Result must be single consumer of the Transpose");
parent_output = transpose->get_input_source_output(0);
consumer = transpose;
}
const auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(parent_output, needed_out_type);
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
consumer->set_argument(0, convert);
consumer->validate_and_infer_types();
if (consumer != body_results[i])
body_results[i]->validate_and_infer_types();
}
}
// We should change existing element type to original for Parameters if needed
const auto& parameters = body_ptr()->get_parameters();
for (size_t i = 0; i < inputShapes.size(); ++i) {
const auto needed_in_type = std::get<2>(inputShapes[i]);
const auto& parameter = parameters[i];
const auto original_type = parameter->get_element_type();
if (original_type != needed_in_type) {
parameter->set_element_type(needed_in_type);
parameter->validate_and_infer_types();
auto parent_output = parameter->output(0);
auto consumer_inputs = parent_output.get_target_inputs();
// Snippets supports Transpose only after Parameter or before Result nodes
// So we have to insert Convert after Transpose (if there is) on Subgraph inputs
if (std::any_of(consumer_inputs.cbegin(), consumer_inputs.cend(),
[](const ov::Input<ov::Node>& input) { return ov::is_type<ov::op::v1::Transpose>(input.get_node()); })) {
OPENVINO_ASSERT(consumer_inputs.size() == 1,
"If Parameter has Transpose on output, this Transpose must be single consumer of the Parameter");
const auto transpose = consumer_inputs.begin()->get_node()->shared_from_this();
transpose->validate_and_infer_types();
parent_output = transpose;
consumer_inputs = parent_output.get_target_inputs();
}
const auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(parent_output, original_type);
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
for (const auto input : consumer_inputs) {
const auto& input_node = input.get_node();
if (input_node == convert.get()) {
continue;
}
input_node->set_argument(input.get_index(), convert->output(0));
}
}
std::shared_ptr<Subgraph> Subgraph::clone() const {
ov::OutputVector subgraph_node_inputs;
for (const auto &input : input_values()) {
auto new_input = std::make_shared<ov::opset1::Parameter>(input.get_element_type(), input.get_partial_shape());
subgraph_node_inputs.push_back(new_input);
}
std::shared_ptr<ov::Model> new_body = body_ptr()->clone();
auto result = std::make_shared<snippets::op::Subgraph>(subgraph_node_inputs, new_body);
// Note: ov::copy_runtime_info accepts only shared_ptr<ov::Node> as "from" but never modifies it,
// so we have to cast away constness to copy runtime info
ov::copy_runtime_info(const_pointer_cast<Node>(shared_from_this()), result);
result->set_friendly_name(get_friendly_name());
if (m_linear_ir)
result->m_linear_ir = std::make_shared<lowered::LinearIR>(m_linear_ir->deep_copy());
// Note: we don't update shapeInfer here, since it's initialized in the constructor
if (m_generator)
result->m_generator = m_generator->clone();
return result;
}
void Subgraph::data_flow_transformations(const std::vector<snippets::pass::Manager::PositionedPass>& backend_passes) {
void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input_shapes,
const std::vector<ov::element::Type>& input_precisions,
const std::vector<ov::element::Type>& output_precisions,
const std::vector<snippets::pass::Manager::PositionedPass>& backend_passes) {
INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::data_flow_transformations")
const auto& params = body_ptr()->get_parameters();
bool inputs_has_dynamic_last_dims = std::any_of(params.begin(), params.end(),
[](const shared_ptr<ov::op::v0::Parameter>& p) {
return p->get_partial_shape().rbegin()->is_dynamic();
});
snippets::pass::Manager manager;
ov::snippets::pass::Manager manager;
if (!blocked_input_shapes.empty())
manager.register_pass<snippets::pass::Canonicalization>(blocked_input_shapes);
if (!input_precisions.empty() && !output_precisions.empty())
manager.register_pass<snippets::pass::AlignElementTypes>(input_precisions, output_precisions);
if (config.m_has_domain_sensitive_ops) {
manager.register_pass<snippets::pass::MatMulToBrgemm>();
manager.register_pass<snippets::pass::FuseTransposeBrgemm>();
@ -605,14 +404,6 @@ void Subgraph::data_flow_transformations(const std::vector<snippets::pass::Manag
manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>();
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
manager.register_pass<snippets::pass::ConvertPowerToPowerStatic>();
// todo: presently dynamic pipeline is activated even if the last two dimension are static
// In general, we can use static kernels in this case, but several parameters (src and dst memory pointers for example)
// should be passed as run-time args, so it's a mixed mode: kernel is shape-aware, but some additional runtime args are required
// Presently Broadcasting is organized in the following way:
// * ALL last dims are static => broadcasting is handled via MoveBroadcast and pointer arithmetics (even for dynamic upper dims)
if (!inputs_has_dynamic_last_dims) {
manager.register_pass<snippets::pass::InsertMoveBroadcast>();
}
manager.register_pass<snippets::pass::PropagatePrecision>(m_generator->get_target_machine());
manager.register_pass<ov::pass::ConstantFolding>();
@ -623,8 +414,9 @@ void Subgraph::data_flow_transformations(const std::vector<snippets::pass::Manag
}
void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
LoweringResult& lowering_result,
const lowered::pass::PassPipeline& backend_passes_pre_common,
const lowered::pass::PassPipeline& backend_passes_post_common) {
const lowered::pass::PassPipeline& backend_passes_post_common) const {
INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::control_flow_transformations")
@ -649,7 +441,11 @@ void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
common_pipeline.register_pass<lowered::pass::InsertBuffers>(buffer_allocation_rank);
common_pipeline.register_pass<lowered::pass::InsertLoadStore>(vector_size);
common_pipeline.register_pass<lowered::pass::MoveScalarToConsumer>();
common_pipeline.register_pass<lowered::pass::InsertBroadcastMove>();
common_pipeline.register_pass<lowered::pass::LoadMoveBroadcastToBroadcastLoad>();
common_pipeline.register_pass<lowered::pass::ValidateShapes>();
common_pipeline.register_pass<lowered::pass::ValidateLoops>();
common_pipeline.register_pass<lowered::pass::InitLoops>();
common_pipeline.register_pass<lowered::pass::InsertLoops>();
@ -669,57 +465,44 @@ void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
final_pipeline.register_pass<lowered::pass::CleanupLoopOffsets>();
final_pipeline.run(linear_ir);
m_buffer_scratchpad = buffer_allocation_pass->get_scratchpad_size();
lowering_result.buffer_scratchpad_size = buffer_allocation_pass->get_scratchpad_size();
}
snippets::Schedule Subgraph::generate(const BlockedShapeVector& output_shapes,
const BlockedShapeVector& input_shapes,
snippets::Schedule Subgraph::generate(const BlockedShapeVector& blocked_input_shapes,
const std::vector<ov::element::Type>& input_precisions,
const std::vector<ov::element::Type>& output_precisions,
const std::vector<snippets::pass::Manager::PositionedPass>& data_flow_backend_passes,
const lowered::pass::PassPipeline& backend_passes_pre_common,
const lowered::pass::PassPipeline& backend_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& factory,
const void* compile_params) {
canonicalize(output_shapes, input_shapes);
return generate(compile_params);
data_flow_transformations(blocked_input_shapes, input_precisions, output_precisions, data_flow_backend_passes);
convert_body_to_linear_ir(factory);
return generate_from_linear_ir(backend_passes_pre_common, backend_passes_post_common, compile_params);
}
snippets::Schedule Subgraph::generate(const BlockedShapeVector& output_shapes,
const BlockedShapeVector& input_shapes,
const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
const lowered::pass::PassPipeline& control_flow_passes_pre_common,
const lowered::pass::PassPipeline& control_flow_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory,
const void* compile_params) {
canonicalize(output_shapes, input_shapes);
return generate(data_flow_passes, control_flow_passes_pre_common, control_flow_passes_post_common,
shape_infer_factory, compile_params);
}
snippets::Schedule Subgraph::generate(const void* compile_params) {
return generate({}, {}, {}, nullptr, compile_params);
}
snippets::Schedule Subgraph::generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
const lowered::pass::PassPipeline& control_flow_passes_pre_common,
const lowered::pass::PassPipeline& control_flow_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory,
const void* compile_params) {
snippets::Schedule Subgraph::generate_from_linear_ir(const lowered::pass::PassPipeline& backend_passes_pre_common,
const lowered::pass::PassPipeline& backend_passes_post_common,
const void* compile_params) const {
INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::generate")
OPENVINO_ASSERT(m_generator != nullptr, "generate is called while generator is not set");
data_flow_transformations(data_flow_passes);
lowered::LinearIR linear_ir = *convert_body_to_linear_ir(shape_infer_factory);
control_flow_transformations(linear_ir, control_flow_passes_pre_common, control_flow_passes_post_common);
// actual code emission
const auto& lowering_result = m_generator->generate(linear_ir, linear_ir.get_config(), compile_params);
const auto ptr = lowering_result.binary_code;
// Note: some transformations performed in the generator, e.g. tail insertion, can break shape propagation
// until we fix this behavior, we have to make a copy of LIR before giving it to the generator.
OPENVINO_ASSERT(m_linear_ir, "Attempt to call generate, when linear IR was not initialized");
auto linear_ir = m_linear_ir->deep_copy();
LoweringResult lowering_result;
control_flow_transformations(linear_ir, lowering_result, backend_passes_pre_common, backend_passes_post_common);
m_generator->generate(linear_ir, lowering_result, compile_params);
VectorDims parallel_exec_domain = linear_ir.get_master_shape();
const size_t loop_depth = linear_ir.get_config().m_loop_depth;
for (size_t i = 0; i < loop_depth; i++)
parallel_exec_domain[parallel_exec_domain.size() - 1 - i] = 1;
return {parallel_exec_domain, ptr};
return {parallel_exec_domain, std::move(lowering_result)};
}
void Subgraph::print() const {

View File

@ -0,0 +1,106 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/pass/align_element_types.hpp"
#include "snippets/itt.hpp"
namespace ov {
namespace snippets {
pass::AlignElementTypes::AlignElementTypes(std::vector<ov::element::Type> input_precisions,
std::vector<ov::element::Type> output_precisions) :
m_input_precisions(std::move(input_precisions)),
m_output_precisions(std::move(output_precisions)) {
}
bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m) {
RUN_ON_MODEL_SCOPE(AlignElementTypes);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::AlignElementTypes")
bool is_modified = false;
const auto& results = m->get_results();
const auto& params = m->get_parameters();
OPENVINO_ASSERT(m_input_precisions.size() == params.size() && m_output_precisions.size() == results.size(),
"Number of parameters for snippet doesn't match passed to the Canonicalization pass. ");
// We should insert Convert before Results to set original output element type if needed
for (size_t i = 0; i < m_output_precisions.size(); i++) {
const auto needed_out_type = m_output_precisions[i];
if (results[i]->get_input_element_type(0) != needed_out_type) {
std::shared_ptr<ov::Node> consumer = results[i];
auto parent_output = consumer->get_input_source_output(0);
// Snippets supports Transpose only after Parameter or before Result nodes
// So we have to insert Convert before Transpose (if there is) on Subgraph outputs
const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent_output.get_node_shared_ptr());
if (transpose) {
OPENVINO_ASSERT(parent_output.get_target_inputs().size() == 1,
"If Result has Transpose on input, this Result must be single consumer of the Transpose");
parent_output = transpose->get_input_source_output(0);
consumer = transpose;
}
const auto convert = std::make_shared<op::ConvertSaturation>(parent_output, needed_out_type);
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
consumer->set_argument(0, convert);
consumer->validate_and_infer_types();
if (transpose)
results[i]->validate_and_infer_types();
is_modified = true;
}
}
// We should change existing element type to original for Parameters if needed
for (size_t i = 0; i < m_input_precisions.size(); ++i) {
const auto needed_in_type = m_input_precisions[i];
const auto& parameter = params[i];
const auto original_type = parameter->get_element_type();
if (original_type != needed_in_type) {
parameter->set_element_type(needed_in_type);
parameter->validate_and_infer_types();
auto parent_output = parameter->output(0);
auto consumer_inputs = parent_output.get_target_inputs();
const auto& first_child = consumer_inputs.begin()->get_node()->shared_from_this();
// Note: RankNormalization of is designed for shape-inference purposes only.
// It does not process any data (nor does it emit any code), so it doesn't require Convert operations
if (is_type<op::RankNormalization>(first_child)) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization is supposed to be the only consumer");
parent_output = first_child->output(0);
consumer_inputs = parent_output.get_target_inputs();
}
// Snippets supports Transpose only after Parameter or before Result nodes
// So we have to insert Convert after Transpose (if there is) on Subgraph inputs
if (std::any_of(consumer_inputs.cbegin(), consumer_inputs.cend(),
[](const ov::Input<ov::Node>& input) { return ov::is_type<ov::op::v1::Transpose>(input.get_node()); })) {
OPENVINO_ASSERT(consumer_inputs.size() == 1,
"If Parameter has Transpose on output, this Transpose must be single consumer of the Parameter");
const auto transpose = consumer_inputs.begin()->get_node()->shared_from_this();
transpose->validate_and_infer_types();
parent_output = transpose;
consumer_inputs = parent_output.get_target_inputs();
}
const auto& convert = std::make_shared<ov::snippets::op::ConvertSaturation>(parent_output, original_type);
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
for (const auto input : consumer_inputs) {
const auto& input_node = input.get_node();
if (input_node == convert.get()) {
continue;
}
input_node->set_argument(input.get_index(), convert->output(0));
}
is_modified = true;
}
}
return is_modified;
}
} // namespace snippets
} // namespace ov

View File

@ -5,7 +5,7 @@
#include "snippets/itt.hpp"
#include "snippets/pass/broadcast_to_movebroadcast.hpp"
#include "snippets/pass/insert_movebroadcast.hpp"
#include "snippets/op/broadcastmove.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/opsets/opset1.hpp"
@ -30,15 +30,19 @@ ov::snippets::pass::BroadcastToMoveBroadcast::BroadcastToMoveBroadcast() {
const auto target_shape = root->get_output_partial_shape(0);
const auto value_shape = root->get_input_partial_shape(0);
if (target_shape.is_dynamic() || value_shape.is_dynamic()) {
return false;
OPENVINO_ASSERT(target_shape.is_static() && value_shape.rank().is_static(), "Broadcast with dynamic target shape is not supported in Snippets");
// Insert BroadcastMove only if the last dimension needs to be broadcasted. Higher-level dims broadcasting
// will be handled by pointer arithmetics. Note that this behavior should be changed in case of full op::Boradcast support.
Output<ov::Node> in_value = root->input_value(0);
if (*target_shape.rbegin() != *value_shape.rbegin()) {
auto broadcasted_shape = value_shape;
*broadcasted_shape.rbegin() = *target_shape.rbegin();
const auto& broadcast_node = std::make_shared<ov::snippets::op::BroadcastMove>(in_value, broadcasted_shape);
in_value = broadcast_node->output(0);
}
const auto broadcast_node = ov::snippets::pass::InsertMoveBroadcast::BroadcastNodeLastDim(root->input_value(0),
target_shape.get_shape(),
value_shape.get_shape());
replace_output_update_name(root->output(0), broadcast_node);
ov::copy_runtime_info(root, broadcast_node.get_node_shared_ptr());
replace_output_update_name(root->output(0), in_value);
ov::copy_runtime_info(root, in_value.get_node_shared_ptr());
return true;
};

View File

@ -0,0 +1,84 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/pass/canonicalization.hpp"
#include "snippets/op/rank_normalization.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils.hpp"
#include "snippets/lowered/port_descriptor.hpp"
namespace ov {
namespace snippets {
pass::Canonicalization::Canonicalization(const BlockedShapeVector& blocked_input_shapes) {
m_in_shapes.reserve(blocked_input_shapes.size());
m_in_layouts.reserve(blocked_input_shapes.size());
for (const auto& bs : blocked_input_shapes) {
m_has_dynamic_inputs |= utils::is_dynamic_vdims(bs.first);
m_in_shapes.emplace_back(bs.first);
m_in_layouts.emplace_back(bs.second);
// Note: Blocking (if any) must be accounted for in input shapes
OPENVINO_ASSERT(m_in_shapes.back().size() == m_in_layouts.back().size(), "Input shapes and layouts must have the same rank");
}
}
bool pass::Canonicalization::run_on_model(const std::shared_ptr<ov::Model>& m) {
RUN_ON_MODEL_SCOPE(Canonicalization);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::Canonicalization")
bool is_modified = false;
const ParameterVector& params = m->get_parameters();
OPENVINO_ASSERT(m_in_shapes.size() == params.size(),
"Number of parameters for snippet doesn't match passed to the Canonicalization pass. ",
"Expected: ", m_in_shapes.size(), " Got: ", params.size(), ".");
// Note that shape rank also incorporates layout, so NCHW16c would have shape rank 5
auto is_blocked_layout = [](const Layout& l) {
return l.size() != std::set<size_t>(l.begin(), l.end()).size();
};
auto compare_ranks = [](const Layout& l, const Layout& r) {
return l.size() < r.size();
};
// Layout with the max rank
const auto& max_rank_it = std::max_element(m_in_layouts.begin(), m_in_layouts.end(), compare_ranks);
Layout base_layout = *max_rank_it;
size_t max_rank = base_layout.size();
const bool base_is_blocked = is_blocked_layout(base_layout);
for (size_t i = 0; i < m_in_layouts.size(); i++) {
const auto& i_layout = m_in_layouts[i];
const auto& i_shape = m_in_shapes[i];
const auto i_rank = i_layout.size();
const bool i_is_blocked = is_blocked_layout(i_layout);
// Canonicalization logic briefly:
// * If this input is blocked => Reshape corresponding input parameter, so the following transformations
// will work with a shape of a larger rank. In dynamic case, this shape will be updated during shapeInfer()
// call, but the important thing is that the shape rank won't change.
// * If some of the input shapes is blocked (=> base_is_blocked), but this input is planar,
// then insert RankNormalization op after this input. This is needed, so all shapes inside the body have
// similar ranks.
if (i_is_blocked) {
OPENVINO_ASSERT(base_is_blocked && i_rank == max_rank, "If this shape is blocked, base must also be blocked");
params[i]->set_partial_shape(snippets::utils::vdims_to_pshape(i_shape));
is_modified = true;
} else if (i_rank < max_rank) {
size_t num_append = base_is_blocked;
OPENVINO_ASSERT(max_rank >= i_rank + num_append, "Unsupported blocked shapes combination in canonicalization");
size_t num_prepend = max_rank - i_rank - num_append;
const auto& out = params[i]->output(0);
const auto& target_inputs = out.get_target_inputs();
auto rank_norm = std::make_shared<op::RankNormalization>(out, num_prepend, num_append);
for (auto& in : target_inputs)
in.replace_source_output(rank_norm);
is_modified = true;
} else {
// todo: 4d blocked + 5d planar layouts are not supported: <N, C, H, W, c> + <N, C, D, H, W>
OPENVINO_ASSERT(equal(base_layout.begin(), base_layout.end(), i_layout.begin()),
"Canonicalization got input shapes of equal ranks but different layouts, which is not supported");
}
}
return is_modified;
}
} // namespace snippets
} // namespace ov

View File

@ -24,8 +24,7 @@ ov::snippets::pass::ConvertConstantsToScalars::ConvertConstantsToScalars() {
// Note that all Constants {1,1,1,1} are converted to Scalar {1} here
// This is needed to simplify shape inference, otherwise {1,1,1,1} Constants can increase output rank
// Also some operations support only scalar shapes, so we need separate scalars and shape [1]
const auto shape = constant->get_output_shape(0).size() == 0 ? ov::Shape{} : ov::Shape{1};
auto scalar = std::make_shared<snippets::op::Scalar>(ov::op::v0::Constant(*constant, shape));
auto scalar = std::make_shared<snippets::op::Scalar>(ov::op::v0::Constant(*constant, ov::Shape{1}));
scalar->set_friendly_name(constant->get_friendly_name());
ov::copy_runtime_info(constant, scalar);
ov::replace_node(constant, scalar);

View File

@ -25,11 +25,9 @@ ov::snippets::pass::SetSoftmaxPorts::SetSoftmaxPorts() {
auto root = m.get_match_root();
const auto& pshape = root->get_input_partial_shape(0);
if (pshape.is_dynamic())
return false;
const auto shape = pshape.get_shape();
const auto rank = shape.size();
OPENVINO_ASSERT(!pshape.rank().is_dynamic(), "SetSoftmaxPorts doesn't support dynamic ranks");
const auto rank = pshape.rank().get_length();
int64_t axis;
if (const auto softmax_v8 = ov::as_type_ptr<ov::op::v8::Softmax>(root)) {
@ -44,7 +42,7 @@ ov::snippets::pass::SetSoftmaxPorts::SetSoftmaxPorts() {
OPENVINO_ASSERT(axis < static_cast<int64_t>(rank), "Softmax has incorrect axis");
std::vector<size_t> subtensor(rank, 1);
for (size_t i = axis; i < rank; ++i)
for (auto i = axis; i < rank; ++i)
subtensor[i] = lowered::PortDescriptor::ServiceDimensions::FULL_DIM;
lowered::PortDescriptorUtils::set_port_descriptor_ptr(root->input(0), std::make_shared<lowered::PortDescriptor>(root->input(0), subtensor));

View File

@ -63,6 +63,7 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
SHAPE_INFER_PREDEFINED(ov::op::v0::Result, EmptyShapeInfer),
//
SHAPE_INFER_OP_SPECIFIC(op::LoadReshape),
SHAPE_INFER_OP_SPECIFIC(op::RankNormalization),
SHAPE_INFER_OP_SPECIFIC(op::BroadcastLoad),
SHAPE_INFER_OP_SPECIFIC(op::BroadcastMove),
};

View File

@ -92,7 +92,8 @@ VectorDims pshape_to_vdims(const PartialShape& pshape) {
result.reserve(pshape.size());
for (const auto& d : pshape)
result.push_back(d.is_dynamic() ? IShapeInferSnippets::DYNAMIC_DIMENSION : d.get_length());
return result;
// Note: PartialShape could be empty which designates scalar value. However, Scalars are represented as {1} in Snippets
return result.empty() ? VectorDims {1} : result;
}
ov::PartialShape vdims_to_pshape(const VectorDims& vdims) {
@ -132,6 +133,10 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port)
return get_planar_vdims(expr_port.get_descriptor_ptr());
}
bool is_dynamic_vdims(const VectorDims& shape) {
return std::any_of(shape.cbegin(), shape.cend(), [](size_t v){ return v == IShapeInferSnippets::DYNAMIC_DIMENSION; });
}
} // namespace utils
} // namespace snippets
} // namespace ov

View File

@ -6,6 +6,8 @@
#include <common_test_utils/ov_test_utils.hpp>
#include "snippets/op/subgraph.hpp"
#include "snippets_helpers.hpp"
#include "snippets/pass_manager.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
namespace ov {
namespace test {
@ -23,11 +25,17 @@ public:
void emit_data() const override {}
};
struct DummyCompiledSnippet : public ov::snippets::CompiledSnippet {
const uint8_t* get_code() const override { return nullptr; }
size_t get_code_size() const override { return 0; }
bool empty() const override { return true; }
};
class DummyTargetMachine : public ov::snippets::TargetMachine {
public:
DummyTargetMachine(const std::vector<ov::Node::type_info_t>& custom_opset = {});
bool is_supported() const override { return true; }
ov::snippets::code get_snippet() const override { return nullptr; }
ov::snippets::CompiledSnippetPtr get_snippet() override { return std::make_shared<DummyCompiledSnippet>(); }
size_t get_lanes() const override { return 10; }
};
@ -35,6 +43,7 @@ class DummyGenerator : public ov::snippets::Generator {
public:
DummyGenerator() : ov::snippets::Generator(std::make_shared<DummyTargetMachine>()) {}
DummyGenerator(const std::shared_ptr<ov::snippets::TargetMachine>& t) : ov::snippets::Generator(t) {}
std::shared_ptr<Generator> clone() const override { return std::make_shared<DummyGenerator>(target); }
protected:
opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const override { return vec2vec; };
@ -48,13 +57,15 @@ public:
void TearDown() override;
static std::shared_ptr<ov::snippets::op::Subgraph> getSubgraph(const std::shared_ptr<Model>& f);
using IShapeInferSnippetsFactory = ov::snippets::IShapeInferSnippetsFactory;
static std::shared_ptr<ov::snippets::op::Subgraph>
getLoweredSubgraph(const std::shared_ptr<Model>& f,
const ov::PartialShape& master_shape,
const std::vector<ov::snippets::pass::Manager::PositionedPass>& backend_passes = {},
const ov::snippets::lowered::pass::PassPipeline& lowered_pre_common = {},
const ov::snippets::lowered::pass::PassPipeline& lowered_post_common = {},
const std::shared_ptr<ov::snippets::Generator>& generator = nullptr);
const std::shared_ptr<ov::snippets::Generator>& generator = nullptr,
const std::shared_ptr<IShapeInferSnippetsFactory>& factory = std::make_shared<IShapeInferSnippetsFactory>());
static std::shared_ptr<ov::snippets::op::Subgraph> getTokenizedSubgraph(const std::shared_ptr<Model>& f);
protected:

View File

@ -5,36 +5,25 @@
#pragma once
#include "lowering_utils.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets_helpers.hpp"
#include "snippets/shape_types.hpp"
#include "snippets/pass/canonicalization.hpp"
namespace ov {
namespace test {
namespace snippets {
using BlockedShape = ov::snippets::op::Subgraph::BlockedShape;
using BlockedShapeVector = ov::snippets::op::Subgraph::BlockedShapeVector;
// todo: implement tests with 3 inputs and two outputs (aka SnippetsCanonicalizationParams3Inputs)
// Note that the expected output shape isn't necessary equal to one of the output blocked_shapes.
// For example, consider the following graph: (1, 2, 2, 1, 8) + (1, 2, 1, 1, 8) + (1, 2, 1, 5, 8) => (1, 2, 2, 1, 8) + (1, 2, 1, 5, 8).
typedef std::tuple<
std::tuple<Shape, BlockedShape>, // Shape & BlockedShape for input 0
std::tuple<Shape, BlockedShape>, // Shape & BlockedShape for input 0
BlockedShape, // BlockedShape output shape passed to canonicalize()
Shape // expected output Shape
> canonicalizationParams;
class CanonicalizationTests : public LoweringTests, public testing::WithParamInterface<canonicalizationParams> {
class CanonicalizationTests : public TransformationTestsF {
public:
static std::string getTestCaseName(testing::TestParamInfo<canonicalizationParams> obj);
using VectorDims = ov::snippets::VectorDims;
using Layout = std::vector<size_t>;
virtual void run();
protected:
void SetUp() override;
std::shared_ptr<SnippetsFunctionBase> snippets_model;
Shape expected_output_shape;
BlockedShapeVector input_blocked_shapes;
BlockedShapeVector output_blocked_shapes;
std::vector<VectorDims> m_input_shapes;
std::vector<Layout> m_input_layouts;
void prepare_functions(const std::vector<VectorDims>& shapes);
};
} // namespace snippets

View File

@ -106,13 +106,13 @@ std::shared_ptr<ov::snippets::op::Subgraph>
const std::vector<ov::snippets::pass::Manager::PositionedPass>& backend_passes,
const ov::snippets::lowered::pass::PassPipeline& lowered_pre_common,
const ov::snippets::lowered::pass::PassPipeline& lowered_post_common,
const std::shared_ptr<ov::snippets::Generator>& generator) {
const std::shared_ptr<ov::snippets::Generator>& generator,
const std::shared_ptr<IShapeInferSnippetsFactory>& factory) {
auto subgraph = getTokenizedSubgraph(f);
subgraph->set_generator(generator == nullptr ? std::make_shared<DummyGenerator>() : generator);
subgraph->set_master_shape(master_shape);
subgraph->set_tile_rank(2);
// Note: lowered_pipeline would have no effect on subgraph body, since it's applied on linear IR
subgraph->generate(backend_passes, lowered_pre_common, lowered_post_common);
subgraph->generate({}, {}, {}, backend_passes, lowered_pre_common, lowered_post_common, factory);
return subgraph;
}

View File

@ -5,101 +5,84 @@
#include <gtest/gtest.h>
#include "pass/canonicalization.hpp"
#include "common_test_utils/common_utils.hpp"
#include <subgraph_lowered.hpp>
#include "snippets/pass/canonicalization.hpp"
#include "snippets/op/rank_normalization.hpp"
#include <subgraph_simple.hpp>
namespace ov {
namespace test {
namespace snippets {
using ov::snippets::op::Subgraph;
namespace {
void normalizeParameter(const std::shared_ptr<ov::opset1::Parameter>& par, size_t num_prepend, size_t num_append) {
auto target_inputs = par->get_output_target_inputs(0);
auto rank_norm = std::make_shared<ov::snippets::op::RankNormalization>(par,
num_prepend,
num_append);
for (auto& t : target_inputs)
t.replace_source_output(rank_norm);
}
} // namespace
class SKIP_CanonicalizationTests : public CanonicalizationTests {
public:
void SetUp() override {
GTEST_SKIP();
}
void TearDown() override{};
};
std::string CanonicalizationTests::getTestCaseName(testing::TestParamInfo<canonicalizationParams> obj) {
std::vector<std::tuple<Shape, Subgraph::BlockedShape>> inputs(2);
Subgraph::BlockedShape output;
Shape expectedOutput;
std::tie(inputs[0], inputs[1], output, expectedOutput) = obj.param;
std::ostringstream result;
for (size_t i = 0; i < inputs.size(); i++) {
const auto& blockedshape = std::get<1>(inputs[i]);
// input shape
result << "IS[" << i << "]=" << ov::test::utils::vec2str(std::get<0>(inputs[i])) << "_";
// input blocked shape
result << "IBS[" << i << "]=" << ov::test::utils::partialShape2str({std::get<0>(blockedshape)}) << "_";
// input blocked order
result << "IBO[" << i << "]=" << ov::test::utils::vec2str(std::get<1>(blockedshape)) << "_";
}
// output blocked shape
result << "OBS[0]=" << ov::test::utils::partialShape2str({std::get<0>(output)}) << "_";
// output blocked order
result << "OBO[0]=" << ov::test::utils::vec2str(std::get<1>(output)) << "_";
result << "ExpOS[0]=" << ov::test::utils::vec2str(expectedOutput) << "_";
return result.str();
void CanonicalizationTests::prepare_functions(const std::vector<VectorDims>& shapes) {
std::vector<PartialShape> pshapes;
pshapes.reserve(shapes.size());
for (const auto& v : shapes )
pshapes.emplace_back(v);
const auto &f = AddFunction(pshapes);
model = f.getOriginal();
model_ref = model->clone();
}
void CanonicalizationTests::SetUp() {
TransformationTestsF::SetUp();
std::vector<std::tuple<Shape, Subgraph::BlockedShape>> inputs(2);
output_blocked_shapes.resize(1);
std::tie(inputs[0], inputs[1], output_blocked_shapes[0], expected_output_shape) = this->GetParam();
input_blocked_shapes = {std::get<1>(inputs[0]), std::get<1>(inputs[1])};
snippets_model = std::make_shared<AddFunction>(std::vector<PartialShape>{std::get<0>(inputs[0]), std::get<0>(inputs[1])});
void CanonicalizationTests::run() {
ASSERT_TRUE(model);
ASSERT_EQ(m_input_shapes.size(), m_input_layouts.size());
BlockedShapeVector blocked_input_shapes;
blocked_input_shapes.reserve(m_input_shapes.size());
for (size_t i = 0; i < m_input_shapes.size(); i++)
blocked_input_shapes.emplace_back(m_input_shapes[i], m_input_layouts[i]);
manager.register_pass<ov::snippets::pass::Canonicalization>(blocked_input_shapes);
disable_rt_info_check();
}
TEST_P(CanonicalizationTests, Add) {
model = snippets_model->getOriginal();
model_ref = snippets_model->getReference();
auto subgraph = getTokenizedSubgraph(model);
subgraph->set_generator(std::make_shared<DummyGenerator>());
auto canonical_output_shape = subgraph->canonicalize(output_blocked_shapes, input_blocked_shapes);
ASSERT_TRUE(canonical_output_shape.is_static());
ASSERT_DIMS_EQ(canonical_output_shape.get_shape(), expected_output_shape);
TEST_F(CanonicalizationTests, smoke_Snippets_Canonicalization_0) {
m_input_shapes = {{2, 3, 10, 64}, {2, 3, 10, 64}};
m_input_layouts = {{0, 1, 2, 3}, {0, 1, 2, 3}};
prepare_functions(m_input_shapes);
run();
}
namespace CanonicalizationTestsInstantiation {
using ov::snippets::op::Subgraph;
std::vector<Shape> input_shapes;
Shape expected_output_shape;
TEST_F(CanonicalizationTests, smoke_Snippets_Canonicalization_1) {
m_input_shapes = {{2, 3, 10, 64},
{10, 64}};
m_input_layouts = {{0, 1, 2, 3},
{0, 1}};
prepare_functions(m_input_shapes);
normalizeParameter(model_ref->get_parameters()[1], 2, 0);
run();
}
using ov::Shape;
ov::element::Type_t prec = ov::element::f32;
std::tuple<Shape, Subgraph::BlockedShape> blockedInput0{{1, 64, 2, 5},
{{1, 4, 2, 5, 16}, {0, 1, 2, 3, 1}, prec}};
Subgraph::BlockedShape output{{1, 4, 2, 5, 16}, {0, 1, 2, 3, 1}, prec};
Shape canonical_shape{1, 4, 2, 5, 16};
TEST_F(CanonicalizationTests, smoke_Snippets_Canonicalization_2) {
m_input_shapes = {{2, 3, 10, 64, 16},
{1, 10, 64}};
m_input_layouts = {{0, 1, 2, 3, 1},
{0, 1, 2}};
prepare_functions({{2, 48, 10, 64},
{1, 10, 64}});
const auto& params = model_ref->get_parameters();
// Note: We can't create functions with mismatching input shapes,
// so we have to set Parameter shapes after the functions were created
// This reproduces Snippets pipeline well, since blocked shapes are set after the tokenization
params[0]->set_partial_shape(PartialShape(m_input_shapes[0]));
model->get_parameters()[0]->set_partial_shape(PartialShape(m_input_shapes[0]));
std::vector<std::tuple<Shape, Subgraph::BlockedShape>> blockedInput1{{{1, 1, 2, 5}, {{1, 1, 2, 5, 1}, {0, 1, 2, 3, 1}, prec}},
{{1, 1, 2, 1}, {{1, 1, 2, 1, 1}, {0, 1, 2, 3, 1}, prec}},
{{1, 64, 1, 1}, {{1, 4, 1, 1, 16}, {0, 1, 2, 3, 1}, prec}}};
normalizeParameter(params[1], 1, 1);
// need to trigger validate..(...) manually to propagate new blocked shapes,
// this is correct since RankNormalization ops re-enables shape propagation for blocked shapes
model_ref->validate_nodes_and_infer_types();
run();
}
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_BroadcastBlocked,
SKIP_CanonicalizationTests /* CVS-114607 */,
::testing::Combine(::testing::Values(blockedInput0),
::testing::ValuesIn(blockedInput1),
::testing::Values(output),
::testing::Values(canonical_shape)),
CanonicalizationTests::getTestCaseName);
std::vector<std::tuple<Shape, Subgraph::BlockedShape>> planarInput1{{{1, 1, 2, 5}, {{1, 2, 5}, {0, 1, 2}, prec}},
{{1, 1, 2, 5}, {{2, 5}, {0, 1}, prec}},
{{1, 2, 5}, {{2, 5}, {0, 1}, prec}},
{{2, 5}, {{2, 5}, {0, 1}, prec}},
{{5}, {{5}, {0}, prec}}};
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_BroadcastPlanar,
SKIP_CanonicalizationTests /* CVS-114607 */,
::testing::Combine(::testing::Values(blockedInput0),
::testing::ValuesIn(planarInput1),
::testing::Values(output),
::testing::Values(canonical_shape)),
CanonicalizationTests::getTestCaseName);
} // namespace CanonicalizationTestsInstantiation
} // namespace snippets
} // namespace test

View File

@ -25,25 +25,25 @@
#include <ngraph/opsets/opset5.hpp>
using namespace std;
namespace ov {
#define CREATE_SNIPPETS_EMITTER(e_type) { \
[this](const ov::snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
return std::make_shared<e_type>(h.get(), isa, expr); \
}, \
[](const std::shared_ptr<ngraph::Node>& n) -> std::set<std::vector<element::Type>> { \
return e_type::get_supported_precisions(n); \
} \
};
}
#define CREATE_CPU_EMITTER(e_type) { \
[this](const ov::snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
return std::make_shared<e_type>(h.get(), isa, expr->get_node()); \
}, \
[](const std::shared_ptr<ngraph::Node>& n) -> std::set<std::vector<element::Type>> { \
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
return e_type::get_supported_precisions(n); \
} \
};
}
class jit_snippet : public dnnl::impl::cpu::x64::jit_generator {
public:
@ -58,94 +58,95 @@ public:
}
};
ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa)
intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa)
: TargetMachine(), h(new jit_snippet()), isa(host_isa) {
// data movement
jitters[ov::op::v0::Parameter::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(NopEmitter);
jitters[ov::op::v0::Result::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(NopEmitter);
jitters[op::v0::Parameter::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(NopEmitter);
jitters[op::v0::Result::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(NopEmitter);
jitters[snippets::op::Buffer::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(NopEmitter);
jitters[snippets::op::VectorBuffer::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(NopEmitter);
// jitters[ov::op::v1::Constant::get_type_info_static()] = CREATE_CPU_EMITTER(); // Not supported
jitters[snippets::op::RankNormalization::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(NopEmitter);
// jitters[op::v1::Constant::get_type_info_static()] = CREATE_CPU_EMITTER(); // Not supported
jitters[snippets::op::Load::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(LoadEmitter);
jitters[snippets::op::LoadReshape::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(LoadEmitter);
jitters[snippets::op::BroadcastLoad::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BroadcastLoadEmitter);
jitters[ov::intel_cpu::LoadConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(LoadConvertEmitter);
jitters[ov::intel_cpu::LoadConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(LoadConvertEmitter);
jitters[intel_cpu::LoadConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(LoadConvertEmitter);
jitters[intel_cpu::LoadConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(LoadConvertEmitter);
jitters[snippets::op::Store::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(StoreEmitter);
jitters[ov::intel_cpu::StoreConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(StoreConvertEmitter);
jitters[ov::intel_cpu::StoreConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(StoreConvertEmitter);
jitters[intel_cpu::StoreConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(StoreConvertEmitter);
jitters[intel_cpu::StoreConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(StoreConvertEmitter);
jitters[snippets::op::Scalar::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(ScalarEmitter);
jitters[snippets::op::BroadcastMove::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BroadcastMoveEmitter);
// jitters[snippets::op::Nop::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(NopEmitter); // Not supported
// jitters[ov::op::v1::Broadcast::get_type_info_static()] = CREATE_CPU_EMITTER(); // Not supported
// jitters[op::v1::Broadcast::get_type_info_static()] = CREATE_CPU_EMITTER(); // Not supported
jitters[snippets::op::ConvertTruncation::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_convert_truncation_emitter);
jitters[snippets::op::ConvertSaturation::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_convert_saturation_emitter);
// jitters[ov::op::v1::FakeQuantize::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[snippets::op::ConvertTruncation::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_convert_truncation_emitter);
jitters[snippets::op::ConvertSaturation::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_convert_saturation_emitter);
// jitters[op::v1::FakeQuantize::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// ternary
jitters[ov::op::v1::Select::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_select_emitter);
jitters[ov::intel_cpu::FusedMulAdd::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_mul_add_emitter);
jitters[op::v1::Select::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_select_emitter);
jitters[intel_cpu::FusedMulAdd::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_mul_add_emitter);
// binary
jitters[ov::op::v1::Add::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_add_emitter);
jitters[ov::op::v1::Divide::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_divide_emitter);
jitters[ov::op::v1::Equal::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_equal_emitter);
jitters[ov::op::v1::FloorMod::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_floor_mod_emitter);
jitters[ov::op::v1::Greater::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_greater_emitter);
jitters[ov::op::v1::GreaterEqual::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_greater_equal_emitter);
jitters[ov::op::v1::Less::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_less_emitter);
jitters[ov::op::v1::LessEqual::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_less_equal_emitter);
jitters[ov::op::v1::LogicalAnd::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_logical_and_emitter);
jitters[ov::op::v1::LogicalOr::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_logical_or_emitter);
jitters[ov::op::v1::LogicalXor::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_logical_xor_emitter);
jitters[ov::op::v1::Maximum::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_maximum_emitter);
jitters[ov::op::v1::Minimum::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_minimum_emitter);
jitters[ov::op::v1::Mod::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_mod_emitter);
jitters[ov::op::v1::Multiply::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_multiply_emitter);
jitters[ov::op::v1::NotEqual::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_not_equal_emitter);
jitters[snippets::op::PowerStatic::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_power_static_emitter);
jitters[ov::op::v1::Power::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_power_dynamic_emitter);
jitters[ov::op::v0::PRelu::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_prelu_emitter);
jitters[ov::op::v0::SquaredDifference::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_squared_difference_emitter);
jitters[ov::op::v1::Subtract::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_subtract_emitter);
jitters[ov::op::v0::Xor::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_logical_xor_emitter);
jitters[op::v1::Add::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_add_emitter);
jitters[op::v1::Divide::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_divide_emitter);
jitters[op::v1::Equal::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_equal_emitter);
jitters[op::v1::FloorMod::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_floor_mod_emitter);
jitters[op::v1::Greater::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_greater_emitter);
jitters[op::v1::GreaterEqual::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_greater_equal_emitter);
jitters[op::v1::Less::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_less_emitter);
jitters[op::v1::LessEqual::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_less_equal_emitter);
jitters[op::v1::LogicalAnd::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_logical_and_emitter);
jitters[op::v1::LogicalOr::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_logical_or_emitter);
jitters[op::v1::LogicalXor::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_logical_xor_emitter);
jitters[op::v1::Maximum::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_maximum_emitter);
jitters[op::v1::Minimum::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_minimum_emitter);
jitters[op::v1::Mod::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_mod_emitter);
jitters[op::v1::Multiply::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_multiply_emitter);
jitters[op::v1::NotEqual::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_not_equal_emitter);
jitters[snippets::op::PowerStatic::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_power_static_emitter);
jitters[op::v1::Power::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_power_dynamic_emitter);
jitters[op::v0::PRelu::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_prelu_emitter);
jitters[op::v0::SquaredDifference::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_squared_difference_emitter);
jitters[op::v1::Subtract::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_subtract_emitter);
jitters[op::v0::Xor::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_logical_xor_emitter);
// unary
jitters[ov::op::v0::Abs::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_abs_emitter);
// jitters[ov::op::v1::Acos::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[ov::op::v1::Asin::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[ov::op::v1::Atan::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[ov::op::v0::Ceiling::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_ceiling_emitter);
jitters[ov::op::v0::Clamp::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_clamp_emitter);
// jitters[ov::op::v1::Cos::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[ov::op::v1::Cosh::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[ov::op::v0::Elu::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_elu_emitter);
jitters[ov::op::v0::Erf::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_erf_emitter);
jitters[ov::op::v0::Exp::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_exp_emitter);
jitters[ov::op::v0::Floor::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_floor_emitter);
jitters[ngraph::opset5::Round::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_round_emitter);
// jitters[ov::op::v1::Log::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[ov::op::v1::LogicalNot::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_logical_not_emitter);
jitters[ov::op::v0::Negative::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_negative_emitter);
jitters[ov::op::v0::Relu::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_relu_emitter);
// jitters[ov::op::v1::Sign::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[ov::op::v0::Sigmoid::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_sigmoid_emitter);
// jitters[ov::op::v1::Sin::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[ov::op::v1::Sinh::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[ov::op::v0::Sqrt::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_sqrt_emitter);
// jitters[ov::op::v1::Tan::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_tanh_emitter);
jitters[op::v0::Abs::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_abs_emitter);
// jitters[op::v1::Acos::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[op::v1::Asin::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[op::v1::Atan::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[op::v0::Ceiling::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_ceiling_emitter);
jitters[op::v0::Clamp::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_clamp_emitter);
// jitters[op::v1::Cos::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[op::v1::Cosh::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[op::v0::Elu::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_elu_emitter);
jitters[op::v0::Erf::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_erf_emitter);
jitters[op::v0::Exp::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_exp_emitter);
jitters[op::v0::Floor::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_floor_emitter);
jitters[ngraph::opset5::Round::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_round_emitter);
// jitters[op::v1::Log::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[op::v1::LogicalNot::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_logical_not_emitter);
jitters[op::v0::Negative::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_negative_emitter);
jitters[op::v0::Relu::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_relu_emitter);
// jitters[op::v1::Sign::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[op::v0::Sigmoid::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_sigmoid_emitter);
// jitters[op::v1::Sin::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[op::v1::Sinh::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[op::v0::Sqrt::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_sqrt_emitter);
// jitters[op::v1::Tan::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_tanh_emitter);
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_swish_emitter);
jitters[ngraph::op::v4::HSwish::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_hswish_emitter);
// jitters[ov::op::v1::HardSigmoid::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[ov::op::v1::Selu::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[ngraph::op::v0::Gelu::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_gelu_v0_emitter);
jitters[ngraph::op::v7::Gelu::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_gelu_v7_emitter);
jitters[intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_swish_emitter);
jitters[ngraph::op::v4::HSwish::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_hswish_emitter);
// jitters[op::v1::HardSigmoid::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
// jitters[op::v1::Selu::get_type_info_static()] = CREATE_CPU_EMITTER(); // not supported
jitters[ngraph::op::v0::Gelu::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_gelu_v0_emitter);
jitters[ngraph::op::v7::Gelu::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_gelu_v7_emitter);
jitters[snippets::op::Fill::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(FillEmitter);
jitters[snippets::op::HorizonMax::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(HorizonEmitter);
@ -154,11 +155,11 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_
jitters[snippets::op::Kernel::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(KernelEmitter);
jitters[snippets::op::LoopBegin::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(LoopBeginEmitter);
jitters[snippets::op::LoopEnd::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(LoopEndEmitter);
jitters[ov::intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmEmitter);
jitters[ov::intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmCopyBEmitter);
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmEmitter);
jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmCopyBEmitter);
}
size_t ov::intel_cpu::CPUTargetMachine::get_lanes() const {
size_t intel_cpu::CPUTargetMachine::get_lanes() const {
switch (isa) {
case dnnl::impl::cpu::x64::avx2 : return dnnl::impl::cpu::x64::cpu_isa_traits<dnnl::impl::cpu::x64::avx2>::vlen / sizeof(float);
case dnnl::impl::cpu::x64::sse41 : return dnnl::impl::cpu::x64::cpu_isa_traits<dnnl::impl::cpu::x64::sse41>::vlen / sizeof(float);
@ -167,28 +168,62 @@ size_t ov::intel_cpu::CPUTargetMachine::get_lanes() const {
}
}
bool ov::intel_cpu::CPUTargetMachine::is_supported() const {
dnnl::impl::cpu::x64::cpu_isa_t intel_cpu::CPUTargetMachine::get_isa() const {
return isa;
}
bool intel_cpu::CPUTargetMachine::is_supported() const {
return dnnl::impl::cpu::x64::mayiuse(isa);
}
ov::snippets::code ov::intel_cpu::CPUTargetMachine::get_snippet() const {
snippets::CompiledSnippetPtr intel_cpu::CPUTargetMachine::get_snippet() {
if (h->create_kernel() != dnnl::impl::status::success) {
IE_THROW() << "Failed to create jit_kernel in get_snippet()";
}
return h->jit_ker();
const auto& result = std::make_shared<CompiledSnippetCPU>(std::unique_ptr<dnnl::impl::cpu::x64::jit_generator>(h.release()));
// Note that we reset all the generated code, since it was copied into CompiledSnippetCPU
h.reset(new jit_snippet());
return result;
}
ov::intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared<CPUTargetMachine>(isa_)) {
intel_cpu::CompiledSnippetCPU::CompiledSnippetCPU(std::unique_ptr<dnnl::impl::cpu::x64::jit_generator> h) : h_compiled(std::move(h)) {
OPENVINO_ASSERT(h_compiled && h_compiled->jit_ker(), "Got invalid jit generator or kernel was nopt compiled");
}
ov::snippets::Generator::opRegType ov::intel_cpu::CPUGenerator::get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const {
if (std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCPU>(op) ||
std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCopyB>(op))
const uint8_t* intel_cpu::CompiledSnippetCPU::get_code() const {
return h_compiled->jit_ker();
}
size_t intel_cpu::CompiledSnippetCPU::get_code_size() const {
return h_compiled->getSize();
}
bool intel_cpu::CompiledSnippetCPU::empty() const {
return get_code_size() == 0;
}
intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared<CPUTargetMachine>(isa_)) {
}
std::shared_ptr<snippets::Generator> intel_cpu::CPUGenerator::clone() const {
const auto& cpu_target_machine = std::dynamic_pointer_cast<CPUTargetMachine>(target);
OPENVINO_ASSERT(cpu_target_machine, "Failed to clone CPUGenerator: the instance contains incompatible TargetMachine type");
return std::make_shared<CPUGenerator>(cpu_target_machine->get_isa());
}
snippets::Generator::opRegType intel_cpu::CPUGenerator::get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const {
if (std::dynamic_pointer_cast<intel_cpu::BrgemmCPU>(op) ||
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyB>(op))
return gpr2gpr;
else if (
std::dynamic_pointer_cast<ov::intel_cpu::FusedMulAdd>(op) ||
std::dynamic_pointer_cast<ov::intel_cpu::SwishNode>(op))
std::dynamic_pointer_cast<intel_cpu::FusedMulAdd>(op) ||
std::dynamic_pointer_cast<intel_cpu::SwishNode>(op))
return vec2vec;
else
OPENVINO_THROW("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!");
}
bool intel_cpu::CPUGenerator::uses_precompiled_kernel(const std::shared_ptr<snippets::Emitter>& e) const {
return std::dynamic_pointer_cast<intel_cpu::BrgemmEmitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyBEmitter>(e);
}
} // namespace ov

View File

@ -13,13 +13,23 @@
namespace ov {
namespace intel_cpu {
class CompiledSnippetCPU : public snippets::CompiledSnippet {
const std::unique_ptr<const dnnl::impl::cpu::x64::jit_generator> h_compiled;
public:
const uint8_t* get_code() const override;
size_t get_code_size() const override;
bool empty() const override;
explicit CompiledSnippetCPU(std::unique_ptr<dnnl::impl::cpu::x64::jit_generator> h);
};
class CPUTargetMachine : public snippets::TargetMachine {
public:
CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa);
explicit CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa);
bool is_supported() const override;
snippets::code get_snippet() const override;
snippets::CompiledSnippetPtr get_snippet() override;
size_t get_lanes() const override;
dnnl::impl::cpu::x64::cpu_isa_t get_isa() const;
private:
std::unique_ptr<dnnl::impl::cpu::x64::jit_generator> h;
@ -29,8 +39,10 @@ private:
class CPUGenerator : public snippets::Generator {
public:
CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa);
std::shared_ptr<Generator> clone() const override;
protected:
bool uses_precompiled_kernel(const std::shared_ptr<snippets::Emitter>& emitter) const override;
opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const override;
};

View File

@ -11,6 +11,7 @@
#include "snippets/lowered/port_connector.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op//brgemm_cpu.hpp"
#include "snippets/op/rank_normalization.hpp"
using namespace InferenceEngine;
using namespace Xbyak;
@ -121,7 +122,12 @@ KernelEmitter::KernelEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPt
element::Type etype;
switch (expr->get_type()) {
case snippets::lowered::IOExpression::io_type::INPUT: {
desc = expr->get_output_port_descriptor(0);
const auto first_consumer = expr->get_output_port_connector(0)->get_consumers().begin()->get_expr();
if (ov::is_type<snippets::op::RankNormalization>(first_consumer->get_node())) {
desc = first_consumer->get_output_port_descriptor(0);
} else {
desc = expr->get_output_port_descriptor(0);
}
etype = expr->get_node()->get_output_element_type(0);
num_inputs++;
break;

View File

@ -157,6 +157,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
NGRAPH_OP(Store, ov::snippets::op)
NGRAPH_OP(Subgraph, ov::snippets::op)
NGRAPH_OP(VectorBuffer, ov::snippets::op)
NGRAPH_OP(RankNormalization, ov::snippets::op)
NGRAPH_OP_X64(LoadConvertSaturation, ov::intel_cpu)
NGRAPH_OP_X64(LoadConvertTruncation, ov::intel_cpu)
NGRAPH_OP_X64(StoreConvertSaturation, ov::intel_cpu)

View File

@ -13,7 +13,6 @@
#include <onednn/dnnl.h>
#include <dnnl_extension_utils.h>
#include <ngraph/pass/visualize_tree.hpp>
#include <ngraph/rt_info.hpp>
#include <ie_ngraph_utils.hpp>
@ -119,67 +118,36 @@ bool SnippetKey::operator==(const SnippetKey& rhs) const {
return true;
}
snippets::op::Subgraph::BlockedShapeVector getBlockedShapes(const std::vector<std::vector<size_t>>& memBlockedDims,
const std::vector<std::vector<size_t>>& memOrders, const std::vector<InferenceEngine::Precision>& memPrecs) {
size_t numShapes = memBlockedDims.size();
if (memOrders.size() != numShapes || memPrecs.size() != numShapes)
IE_THROW(Unexpected) << "Number of shapes is mismacthed for dimensions, orders and precisions";
snippets::op::Subgraph::BlockedShapeVector blockedShapes(numShapes);
for (size_t i = 0; i < numShapes; i++) {
size_t dimSize = memBlockedDims[i].size();
std::vector<Dimension> dims(dimSize);
for (size_t j = 0; j < dimSize; j++) {
dims[j] = memBlockedDims[i][j];
}
ov::PartialShape shape(dims);
ov::AxisVector order(memOrders[i]);
ov::element::Type precision = InferenceEngine::details::convertPrecision(memPrecs[i]);
blockedShapes[i] = snippets::op::Subgraph::BlockedShape{shape, order, precision};
}
return blockedShapes;
}
} // namespace
Snippet::Snippet(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context)
: Node(op, context, SnippetShapeInferFactory(op)) {
host_isa = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ?
dnnl::impl::cpu::x64::avx512_core : dnnl::impl::cpu::x64::avx2;
original_snippet = ov::as_type_ptr<snippets::op::Subgraph>(op);
if (!original_snippet) {
IE_THROW(NotImplemented) << "Node is not an instance of snippets::op::Subgraph";
}
init_body_hash();
is_dynamic = isDynamicNgraphNode(op);
}
const auto& tmp_snippet = ov::as_type_ptr<snippets::op::Subgraph>(op);
OPENVINO_ASSERT(tmp_snippet, "Attempt to create Snippet node from an invalid op type");
snippetAttrs.snippet = tmp_snippet->clone();
snippetAttrs.bodyHash = get_body_hash(tmp_snippet);
void Snippet::copy_snippet() const {
ov::OutputVector subgraph_node_inputs;
for (const auto &input : original_snippet->input_values()) {
auto new_input = std::make_shared<ov::opset1::Parameter>(input.get_element_type(), input.get_partial_shape());
subgraph_node_inputs.push_back(new_input);
}
std::shared_ptr<ov::Model> new_body = original_snippet->body_ptr()->clone();
snippetAttrs.snippet = std::make_shared<snippets::op::Subgraph>(subgraph_node_inputs, new_body);
ov::copy_runtime_info(original_snippet, snippetAttrs.snippet);
snippetAttrs.snippet->set_friendly_name(original_snippet->get_friendly_name());
#if defined(OPENVINO_ARCH_X86_64)
snippetAttrs.snippet->set_generator(std::make_shared<CPUGenerator>(host_isa));
#else
IE_THROW(NotImplemented) << "CPU plugin: code-generation is not supported on non-x64 platforms";
OPENVINO_THROW("CPU plugin: Snippets code-generator is not supported on non-x64 platforms");
#endif // OPENVINO_ARCH_X86_64
// Note: we have to update shapeInfer, so it uses the per-thread op::Subgraph copy
shapeInference = SnippetShapeInferFactory(snippetAttrs.snippet).makeShapeInfer();
is_dynamic = isDynamicNgraphNode(op);
}
void Snippet::init_body_hash() {
uint64_t Snippet::get_body_hash(const std::shared_ptr<snippets::op::Subgraph>& snippet) {
uint64_t seed = 0;
ov::snippets::pass::Hash hash_function(seed);
hash_function.run_on_model(original_snippet->body_ptr());
snippetAttrs.bodyHash = seed;
hash_function.run_on_model(snippet->body_ptr());
return seed;
}
void Snippet::initSupportedPrimitiveDescriptors() {
copy_snippet();
if (!supportedPrimitiveDescriptors.empty())
return;
@ -315,16 +283,29 @@ void Snippet::selectOptimalPrimitiveDescriptor() {
}
void Snippet::initOptimalPrimitiveDescriptor() {
const auto isPlanar = [](const VectorDims& order ) {
for (size_t i = 0; i < order.size(); ++i)
if (order[i] != i)
return false;
return true;
};
Node::initOptimalPrimitiveDescriptor();
// memory order and precision is determined now, there is no need to prepare for each dynamic shapes.
const auto config = getSelectedPrimitiveDescriptor()->getConfig();
inputNum = config.inConfs.size();
snippets::op::Subgraph::BlockedShapeVector in_blocked_shapes;
snippetAttrs.inMemPrecs.resize(inputNum);
snippetAttrs.inMemOrders.resize(inputNum);
in_blocked_shapes.reserve(inputNum);
snippetAttrs.has_non_planar_inputs = false;
for (size_t i = 0; i < inputNum; i++) {
const auto& memDesc = config.inConfs[i].getMemDesc();
snippetAttrs.inMemPrecs[i] = memDesc->getPrecision();
snippetAttrs.inMemOrders[i] = memDesc->as<BlockedMemoryDesc>()->getOrder();
const auto& blockedDesc = memDesc->as<BlockedMemoryDesc>();
const auto& order = blockedDesc->getOrder();
snippetAttrs.inMemOrders[i] = order;
snippetAttrs.has_non_planar_inputs |= !isPlanar(order);
in_blocked_shapes.emplace_back(blockedDesc->getBlockDims(), order);
}
outputNum = config.outConfs.size();
snippetAttrs.outMemPrecs.resize(outputNum);
@ -338,6 +319,52 @@ void Snippet::initOptimalPrimitiveDescriptor() {
snippetAttrs.outMemBlockedDims.resize(outputNum);
srcMemPtrs.resize(inputNum);
dstMemPtrs.resize(outputNum);
// here we should perform all shape-agnostic snippets passes
// * canonicalization (RankNormalization insert)
// * precision propagation & align element types
// * data flow optimizations
// The result of these transformations will be reused by all shapes
using Manager = snippets::pass::Manager;
std::vector<Manager::PositionedPass> backend_passes;
#if defined(OPENVINO_ARCH_X86_64)
using PassPosition = snippets::pass::Manager::PassPosition;
using Place = snippets::pass::Manager::PassPosition::Place;
# define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...) \
backend_passes.emplace_back(PASS_POS, std::make_shared<PASS>(__VA_ARGS__))
#else
# define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...)
#endif // OPENVINO_ARCH_X86_64
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ConvertToSwishCPU);
if (context->getConfig().inferencePrecision == ov::element::bf16 && snippetAttrs.snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ov::snippets::pass::MatMulToBrgemm);
SNIPPETS_REGISTER_PASS(PassPosition(Place::After, "MatMulToBrgemm"), pass::EnforcePrecision, element::f32, element::bf16);
}
SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, "PropagatePrecision"), ov::intel_cpu::pass::BrgemmToBrgemmCPU);
SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, "PropagatePrecision"), ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::RemoveConverts);
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::MulAddToFMA);
#undef SNIPPETS_REGISTER_PASS
std::vector<ov::element::Type> input_precisions;
std::vector<ov::element::Type> output_precisions;
input_precisions.reserve(inputNum);
for (const auto& p : snippetAttrs.inMemPrecs) {
input_precisions.push_back(InferenceEngine::details::convertPrecision(p));
}
output_precisions.reserve(outputNum);
for (const auto& p : snippetAttrs.outMemPrecs)
output_precisions.push_back(InferenceEngine::details::convertPrecision(p));
snippetAttrs.snippet->data_flow_transformations(in_blocked_shapes, input_precisions, output_precisions, backend_passes);
snippetAttrs.snippet->convert_body_to_linear_ir(std::make_shared<snippets::CPUShapeInferSnippetsFactory>());
}
InferenceEngine::Precision Snippet::getRuntimePrecision() const {
@ -361,9 +388,8 @@ void Snippet::prepareParams() {
SnippetKey key = {snippetAttrs};
auto builder = [this](const SnippetKey& key) -> std::shared_ptr<SnippetExecutor> {
std::shared_ptr<SnippetExecutor> executor = std::make_shared<SnippetJitExecutor>(key.attrs, is_canonicalized,
is_dynamic, context->getConfig().inferencePrecision == ov::element::bf16);
is_canonicalized = true;
std::shared_ptr<SnippetExecutor> executor =
std::make_shared<SnippetJitExecutor>(key.attrs, is_dynamic, context->getConfig().inferencePrecision == ov::element::bf16);
return executor;
};
@ -426,15 +452,17 @@ void Snippet::executeDynamicImpl(dnnl::stream strm) {
}
void Snippet::SnippetJitExecutor::exec(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) {
if (schedule.ptr == nullptr) {
if (schedule.lowering_result.compiled_snippet->empty()) {
IE_THROW() << "Snippet can't use Optimized implementation and can't fallback to reference";
}
auto initStartMemoryOffsets = [this, &inMemPtrs, &outMemPtrs]() {
for (size_t i = 0; i < numInput; i++) {
start_offset_in[i] = inMemPtrs[i]->getDescWithType<BlockedMemoryDesc>()->getOffsetPadding() * dataSize[i];
start_offset_in[i] =
static_cast<ptrdiff_t>(inMemPtrs[i]->getDescWithType<BlockedMemoryDesc>()->getOffsetPadding() * dataSize[i]);
}
for (size_t i = 0; i < numOutput; i++) {
start_offset_out[i] = outMemPtrs[i]->getDescWithType<BlockedMemoryDesc>()->getOffsetPadding() * dataSize[i + numInput];
start_offset_out[i] =
static_cast<ptrdiff_t>(outMemPtrs[i]->getDescWithType<BlockedMemoryDesc>()->getOffsetPadding() * dataSize[i + numInput]);
}
};
// initialize start offsets to src and dst memory
@ -465,13 +493,13 @@ void Snippet::SnippetJitExecutor::update_ptrs(jit_snippets_call_args& call_args,
void Snippet::SnippetJitExecutor::schedule_6d(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) {
const auto& dom = parallel_exec_domain;
// < N, C, H, W > < 1, 1, N, C*H*W>
const auto& callable = schedule.get_callable<kernel>();
parallel_for5d(dom[0], dom[1], dom[2], dom[3], dom[4],
[&](int64_t d0, int64_t d1, int64_t d2, int64_t d3, int64_t d4) {
int64_t indexes[] = {d0, d1, d2, d3, d4};
jit_snippets_call_args call_args;
update_ptrs(call_args, inMemPtrs, outMemPtrs);
schedule.get_callable<kernel>()(indexes, &call_args);
callable(indexes, &call_args);
});
}
@ -487,8 +515,8 @@ void Snippet::SnippetJitExecutor::schedule_nt(const std::vector<MemoryPtr>& inMe
std::vector<int64_t> indexes(work_size.size() - 1, 0);
for (size_t iwork = start; iwork < end; ++iwork) {
size_t tmp = iwork;
for (ptrdiff_t j = work_size.size() - 2; j >= 0; j--) {
indexes[j] = tmp % work_size[j];
for (ptrdiff_t j = static_cast<ptrdiff_t>(work_size.size()) - 2; j >= 0; j--) {
indexes[j] = static_cast<int64_t>(tmp % work_size[j]);
tmp /= work_size[j];
}
@ -497,49 +525,25 @@ void Snippet::SnippetJitExecutor::schedule_nt(const std::vector<MemoryPtr>& inMe
});
}
Snippet::SnippetExecutor::SnippetExecutor(const SnippetAttrs& attrs, bool is_canonicalized, bool is_dynamic, bool enforceBF16)
: snippetAttrs(attrs), is_canonicalized(is_canonicalized), is_dynamic(is_dynamic), enforceBF16(enforceBF16) {}
Snippet::SnippetExecutor::SnippetExecutor(SnippetAttrs attrs, bool is_dynamic, bool enforceBF16)
: snippetAttrs(std::move(attrs)), is_dynamic(is_dynamic), enforceBF16(enforceBF16) {}
Snippet::SnippetJitExecutor::SnippetJitExecutor(const SnippetAttrs& attrs, bool is_canonicalized, bool is_dynamic, bool enforceBF16) :
SnippetExecutor(attrs, is_canonicalized, is_dynamic, enforceBF16) {
Snippet::SnippetJitExecutor::SnippetJitExecutor(SnippetAttrs attrs, bool is_dynamic, bool enforceBF16) :
SnippetExecutor(std::move(attrs), is_dynamic, enforceBF16) {
numInput = snippetAttrs.inMemBlockedDims.size();
numOutput = snippetAttrs.outMemBlockedDims.size();
start_offset_in.resize(numInput);
start_offset_out.resize(numOutput);
auto local_copy = [this]() {
ov::OutputVector subgraph_node_inputs;
for (size_t i = 0; i < numInput; i++) {
const auto paramShape = snippetAttrs.snippet->body_ptr()->get_parameters()[i]->get_shape();
const auto paramType = snippetAttrs.snippet->body_ptr()->get_parameters()[i]->get_element_type();
auto new_input = std::make_shared<ov::opset1::Parameter>(paramType, paramShape);
subgraph_node_inputs.push_back(new_input);
}
std::shared_ptr<ov::Model> new_body = snippetAttrs.snippet->body_ptr()->clone();
snippet_for_generation = std::make_shared<ov::snippets::op::Subgraph>(subgraph_node_inputs, new_body);
ov::copy_runtime_info(snippetAttrs.snippet, snippet_for_generation);
snippet_for_generation->set_friendly_name(snippetAttrs.snippet->get_friendly_name());
#if defined(OPENVINO_ARCH_X86_64)
auto host_isa = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)
? dnnl::impl::cpu::x64::avx512_core
: dnnl::impl::cpu::x64::avx2;
snippet_for_generation->set_generator(std::make_shared<CPUGenerator>(host_isa));
#else
IE_THROW(NotImplemented) << "CPU plugin: code-generation is not supported on non-x64 platforms";
#endif // OPENVINO_ARCH_X86_64
};
// is_canonicalized is ture means just reshape canonicalized graph with new input shapes, and get updated master shape,
// false means canonicalization, determine master_shape on snippetAttrs.snippet.
ov::PartialShape canonicalShape = canonicalizeBody(is_canonicalized);
if (is_dynamic) {
// we need a local snippets for generation, which will be adjusted based on input shapes possibily.
// The adjustment may be not compatible with new input shape in dynamic node, such as broadcastMove inserted.
local_copy();
} else {
snippet_for_generation = snippetAttrs.snippet;
// todo: snippets don't support backend-provided blocking, so we need to reshape body
// using blocked shapes first. This can be removed after [121670]
if (snippetAttrs.has_non_planar_inputs) {
std::vector<snippets::VectorDimsRef> in_shapes;
for (const auto& s : snippetAttrs.inMemBlockedDims)
in_shapes.emplace_back(s);
snippetAttrs.snippet->shape_infer(in_shapes);
}
const VectorDims& canonicalShape = snippetAttrs.snippet->infer_master_shape();
// initialize by maximum output dimension. Dimensions of outputs should be broadcastable
tensorRank = std::max(static_cast<size_t>(rank6D), canonicalShape.size());
@ -552,85 +556,39 @@ Snippet::SnippetJitExecutor::SnippetJitExecutor(const SnippetAttrs& attrs, bool
};
initDataSizes();
if (canonicalShape.is_dynamic())
if (std::any_of(canonicalShape.begin(), canonicalShape.end(),
[](size_t x){return x == snippets::IShapeInferSnippets::DYNAMIC_DIMENSION;}))
IE_THROW() << "Snippets: Canonicalization returned dynamic shape in static pipeline";
snippet_for_generation->set_min_parallel_work_amount(static_cast<size_t>(parallel_get_max_threads()));
snippetAttrs.snippet->set_min_parallel_work_amount(static_cast<size_t>(parallel_get_max_threads()));
// Note: minimal JIT work amount is a predefined value that describes the number of kernel iterations (work amount)
// needed to cover kernel call overhead. It is used for balancing between parallel and JIT work amounts in domain optimization.
snippet_for_generation->set_min_jit_work_amount(256);
snippetAttrs.snippet->set_min_jit_work_amount(256);
// generate
jit_snippets_compile_args jcp;
jcp.parallel_executor_ndims = tensorRank;
generate(&jcp);
buffer_scratchpad_size = snippet_for_generation->get_buffer_scratchpad_size();
buffer_scratchpad_size = schedule.lowering_result.buffer_scratchpad_size;
buffer_scratchpad.resize(buffer_scratchpad_size * parallel_get_max_threads(), 0);
parallel_exec_domain = schedule.parallel_exec_domain;
harnessWorkAmount = std::accumulate(parallel_exec_domain.begin(), parallel_exec_domain.end(), 1, std::multiplies<size_t>());
parallel_exec_domain = getNormalizedDimsBySize(parallel_exec_domain, tensorRank);
}
ov::PartialShape Snippet::SnippetJitExecutor::canonicalizeBody(bool reshape) {
ov::snippets::op::Subgraph::BlockedShapeVector input_blocked_shapes = getBlockedShapes(
snippetAttrs.inMemBlockedDims, snippetAttrs.inMemOrders, snippetAttrs.inMemPrecs);
if (reshape) {
const auto& canonicalShape = snippetAttrs.snippet->canonicalized_body_shape_infer(input_blocked_shapes);
return canonicalShape;
} else {
ov::snippets::op::Subgraph::BlockedShapeVector output_blocked_shapes = getBlockedShapes(
snippetAttrs.outMemBlockedDims, snippetAttrs.outMemOrders, snippetAttrs.outMemPrecs);
const auto& canonicalShape = snippetAttrs.snippet->canonicalize(output_blocked_shapes, input_blocked_shapes);
return canonicalShape;
}
}
void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp) {
using Manager = snippets::pass::Manager;
std::vector<Manager::PositionedPass> backend_passes;
#if defined(OPENVINO_ARCH_X86_64)
using PassPosition = snippets::pass::Manager::PassPosition;
using Place = snippets::pass::Manager::PassPosition::Place;
# define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...) \
backend_passes.emplace_back(PASS_POS, std::make_shared<PASS>(__VA_ARGS__))
#else
# define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...)
#endif // OPENVINO_ARCH_X86_64
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ConvertToSwishCPU);
if (enforceBF16 && snippet_for_generation->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ov::snippets::pass::MatMulToBrgemm);
SNIPPETS_REGISTER_PASS(PassPosition(Place::After, "MatMulToBrgemm"), pass::EnforcePrecision, element::f32, element::bf16);
}
SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, "PropagatePrecision"), ov::intel_cpu::pass::BrgemmToBrgemmCPU);
SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, "PropagatePrecision"), ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::RemoveConverts);
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::MulAddToFMA);
#undef SNIPPETS_REGISTER_PASS
ov::snippets::lowered::pass::PassPipeline control_flow_markup_pipeline;
CPU_REGISTER_PASS_X64(control_flow_markup_pipeline, ov::intel_cpu::pass::BrgemmBlocking)
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert)
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape);
// Note: we need to pass valid shapeInfer factory to generate, so it can be used in OptimizeDomain pass
// in all other cases nGraph shape inference will be used until ticket # 113209 (PR 18563) is merged
schedule = snippet_for_generation->generate(backend_passes,
control_flow_markup_pipeline,
control_flow_pipeline,
std::make_shared<snippets::CPUShapeInferSnippetsFactory>(),
reinterpret_cast<const void*>(jcp));
schedule = snippetAttrs.snippet->generate_from_linear_ir(control_flow_markup_pipeline,
control_flow_pipeline,
reinterpret_cast<const void*>(jcp));
}
bool Snippet::SnippetJitExecutor::schedule_created() {
return schedule.ptr != nullptr;
return !schedule.lowering_result.compiled_snippet->empty();
}
} // namespace node

View File

@ -48,31 +48,24 @@ public:
// Local copy of subgraph node for canonization & code generation
std::shared_ptr<snippets::op::Subgraph> snippet;
uint64_t bodyHash;
std::vector<std::vector<size_t>> inMemBlockedDims;
std::vector<std::vector<size_t>> inMemOrders;
std::vector<VectorDims> inMemBlockedDims;
std::vector<VectorDims> inMemOrders;
std::vector<InferenceEngine::Precision> inMemPrecs;
std::vector<std::vector<size_t>> outMemBlockedDims;
std::vector<std::vector<size_t>> outMemOrders;
std::vector<VectorDims> outMemBlockedDims;
std::vector<VectorDims> outMemOrders;
std::vector<InferenceEngine::Precision> outMemPrecs;
// todo: used flag if we need extra shape infer, can be removed after [121670]
bool has_non_planar_inputs;
};
private:
static const size_t rank6D {6};
typedef void (*kernel)(const void *, const void *);
// Create a deep local copy of the input snippet to perform canonicalization & code generation
// TODO: Probably better to implement a proper copy constructor
void copy_snippet() const;
void init_body_hash();
static uint64_t get_body_hash(const std::shared_ptr<snippets::op::Subgraph>& snippet);
size_t inputNum = 0;
size_t outputNum = 0;
// Original subgraph node
std::shared_ptr<snippets::op::Subgraph> original_snippet;
mutable std::shared_ptr<snippets::op::Subgraph> local_snippet;
// Holds ISA version used is codeGeneration target
dnnl::impl::cpu::x64::cpu_isa_t host_isa;
@ -80,18 +73,17 @@ private:
std::vector<MemoryPtr> dstMemPtrs = {};
mutable SnippetAttrs snippetAttrs;
mutable bool is_canonicalized = false;
bool is_dynamic = false;
class SnippetExecutor {
public:
SnippetExecutor(const SnippetAttrs& attrs, bool is_canonicalized, bool is_dynamic, bool enforceBF16);
SnippetExecutor(SnippetAttrs attrs, bool is_dynamic, bool enforceBF16);
virtual void exec(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) = 0;
virtual ~SnippetExecutor() = default;
std::shared_ptr<IShapeInfer> shapeInference = nullptr;
protected:
SnippetAttrs snippetAttrs;
bool is_canonicalized = false;
bool is_dynamic = false;
bool enforceBF16 = false;
};
@ -100,7 +92,7 @@ private:
class SnippetJitExecutor : public SnippetExecutor {
public:
SnippetJitExecutor(const SnippetAttrs& attrs, bool is_canonicalized, bool is_dynamic, bool enforceBF16);
SnippetJitExecutor(SnippetAttrs attrs, bool is_dynamic, bool enforceBF16);
void exec(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) override;
bool schedule_created();
@ -113,16 +105,12 @@ private:
size_t numInput = 0;
size_t numOutput = 0;
ov::PartialShape canonicalizeBody(bool reshape);
void generate(const jit_snippets_compile_args*);
inline void update_ptrs(jit_snippets_call_args&, const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs);
// Evaluates generated snippet using parallel backend
void schedule_6d(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs);
void schedule_nt(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs);
std::shared_ptr<snippets::op::Subgraph> snippet_for_generation;
// Holds generated snippet with information about how to schedule it
snippets::Schedule schedule;

View File

@ -43,3 +43,8 @@ void FusedMulAdd::validate_and_infer_types() {
}
set_output_type(0, element_type, pshape);
}
const ov::op::AutoBroadcastSpec& FusedMulAdd::get_autob() const {
static ov::op::AutoBroadcastSpec autob_spec(ov::op::AutoBroadcastType::NUMPY);
return autob_spec;
}

View File

@ -24,6 +24,7 @@ public:
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
const ov::op::AutoBroadcastSpec& get_autob() const override;
};
} // namespace intel_cpu

View File

@ -68,6 +68,9 @@ std::vector<std::vector<InputShape>> inShapesAddPair {
{{{}, {{1, 128, 9, 30}}}, {{}, {{1, 128, 1, 30}}}},
{{{}, {{1, 128, 9, 1}}}, {{}, {{1, 128, 1, 30}}}},
{{{}, {{1, 128, 9, 16}}}, {{}, {{1, 128, 9, 1}}}},
// Test Canonicalization and Dimension collapsing
{{{}, {{2, 17, 3, 4}}}, {{}, {{1, 3, 4}}}},
{{{}, {{2, 17, 3, 4}}}, {{}, {{1, 4}}}},
// DS
{{{1, -1, {1, 10}, {1, 33}}, {{1, 128, 1, 1}, {1, 128, 1, 9}, {1, 128, 1, 17}, {1, 128, 1, 29}, {1, 128, 9, 1}, {1, 128, 1, 1}}},
{{{1, 1}, {128, 128}, {1, 10}, {1, 33}}, {{1, 128, 1, 1}, {1, 128, 1, 9}, {1, 128, 1, 17}, {1, 128, 1, 29}, {1, 128, 1, 30}, {1, 128, 1, 1}}}},

View File

@ -6,8 +6,10 @@
#include <subgraph_simple.hpp>
#include <transformations/snippets/x64/pass/mul_add_to_fma.hpp>
#include <transformations/snippets/x64/op/fused_mul_add.hpp>
#include <transformations/snippets/x64/shape_inference.hpp>
#include "snippets/op/scalar.hpp"
#include "lowering_utils.hpp"
#include "common_test_utils/common_utils.hpp"
#include "snippets/pass_manager.hpp"
namespace ov {
@ -61,7 +63,7 @@ protected:
ParameterVector parameters{data0, data1};
std::shared_ptr<Node> data2;
if (scalar_input) {
data2 = std::make_shared<ov::snippets::op::Scalar>(precision, Shape{}, 2.f);
data2 = std::make_shared<ov::snippets::op::Scalar>(precision, Shape{1}, 2.f);
} else {
auto parameter = std::make_shared<op::v0::Parameter>(precision, input_shapes[2]);
parameters.push_back(parameter);
@ -110,8 +112,8 @@ public:
std::ostringstream result;
for (size_t i = 0; i < inputShapes.size(); i++)
result << "IS[" << i << "]=" << inputShapes[i] << "_";
result << "MS=" << master_shape << "_";
result << "IS[" << i << "]=" << ov::test::utils::partialShape2str({inputShapes[i]}) << "_";
result << "MS=" << ov::test::utils::partialShape2str({master_shape}) << "_";
result << "add_input_idx=" << add_input_idx;
return result.str();
}
@ -146,7 +148,8 @@ TEST_P(MulAddToFMATests, MulAddToFMATests) {
backend_passes,
{},
{},
generator);
generator,
std::make_shared<ov::snippets::CPUShapeInferSnippetsFactory>());
model = subgraph->body_ptr();
model_ref = snippets_model->getLowered();
}