Snippets shape inference infrastructure (#18887)

This commit is contained in:
Ivan Novoselov 2023-09-08 11:58:21 +03:00 committed by GitHub
parent 25b1b4e26c
commit 8124f5c435
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1013 additions and 265 deletions

View File

@ -11,13 +11,14 @@
#include "snippets/lowered/port_connector.hpp" #include "snippets/lowered/port_connector.hpp"
#include "snippets/lowered/expression_port.hpp" #include "snippets/lowered/expression_port.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
namespace ov { namespace ov {
namespace snippets { namespace snippets {
namespace lowered { namespace lowered {
class LinearIR; class LinearIR;
using ExpressionPtr = std::shared_ptr<Expression>;
class Expression : public std::enable_shared_from_this<Expression> { class Expression : public std::enable_shared_from_this<Expression> {
friend class LinearIR; friend class LinearIR;
friend class ExpressionPort; friend class ExpressionPort;
@ -49,6 +50,8 @@ public:
ExpressionPort get_input_port(size_t i); ExpressionPort get_input_port(size_t i);
ExpressionPort get_output_port(size_t i); ExpressionPort get_output_port(size_t i);
void updateShapes();
virtual bool needShapeInfer() const {return true; }
std::vector<size_t> get_loop_ids() const; std::vector<size_t> get_loop_ids() const;
void set_loop_ids(const std::vector<size_t>& loops); void set_loop_ids(const std::vector<size_t>& loops);
@ -56,7 +59,7 @@ public:
protected: protected:
// Note: The constructor initialization is private since an expression can be created only by Linear IR. // Note: The constructor initialization is private since an expression can be created only by Linear IR.
// The method must be used only by Linear IR builder of expressions! // The method must be used only by Linear IR builder of expressions!
explicit Expression(const std::shared_ptr<Node>& n); Expression(const std::shared_ptr<Node>& n, const std::shared_ptr<IShapeInferSnippetsFactory>& factory);
void replace_input(size_t port, PortConnectorPtr to); void replace_input(size_t port, PortConnectorPtr to);
@ -68,7 +71,8 @@ protected:
std::vector<PortDescriptorPtr> m_output_port_descriptors{}; std::vector<PortDescriptorPtr> m_output_port_descriptors{};
// The order Loops identifies: Outer ---> Inner // The order Loops identifies: Outer ---> Inner
// Note: The loops with the same dimension index (splitted dimension) should be successively nested // Note: The loops with the same dimension index (splitted dimension) should be successively nested
std::vector<size_t> m_loop_ids; std::vector<size_t> m_loop_ids{};
std::shared_ptr<IShapeInferSnippets> m_shapeInference{nullptr};
}; };
using ExpressionPtr = std::shared_ptr<Expression>; using ExpressionPtr = std::shared_ptr<Expression>;
@ -80,10 +84,11 @@ public:
int64_t get_index() const { return m_index; } int64_t get_index() const { return m_index; }
io_type get_type() const { return m_type; } io_type get_type() const { return m_type; }
// Result needs shapeInfer to copy shape from Parent's output to this expr input
bool needShapeInfer() const override {return m_type == io_type::OUTPUT; }
private: private:
explicit IOExpression(const std::shared_ptr<ov::opset1::Parameter>& n, int64_t index); explicit IOExpression(const std::shared_ptr<ov::opset1::Parameter>& n, int64_t index, const std::shared_ptr<IShapeInferSnippetsFactory>& factory);
explicit IOExpression(const std::shared_ptr<ov::opset1::Result>& n, int64_t index); explicit IOExpression(const std::shared_ptr<ov::opset1::Result>& n, int64_t index, const std::shared_ptr<IShapeInferSnippetsFactory>& factory);
int64_t m_index = -1; int64_t m_index = -1;
io_type m_type = io_type::UNDEFINED; io_type m_type = io_type::UNDEFINED;

View File

@ -38,9 +38,9 @@ private:
const std::shared_ptr<ov::Model>& model); const std::shared_ptr<ov::Model>& model);
/* -- Input Builders - get input port connectors from method parameters and create new output port connectors themselves */ /* -- Input Builders - get input port connectors from method parameters and create new output port connectors themselves */
static ExpressionPtr create(const std::shared_ptr<op::LoopBegin>& n, const std::vector<PortConnectorPtr>& inputs); static ExpressionPtr create(const std::shared_ptr<op::LoopBegin>& n, const std::vector<PortConnectorPtr>& inputs, const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<op::LoopEnd>& n, const std::vector<PortConnectorPtr>& inputs); static ExpressionPtr create(const std::shared_ptr<op::LoopEnd>& n, const std::vector<PortConnectorPtr>& inputs, const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<ov::Node>& n, const std::vector<PortConnectorPtr>& inputs); static ExpressionPtr create(const std::shared_ptr<ov::Node>& n, const std::vector<PortConnectorPtr>& inputs, const LinearIR& linear_ir);
// Creates inputs for expression using parent output port connectors // Creates inputs for expression using parent output port connectors
static void create_expression_inputs(const LinearIR& linear_ir, const ExpressionPtr& expr); static void create_expression_inputs(const LinearIR& linear_ir, const ExpressionPtr& expr);

View File

@ -8,6 +8,7 @@
#include "expression.hpp" #include "expression.hpp"
#include "snippets/target_machine.hpp" #include "snippets/target_machine.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
namespace ov { namespace ov {
namespace snippets { namespace snippets {
@ -36,7 +37,7 @@ public:
using constExprReverseIt = container::const_reverse_iterator; using constExprReverseIt = container::const_reverse_iterator;
LinearIR() = default; LinearIR() = default;
explicit LinearIR(const std::shared_ptr<ov::Model>& m, Config config = {}); LinearIR(const std::shared_ptr<ov::Model>& m, const std::shared_ptr<IShapeInferSnippetsFactory>& factory, Config config = {});
ExpressionPtr create_expression(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& inputs); ExpressionPtr create_expression(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& inputs);
@ -96,12 +97,13 @@ public:
iterator find_after(iterator it, const ExpressionPtr& target) const; iterator find_after(iterator it, const ExpressionPtr& target) const;
void init_emitters(const std::shared_ptr<TargetMachine>& target); void init_emitters(const std::shared_ptr<TargetMachine>& target);
void serialize(const std::string& xml, const std::string& bin); void serialize(const std::string& xml, const std::string& bin) const;
class LoopManager; class LoopManager;
using LoopManagerPtr = std::shared_ptr<LoopManager>; using LoopManagerPtr = std::shared_ptr<LoopManager>;
const LoopManagerPtr& get_loop_manager() const { return m_loop_manager; } const LoopManagerPtr& get_loop_manager() const { return m_loop_manager; }
const std::shared_ptr<IShapeInferSnippetsFactory>& get_shape_infer_factory() { return m_shape_infer_factory; }
private: private:
static ov::NodeVector get_ordered_ops(const std::shared_ptr<ov::Model>& model); static ov::NodeVector get_ordered_ops(const std::shared_ptr<ov::Model>& model);
@ -116,6 +118,7 @@ private:
io_container m_io_expressions; io_container m_io_expressions;
Config m_config{}; Config m_config{};
LoopManagerPtr m_loop_manager = nullptr; LoopManagerPtr m_loop_manager = nullptr;
std::shared_ptr<IShapeInferSnippetsFactory> m_shape_infer_factory;
}; };
template<typename iterator> template<typename iterator>

View File

@ -6,6 +6,7 @@
#include "openvino/core/node.hpp" #include "openvino/core/node.hpp"
#include "openvino/core/attribute_visitor.hpp" #include "openvino/core/attribute_visitor.hpp"
#include "snippets/shape_types.hpp"
namespace ov { namespace ov {
@ -23,28 +24,28 @@ public:
}; };
explicit PortDescriptor(const ov::Input<ov::Node>& node, explicit PortDescriptor(const ov::Input<ov::Node>& node,
std::vector<size_t> subtensor_shape = {}, VectorDims subtensor_shape = {},
std::vector<size_t> layout = {}); std::vector<size_t> layout = {});
explicit PortDescriptor(const ov::Input<const ov::Node>& node, explicit PortDescriptor(const ov::Input<const ov::Node>& node,
std::vector<size_t> subtensor_shape = {}, VectorDims subtensor_shape = {},
std::vector<size_t> layout = {}); std::vector<size_t> layout = {});
explicit PortDescriptor(const ov::Output<ov::Node>& node, explicit PortDescriptor(const ov::Output<ov::Node>& node,
std::vector<size_t> subtensor_shape = {}, VectorDims subtensor_shape = {},
std::vector<size_t> layout = {}); std::vector<size_t> layout = {});
explicit PortDescriptor(const ov::Output<const ov::Node>& node, explicit PortDescriptor(const ov::Output<const ov::Node>& node,
std::vector<size_t> subtensor_shape = {}, VectorDims subtensor_shape = {},
std::vector<size_t> layout = {}); std::vector<size_t> layout = {});
PortDescriptor(std::vector<size_t> shape, std::vector<size_t> subtensor_shape, std::vector<size_t> layout = {}); PortDescriptor(VectorDims shape, VectorDims subtensor_shape, std::vector<size_t> layout = {});
PortDescriptor() = default; PortDescriptor() = default;
std::vector<size_t> get_shape() const {return m_tensor_shape;} VectorDims get_shape() const {return m_tensor_shape;}
std::vector<size_t> get_subtensor() const {return m_subtensor_shape;} VectorDims get_subtensor() const {return m_subtensor_shape;}
std::vector<size_t> get_layout() const {return m_layout;} std::vector<size_t> get_layout() const {return m_layout;}
size_t get_reg() const { return m_reg; } size_t get_reg() const { return m_reg; }
void set_shape(const std::vector<size_t>& tensor) { m_tensor_shape = tensor; } void set_shape(const VectorDims& tensor) { m_tensor_shape = tensor; }
void set_layout(const std::vector<size_t>& layout) { m_layout = layout; } void set_layout(const std::vector<size_t>& layout) { m_layout = layout; }
void set_subtensor(const std::vector<size_t>& subtensor) { m_subtensor_shape = subtensor; } void set_subtensor(const VectorDims& subtensor) { m_subtensor_shape = subtensor; }
void set_reg(size_t reg) { m_reg = reg; } void set_reg(size_t reg) { m_reg = reg; }
std::string serialize() const; std::string serialize() const;
@ -57,11 +58,11 @@ public:
private: private:
void validate_arguments(); void validate_arguments();
/// \brief Original tensor shape /// \brief Original tensor shape
std::vector<size_t> m_tensor_shape{}; VectorDims m_tensor_shape{};
/// \brief Order of dimensions: NCHW == {0, 1, 2, 3}, NHWC == {0, 2, 3, 1}, NCHW16c == {0, 1, 2, 3, 1} /// \brief Order of dimensions: NCHW == {0, 1, 2, 3}, NHWC == {0, 2, 3, 1}, NCHW16c == {0, 1, 2, 3, 1}
std::vector<size_t> m_layout{}; std::vector<size_t> m_layout{};
/// \brief Minimal tensor size that could be processed in one call /// \brief Minimal tensor size that could be processed in one call
std::vector<size_t> m_subtensor_shape{}; VectorDims m_subtensor_shape{};
/// \brief The corresponding abstract/physical register /// \brief The corresponding abstract/physical register
size_t m_reg = 0; size_t m_reg = 0;
}; };

View File

@ -6,6 +6,7 @@
#include "openvino/op/op.hpp" #include "openvino/op/op.hpp"
#include "memory_access.hpp" #include "memory_access.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
namespace ov { namespace ov {
namespace snippets { namespace snippets {
@ -38,6 +39,14 @@ public:
bool has_evaluate() const override { return false; } bool has_evaluate() const override { return false; }
class ShapeInfer : public IShapeInferSnippets {
protected:
std::vector<std::vector<size_t>> m_io_layouts;
public:
explicit ShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
protected: protected:
ov::element::Type get_output_type() const; ov::element::Type get_output_type() const;
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const; std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;

View File

@ -4,8 +4,8 @@
#pragma once #pragma once
#include "snippets/shape_inference/shape_infer_instances.hpp"
#include <snippets/op/memory_access.hpp> #include <snippets/op/memory_access.hpp>
#include "openvino/op/op.hpp" #include "openvino/op/op.hpp"
namespace ov { namespace ov {
@ -29,7 +29,15 @@ public:
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override; std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
ov::PartialShape get_output_shape() {return output_shape;}
// Note:BroadcastMove and BroadcastLoad are implemented as separate classes,
// but have identical shapeInfer semantics. In order to avoid code duplication,
// we created dummy ShapeInfer classes that are essentially instantiations
// of a common ShapeInfer class template;
struct ShapeInfer : public BroadcastShapeInfer<BroadcastLoad> {
explicit ShapeInfer(const std::shared_ptr<Node>& n) : BroadcastShapeInfer<BroadcastLoad>(n) {}
};
private: private:
ov::PartialShape output_shape; ov::PartialShape output_shape;
}; };

View File

@ -5,6 +5,7 @@
#pragma once #pragma once
#include "openvino/op/op.hpp" #include "openvino/op/op.hpp"
#include "snippets/shape_inference/shape_infer_instances.hpp"
namespace ov { namespace ov {
namespace snippets { namespace snippets {
@ -27,7 +28,14 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override; std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
ov::PartialShape get_output_shape() {return output_shape;}
// Note:BroadcastMove and BroadcastLoad are implemented as separate classes,
// but have identical shapeInfer semantics. In order to avoid code duplication,
// we created dummy ShapeInfer classes that are essentially instantiations
// of a common ShapeInfer class template;
struct ShapeInfer : public BroadcastShapeInfer<BroadcastMove> {
explicit ShapeInfer(const std::shared_ptr<Node>& n) : BroadcastShapeInfer<BroadcastMove>(n) {}
};
protected: protected:
ov::PartialShape output_shape; ov::PartialShape output_shape;

View File

@ -6,6 +6,7 @@
#include "openvino/op/op.hpp" #include "openvino/op/op.hpp"
#include "snippets/op/memory_access.hpp" #include "snippets/op/memory_access.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
namespace ov { namespace ov {
namespace snippets { namespace snippets {
@ -58,7 +59,15 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override; std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
private: class ShapeInfer : public IShapeInferSnippets {
std::vector<size_t> m_order;
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
protected:
std::vector<size_t> m_order; std::vector<size_t> m_order;
}; };
} // namespace op } // namespace op

View File

@ -11,6 +11,7 @@
#include "openvino/op/op.hpp" #include "openvino/op/op.hpp"
#include "openvino/core/rt_info.hpp" #include "openvino/core/rt_info.hpp"
#include "snippets/pass_manager.hpp" #include "snippets/pass_manager.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
#include "snippets/generator.hpp" #include "snippets/generator.hpp"
@ -102,17 +103,21 @@ public:
const std::vector<pass::Manager::PositionedPass>& data_flow_passes, const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
const lowered::pass::PassPipeline& control_flow_passes_pre_common, const lowered::pass::PassPipeline& control_flow_passes_pre_common,
const lowered::pass::PassPipeline& control_flow_passes_post_common, const lowered::pass::PassPipeline& control_flow_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory = nullptr,
const void* compile_params = nullptr); const void* compile_params = nullptr);
snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr); snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr);
snippets::Schedule generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes, snippets::Schedule generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
const lowered::pass::PassPipeline& control_flow_passes_pre_common, const lowered::pass::PassPipeline& control_flow_passes_pre_common,
const lowered::pass::PassPipeline& control_flow_passes_post_common, const lowered::pass::PassPipeline& control_flow_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory = nullptr,
const void* compile_params = nullptr); const void* compile_params = nullptr);
snippets::Schedule generate(const void* compile_params = nullptr); snippets::Schedule generate(const void* compile_params = nullptr);
ov::PartialShape canonicalize(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes); ov::PartialShape canonicalize(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes);
ov::PartialShape canonicalized_body_shape_infer(const BlockedShapeVector& input_shapes); ov::PartialShape canonicalized_body_shape_infer(const BlockedShapeVector& input_shapes);
std::vector<PartialShape> reshape_body(const std::vector<PartialShape>& input_shapes); std::vector<PartialShape> reshape_body(const std::vector<PartialShape>& input_shapes);
std::vector<Shape> reshape_body(const std::vector<Shape>& input_shapes); std::vector<Shape> reshape_body(const std::vector<Shape>& input_shapes);
IShapeInferSnippets::Result shape_infer(const std::vector<VectorDimsRef>& input_shapes);
// plugin sets generator for a snippet to some specific generator. // plugin sets generator for a snippet to some specific generator.
// it's going to be replaced with Jitters table later // it's going to be replaced with Jitters table later
@ -121,7 +126,6 @@ public:
void set_virtual_port_count(const size_t count); void set_virtual_port_count(const size_t count);
void print() const; void print() const;
void print_statistics(bool verbose);
void serialize() const; void serialize() const;
void set_master_shape(ov::PartialShape new_shape) {master_shape = std::move(new_shape);} void set_master_shape(ov::PartialShape new_shape) {master_shape = std::move(new_shape);}
@ -137,6 +141,8 @@ public:
// Return estimated unique buffer count (upper bound). It's needed for tokenization // Return estimated unique buffer count (upper bound). It's needed for tokenization
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t; static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
static auto is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool; static auto is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool;
std::shared_ptr<lowered::LinearIR>
convert_body_to_linear_ir(const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory = std::make_shared<IShapeInferSnippetsFactory>()) const;
private: private:
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes); void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
@ -158,6 +164,7 @@ private:
size_t tileRank = 0; // set by plugin to specify the number of dimensions processed in a single kernel call size_t tileRank = 0; // set by plugin to specify the number of dimensions processed in a single kernel call
size_t maxInputRank = 0; size_t maxInputRank = 0;
std::vector<size_t> appendOnesForCanonical; std::vector<size_t> appendOnesForCanonical;
std::shared_ptr<lowered::LinearIR> m_linear_ir = nullptr;
/** /**
* @interface SubgraphConfig * @interface SubgraphConfig
@ -172,27 +179,49 @@ private:
// (e.g. Transpose, Softmax, MatMul in general doesn't support dimensions collapsing) // (e.g. Transpose, Softmax, MatMul in general doesn't support dimensions collapsing)
bool m_has_domain_sensitive_ops = false; bool m_has_domain_sensitive_ops = false;
} config; } config;
class ShapeInferSnippetsNode : public IShapeInferSnippets {
public:
const Result& get_last_result() {return m_last_result; }
protected:
Result m_last_result{{}, ShapeInferStatus::success};
};
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;
class NgraphShapeInfer : public ShapeInferSnippetsNode {
std::shared_ptr<ov::Model> m_ngraph_body;
ParameterVector m_parameters;
ResultVector m_results;
public:
explicit NgraphShapeInfer(const std::shared_ptr<ov::Model>& body);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
class LIRShapeInfer : public ShapeInferSnippetsNode {
using IOExpression = lowered::IOExpression;
std::shared_ptr<lowered::LinearIR> m_lir_body;
std::vector<std::shared_ptr<IOExpression>> m_param_exprs;
std::vector<std::shared_ptr<IOExpression>> m_result_exprs;
public:
explicit LIRShapeInfer(const std::shared_ptr<lowered::LinearIR>& body);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
}; };
static inline std::ostream& operator<<(std::ostream& os, const op::Subgraph::BlockedShape& blocked_shape) { static inline auto create_body(const std::string& name, const ov::ResultVector& results, const ov::ParameterVector& parameters) ->
os << std::get<0>(blocked_shape) << " " << std::get<1>(blocked_shape) << " " << std::get<2>(blocked_shape);
return os;
}
static inline auto create_body(std::string name, const ov::ResultVector& results, const ov::ParameterVector& parameters) ->
std::shared_ptr<ov::Model> { std::shared_ptr<ov::Model> {
auto body = std::make_shared<ov::Model>(results, parameters, name); auto body = std::make_shared<ov::Model>(results, parameters, name);
return body; return body;
}; }
static inline auto build_subgraph(const std::shared_ptr<ov::Node>& node, const ov::OutputVector& inputs, static inline auto build_subgraph(const std::shared_ptr<ov::Node>& node, const ov::OutputVector& inputs,
const std::shared_ptr<ov::Model>& body, const std::string name = "") const std::shared_ptr<ov::Model>& body, const std::string& name = "")
-> std::shared_ptr<Subgraph>{ -> std::shared_ptr<Subgraph>{
auto subgraph = std::make_shared<Subgraph>(inputs, body); auto subgraph = std::make_shared<Subgraph>(inputs, body);
copy_runtime_info(node, subgraph); copy_runtime_info(node, subgraph);
subgraph->set_friendly_name(name.empty() ? node->get_friendly_name() : name); subgraph->set_friendly_name(name.empty() ? node->get_friendly_name() : name);
return subgraph; return subgraph;
}; }
// Need to update tensor name manually, since intel_cpu::Graph::Replicate() looks at input.get_shape().get_name(); // Need to update tensor name manually, since intel_cpu::Graph::Replicate() looks at input.get_shape().get_name();
// If subgraph->get_output_size() == 1, then the name will be restored correctly from the node name // If subgraph->get_output_size() == 1, then the name will be restored correctly from the node name

View File

@ -5,6 +5,7 @@
#pragma once #pragma once
#include "openvino/op/op.hpp" #include "openvino/op/op.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
namespace ov { namespace ov {
namespace snippets { namespace snippets {

View File

@ -0,0 +1,60 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shape_inference.hpp"
namespace ov {
namespace snippets {
class NumpyBroadcastShapeInfer : public IShapeInferSnippets {
public:
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
template<class BroadcastOP>
class BroadcastShapeInfer : public IShapeInferSnippets {
VectorDims::value_type m_broadcasted_dim;
public:
explicit BroadcastShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
class PassThroughShapeInfer : public IShapeInferSnippets {
public:
inline Result infer(const std::vector<VectorDimsRef>& input_shapes) override {
OPENVINO_ASSERT(!input_shapes.empty(), "Empty Input shapes are not allowed for PassThroughShapeInfer");
return {{input_shapes[0].get()}, ShapeInferStatus::success};
}
};
class EmptyShapeInfer : public IShapeInferSnippets {
public:
inline Result infer(const std::vector<VectorDimsRef>& input_shapes) override {
return {{}, ShapeInferStatus::success};
}
};
class SingleElementShapeInfer : public IShapeInferSnippets {
public:
inline Result infer(const std::vector<VectorDimsRef>& input_shapes) override {
return {{{1}}, ShapeInferStatus::success};
}
};
class SelectShapeInfer : public IShapeInferSnippets {
ov::op::AutoBroadcastSpec m_broadcast_spec;
public:
explicit SelectShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
class HorizonOpShapeInfer : public IShapeInferSnippets {
public:
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
} // namespace snippets
} // namespace ov

View File

@ -0,0 +1,72 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/core.hpp>
#include "snippets/shape_types.hpp"
namespace ov {
namespace snippets {
enum class ShapeInferStatus {
success, ///< shapes were successfully calculated
skip ///< shape inference was skipped.
};
/**
* This is Snippets specific shape inference interface.
*
*/
class IShapeInferSnippets {
public:
enum {DYNAMIC_DIMENSION = 0xffffffffffffffff};
struct Result {
std::vector<VectorDims> dims;
ShapeInferStatus status;
};
virtual ~IShapeInferSnippets() = default;
/**
* @brief This method actually performs all the necessary shape inference computations
*
* @param input_shapes are the input tensors shapes
* @return Result instance that contains an array of calculated shapes (per each output port) and a status of the shape infer call
*/
virtual Result infer(const std::vector<VectorDimsRef>& input_shapes) = 0;
};
class IShapeInferSnippetsFactory {
public:
// Helper type to define specific Makers map values.
using ShapeInferPtr = std::shared_ptr<IShapeInferSnippets>;
// Helper type to define specific Makers map type.
using TRegistry = std::unordered_map<ov::DiscreteTypeInfo, std::function<ShapeInferPtr (std::shared_ptr<ov::Node>)>>;
/**
* \brief Creates the shape inference object.
*
* \param key Key value to get specified shape inference object maker.
* \param args Inference object args.
*
* \return Pointer to shape inference object or nullptr if failed to construct the object.
*/
ShapeInferPtr make(const ov::DiscreteTypeInfo& key, const std::shared_ptr<ov::Node>& op);
virtual ~IShapeInferSnippetsFactory() = default;
private:
/** \brief Factory makers registry which can be specialized for key and value. */
static const TRegistry registry;
protected:
/**
* @brief get shape infer instances for operations from backend-specific opset
* @return Pointer to shape inference object or nullptr if failed to construct the object.
*/
virtual ShapeInferPtr get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key, const std::shared_ptr<ov::Node>& op) const;
};
std::shared_ptr<IShapeInferSnippets> make_shape_inference(const std::shared_ptr<ov::Node>& op,
const std::shared_ptr<IShapeInferSnippetsFactory>& factory);
} // namespace snippets
} // namespace ov

View File

@ -0,0 +1,18 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
namespace ov {
namespace snippets {
/*
* This header file contain declarations of shape-relevant classes used cross several snippets subsystems.
* The main purpose of storing such declarations here is to eliminate false dependencies. For example,
* both PortDescriptor and IShapeInferSnippets use VectorDims, but these two classes are completely independent semantically.
*/
using VectorDims = std::vector<size_t>;
using VectorDimsRef = std::reference_wrapper<const VectorDims>;
} // namespace snippets
} // namespace ov

View File

@ -10,6 +10,7 @@
#include "snippets_isa.hpp" #include "snippets_isa.hpp"
#include "emitter.hpp" #include "emitter.hpp"
#include "shape_types.hpp"
namespace ov { namespace ov {
@ -24,9 +25,11 @@ inline auto is_scalar_constant(const std::shared_ptr<ov::Node>& source_output_no
return ov::is_type<ov::opset1::Constant>(source_output_node) && ov::shape_size(source_output_node->get_shape()) == 1; return ov::is_type<ov::opset1::Constant>(source_output_node) && ov::shape_size(source_output_node->get_shape()) == 1;
} }
ov::PartialShape get_port_planar_shape(const Input<Node>& out); ov::PartialShape get_planar_pshape(const Input<Node>& out);
ov::PartialShape get_port_planar_shape(const Output<Node>& out); ov::PartialShape get_planar_pshape(const Output<Node>& out);
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout); ov::PartialShape get_planar_pshape(const ov::PartialShape& shape, const std::vector<size_t>& layout);
VectorDims pshape_to_vdims(const PartialShape&);
ov::PartialShape vdims_to_pshape(const VectorDims&);
inline auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t { inline auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t {
return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) + 1 : allocation_rank; return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) + 1 : allocation_rank;
@ -47,6 +50,11 @@ template <typename T, typename P, typename... Args>
constexpr bool everyone_is(T val, P item, Args... item_others) { constexpr bool everyone_is(T val, P item, Args... item_others) {
return val == item && everyone_is(val, item_others...); return val == item && everyone_is(val, item_others...);
} }
VectorDims get_planar_vdims(const VectorDims& shape, const std::vector<size_t>& layout);
VectorDims get_planar_vdims(const snippets::lowered::PortDescriptorPtr& port_desc);
VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port);
} // namespace utils } // namespace utils
} // namespace snippets } // namespace snippets
} // namespace ov } // namespace ov

View File

@ -14,8 +14,9 @@ namespace ov {
namespace snippets { namespace snippets {
namespace lowered { namespace lowered {
Expression::Expression(const std::shared_ptr<Node>& n) Expression::Expression(const std::shared_ptr<Node>& n, const std::shared_ptr<IShapeInferSnippetsFactory>& factory)
: m_source_node{n}, m_emitter{nullptr}, m_input_port_connectors{}, m_output_port_connectors{} { : m_source_node{n}, m_emitter{nullptr},
m_input_port_connectors{}, m_output_port_connectors{}, m_shapeInference(make_shape_inference(n, factory)) {
m_input_port_descriptors.reserve(n->get_input_size()); m_input_port_descriptors.reserve(n->get_input_size());
m_output_port_descriptors.reserve(n->get_output_size()); m_output_port_descriptors.reserve(n->get_output_size());
for (const auto& input : n->inputs()) { for (const auto& input : n->inputs()) {
@ -110,10 +111,38 @@ ExpressionPort Expression::get_output_port(size_t i) {
return ExpressionPort(this->shared_from_this(), ExpressionPort::Type::Output, i); return ExpressionPort(this->shared_from_this(), ExpressionPort::Type::Output, i);
} }
IOExpression::IOExpression(const std::shared_ptr<ov::opset1::Parameter>& par, int64_t index) void Expression::updateShapes() {
: Expression(par), m_index(index), m_type{io_type::INPUT} {} IShapeInferSnippets::Result result;
IOExpression::IOExpression(const std::shared_ptr<ov::opset1::Result>& res, int64_t index) try {
: Expression(res), m_index(index), m_type{io_type::OUTPUT} {} std::vector<VectorDimsRef> input_shapes;
const auto& in_connectors = get_input_port_connectors();
const auto& in_descriptors = get_input_port_descriptors();
input_shapes.reserve(in_connectors.size());
for (size_t i = 0; i < in_connectors.size(); i++) {
const auto& src_port = in_connectors[i]->get_source();
const auto i_shape = src_port.get_descriptor_ptr()->get_shape();
// todo: do we really need to store the same shape twice in parent's out_port_desc and this in_port_descs
in_descriptors[i]->set_shape(i_shape);
input_shapes.emplace_back(i_shape);
}
result = m_shapeInference->infer(input_shapes);
}
catch (const std::exception& exp) {
OPENVINO_THROW("Shape inference of " + (get_node()->get_friendly_name()) + " failed: " + exp.what());
}
const auto& out_descriptors = get_output_port_descriptors();
OPENVINO_ASSERT(result.dims.size() == out_descriptors.size(), "shapeInference call returned invalid number of output shapes");
for (size_t i = 0; i < out_descriptors.size(); i++)
out_descriptors[i]->set_shape(result.dims[i]);
}
IOExpression::IOExpression(const std::shared_ptr<ov::opset1::Parameter>& par, int64_t index, const std::shared_ptr<IShapeInferSnippetsFactory>& factory)
: Expression(par, factory), m_index(index), m_type{io_type::INPUT} {}
IOExpression::IOExpression(const std::shared_ptr<ov::opset1::Result>& res, int64_t index, const std::shared_ptr<IShapeInferSnippetsFactory>& factory)
: Expression(res, factory), m_index(index), m_type{io_type::OUTPUT} {}
}// namespace lowered }// namespace lowered
}// namespace snippets }// namespace snippets

View File

@ -57,7 +57,7 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<ov::op::
const LinearIR& linear_ir, const std::shared_ptr<ov::Model>& model) { const LinearIR& linear_ir, const std::shared_ptr<ov::Model>& model) {
// Note: ctor of shared_ptr isn't friend class for Expression -> we cannot use directly make_shared<Expression>(args) // Note: ctor of shared_ptr isn't friend class for Expression -> we cannot use directly make_shared<Expression>(args)
OPENVINO_ASSERT(model != nullptr, "To create IOExpression from Parameter there must be inited model!"); OPENVINO_ASSERT(model != nullptr, "To create IOExpression from Parameter there must be inited model!");
auto expr = std::shared_ptr<IOExpression>(new IOExpression(par, model->get_parameter_index(par))); auto expr = std::shared_ptr<IOExpression>(new IOExpression(par, model->get_parameter_index(par), linear_ir.m_shape_infer_factory));
create_expression_outputs(expr); create_expression_outputs(expr);
expr->validate(); expr->validate();
return expr; return expr;
@ -67,7 +67,7 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<ov::op::
const LinearIR& linear_ir, const std::shared_ptr<ov::Model>& model) { const LinearIR& linear_ir, const std::shared_ptr<ov::Model>& model) {
// Note: ctor of shared_ptr isn't friend class for Expression -> we cannot use directly make_shared<Expression>(args) // Note: ctor of shared_ptr isn't friend class for Expression -> we cannot use directly make_shared<Expression>(args)
OPENVINO_ASSERT(model != nullptr, "To create IOExpression from Result there must be inited model!"); OPENVINO_ASSERT(model != nullptr, "To create IOExpression from Result there must be inited model!");
auto expr = std::shared_ptr<IOExpression>(new IOExpression(res, model->get_result_index(res))); auto expr = std::shared_ptr<IOExpression>(new IOExpression(res, model->get_result_index(res), linear_ir.m_shape_infer_factory));
create_expression_inputs(linear_ir, expr); create_expression_inputs(linear_ir, expr);
// The Result node don't need output port (because of sense of the node). But each node in ngraph must have one output at least. // The Result node don't need output port (because of sense of the node). But each node in ngraph must have one output at least.
// The port descriptors are automatically created in constructor. We manually clean output ports. // The port descriptors are automatically created in constructor. We manually clean output ports.
@ -80,24 +80,28 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<ov::Node
const std::shared_ptr<ov::Model>& model) { const std::shared_ptr<ov::Model>& model) {
OPENVINO_ASSERT(!ov::is_type<op::LoopBase>(n), "Default expression builder doesn't support LoopBegin and LoopEnd"); OPENVINO_ASSERT(!ov::is_type<op::LoopBase>(n), "Default expression builder doesn't support LoopBegin and LoopEnd");
// Note: ctor of shared_ptr isn't friend class for Expression // Note: ctor of shared_ptr isn't friend class for Expression
auto expr = std::shared_ptr<Expression>(new Expression(n)); auto expr = std::shared_ptr<Expression>(new Expression(n, linear_ir.m_shape_infer_factory));
create_expression_inputs(linear_ir, expr); create_expression_inputs(linear_ir, expr);
create_expression_outputs(expr); create_expression_outputs(expr);
expr->validate(); expr->validate();
return expr; return expr;
} }
ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<op::LoopBegin>& n, const std::vector<PortConnectorPtr>& inputs) { ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<op::LoopBegin>& n,
const std::vector<PortConnectorPtr>& inputs,
const LinearIR& linear_ir) {
OPENVINO_ASSERT(inputs.empty(), "LoopBegin cannot have inputs"); OPENVINO_ASSERT(inputs.empty(), "LoopBegin cannot have inputs");
auto expr = std::make_shared<Expression>(Expression(n)); auto expr = std::make_shared<Expression>(Expression(n, linear_ir.m_shape_infer_factory));
init_expression_inputs(expr, inputs); init_expression_inputs(expr, inputs);
create_expression_outputs(expr); create_expression_outputs(expr);
expr->validate(); expr->validate();
return expr; return expr;
} }
ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<op::LoopEnd>& n, const std::vector<PortConnectorPtr>& inputs) { ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<op::LoopEnd>& n,
auto expr = std::shared_ptr<Expression>(new Expression(n)); const std::vector<PortConnectorPtr>& inputs,
const LinearIR& linear_ir) {
auto expr = std::shared_ptr<Expression>(new Expression(n, linear_ir.m_shape_infer_factory));
expr->m_input_port_descriptors.resize(inputs.size(), nullptr); expr->m_input_port_descriptors.resize(inputs.size(), nullptr);
for (size_t i = 0; i < inputs.size() - 1; ++i) { for (size_t i = 0; i < inputs.size() - 1; ++i) {
expr->m_input_port_descriptors[i] = std::make_shared<PortDescriptor>(); expr->m_input_port_descriptors[i] = std::make_shared<PortDescriptor>();
@ -113,14 +117,20 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<op::Loop
return expr; return expr;
} }
ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<ov::Node>& n, const std::vector<PortConnectorPtr>& inputs) { ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<ov::Node>& n,
const std::vector<PortConnectorPtr>& inputs,
const LinearIR& linear_ir) {
OPENVINO_ASSERT(!ov::is_type<ov::op::v0::Parameter>(n) && OPENVINO_ASSERT(!ov::is_type<ov::op::v0::Parameter>(n) &&
!ov::is_type<ov::op::v0::Result>(n), !ov::is_type<ov::op::v0::Result>(n),
"Expression builder with inputs doesn't support Result and Parameter"); "Expression builder with inputs doesn't support Result and Parameter");
auto expr = std::shared_ptr<Expression>(new Expression(n)); auto expr = std::shared_ptr<Expression>(new Expression(n, linear_ir.m_shape_infer_factory));
init_expression_inputs(expr, inputs); init_expression_inputs(expr, inputs);
create_expression_outputs(expr); create_expression_outputs(expr);
expr->validate(); expr->validate();
// todo: here we blindly synchronize input shapes from parent and child. Remove this when shapes will be stored in
// port connector itself
if (linear_ir.m_shape_infer_factory)
expr->updateShapes();
return expr; return expr;
} }
}// namespace lowered }// namespace lowered

View File

@ -18,8 +18,8 @@ namespace ov {
namespace snippets { namespace snippets {
namespace lowered { namespace lowered {
LinearIR::LinearIR(const std::shared_ptr<ov::Model>& model, Config config) LinearIR::LinearIR(const std::shared_ptr<ov::Model>& model, const std::shared_ptr<IShapeInferSnippetsFactory>& factory, Config config)
: m_io_expressions{}, m_config{std::move(config)}, m_loop_manager(std::make_shared<LoopManager>()) { : m_io_expressions{}, m_config{config}, m_loop_manager(std::make_shared<LoopManager>()), m_shape_infer_factory(factory) {
constExprIt last_param = m_expressions.end(); constExprIt last_param = m_expressions.end();
for (const auto& n : get_ordered_ops(model)) { for (const auto& n : get_ordered_ops(model)) {
constExprIt insertion_pos = m_expressions.end(); constExprIt insertion_pos = m_expressions.end();
@ -48,7 +48,7 @@ ExpressionPtr LinearIR::create_expression(const std::shared_ptr<Node>& n, const
} }
ExpressionPtr LinearIR::create_expression(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& inputs) { ExpressionPtr LinearIR::create_expression(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& inputs) {
return ExpressionFactory::build(n, inputs); return ExpressionFactory::build(n, inputs, *this);
} }
ov::NodeVector LinearIR::get_ordered_ops(const std::shared_ptr<ov::Model>& m) { ov::NodeVector LinearIR::get_ordered_ops(const std::shared_ptr<ov::Model>& m) {
@ -66,7 +66,7 @@ ov::NodeVector LinearIR::get_ordered_ops(const std::shared_ptr<ov::Model>& m) {
return ov::topological_sort(nodes); return ov::topological_sort(nodes);
} }
void LinearIR::serialize(const std::string& xml, const std::string& bin) { void LinearIR::serialize(const std::string& xml, const std::string& bin) const {
auto first_node = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{}); auto first_node = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{});
first_node->set_friendly_name("Start"); first_node->set_friendly_name("Start");
first_node->get_rt_info()["execTimeMcs"] = 0; first_node->get_rt_info()["execTimeMcs"] = 0;

View File

@ -182,7 +182,7 @@ void LinearIR::LoopManager::mark_loop(LinearIR::constExprIt loop_begin_pos,
std::vector<size_t> loop_tensor(loop_depth, 1); std::vector<size_t> loop_tensor(loop_depth, 1);
for (const auto& exit_point : loop_exit_points) { for (const auto& exit_point : loop_exit_points) {
const auto& desc = exit_point.get_descriptor_ptr(); const auto& desc = exit_point.get_descriptor_ptr();
const auto shape = utils::get_reordered_planar_shape(ov::PartialShape(desc->get_shape()), desc->get_layout()).get_shape(); const auto shape = utils::get_planar_vdims(desc);
auto subtensor = desc->get_subtensor(); auto subtensor = desc->get_subtensor();
if (subtensor.empty()) { if (subtensor.empty()) {
subtensor.resize(loop_depth, 1); subtensor.resize(loop_depth, 1);

View File

@ -37,13 +37,13 @@ ov::Shape compute_allocation_shape(const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<size_t>& parent_loop_ids, const std::vector<size_t>& parent_loop_ids,
const ov::Output<ov::Node>& parent_output, const ov::Output<ov::Node>& parent_output,
const int allocation_rank) { const int allocation_rank) {
const auto port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(parent_output); const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(parent_output);
const auto planar_shape = utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout()); const auto planar_shape = utils::get_planar_vdims(port);
const size_t rank = allocation_rank >= 0 ? std::min(static_cast<size_t>(allocation_rank), planar_shape.size()) : planar_shape.size(); const size_t rank = allocation_rank >= 0 ? std::min(static_cast<size_t>(allocation_rank), planar_shape.size()) : planar_shape.size();
ov::Shape allocation_shape(rank); ov::Shape allocation_shape(rank);
for (size_t i = 0; i < rank; ++i) { for (size_t i = 0; i < rank; ++i) {
*(allocation_shape.rbegin() + i) = (planar_shape.rbegin() + i)->get_length(); *(allocation_shape.rbegin() + i) = *(planar_shape.rbegin() + i);
} }
if (buffer_loop_ids.empty() || parent_loop_ids.empty()) { if (buffer_loop_ids.empty() || parent_loop_ids.empty()) {

View File

@ -10,17 +10,17 @@ namespace lowered {
size_t PortDescriptor::ServiceDimensions::FULL_DIM = SIZE_MAX; size_t PortDescriptor::ServiceDimensions::FULL_DIM = SIZE_MAX;
PortDescriptor::PortDescriptor(const ov::Input<ov::Node>& in, std::vector<size_t> subtensor_shape, std::vector<size_t> layout) PortDescriptor::PortDescriptor(const ov::Input<ov::Node>& in, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(ov::Input<const Node>(in.get_node(), in.get_index()), std::move(subtensor_shape), std::move(layout)) {} : PortDescriptor(ov::Input<const Node>(in.get_node(), in.get_index()), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(const ov::Input<const ov::Node>& in, std::vector<size_t> subtensor_shape, std::vector<size_t> layout) PortDescriptor::PortDescriptor(const ov::Input<const ov::Node>& in, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(in.get_shape(), std::move(subtensor_shape), std::move(layout)) {} : PortDescriptor(in.get_shape(), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(const ov::Output<ov::Node>& out, std::vector<size_t> subtensor_shape, std::vector<size_t> layout) PortDescriptor::PortDescriptor(const ov::Output<ov::Node>& out, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(ov::Output<const Node>(out.get_node(), out.get_index()), std::move(subtensor_shape), std::move(layout)) {} : PortDescriptor(ov::Output<const Node>(out.get_node(), out.get_index()), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(const ov::Output<const ov::Node>& out, std::vector<size_t> subtensor_shape, std::vector<size_t> layout) PortDescriptor::PortDescriptor(const ov::Output<const ov::Node>& out, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(out.get_shape(), std::move(subtensor_shape), std::move(layout)) {} : PortDescriptor(out.get_shape(), std::move(subtensor_shape), std::move(layout)) {}
PortDescriptor::PortDescriptor(std::vector<size_t> shape, std::vector<size_t> subtensor_shape, std::vector<size_t> layout) PortDescriptor::PortDescriptor(VectorDims shape, VectorDims subtensor_shape, std::vector<size_t> layout)
: m_tensor_shape(std::move(shape)), m_layout(std::move(layout)), m_subtensor_shape(std::move(subtensor_shape)) { : m_tensor_shape(std::move(shape)), m_layout(std::move(layout)), m_subtensor_shape(std::move(subtensor_shape)) {
validate_arguments(); validate_arguments();
} }

View File

@ -13,6 +13,23 @@ namespace ov {
namespace snippets { namespace snippets {
namespace op { namespace op {
namespace {
std::vector<size_t> get_output_layout(const std::shared_ptr<const ov::Node>& n) {
const auto& key = lowered::PortDescriptorVectorAttribute::get_type_info_static();
auto& rt_info = n->get_rt_info();
const auto& found = rt_info.find(key);
if (found != rt_info.end()) {
const auto& out_descs = found->second.as<lowered::PortDescriptorVectorAttribute>().outputs;
if (out_descs.size() != n->get_output_size())
OPENVINO_THROW("Get output port descriptor is failed: incorrect count");
const auto& port_desc = out_descs[0];
return port_desc->get_layout();
}
return {};
}
} // namespace
Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B, Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
const size_t offset_a, const size_t offset_b, const size_t offset_c, const size_t offset_a, const size_t offset_b, const size_t offset_c,
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c) std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c)
@ -39,10 +56,10 @@ void Brgemm::custom_constructor_validate_and_infer_types(std::vector<size_t> lay
// During ctor call, Brgemm doesn't know his port descriptors. // During ctor call, Brgemm doesn't know his port descriptors.
// So we use explicit layouts from parameters // So we use explicit layouts from parameters
const auto planar_input_shapes = const auto planar_input_shapes =
std::vector<ov::PartialShape>{ ov::snippets::utils::get_reordered_planar_shape(get_input_partial_shape(0), layout_a), std::vector<ov::PartialShape>{ ov::snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_a),
ov::snippets::utils::get_reordered_planar_shape(get_input_partial_shape(1), layout_b) }; ov::snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) };
auto output_shape = get_output_partial_shape(planar_input_shapes); auto output_shape = get_output_partial_shape(planar_input_shapes);
set_output_type(0, get_output_type(), ov::snippets::utils::get_reordered_planar_shape(output_shape, layout_c)); set_output_type(0, get_output_type(), ov::snippets::utils::get_planar_pshape(output_shape, layout_c));
} }
void Brgemm::validate_inputs() const { void Brgemm::validate_inputs() const {
@ -97,21 +114,15 @@ ov::element::Type Brgemm::get_output_type() const {
std::vector<ov::PartialShape> Brgemm::get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const { std::vector<ov::PartialShape> Brgemm::get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const {
OPENVINO_ASSERT(inputs.size() == 2, "Brgemm::get_planar_input_shapes() expects 2 inputs"); OPENVINO_ASSERT(inputs.size() == 2, "Brgemm::get_planar_input_shapes() expects 2 inputs");
return { utils::get_port_planar_shape(inputs[0]), utils::get_port_planar_shape(inputs[1]) }; return {utils::get_planar_pshape(inputs[0]), utils::get_planar_pshape(inputs[1]) };
} }
ov::PartialShape Brgemm::get_planar_output_shape(const ov::PartialShape& output_shape) const { ov::PartialShape Brgemm::get_planar_output_shape(const ov::PartialShape& output_shape) const {
// This method can be safely called from validate_and_infer_types() before output creation // This method can be safely called from validate_and_infer_types() before output creation
const auto& key = lowered::PortDescriptorVectorAttribute::get_type_info_static(); const auto& out_layout = get_output_layout(shared_from_this());
auto& rt_info = get_rt_info(); if (!out_layout.empty())
const auto& found = rt_info.find(key); return utils::get_planar_pshape(output_shape, out_layout);
if (found != rt_info.end()) {
const auto& out_descs = found->second.as<lowered::PortDescriptorVectorAttribute>().outputs;
if (out_descs.size() != get_output_size())
OPENVINO_THROW("Get output port descriptor is failed: incorrect count");
const auto& port_desc = out_descs[0];
return utils::get_reordered_planar_shape(output_shape, port_desc->get_layout());
}
return output_shape; return output_shape;
} }
@ -178,6 +189,76 @@ ov::PartialShape Brgemm::get_output_partial_shape(const std::vector<ov::PartialS
return output_shape; return output_shape;
} }
Brgemm::ShapeInfer::ShapeInfer(const std::shared_ptr<Node>& n) {
for (const auto& in : n->inputs()) {
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in);
m_io_layouts.push_back(port->get_layout());
}
m_io_layouts.push_back(get_output_layout(n));
}
IShapeInferSnippets::Result Brgemm::ShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference");
// Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it?
const auto& arg0_shape = snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]);
const auto& arg1_shape = snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]);
size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size();
// temporary shapes to calculate output shape
VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape);
// one-dimensional tensors unsqueezing is applied to each input independently.
if (arg0_rank == 1) {
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
// For example {S} will be reshaped to {1, S}.
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
arg0_rank = arg0_shape_tmp.size();
}
if (arg1_rank == 1) {
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
// For example {S} will be reshaped to {S, 1}.
arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
arg1_rank = arg1_shape_tmp.size();
}
// add 1 to begin to align shape ranks if needed
if (arg0_rank < arg1_rank)
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
else if (arg0_rank > arg1_rank)
arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);
size_t max_rank = arg0_shape_tmp.size();
VectorDims output_shape(max_rank);
for (size_t i = 0; i < max_rank - 2; ++i) {
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
output_shape[i] = arg0_shape_tmp[i];
} else {
if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg1_shape_tmp[i];
else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg0_shape_tmp[i];
else
OPENVINO_THROW("Incompatible Brgemm batch dimension");
}
}
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N
// removing the temporary axes from originally 1D tensors.
if (arg0_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 2);
}
if (arg1_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 1);
}
output_shape = snippets::utils::get_planar_vdims(output_shape, m_io_layouts[2]);
return {{output_shape}, snippets::ShapeInferStatus::success};
}
} // namespace op } // namespace op
} // namespace snippets } // namespace snippets
} // namespace ov } // namespace ov

View File

@ -5,6 +5,7 @@
#include "snippets/itt.hpp" #include "snippets/itt.hpp"
#include "snippets/op/load.hpp" #include "snippets/op/load.hpp"
#include "snippets/utils.hpp"
namespace ov { namespace ov {
@ -69,6 +70,15 @@ std::shared_ptr<Node> LoadReshape::clone_with_new_inputs(const OutputVector& new
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<LoadReshape>(new_args.at(0), get_count(), get_offset(), m_order); return std::make_shared<LoadReshape>(new_args.at(0), get_count(), get_offset(), m_order);
} }
LoadReshape::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) {
const auto& loadReshape = ov::as_type_ptr<LoadReshape>(n);
OPENVINO_ASSERT(loadReshape, "Got invalid node in LoadReshape::ShapeInfer");
m_order = loadReshape->m_order;
}
IShapeInferSnippets::Result LoadReshape::ShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes");
return {{utils::get_planar_vdims(input_shapes[0], m_order)}, ShapeInferStatus::success};
}
}// namespace op }// namespace op
}// namespace snippets }// namespace snippets

View File

@ -44,7 +44,7 @@
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
#include "snippets/pass_manager.hpp" #include "snippets/pass_manager.hpp"
#include "ngraph/pass/constant_folding.hpp" #include "openvino/pass/constant_folding.hpp"
#include "ov_ops/type_relaxed.hpp" #include "ov_ops/type_relaxed.hpp"
#include <openvino/pass/serialize.hpp> #include <openvino/pass/serialize.hpp>
@ -104,7 +104,7 @@ auto Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t {
for (const auto& op : ops) { for (const auto& op : ops) {
if (const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(op)) { if (const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(op)) {
// At the moment Transposes are supported only on Results and Parameters but // At the moment Transposes are supported only on Results and Parameters, but
// then we should have the different Buffers for Transpose as well (Transpose isn't inplace) // then we should have the different Buffers for Transpose as well (Transpose isn't inplace)
const auto consumers = transpose->get_output_target_inputs(0); const auto consumers = transpose->get_output_target_inputs(0);
// If after Transpose there is Result it means that there won't be Buffer after Transpose. // If after Transpose there is Result it means that there won't be Buffer after Transpose.
@ -119,7 +119,7 @@ auto Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t {
} }
} else if (ov::is_type<ov::op::v1::Softmax>(op) || ov::is_type<ov::op::v8::Softmax>(op)) { } else if (ov::is_type<ov::op::v1::Softmax>(op) || ov::is_type<ov::op::v8::Softmax>(op)) {
// Softmax always uses 2 FP32 Buffers after decomposition. // Softmax always uses 2 FP32 Buffers after decomposition.
// They are inplace and the same so we can push precision size only once // They are inplace and the same, so we can push precision size only once
push_prc_size(ov::element::f32.size()); push_prc_size(ov::element::f32.size());
} else if (const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(op)) { } else if (const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(op)) {
// Since all buffers around Matmul must be unique, we explicitely add values to the vector without any checks // Since all buffers around Matmul must be unique, we explicitely add values to the vector without any checks
@ -195,7 +195,7 @@ void Subgraph::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::validate_and_infer_types") OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::validate_and_infer_types")
ov::ParameterVector old_parameters; ov::ParameterVector old_parameters;
for (auto op : body_ptr()->get_parameters()) { for (const auto& op : body_ptr()->get_parameters()) {
old_parameters.push_back(op); old_parameters.push_back(op);
} }
@ -257,7 +257,7 @@ auto Subgraph::wrap_node_as_subgraph(const std::shared_ptr<ov::Node>& node) -> s
} }
ov::ResultVector body_results; ov::ResultVector body_results;
for (auto output : node->outputs()) { for (const auto& output : node->outputs()) {
body_results.push_back(std::make_shared<ov::opset1::Result>(body_node->output(output.get_index()))); body_results.push_back(std::make_shared<ov::opset1::Result>(body_node->output(output.get_index())));
} }
@ -469,6 +469,75 @@ bool Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noex
(elementwise->get_autob().m_type != ov::op::AutoBroadcastType::PDPD); (elementwise->get_autob().m_type != ov::op::AutoBroadcastType::PDPD);
} }
IShapeInferSnippets::Result Subgraph::shape_infer(const std::vector<VectorDimsRef>& input_shapes) {
if (!m_shape_infer && !m_linear_ir) {
OPENVINO_ASSERT(body_ptr(), "Can't create shape infer for Subgraph with an empty body");
m_shape_infer = std::make_shared<NgraphShapeInfer>(body_ptr());
} else if (!std::dynamic_pointer_cast<LIRShapeInfer>(m_shape_infer) && m_linear_ir) {
m_shape_infer = std::make_shared<LIRShapeInfer>(m_linear_ir);
}
return m_shape_infer->infer(input_shapes);
}
Subgraph::NgraphShapeInfer::NgraphShapeInfer(const std::shared_ptr<ov::Model>& body) :
m_ngraph_body(body), m_parameters(body->get_parameters()), m_results(body->get_results()) {
}
IShapeInferSnippets::Result Subgraph::NgraphShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(m_parameters.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
for (size_t i = 0; i < m_parameters.size(); ++i)
m_parameters[i]->set_partial_shape(utils::vdims_to_pshape(input_shapes[i].get()));
m_ngraph_body->validate_nodes_and_infer_types();
std::vector<VectorDims> outputDims;
for (const auto& res : m_results)
outputDims.emplace_back(utils::pshape_to_vdims(res->get_input_partial_shape(0)));
m_last_result = {outputDims, ShapeInferStatus::success};
return m_last_result;
}
Subgraph::LIRShapeInfer::LIRShapeInfer(const std::shared_ptr<lowered::LinearIR>& body) :
m_lir_body(body) {
for (const auto& io_expr : m_lir_body->get_IO_ops()) {
switch (io_expr->get_type()) {
case IOExpression::io_type::INPUT : m_param_exprs.push_back(io_expr); break;
case IOExpression::io_type::OUTPUT : m_result_exprs.push_back(io_expr); break;
default : OPENVINO_THROW("Undefined io expression type");
}
}
}
IShapeInferSnippets::Result
Subgraph::LIRShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(m_param_exprs.size() == input_shapes.size(), "Got invalid number of input shapes in LIR ShapeInfer");
// todo: check that order of param_exprs is always the same as that of input_shapes
// if not use io_expr index to sort in constructor
for (size_t i = 0; i < m_param_exprs.size(); ++i) {
m_param_exprs[i]->get_output_port_descriptor(0)->set_shape(input_shapes[i]);
}
for (const auto& expr : *m_lir_body) {
if (expr->needShapeInfer())
expr->updateShapes();
}
std::vector<VectorDims> outputDims;
outputDims.reserve(m_result_exprs.size());
for (const auto& r : m_result_exprs) {
outputDims.push_back(r->get_input_port_descriptor(0)->get_shape());
}
m_last_result = {outputDims, ShapeInferStatus::success};
return m_last_result;
}
std::shared_ptr<lowered::LinearIR>
Subgraph::convert_body_to_linear_ir(const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory) const {
lowered::Config lowering_config;
lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops;
lowering_config.m_need_fill_tail_register = config.m_has_domain_sensitive_ops;
lowering_config.m_loop_depth = tileRank;
return std::make_shared<lowered::LinearIR>(body_ptr(), shape_infer_factory, lowering_config);
}
void Subgraph::align_element_types(const BlockedShapeVector& outputShapes, void Subgraph::align_element_types(const BlockedShapeVector& outputShapes,
const BlockedShapeVector& inputShapes) { const BlockedShapeVector& inputShapes) {
// We should insert Convert before Results to set original output element type if needed // We should insert Convert before Results to set original output element type if needed
@ -632,18 +701,21 @@ snippets::Schedule Subgraph::generate(const BlockedShapeVector& output_shapes,
const std::vector<pass::Manager::PositionedPass>& data_flow_passes, const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
const lowered::pass::PassPipeline& control_flow_passes_pre_common, const lowered::pass::PassPipeline& control_flow_passes_pre_common,
const lowered::pass::PassPipeline& control_flow_passes_post_common, const lowered::pass::PassPipeline& control_flow_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory,
const void* compile_params) { const void* compile_params) {
canonicalize(output_shapes, input_shapes); canonicalize(output_shapes, input_shapes);
return generate(data_flow_passes, control_flow_passes_pre_common, control_flow_passes_post_common, compile_params); return generate(data_flow_passes, control_flow_passes_pre_common, control_flow_passes_post_common,
shape_infer_factory, compile_params);
} }
snippets::Schedule Subgraph::generate(const void* compile_params) { snippets::Schedule Subgraph::generate(const void* compile_params) {
return generate({}, {}, {}, compile_params); return generate({}, {}, {}, nullptr, compile_params);
} }
snippets::Schedule Subgraph::generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes, snippets::Schedule Subgraph::generate(const std::vector<pass::Manager::PositionedPass>& data_flow_passes,
const lowered::pass::PassPipeline& control_flow_passes_pre_common, const lowered::pass::PassPipeline& control_flow_passes_pre_common,
const lowered::pass::PassPipeline& control_flow_passes_post_common, const lowered::pass::PassPipeline& control_flow_passes_post_common,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory,
const void* compile_params) { const void* compile_params) {
INTERNAL_OP_SCOPE(Subgraph); INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::generate") OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::generate")
@ -651,16 +723,11 @@ snippets::Schedule Subgraph::generate(const std::vector<pass::Manager::Positione
data_flow_transformations(data_flow_passes); data_flow_transformations(data_flow_passes);
lowered::Config lowering_config; lowered::LinearIR linear_ir = *convert_body_to_linear_ir(shape_infer_factory);
lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops;
lowering_config.m_need_fill_tail_register = config.m_has_domain_sensitive_ops;
lowering_config.m_loop_depth = tileRank;
lowered::LinearIR linear_ir = lowered::LinearIR(body_ptr(), lowering_config);
control_flow_transformations(linear_ir, control_flow_passes_pre_common, control_flow_passes_post_common); control_flow_transformations(linear_ir, control_flow_passes_pre_common, control_flow_passes_post_common);
// actual code emission // actual code emission
const auto& lowering_result = m_generator->generate(linear_ir, lowering_config, compile_params); const auto& lowering_result = m_generator->generate(linear_ir, linear_ir.get_config(), compile_params);
const auto ptr = lowering_result.binary_code; const auto ptr = lowering_result.binary_code;
return {master_shape, false /*canBeLinearized*/, ptr}; return {master_shape, false /*canBeLinearized*/, ptr};
@ -691,66 +758,6 @@ void Subgraph::print() const {
} }
} }
void Subgraph::print_statistics(bool verbose) {
INTERNAL_OP_SCOPE(Subgraph);
auto getNodeInventory = [](std::shared_ptr<ov::Node> n) -> size_t {
size_t total = 0;
for (auto input : n->inputs()) {
total += input.get_tensor().size();
}
for (auto output : n->outputs()) {
total += output.get_tensor().size();
}
if (auto subgraph = ov::as_type_ptr<op::Subgraph>(n)) {
for (auto op : subgraph->body_ptr()->get_ordered_ops()) {
if (ov::as_type_ptr<ov::opset1::Constant>(op)) {
total += op->output(0).get_tensor().size();
}
}
}
return total;
};
auto getModelInventory = [getNodeInventory](const ov::Model& f) -> size_t {
size_t total = 0;
for (auto op : f.get_ordered_ops()) {
// Results and parameters are artificially introduced,
// while Constants are already considered if they are inputs of other operation
// this should lead to 1:1 inventory for single node operations
if (!ov::as_type_ptr<ov::opset1::Parameter>(op)
&& !ov::as_type_ptr<ov::opset1::Result>(op)
&& !ov::as_type_ptr<ov::opset1::Constant>(op)) {
total += getNodeInventory(op);
}
}
return total;
};
auto countConstants = [](const ov::Model& f) -> size_t {
size_t count = 0;
for (auto op : f.get_ordered_ops()) {
count += !!ov::as_type_ptr<ov::opset1::Constant>(op) ? 1 : 0;
}
return count;
};
std::cout << get_friendly_name()
<< ";" << this
<< ";" << body_ptr()->get_ops().size()
<< ";" << body_ptr()->get_parameters().size()
<< ";" << body_ptr()->get_results().size()
<< ";" << countConstants(body())
<< ";" << getModelInventory(body())
<< ";" << getNodeInventory(shared_from_this()) << std::endl;
if (verbose) {
this->print();
}
}
void Subgraph::serialize() const { void Subgraph::serialize() const {
std::stringstream xmlFile, binFile; std::stringstream xmlFile, binFile;

View File

@ -0,0 +1,166 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/shape_inference/shape_infer_instances.hpp"
#include "snippets/snippets_isa.hpp"
#include "openvino/op/select.hpp"
namespace ov {
namespace snippets {
using Result = IShapeInferSnippets::Result;
namespace {
/*
* Merge SRC to DST with broadcasting rules defined by the Autobroadcast specifier
*/
bool broadcast_merge_into(VectorDims& dst, const VectorDims& src, const ov::op::AutoBroadcastSpec& autob) {
auto broadcast_merge_dim = [](size_t& dst, const size_t& d1, const size_t& d2) {
if (d1 == d2 || d1 == 1 || d1 == IShapeInferSnippets::DYNAMIC_DIMENSION) {
dst = d2;
} else if (d2 == 1 || d2 == IShapeInferSnippets::DYNAMIC_DIMENSION) {
dst = d1;
} else {
return false;
}
return true;
};
// Ranks are both static.
const auto dst_rank = static_cast<int64_t>(dst.size());
const auto src_rank = static_cast<int64_t>(src.size());
switch (autob.m_type) {
case ov::op::AutoBroadcastType::NONE:
return true;
case ov::op::AutoBroadcastType::NUMPY: {
const auto new_rank = std::max(dst_rank, src_rank);
VectorDims dims(new_rank);
bool success = true;
for (int64_t i = 0; i < new_rank; i++) {
auto dsti = i < (new_rank - dst_rank) ? 1 : dst[i - (new_rank - dst_rank)];
auto srci = i < (new_rank - src_rank) ? 1 : src[i - (new_rank - src_rank)];
success &= broadcast_merge_dim(dims[i], dsti, srci);
}
dst = std::move(dims);
return success;
}
case ov::op::AutoBroadcastType::PDPD: {
int64_t axis = autob.m_axis;
if (src_rank > dst_rank || axis < -1)
return false;
axis = (axis == -1) ? (dst_rank - src_rank) : axis;
if (src_rank + axis > dst_rank)
return false;
bool success = true;
for (int64_t i = 0; i < src_rank; ++i) {
if (dst[axis + i] != IShapeInferSnippets::DYNAMIC_DIMENSION &&
src[i] != IShapeInferSnippets::DYNAMIC_DIMENSION) {
if (src[i] > dst[axis + i])
return false;
}
success &= broadcast_merge_dim(dst[axis + i], dst[axis + i], src[i]);
}
return success;
}
default:
OPENVINO_THROW("Unsupported auto broadcast type: ", autob.m_type);
}
return false;
}
/*
* Merge SRC to DST, no broadcasting is allowed
*/
bool merge_into(VectorDims& dst, const VectorDims& src) {
auto merge_dim = [](size_t& dst, const size_t& d1, const size_t& d2) {
if (d1 == d2 || d1 == IShapeInferSnippets::DYNAMIC_DIMENSION) {
dst = d2;
} else if (d2 == IShapeInferSnippets::DYNAMIC_DIMENSION) {
dst = d1;
} else {
return false;
}
return true;
};
if (dst.size() != src.size())
return false;
bool success = true;
for (size_t i = 0; i < dst.size(); i++)
success &= merge_dim(dst[i], dst[i], src[i]);
return success;
}
} // namespace
Result NumpyBroadcastShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(!input_shapes.empty(), "No input shapes were provided for NumpyBroadcastShapeInfer");
auto output_shape = input_shapes[0].get();
for (size_t i = 1; i < input_shapes.size(); i++) {
OPENVINO_ASSERT(broadcast_merge_into(output_shape, input_shapes[i], ov::op::AutoBroadcastType::NUMPY),
"Failed to broadcast-merge input shapes in NumpyBroadcastShapeInfer");
}
return {{std::move(output_shape)}, ShapeInferStatus::success};
}
template<class BroadcastOP>
BroadcastShapeInfer<BroadcastOP>::BroadcastShapeInfer(const std::shared_ptr<Node>& n) {
static_assert(std::is_base_of<snippets::op::BroadcastMove, BroadcastOP>() ||
std::is_base_of<snippets::op::BroadcastLoad, BroadcastOP>(),
"This ShapeInfer class could be used only for BroadcastMove and BroadcastLoad operations.");
const auto& broadcast = as_type_ptr<BroadcastOP>(n);
OPENVINO_ASSERT(broadcast, "Invalid node passed to BroadcastShapeInfer.",
"Expected ", typeid(BroadcastOP).name(), "got ", n->get_type_name());
const auto last_dim = *broadcast->get_output_shape().rbegin();
m_broadcasted_dim = last_dim.is_dynamic() ? IShapeInferSnippets::DYNAMIC_DIMENSION : last_dim.get_length();
}
template<class BroadcastOP>
Result BroadcastShapeInfer<BroadcastOP>::infer(const std::vector<VectorDimsRef>& input_shapes) {
auto out_shape = input_shapes[0].get();
out_shape.back() = m_broadcasted_dim;
return {{out_shape}, ShapeInferStatus::success};
}
//// Note: we need to manually create template instances here, so they can be reused in Broadcast* headers.
template class BroadcastShapeInfer<op::BroadcastMove>;
template class BroadcastShapeInfer<op::BroadcastLoad>;
SelectShapeInfer::SelectShapeInfer(const std::shared_ptr<Node>& n) {
const auto& select = as_type_ptr<ov::op::v1::Select>(n);
OPENVINO_ASSERT(select, "Invalid node passed to SelectShapeInfer.");
m_broadcast_spec = select->get_auto_broadcast();
}
Result SelectShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 3, "Invalid number of shapes passed SelectShapeInfer");
VectorDims result_shape;
if (m_broadcast_spec == ov::op::AutoBroadcastType::PDPD) {
result_shape = input_shapes[1]; // 'then' tensor
// in PDPD type, Broadcast-merging 'else' into 'then' one way not each other.
OPENVINO_ASSERT(broadcast_merge_into(result_shape, input_shapes[2], m_broadcast_spec),
"'Else' tensor shape is not broadcastable.");
OPENVINO_ASSERT(broadcast_merge_into(result_shape, input_shapes[0], m_broadcast_spec),
"'Cond' tensor shape is not broadcastable.");
} else {
result_shape = input_shapes[2];
for (int input_port = 1; input_port >= 0; input_port--) {
if (m_broadcast_spec.m_type == ov::op::AutoBroadcastType::NONE) {
OPENVINO_ASSERT(merge_into(result_shape, input_shapes[input_port]),
"Argument shapes are inconsistent.");
} else if (m_broadcast_spec.m_type == ov::op::AutoBroadcastType::NUMPY) {
OPENVINO_ASSERT(broadcast_merge_into(result_shape, input_shapes[input_port], m_broadcast_spec),
"Argument shapes are inconsistent.");
} else {
OPENVINO_THROW("Unsupported auto broadcast specification");
}
}
}
return {{result_shape}, ShapeInferStatus::success};
}
Result HorizonOpShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 1, "Got invalid number of input shapes in HorizonShapeInfer");
auto output_shapes = input_shapes[0].get();
if (!output_shapes.empty())
output_shapes.back() = 1;
return {{output_shapes}, ShapeInferStatus::success};
}
} // namespace snippets
} // namespace ov

View File

@ -0,0 +1,90 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/shape_inference/shape_inference.hpp"
#include "snippets/shape_inference/shape_infer_instances.hpp"
#include <openvino/op/util/unary_elementwise_arithmetic.hpp>
#include <openvino/op/util/binary_elementwise_arithmetic.hpp>
#include <openvino/op/util/binary_elementwise_comparison.hpp>
#include <openvino/op/util/binary_elementwise_logical.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/op/result.hpp>
#include <snippets/snippets_isa.hpp>
namespace ov {
namespace snippets {
using ShapeInferPtr = IShapeInferSnippetsFactory::ShapeInferPtr;
ShapeInferPtr IShapeInferSnippetsFactory::make(const ov::DiscreteTypeInfo& key, const std::shared_ptr<ov::Node>& op) {
const auto& maker_iter = registry.find(key);
if (maker_iter != registry.end())
return maker_iter->second(op);
return get_specific_op_shape_infer(key, op);
}
ShapeInferPtr IShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key,
const std::shared_ptr<ov::Node>& op) const {
return {};
}
#define SHAPE_INFER_PREDEFINED(OP, InferType) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>();} }
#define SHAPE_INFER_OP_SPECIFIC(OP) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<OP::ShapeInfer>(n);} }
#define SHAPE_INFER_OP_SPECIFIC_EXTERNAL(OP, InferType) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>(n);} }
const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry {
SHAPE_INFER_PREDEFINED(op::ConvertTruncation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(op::ConvertSaturation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(op::Load, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(op::Store, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(op::Buffer, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(op::Fill, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::op::v0::Parameter, PassThroughShapeInfer),
// Note: We should remove Softmax shape infers after the decomposition activity,
// since there won't be any Softmax ops on LIR. Ticket: 112847
SHAPE_INFER_PREDEFINED(ov::op::v1::Softmax, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::op::v8::Softmax, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::op::v1::LogicalNot, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::op::v0::PRelu, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(op::HorizonMax, HorizonOpShapeInfer),
SHAPE_INFER_PREDEFINED(op::HorizonSum, HorizonOpShapeInfer),
//
SHAPE_INFER_PREDEFINED(op::LoopBegin, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::Scalar, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::VectorBuffer, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::LoopEnd, EmptyShapeInfer),
SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer),
// Note that Result has no output PortConnectors, so the shape must be empty
SHAPE_INFER_PREDEFINED(ov::op::v0::Result, EmptyShapeInfer),
//
SHAPE_INFER_OP_SPECIFIC(op::LoadReshape),
SHAPE_INFER_OP_SPECIFIC(op::Brgemm),
SHAPE_INFER_OP_SPECIFIC(op::BroadcastLoad),
SHAPE_INFER_OP_SPECIFIC(op::BroadcastMove),
};
#undef SHAPE_INFER_OP_SPECIFIC_EXTERNAL
#undef SHAPE_INFER_OP_SPECIFIC
#undef SHAPE_INFER_PREDEFINED
std::shared_ptr<IShapeInferSnippets> make_shape_inference(const std::shared_ptr<ov::Node>& op,
const std::shared_ptr<IShapeInferSnippetsFactory>& factory) {
if (!factory) {
return nullptr;
} else if (auto shape_infer = factory->make(op->get_type_info(), op)) {
return shape_infer;
} else if (ov::is_type<ov::op::util::UnaryElementwiseArithmetic>(op)) {
return std::make_shared<PassThroughShapeInfer>();
} else if (ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(op) ||
ov::is_type<ov::op::util::BinaryElementwiseComparison>(op) ||
ov::is_type<ov::op::util::BinaryElementwiseLogical>(op)) {
return std::make_shared<NumpyBroadcastShapeInfer>();
} else {
OPENVINO_THROW("Operation type " + std::string(op->get_type_info().name) + " is not supported in Snippets shape inference pipeline");
}
}
} // namespace snippets
} // namespace ov

View File

@ -70,7 +70,7 @@ auto get_non_scalar_constant_count_for_fq(const std::shared_ptr<ov::op::v0::Fake
} }
} }
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout) { ov::PartialShape get_planar_pshape(const ov::PartialShape& shape, const std::vector<size_t>& layout) {
if (layout.empty()) if (layout.empty())
return shape; return shape;
std::vector<Dimension> reordered_shape(layout.size()); std::vector<Dimension> reordered_shape(layout.size());
@ -87,14 +87,47 @@ ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const
return reordered_shape; return reordered_shape;
} }
ov::PartialShape get_port_planar_shape(const Input<Node>& in) { VectorDims pshape_to_vdims(const PartialShape& pshape) {
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in); VectorDims result;
return utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout()); result.reserve(pshape.size());
for (const auto& d : pshape)
result.push_back(d.is_dynamic() ? IShapeInferSnippets::DYNAMIC_DIMENSION : d.get_length());
return result;
} }
ov::PartialShape get_port_planar_shape(const Output<Node>& out) { ov::PartialShape vdims_to_pshape(const VectorDims& vdims) {
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(out); ov::PartialShape result;
return utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout()); result.reserve(vdims.size());
for (const auto& v : vdims)
result.push_back(v != IShapeInferSnippets::DYNAMIC_DIMENSION ?
Dimension(static_cast<Dimension::value_type>(v)) :
Dimension());
return result;
}
ov::PartialShape get_planar_pshape(const Input<Node>& in) {
const auto& port = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(in);
return utils::get_planar_pshape(ov::Shape{port->get_shape()}, port->get_layout());
}
ov::PartialShape get_planar_pshape(const Output<Node>& out) {
const auto& port = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(out);
return utils::get_planar_pshape(ov::Shape{port->get_shape()}, port->get_layout());
}
VectorDims get_planar_vdims(const VectorDims& shape, const std::vector<size_t>& layout) {
VectorDims reordered_shape(shape.size());
for (size_t i = 0; i < layout.size(); i++)
reordered_shape[i] = shape[layout[i]];
return reordered_shape;
}
VectorDims get_planar_vdims(const snippets::lowered::PortDescriptorPtr& port_desc) {
return get_planar_vdims(port_desc->get_shape(), port_desc->get_layout());
}
VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port) {
return get_planar_vdims(expr_port.get_descriptor_ptr());
} }
} // namespace utils } // namespace utils

View File

@ -13,6 +13,7 @@
#include "snippets/lowered/pass/validate_loops.hpp" #include "snippets/lowered/pass/validate_loops.hpp"
#include "snippets/lowered/pass/insert_loops.hpp" #include "snippets/lowered/pass/insert_loops.hpp"
#include "snippets/lowered/pass/insert_tail_loop.hpp" #include "snippets/lowered/pass/insert_tail_loop.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
#include "snippets/op/loop.hpp" #include "snippets/op/loop.hpp"
@ -26,7 +27,8 @@ constexpr static size_t vector_size = 16;
static void init_linear_ir(const std::vector<ov::PartialShape>& in_shapes, LinearIR& linear_ir, size_t block_size) { static void init_linear_ir(const std::vector<ov::PartialShape>& in_shapes, LinearIR& linear_ir, size_t block_size) {
auto body = ov::test::snippets::AddFunction(in_shapes).getOriginal(); auto body = ov::test::snippets::AddFunction(in_shapes).getOriginal();
linear_ir = LinearIR(body); auto shape_infer_factory = std::make_shared<ov::snippets::IShapeInferSnippetsFactory>();
linear_ir = LinearIR(body, shape_infer_factory);
auto expr_it = std::find_if(linear_ir.cbegin(), linear_ir.cend(), auto expr_it = std::find_if(linear_ir.cbegin(), linear_ir.cend(),
[](const ExpressionPtr& expr) { return ov::is_type<ov::op::v1::Add>(expr->get_node()); }); [](const ExpressionPtr& expr) { return ov::is_type<ov::op::v1::Add>(expr->get_node()); });
ASSERT_TRUE(expr_it != linear_ir.cend()); ASSERT_TRUE(expr_it != linear_ir.cend());

View File

@ -131,8 +131,13 @@ KernelEmitter::KernelEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPt
IE_THROW() << "Kernel detected unsupported io_type"; IE_THROW() << "Kernel detected unsupported io_type";
} }
} }
io_shapes.push_back(desc->get_shape()); const auto& shape = desc->get_shape();
io_data_layouts.push_back(desc->get_layout()); const auto& layout = desc->get_layout();
OPENVINO_ASSERT(shape.size() == layout.size(), "Shape and layout must have the same length");
const auto max_dim = *std::max_element(layout.begin(), layout.end());
OPENVINO_ASSERT(max_dim < shape.size(), "Max layout index can't be larger than the shape size");
io_shapes.push_back(shape);
io_data_layouts.push_back(layout);
io_data_sizes.push_back(etype.size()); io_data_sizes.push_back(etype.size());
} }

View File

@ -9,13 +9,10 @@
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <tuple>
#include <dnnl_debug.h>
#include <onednn/dnnl.h> #include <onednn/dnnl.h>
#include <dnnl_extension_utils.h> #include <dnnl_extension_utils.h>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pass/visualize_tree.hpp> #include <ngraph/pass/visualize_tree.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ie_ngraph_utils.hpp> #include <ie_ngraph_utils.hpp>
@ -31,6 +28,7 @@
#include "transformations/snippets/x64/pass/remove_converts.hpp" #include "transformations/snippets/x64/pass/remove_converts.hpp"
#include "transformations/snippets/x64/pass/enforce_precision.hpp" #include "transformations/snippets/x64/pass/enforce_precision.hpp"
#include "transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.hpp" #include "transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.hpp"
#include "transformations/snippets/x64/shape_inference.hpp"
#include "transformations/cpu_opset/common/pass/convert_to_swish_cpu.hpp" #include "transformations/cpu_opset/common/pass/convert_to_swish_cpu.hpp"
#include "transformations/defs.hpp" #include "transformations/defs.hpp"
#include "shape_inference/custom/subgraph.hpp" #include "shape_inference/custom/subgraph.hpp"
@ -142,7 +140,7 @@ snippets::op::Subgraph::BlockedShapeVector getBlockedShapes(const std::vector<st
} }
} // namespace } // namespace
Snippet::Snippet(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context) Snippet::Snippet(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context)
: Node(op, context, SnippetShapeInferFactory(op)) { : Node(op, context, SnippetShapeInferFactory(op)) {
host_isa = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ? host_isa = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ?
dnnl::impl::cpu::x64::avx512_core : dnnl::impl::cpu::x64::avx2; dnnl::impl::cpu::x64::avx512_core : dnnl::impl::cpu::x64::avx2;
@ -724,10 +722,12 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp)
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline; ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert); CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert);
// Todo: We don't need shape infer factory now, since shape infer will be done through validate_and_infer_types
// pass std::make_shared<snippets::CPUShapeInferSnippetsFactory>() instead of nullptr, when shape infer is performed on LIR
schedule = snippet_for_generation->generate(backend_passes, schedule = snippet_for_generation->generate(backend_passes,
control_flow_markup_pipeline, control_flow_markup_pipeline,
control_flow_pipeline, control_flow_pipeline,
nullptr,
reinterpret_cast<const void*>(jcp)); reinterpret_cast<const void*>(jcp));
} }

View File

@ -24,7 +24,7 @@ namespace node {
/// precision: fp32 /// precision: fp32
class Snippet : public Node { class Snippet : public Node {
public: public:
Snippet(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context); Snippet(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr& context);
~Snippet() override = default; ~Snippet() override = default;
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
@ -33,10 +33,6 @@ public:
void initOptimalPrimitiveDescriptor() override; void initOptimalPrimitiveDescriptor() override;
InferenceEngine::Precision getRuntimePrecision() const override; InferenceEngine::Precision getRuntimePrecision() const override;
// to avoid collisions in throughput mode with copy of TypeRelaxed nodes
// we should have common shared mutex between streams
void setSharedMutex(const std::shared_ptr<std::mutex>& mutex);
// Here we convert to canonical for & jit everything // Here we convert to canonical for & jit everything
void prepareParams() override; void prepareParams() override;
bool needPrepareParams() const override; bool needPrepareParams() const override;

View File

@ -10,69 +10,19 @@ namespace ov {
namespace intel_cpu { namespace intel_cpu {
namespace node { namespace node {
using Result = IShapeInfer::Result; using Result = IShapeInfer::Result;
class SnippetShapeInfer : public ShapeInferEmptyPads { class SnippetShapeInfer : public ShapeInferEmptyPads {
public: public:
SnippetShapeInfer(std::shared_ptr<ov::Model> body) : m_body(body) {} explicit SnippetShapeInfer(const std::shared_ptr<snippets::op::Subgraph>& s) : m_subgraph(s) {
m_status_map[snippets::ShapeInferStatus::success] = ov::intel_cpu::ShapeInferStatus::success;
m_status_map[snippets::ShapeInferStatus::skip] = ov::intel_cpu::ShapeInferStatus::skip;
}
Result infer( Result infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes, const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override { const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
auto broadcast_merge = [](VectorDims& dst, const VectorDims& src) { const auto& snippets_result = m_subgraph->shape_infer(input_shapes);
// Ranks are both static. OPENVINO_ASSERT(m_status_map.count(snippets_result.status) != 0, "Failed to map snippets shapeInfer status to the plugin one");
auto dst_rank = dst.size(); return {snippets_result.dims, m_status_map.at(snippets_result.status)};
auto src_rank = src.size();
const auto new_rank = std::max(dst_rank, src_rank);
dst.insert(dst.begin(), new_rank - dst_rank, 1);
for (size_t i = 0; i < new_rank; i++) {
auto srci = i < (new_rank - src_rank) ? 1 : src[i - (new_rank - src_rank)];
if (dst[i] != srci && srci != Shape::UNDEFINED_DIM) {
if (dst[i] == 1 || dst[i] == Shape::UNDEFINED_DIM) {
dst[i] = srci;
} else {
if (srci != 1) {
IE_THROW() << "Got imcompatible input shapes in snippets shape infer";
}
}
}
}
};
const size_t out_size = m_body->get_output_size();
if (out_size == 1) {
VectorDims masterShape;
for (size_t i = 0; i < input_shapes.size(); i++) {
if (i == 0)
masterShape = input_shapes[i];
else
broadcast_merge(masterShape, input_shapes[i]);
}
size_t output_rank = m_body->get_output_partial_shape(0).rank().get_length();
if (output_rank > masterShape.size()) {
masterShape.insert(masterShape.begin(), output_rank - masterShape.size(), 1);
}
return {{masterShape}, ShapeInferStatus::success};
} else {
std::vector<VectorDims> outputDims;
std::vector<ov::Shape> new_shapes;
for (const auto& s : input_shapes)
new_shapes.emplace_back(s);
auto& params = m_body->get_parameters();
if (params.size() != input_shapes.size()) {
IE_THROW() << "Got invalid number of input shapes to reshape subgraph body";
}
for (size_t i = 0; i < params.size(); ++i) {
params[i]->set_partial_shape(new_shapes[i]);
}
m_body->validate_nodes_and_infer_types();
for (const auto& res : m_body->get_results()) {
auto& pshape = res->get_input_partial_shape(0);
if (!pshape.is_static()) {
IE_THROW() << "Subgraph inferred dynamic output shape during reshape with static inputs";
}
outputDims.emplace_back(pshape.get_shape());
}
return {outputDims, ShapeInferStatus::success};
}
} }
port_mask_t get_port_mask() const override { port_mask_t get_port_mask() const override {
@ -80,21 +30,22 @@ public:
} }
private: private:
std::shared_ptr<ov::Model> m_body; std::shared_ptr<snippets::op::Subgraph> m_subgraph;
std::map<snippets::ShapeInferStatus, ov::intel_cpu::ShapeInferStatus> m_status_map;
}; };
class SnippetShapeInferFactory : public ShapeInferFactory { class SnippetShapeInferFactory : public ShapeInferFactory {
public: public:
SnippetShapeInferFactory(const std::shared_ptr<ov::Node>& op) { explicit SnippetShapeInferFactory(const std::shared_ptr<ov::Node>& op) {
auto subgraph = ov::as_type_ptr<snippets::op::Subgraph>(op); m_subgraph = ov::as_type_ptr<snippets::op::Subgraph>(op);
snippet_body = subgraph->body_ptr()->clone(); OPENVINO_ASSERT(m_subgraph, "Invalid node type detected in SnippetShapeInferFactory");
} }
ShapeInferPtr makeShapeInfer() const override { ShapeInferPtr makeShapeInfer() const override {
return std::make_shared<SnippetShapeInfer>(snippet_body); return std::make_shared<SnippetShapeInfer>(m_subgraph);
} }
private: private:
std::shared_ptr<ov::Model> snippet_body = nullptr; std::shared_ptr<snippets::op::Subgraph> m_subgraph = nullptr;
}; };
} // namespace node } // namespace node
} // namespace intel_cpu } // namespace intel_cpu

View File

@ -10,13 +10,15 @@
#include "utils/general_utils.h" #include "utils/general_utils.h"
using namespace ov; namespace ov {
namespace intel_cpu {
intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type, intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type,
const size_t offset_in, const size_t offset_out0, const size_t offset_out1, const size_t offset_in, const size_t offset_out0, const size_t offset_out1,
std::vector<size_t> layout_input, const size_t blk_size_k, const size_t blk_size_n) std::vector<size_t> layout_input, const size_t blk_size_k, const size_t blk_size_n)
: snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), : snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1),
m_type(type), m_src_type(src_type) { m_type(type), m_src_type(src_type) {
m_brgemmVNNIFactor = 4 / m_src_type.size();
set_output_size(type == Type::WithCompensations ? 2 : 1); set_output_size(type == Type::WithCompensations ? 2 : 1);
set_input_port_descriptor({0, offset_in}, 0); set_input_port_descriptor({0, offset_in}, 0);
set_output_port_descriptor({0, offset_out0}, 0); set_output_port_descriptor({0, offset_out0}, 0);
@ -32,6 +34,7 @@ intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type s
std::vector<size_t> layout_input, const size_t blk_size_k, const size_t blk_size_n) std::vector<size_t> layout_input, const size_t blk_size_k, const size_t blk_size_n)
: snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), : snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1),
m_type(type), m_src_type(src_type) { m_type(type), m_src_type(src_type) {
m_brgemmVNNIFactor = 4 / m_src_type.size();
set_output_size(type == Type::WithCompensations ? 2 : 1); set_output_size(type == Type::WithCompensations ? 2 : 1);
set_input_port_descriptor(desc_in0, 0); set_input_port_descriptor(desc_in0, 0);
set_output_port_descriptor(desc_out0, 0); set_output_port_descriptor(desc_out0, 0);
@ -42,38 +45,38 @@ intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type s
custom_constructor_validate_and_infer_types(std::move(layout_input)); custom_constructor_validate_and_infer_types(std::move(layout_input));
} }
bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { bool BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes); INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes);
MemoryAccess::visit_attributes(visitor); MemoryAccess::visit_attributes(visitor);
visitor.on_attribute("src_type", m_src_type); visitor.on_attribute("src_type", m_src_type);
return true; return true;
} }
void intel_cpu::BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input) { void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input) {
INTERNAL_OP_SCOPE(BrgemmRepack_ctor_validate_and_infer_types); INTERNAL_OP_SCOPE(BrgemmRepack_ctor_validate_and_infer_types);
// During ctor call, BrgemmCopyB doesn't know his port descriptors. // During ctor call, BrgemmCopyB doesn't know his port descriptors.
// So we use port descs from source inputs // So we use port descs from source inputs
const auto element_type = get_input_element_type(0); const auto element_type = get_input_element_type(0);
const auto pshape = snippets::utils::get_reordered_planar_shape(get_input_partial_shape(0), layout_input); const auto pshape = snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_input);
validate(pshape, element_type); validate(pshape, element_type);
} }
void intel_cpu::BrgemmCopyB::validate_and_infer_types() { void BrgemmCopyB::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types); INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types);
const auto element_type = get_input_element_type(0); const auto& element_type = get_input_element_type(0);
const auto pshape = snippets::utils::get_port_planar_shape(input(0)); const auto& pshape = snippets::utils::get_planar_pshape(input(0));
validate(pshape, element_type); validate(pshape, element_type);
} }
void intel_cpu::BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::element::Type& element_type) { void BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::element::Type& element_type) {
NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8), NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8),
"BrgemmCopyB doesn't support element type" + element_type.get_type_name()); "BrgemmCopyB doesn't support element type" + element_type.get_type_name());
if (pshape.is_dynamic()) { if (pshape.is_dynamic()) {
set_output_type(0, element_type, ov::PartialShape{ov::Dimension::dynamic()}); set_output_type(0, element_type, ov::PartialShape {ov::Dimension::dynamic()});
if (is_with_compensations()) { if (is_with_compensations()) {
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()}); set_output_type(1, ov::element::f32, ov::PartialShape {ov::Dimension::dynamic()});
} }
return; return;
} }
@ -81,9 +84,8 @@ void intel_cpu::BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::
const auto shape = pshape.get_shape(); const auto shape = pshape.get_shape();
const auto N = *shape.rbegin(); const auto N = *shape.rbegin();
const auto K = *(shape.rbegin() + 1); const auto K = *(shape.rbegin() + 1);
const auto brgemmVNNIFactor = 4 / m_src_type.size();
set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, brgemmVNNIFactor)), set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, m_brgemmVNNIFactor)),
ov::Dimension(rnd_up(N, m_N_blk))}); ov::Dimension(rnd_up(N, m_N_blk))});
if (is_with_compensations()) { if (is_with_compensations()) {
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, m_N_blk))}); set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, m_N_blk))});
@ -91,7 +93,7 @@ void intel_cpu::BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::
} }
void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n) { void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n) {
const auto input_shape = snippets::utils::get_port_planar_shape(input(0)).get_shape(); const auto& input_shape = snippets::utils::get_planar_pshape(input(0)).get_shape();
m_K_blk = blk_size_k != 0 ? blk_size_k : *(input_shape.rbegin() + 1); m_K_blk = blk_size_k != 0 ? blk_size_k : *(input_shape.rbegin() + 1);
m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin(); m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin();
} }
@ -107,8 +109,41 @@ std::shared_ptr<Node> intel_cpu::BrgemmCopyB::clone_with_new_inputs(const Output
m_K_blk, m_N_blk); m_K_blk, m_N_blk);
} }
size_t intel_cpu::BrgemmCopyB::get_offset_compensations() const { size_t BrgemmCopyB::get_offset_compensations() const {
OPENVINO_ASSERT(is_with_compensations() && get_output_size() == 2, OPENVINO_ASSERT(is_with_compensations() && get_output_size() == 2,
"The offset for compensations must be in BrgemmCopyB only with compensations and 2 outputs!"); "The offset for compensations must be in BrgemmCopyB only with compensations and 2 outputs!");
return get_output_offset(1); return get_output_offset(1);
} }
BrgemmCopyB::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) {
const auto& brg_copyb = ov::as_type_ptr<BrgemmCopyB>(n);
OPENVINO_ASSERT(brg_copyb, "Got invalid node in BrgemmCopyB::ShapeInfer");
m_layout = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->input(0))->get_layout();
m_num_outs = brg_copyb->get_output_size();
m_N_blk = brg_copyb->get_n_block_size();
m_brgemmVNNIFactor = brg_copyb->m_brgemmVNNIFactor;
}
snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector<snippets::VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes");
const auto& old_shape = input_shapes[0].get();
snippets::VectorDims planar_shape;
planar_shape.reserve(old_shape.size());
for (const auto idx : m_layout)
planar_shape.push_back(old_shape[idx]);
const auto N = *planar_shape.rbegin();
const auto K = *(planar_shape.rbegin() + 1);
OPENVINO_ASSERT(N != DYNAMIC_DIMENSION && K != DYNAMIC_DIMENSION,
"BrgemmCopyB shape infer got dynamic N or K dimension, which is not supported");
std::vector<snippets::VectorDims> new_shapes(m_num_outs);
new_shapes[0].push_back(rnd_up(K, m_brgemmVNNIFactor));
new_shapes[0].push_back(rnd_up(N, m_N_blk));
if (m_num_outs == 2) {
new_shapes[1].push_back(rnd_up(N, m_N_blk));
}
return {new_shapes, snippets::ShapeInferStatus::success};
}
} // namespace intel_cpu
} // namespace ov

View File

@ -5,6 +5,7 @@
#pragma once #pragma once
#include "snippets/op/memory_access.hpp" #include "snippets/op/memory_access.hpp"
#include <snippets/shape_inference/shape_inference.hpp>
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
@ -51,6 +52,16 @@ public:
bool has_evaluate() const override { return false; } bool has_evaluate() const override { return false; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override; std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
class ShapeInfer : public snippets::IShapeInferSnippets {
std::vector<size_t> m_layout{};
size_t m_num_outs = 1;
size_t m_N_blk = 64;
size_t m_brgemmVNNIFactor = 1;
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
Result infer(const std::vector<snippets::VectorDimsRef>& input_shapes) override;
};
private: private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input = {}); void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input = {});
void validate(const ov::PartialShape& pshape, const ov::element::Type& element_type); void validate(const ov::PartialShape& pshape, const ov::element::Type& element_type);
@ -61,6 +72,7 @@ private:
size_t m_K_blk = 0; size_t m_K_blk = 0;
size_t m_N_blk = 0; size_t m_N_blk = 0;
size_t m_brgemmVNNIFactor = 1;
}; };
} // namespace intel_cpu } // namespace intel_cpu

View File

@ -7,6 +7,7 @@
#include "snippets/utils.hpp" #include "snippets/utils.hpp"
#include "snippets/lowered/port_descriptor.hpp" #include "snippets/lowered/port_descriptor.hpp"
#include "utils/general_utils.h" #include "utils/general_utils.h"
#include "snippets/utils.hpp"
namespace ov { namespace ov {
@ -78,19 +79,19 @@ void BrgemmCPU::custom_constructor_validate_and_infer_types(std::vector<size_t>
// So we use port descs from source inputs // So we use port descs from source inputs
const auto brgemm_copy = is_with_data_repacking() ? get_brgemm_copy() : nullptr; const auto brgemm_copy = is_with_data_repacking() ? get_brgemm_copy() : nullptr;
const auto planar_input_shapes = const auto planar_input_shapes =
std::vector<ov::PartialShape>{ snippets::utils::get_reordered_planar_shape(get_input_partial_shape(0), layout_a), std::vector<ov::PartialShape>{ snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_a),
brgemm_copy ? snippets::utils::get_port_planar_shape(brgemm_copy->input(0)) brgemm_copy ? snippets::utils::get_planar_pshape(brgemm_copy->input(0))
: snippets::utils::get_reordered_planar_shape(get_input_partial_shape(1), layout_b) }; : snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) };
auto output_shape = get_output_partial_shape(planar_input_shapes); auto output_shape = get_output_partial_shape(planar_input_shapes);
set_output_type(0, get_output_type(), snippets::utils::get_reordered_planar_shape(output_shape, layout_c)); set_output_type(0, get_output_type(), snippets::utils::get_planar_pshape(output_shape, layout_c));
// Additional check for 3rd input // Additional check for 3rd input
validate_with_scratchpad(planar_input_shapes[1].get_shape()); validate_with_scratchpad(planar_input_shapes[1].get_shape());
} }
void BrgemmCPU::compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) { void BrgemmCPU::compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) {
const auto input_shape_0 = snippets::utils::get_port_planar_shape(input(0)).get_shape(); const auto input_shape_0 = snippets::utils::get_planar_pshape(input(0)).get_shape();
const auto input_shape_1 = snippets::utils::get_port_planar_shape(input(1)).get_shape(); const auto input_shape_1 = snippets::utils::get_planar_pshape(input(1)).get_shape();
m_M_blk = blk_size_m != 0 ? blk_size_m : *(input_shape_0.rbegin() + 1); m_M_blk = blk_size_m != 0 ? blk_size_m : *(input_shape_0.rbegin() + 1);
m_K_blk = blk_size_k != 0 ? blk_size_k : *input_shape_0.rbegin(); m_K_blk = blk_size_k != 0 ? blk_size_k : *input_shape_0.rbegin();
m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape_1.rbegin(); m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape_1.rbegin();
@ -180,5 +181,13 @@ size_t BrgemmCPU::get_offset_scratch() const {
return get_input_offset(2); return get_input_offset(2);
} }
BrgemmCPU::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) : Brgemm::ShapeInfer(n) {
const auto& brg = ov::as_type_ptr<BrgemmCPU>(n);
OPENVINO_ASSERT(brg, "Got invalid node in BrgemmCPU::ShapeInfer");
const auto brgemm_copy = brg->is_with_data_repacking() ? brg->get_brgemm_copy() : nullptr;
if (brgemm_copy)
m_io_layouts[1] = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(brgemm_copy->input(0))->get_layout();
}
} // namespace intel_cpu } // namespace intel_cpu
} // namespace ov } // namespace ov

View File

@ -69,6 +69,12 @@ public:
constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024; constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024;
class ShapeInfer : public Brgemm::ShapeInfer {
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
};
private: private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c); void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c);
void compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n); void compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n);

View File

@ -62,8 +62,8 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
const auto& brgemm_in1_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->input(1)); const auto& brgemm_in1_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->input(1));
const auto& brgemm_out_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->output(0)); const auto& brgemm_out_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->output(0));
const auto dimsMatMulIn0 = snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape(); const auto dimsMatMulIn0 = snippets::utils::get_planar_pshape(brgemm->input_value(0)).get_shape();
const auto dimsMatMulIn1 = snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape(); const auto dimsMatMulIn1 = snippets::utils::get_planar_pshape(brgemm->input_value(1)).get_shape();
const auto K = *dimsMatMulIn0.rbegin(); const auto K = *dimsMatMulIn0.rbegin();
const auto N = *dimsMatMulIn1.rbegin(); const auto N = *dimsMatMulIn1.rbegin();

View File

@ -47,8 +47,8 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
return false; return false;
} }
const auto dimsMatMulIn0 = snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape(); const auto dimsMatMulIn0 = snippets::utils::get_planar_pshape(brgemm->input_value(0)).get_shape();
const auto dimsMatMulIn1 = snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape(); const auto dimsMatMulIn1 = snippets::utils::get_planar_pshape(brgemm->input_value(1)).get_shape();
const auto K = *dimsMatMulIn0.rbegin(); const auto K = *dimsMatMulIn0.rbegin();
const auto N = *dimsMatMulIn1.rbegin(); const auto N = *dimsMatMulIn1.rbegin();

View File

@ -0,0 +1,47 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shape_inference.hpp"
#include <snippets/shape_inference/shape_infer_instances.hpp>
#include "op/brgemm_copy_b.hpp"
#include "op/brgemm_cpu.hpp"
#include "op/fused_mul_add.hpp"
#include "op/load_convert.hpp"
#include "op/store_convert.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
namespace ov {
namespace snippets {
using ShapeInferPtr = IShapeInferSnippetsFactory::ShapeInferPtr;
ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key,
const std::shared_ptr<ov::Node>& op) const {
const auto& maker_iter = specific_ops_registry.find(key);
if (maker_iter != specific_ops_registry.end())
return maker_iter->second(op);
return {};
}
#define SHAPE_INFER_PREDEFINED(OP, InferType) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>();} }
#define SHAPE_INFER_OP_SPECIFIC(OP) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<OP::ShapeInfer>(n);} }
const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry {
SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::SwishNode, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertSaturation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer),
//
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB),
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCPU),
};
#undef SHAPE_INFER_OP_SPECIFIC
#undef SHAPE_INFER_PREDEFINED
} // namespace snippets
} // namespace ov

View File

@ -0,0 +1,28 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <snippets/shape_inference/shape_inference.hpp>
namespace ov {
namespace snippets {
/**
* \brief Shape infer factory that can create shape-infer instances for cpu-specific operations
*/
class CPUShapeInferSnippetsFactory : public IShapeInferSnippetsFactory{
/** \brief Factory makers registry which can be specialized for key and value. */
static const TRegistry specific_ops_registry;
protected:
/**
* @brief get shape infer instances for operations from backend-specific opset
* @return register ShapeInferPtr
*/
ShapeInferPtr get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key, const std::shared_ptr<ov::Node>& op) const override;
};
} // namespace snippets
} // namespace ov