diff --git a/src/common/snippets/include/snippets/lowered/expression.hpp b/src/common/snippets/include/snippets/lowered/expression.hpp index 05584347c45..c5a1b2b8cb6 100644 --- a/src/common/snippets/include/snippets/lowered/expression.hpp +++ b/src/common/snippets/include/snippets/lowered/expression.hpp @@ -11,13 +11,14 @@ #include "snippets/lowered/port_connector.hpp" #include "snippets/lowered/expression_port.hpp" +#include "snippets/shape_inference/shape_inference.hpp" namespace ov { namespace snippets { namespace lowered { class LinearIR; - +using ExpressionPtr = std::shared_ptr; class Expression : public std::enable_shared_from_this { friend class LinearIR; friend class ExpressionPort; @@ -49,6 +50,8 @@ public: ExpressionPort get_input_port(size_t i); ExpressionPort get_output_port(size_t i); + void updateShapes(); + virtual bool needShapeInfer() const {return true; } std::vector get_loop_ids() const; void set_loop_ids(const std::vector& loops); @@ -56,7 +59,7 @@ public: protected: // Note: The constructor initialization is private since an expression can be created only by Linear IR. // The method must be used only by Linear IR builder of expressions! - explicit Expression(const std::shared_ptr& n); + Expression(const std::shared_ptr& n, const std::shared_ptr& factory); void replace_input(size_t port, PortConnectorPtr to); @@ -68,7 +71,8 @@ protected: std::vector m_output_port_descriptors{}; // The order Loops identifies: Outer ---> Inner // Note: The loops with the same dimension index (splitted dimension) should be successively nested - std::vector m_loop_ids; + std::vector m_loop_ids{}; + std::shared_ptr m_shapeInference{nullptr}; }; using ExpressionPtr = std::shared_ptr; @@ -80,10 +84,11 @@ public: int64_t get_index() const { return m_index; } io_type get_type() const { return m_type; } - + // Result needs shapeInfer to copy shape from Parent's output to this expr input + bool needShapeInfer() const override {return m_type == io_type::OUTPUT; } private: - explicit IOExpression(const std::shared_ptr& n, int64_t index); - explicit IOExpression(const std::shared_ptr& n, int64_t index); + explicit IOExpression(const std::shared_ptr& n, int64_t index, const std::shared_ptr& factory); + explicit IOExpression(const std::shared_ptr& n, int64_t index, const std::shared_ptr& factory); int64_t m_index = -1; io_type m_type = io_type::UNDEFINED; diff --git a/src/common/snippets/include/snippets/lowered/expression_factory.hpp b/src/common/snippets/include/snippets/lowered/expression_factory.hpp index 947bbd3c823..bb238356dfa 100644 --- a/src/common/snippets/include/snippets/lowered/expression_factory.hpp +++ b/src/common/snippets/include/snippets/lowered/expression_factory.hpp @@ -38,9 +38,9 @@ private: const std::shared_ptr& model); /* -- Input Builders - get input port connectors from method parameters and create new output port connectors themselves */ - static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs); - static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs); - static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs); + static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs, const LinearIR& linear_ir); + static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs, const LinearIR& linear_ir); + static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs, const LinearIR& linear_ir); // Creates inputs for expression using parent output port connectors static void create_expression_inputs(const LinearIR& linear_ir, const ExpressionPtr& expr); diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 5e14e308c40..fce7a80d964 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -8,6 +8,7 @@ #include "expression.hpp" #include "snippets/target_machine.hpp" +#include "snippets/shape_inference/shape_inference.hpp" namespace ov { namespace snippets { @@ -36,7 +37,7 @@ public: using constExprReverseIt = container::const_reverse_iterator; LinearIR() = default; - explicit LinearIR(const std::shared_ptr& m, Config config = {}); + LinearIR(const std::shared_ptr& m, const std::shared_ptr& factory, Config config = {}); ExpressionPtr create_expression(const std::shared_ptr& n, const std::vector& inputs); @@ -96,12 +97,13 @@ public: iterator find_after(iterator it, const ExpressionPtr& target) const; void init_emitters(const std::shared_ptr& target); - void serialize(const std::string& xml, const std::string& bin); + void serialize(const std::string& xml, const std::string& bin) const; class LoopManager; using LoopManagerPtr = std::shared_ptr; const LoopManagerPtr& get_loop_manager() const { return m_loop_manager; } + const std::shared_ptr& get_shape_infer_factory() { return m_shape_infer_factory; } private: static ov::NodeVector get_ordered_ops(const std::shared_ptr& model); @@ -116,6 +118,7 @@ private: io_container m_io_expressions; Config m_config{}; LoopManagerPtr m_loop_manager = nullptr; + std::shared_ptr m_shape_infer_factory; }; template diff --git a/src/common/snippets/include/snippets/lowered/port_descriptor.hpp b/src/common/snippets/include/snippets/lowered/port_descriptor.hpp index 8a070cac31c..ce3e0c641f2 100644 --- a/src/common/snippets/include/snippets/lowered/port_descriptor.hpp +++ b/src/common/snippets/include/snippets/lowered/port_descriptor.hpp @@ -6,6 +6,7 @@ #include "openvino/core/node.hpp" #include "openvino/core/attribute_visitor.hpp" +#include "snippets/shape_types.hpp" namespace ov { @@ -23,28 +24,28 @@ public: }; explicit PortDescriptor(const ov::Input& node, - std::vector subtensor_shape = {}, + VectorDims subtensor_shape = {}, std::vector layout = {}); explicit PortDescriptor(const ov::Input& node, - std::vector subtensor_shape = {}, + VectorDims subtensor_shape = {}, std::vector layout = {}); explicit PortDescriptor(const ov::Output& node, - std::vector subtensor_shape = {}, + VectorDims subtensor_shape = {}, std::vector layout = {}); explicit PortDescriptor(const ov::Output& node, - std::vector subtensor_shape = {}, + VectorDims subtensor_shape = {}, std::vector layout = {}); - PortDescriptor(std::vector shape, std::vector subtensor_shape, std::vector layout = {}); + PortDescriptor(VectorDims shape, VectorDims subtensor_shape, std::vector layout = {}); PortDescriptor() = default; - std::vector get_shape() const {return m_tensor_shape;} - std::vector get_subtensor() const {return m_subtensor_shape;} + VectorDims get_shape() const {return m_tensor_shape;} + VectorDims get_subtensor() const {return m_subtensor_shape;} std::vector get_layout() const {return m_layout;} size_t get_reg() const { return m_reg; } - void set_shape(const std::vector& tensor) { m_tensor_shape = tensor; } + void set_shape(const VectorDims& tensor) { m_tensor_shape = tensor; } void set_layout(const std::vector& layout) { m_layout = layout; } - void set_subtensor(const std::vector& subtensor) { m_subtensor_shape = subtensor; } + void set_subtensor(const VectorDims& subtensor) { m_subtensor_shape = subtensor; } void set_reg(size_t reg) { m_reg = reg; } std::string serialize() const; @@ -57,11 +58,11 @@ public: private: void validate_arguments(); /// \brief Original tensor shape - std::vector m_tensor_shape{}; + VectorDims m_tensor_shape{}; /// \brief Order of dimensions: NCHW == {0, 1, 2, 3}, NHWC == {0, 2, 3, 1}, NCHW16c == {0, 1, 2, 3, 1} std::vector m_layout{}; /// \brief Minimal tensor size that could be processed in one call - std::vector m_subtensor_shape{}; + VectorDims m_subtensor_shape{}; /// \brief The corresponding abstract/physical register size_t m_reg = 0; }; diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index c6c73af44e1..50cca60bbbc 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -6,6 +6,7 @@ #include "openvino/op/op.hpp" #include "memory_access.hpp" +#include "snippets/shape_inference/shape_inference.hpp" namespace ov { namespace snippets { @@ -38,6 +39,14 @@ public: bool has_evaluate() const override { return false; } + class ShapeInfer : public IShapeInferSnippets { + protected: + std::vector> m_io_layouts; + public: + explicit ShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; + }; + protected: ov::element::Type get_output_type() const; std::vector get_planar_input_shapes(const std::vector>& inputs) const; diff --git a/src/common/snippets/include/snippets/op/broadcastload.hpp b/src/common/snippets/include/snippets/op/broadcastload.hpp index 337c698dbd6..540be423bb8 100644 --- a/src/common/snippets/include/snippets/op/broadcastload.hpp +++ b/src/common/snippets/include/snippets/op/broadcastload.hpp @@ -4,8 +4,8 @@ #pragma once +#include "snippets/shape_inference/shape_infer_instances.hpp" #include - #include "openvino/op/op.hpp" namespace ov { @@ -29,7 +29,15 @@ public: bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; + ov::PartialShape get_output_shape() {return output_shape;} + // Note:BroadcastMove and BroadcastLoad are implemented as separate classes, + // but have identical shapeInfer semantics. In order to avoid code duplication, + // we created dummy ShapeInfer classes that are essentially instantiations + // of a common ShapeInfer class template; + struct ShapeInfer : public BroadcastShapeInfer { + explicit ShapeInfer(const std::shared_ptr& n) : BroadcastShapeInfer(n) {} + }; private: ov::PartialShape output_shape; }; diff --git a/src/common/snippets/include/snippets/op/broadcastmove.hpp b/src/common/snippets/include/snippets/op/broadcastmove.hpp index ac369d2419b..d915fbc2863 100644 --- a/src/common/snippets/include/snippets/op/broadcastmove.hpp +++ b/src/common/snippets/include/snippets/op/broadcastmove.hpp @@ -5,6 +5,7 @@ #pragma once #include "openvino/op/op.hpp" +#include "snippets/shape_inference/shape_infer_instances.hpp" namespace ov { namespace snippets { @@ -27,7 +28,14 @@ public: std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; - + ov::PartialShape get_output_shape() {return output_shape;} + // Note:BroadcastMove and BroadcastLoad are implemented as separate classes, + // but have identical shapeInfer semantics. In order to avoid code duplication, + // we created dummy ShapeInfer classes that are essentially instantiations + // of a common ShapeInfer class template; + struct ShapeInfer : public BroadcastShapeInfer { + explicit ShapeInfer(const std::shared_ptr& n) : BroadcastShapeInfer(n) {} + }; protected: ov::PartialShape output_shape; diff --git a/src/common/snippets/include/snippets/op/load.hpp b/src/common/snippets/include/snippets/op/load.hpp index a10d7c5ca16..abd93d82648 100644 --- a/src/common/snippets/include/snippets/op/load.hpp +++ b/src/common/snippets/include/snippets/op/load.hpp @@ -6,6 +6,7 @@ #include "openvino/op/op.hpp" #include "snippets/op/memory_access.hpp" +#include "snippets/shape_inference/shape_inference.hpp" namespace ov { namespace snippets { @@ -58,7 +59,15 @@ public: std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; -private: + class ShapeInfer : public IShapeInferSnippets { + std::vector m_order; + public: + explicit ShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; + }; + + +protected: std::vector m_order; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index f192d58a498..a357c52266b 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -11,6 +11,7 @@ #include "openvino/op/op.hpp" #include "openvino/core/rt_info.hpp" #include "snippets/pass_manager.hpp" +#include "snippets/shape_inference/shape_inference.hpp" #include "snippets/generator.hpp" @@ -102,17 +103,21 @@ public: const std::vector& data_flow_passes, const lowered::pass::PassPipeline& control_flow_passes_pre_common, const lowered::pass::PassPipeline& control_flow_passes_post_common, + const std::shared_ptr& shape_infer_factory = nullptr, const void* compile_params = nullptr); snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr); snippets::Schedule generate(const std::vector& data_flow_passes, const lowered::pass::PassPipeline& control_flow_passes_pre_common, const lowered::pass::PassPipeline& control_flow_passes_post_common, + const std::shared_ptr& shape_infer_factory = nullptr, const void* compile_params = nullptr); snippets::Schedule generate(const void* compile_params = nullptr); + ov::PartialShape canonicalize(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes); ov::PartialShape canonicalized_body_shape_infer(const BlockedShapeVector& input_shapes); std::vector reshape_body(const std::vector& input_shapes); std::vector reshape_body(const std::vector& input_shapes); + IShapeInferSnippets::Result shape_infer(const std::vector& input_shapes); // plugin sets generator for a snippet to some specific generator. // it's going to be replaced with Jitters table later @@ -121,7 +126,6 @@ public: void set_virtual_port_count(const size_t count); void print() const; - void print_statistics(bool verbose); void serialize() const; void set_master_shape(ov::PartialShape new_shape) {master_shape = std::move(new_shape);} @@ -137,6 +141,8 @@ public: // Return estimated unique buffer count (upper bound). It's needed for tokenization static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t; static auto is_domain_sensitive_op(const std::shared_ptr& op) -> bool; + std::shared_ptr + convert_body_to_linear_ir(const std::shared_ptr& shape_infer_factory = std::make_shared()) const; private: void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes); @@ -158,6 +164,7 @@ private: size_t tileRank = 0; // set by plugin to specify the number of dimensions processed in a single kernel call size_t maxInputRank = 0; std::vector appendOnesForCanonical; + std::shared_ptr m_linear_ir = nullptr; /** * @interface SubgraphConfig @@ -172,27 +179,49 @@ private: // (e.g. Transpose, Softmax, MatMul in general doesn't support dimensions collapsing) bool m_has_domain_sensitive_ops = false; } config; + + class ShapeInferSnippetsNode : public IShapeInferSnippets { + public: + const Result& get_last_result() {return m_last_result; } + protected: + Result m_last_result{{}, ShapeInferStatus::success}; + }; + + std::shared_ptr m_shape_infer = nullptr; + + class NgraphShapeInfer : public ShapeInferSnippetsNode { + std::shared_ptr m_ngraph_body; + ParameterVector m_parameters; + ResultVector m_results; + public: + explicit NgraphShapeInfer(const std::shared_ptr& body); + Result infer(const std::vector& input_shapes) override; + }; + class LIRShapeInfer : public ShapeInferSnippetsNode { + using IOExpression = lowered::IOExpression; + std::shared_ptr m_lir_body; + std::vector> m_param_exprs; + std::vector> m_result_exprs; + public: + explicit LIRShapeInfer(const std::shared_ptr& body); + Result infer(const std::vector& input_shapes) override; + }; }; -static inline std::ostream& operator<<(std::ostream& os, const op::Subgraph::BlockedShape& blocked_shape) { - os << std::get<0>(blocked_shape) << " " << std::get<1>(blocked_shape) << " " << std::get<2>(blocked_shape); - return os; -} - -static inline auto create_body(std::string name, const ov::ResultVector& results, const ov::ParameterVector& parameters) -> +static inline auto create_body(const std::string& name, const ov::ResultVector& results, const ov::ParameterVector& parameters) -> std::shared_ptr { auto body = std::make_shared(results, parameters, name); return body; -}; +} static inline auto build_subgraph(const std::shared_ptr& node, const ov::OutputVector& inputs, - const std::shared_ptr& body, const std::string name = "") + const std::shared_ptr& body, const std::string& name = "") -> std::shared_ptr{ auto subgraph = std::make_shared(inputs, body); copy_runtime_info(node, subgraph); subgraph->set_friendly_name(name.empty() ? node->get_friendly_name() : name); return subgraph; -}; +} // Need to update tensor name manually, since intel_cpu::Graph::Replicate() looks at input.get_shape().get_name(); // If subgraph->get_output_size() == 1, then the name will be restored correctly from the node name diff --git a/src/common/snippets/include/snippets/op/vector_buffer.hpp b/src/common/snippets/include/snippets/op/vector_buffer.hpp index c5ff01e7b4d..4bd1fb1cf93 100644 --- a/src/common/snippets/include/snippets/op/vector_buffer.hpp +++ b/src/common/snippets/include/snippets/op/vector_buffer.hpp @@ -5,6 +5,7 @@ #pragma once #include "openvino/op/op.hpp" +#include "snippets/shape_inference/shape_inference.hpp" namespace ov { namespace snippets { diff --git a/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp b/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp new file mode 100644 index 00000000000..0ca0668111b --- /dev/null +++ b/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shape_inference.hpp" + +namespace ov { +namespace snippets { +class NumpyBroadcastShapeInfer : public IShapeInferSnippets { +public: + Result infer(const std::vector& input_shapes) override; +}; + + +template +class BroadcastShapeInfer : public IShapeInferSnippets { + VectorDims::value_type m_broadcasted_dim; +public: + explicit BroadcastShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; +}; + +class PassThroughShapeInfer : public IShapeInferSnippets { +public: + inline Result infer(const std::vector& input_shapes) override { + OPENVINO_ASSERT(!input_shapes.empty(), "Empty Input shapes are not allowed for PassThroughShapeInfer"); + return {{input_shapes[0].get()}, ShapeInferStatus::success}; + } +}; + +class EmptyShapeInfer : public IShapeInferSnippets { +public: + inline Result infer(const std::vector& input_shapes) override { + return {{}, ShapeInferStatus::success}; + } +}; + +class SingleElementShapeInfer : public IShapeInferSnippets { +public: + inline Result infer(const std::vector& input_shapes) override { + return {{{1}}, ShapeInferStatus::success}; + } +}; + +class SelectShapeInfer : public IShapeInferSnippets { + ov::op::AutoBroadcastSpec m_broadcast_spec; +public: + explicit SelectShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; +}; + +class HorizonOpShapeInfer : public IShapeInferSnippets { +public: + Result infer(const std::vector& input_shapes) override; +}; + +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/include/snippets/shape_inference/shape_inference.hpp b/src/common/snippets/include/snippets/shape_inference/shape_inference.hpp new file mode 100644 index 00000000000..719ca69610e --- /dev/null +++ b/src/common/snippets/include/snippets/shape_inference/shape_inference.hpp @@ -0,0 +1,72 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "snippets/shape_types.hpp" + +namespace ov { +namespace snippets { + +enum class ShapeInferStatus { + success, ///< shapes were successfully calculated + skip ///< shape inference was skipped. +}; +/** + * This is Snippets specific shape inference interface. + * + */ +class IShapeInferSnippets { +public: + enum {DYNAMIC_DIMENSION = 0xffffffffffffffff}; + struct Result { + std::vector dims; + ShapeInferStatus status; + }; + + virtual ~IShapeInferSnippets() = default; + + /** + * @brief This method actually performs all the necessary shape inference computations + * + * @param input_shapes are the input tensors shapes + * @return Result instance that contains an array of calculated shapes (per each output port) and a status of the shape infer call + */ + virtual Result infer(const std::vector& input_shapes) = 0; +}; + +class IShapeInferSnippetsFactory { +public: + // Helper type to define specific Makers map values. + using ShapeInferPtr = std::shared_ptr; + // Helper type to define specific Makers map type. + using TRegistry = std::unordered_map)>>; + + /** + * \brief Creates the shape inference object. + * + * \param key Key value to get specified shape inference object maker. + * \param args Inference object args. + * + * \return Pointer to shape inference object or nullptr if failed to construct the object. + */ + ShapeInferPtr make(const ov::DiscreteTypeInfo& key, const std::shared_ptr& op); + virtual ~IShapeInferSnippetsFactory() = default; + +private: + /** \brief Factory makers registry which can be specialized for key and value. */ + static const TRegistry registry; + +protected: + /** + * @brief get shape infer instances for operations from backend-specific opset + * @return Pointer to shape inference object or nullptr if failed to construct the object. + */ + virtual ShapeInferPtr get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key, const std::shared_ptr& op) const; +}; +std::shared_ptr make_shape_inference(const std::shared_ptr& op, + const std::shared_ptr& factory); +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/include/snippets/shape_types.hpp b/src/common/snippets/include/snippets/shape_types.hpp new file mode 100644 index 00000000000..089dbea8055 --- /dev/null +++ b/src/common/snippets/include/snippets/shape_types.hpp @@ -0,0 +1,18 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include +namespace ov { +namespace snippets { +/* + * This header file contain declarations of shape-relevant classes used cross several snippets subsystems. + * The main purpose of storing such declarations here is to eliminate false dependencies. For example, + * both PortDescriptor and IShapeInferSnippets use VectorDims, but these two classes are completely independent semantically. + */ +using VectorDims = std::vector; +using VectorDimsRef = std::reference_wrapper; + +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index 8217e252d3e..00badfb0469 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -10,6 +10,7 @@ #include "snippets_isa.hpp" #include "emitter.hpp" +#include "shape_types.hpp" namespace ov { @@ -24,9 +25,11 @@ inline auto is_scalar_constant(const std::shared_ptr& source_output_no return ov::is_type(source_output_node) && ov::shape_size(source_output_node->get_shape()) == 1; } -ov::PartialShape get_port_planar_shape(const Input& out); -ov::PartialShape get_port_planar_shape(const Output& out); -ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector& layout); +ov::PartialShape get_planar_pshape(const Input& out); +ov::PartialShape get_planar_pshape(const Output& out); +ov::PartialShape get_planar_pshape(const ov::PartialShape& shape, const std::vector& layout); +VectorDims pshape_to_vdims(const PartialShape&); +ov::PartialShape vdims_to_pshape(const VectorDims&); inline auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t { return allocation_rank < 0 ? allocation_rank + static_cast(shape_rank) + 1 : allocation_rank; @@ -47,6 +50,11 @@ template constexpr bool everyone_is(T val, P item, Args... item_others) { return val == item && everyone_is(val, item_others...); } + +VectorDims get_planar_vdims(const VectorDims& shape, const std::vector& layout); +VectorDims get_planar_vdims(const snippets::lowered::PortDescriptorPtr& port_desc); +VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port); + } // namespace utils } // namespace snippets } // namespace ov diff --git a/src/common/snippets/src/lowered/expression.cpp b/src/common/snippets/src/lowered/expression.cpp index cbd50f935df..53ab34049bd 100644 --- a/src/common/snippets/src/lowered/expression.cpp +++ b/src/common/snippets/src/lowered/expression.cpp @@ -14,8 +14,9 @@ namespace ov { namespace snippets { namespace lowered { -Expression::Expression(const std::shared_ptr& n) - : m_source_node{n}, m_emitter{nullptr}, m_input_port_connectors{}, m_output_port_connectors{} { +Expression::Expression(const std::shared_ptr& n, const std::shared_ptr& factory) + : m_source_node{n}, m_emitter{nullptr}, + m_input_port_connectors{}, m_output_port_connectors{}, m_shapeInference(make_shape_inference(n, factory)) { m_input_port_descriptors.reserve(n->get_input_size()); m_output_port_descriptors.reserve(n->get_output_size()); for (const auto& input : n->inputs()) { @@ -110,10 +111,38 @@ ExpressionPort Expression::get_output_port(size_t i) { return ExpressionPort(this->shared_from_this(), ExpressionPort::Type::Output, i); } -IOExpression::IOExpression(const std::shared_ptr& par, int64_t index) - : Expression(par), m_index(index), m_type{io_type::INPUT} {} -IOExpression::IOExpression(const std::shared_ptr& res, int64_t index) - : Expression(res), m_index(index), m_type{io_type::OUTPUT} {} +void Expression::updateShapes() { + IShapeInferSnippets::Result result; + try { + std::vector input_shapes; + + const auto& in_connectors = get_input_port_connectors(); + const auto& in_descriptors = get_input_port_descriptors(); + + input_shapes.reserve(in_connectors.size()); + for (size_t i = 0; i < in_connectors.size(); i++) { + const auto& src_port = in_connectors[i]->get_source(); + const auto i_shape = src_port.get_descriptor_ptr()->get_shape(); + // todo: do we really need to store the same shape twice in parent's out_port_desc and this in_port_descs + in_descriptors[i]->set_shape(i_shape); + input_shapes.emplace_back(i_shape); + } + + result = m_shapeInference->infer(input_shapes); + } + catch (const std::exception& exp) { + OPENVINO_THROW("Shape inference of " + (get_node()->get_friendly_name()) + " failed: " + exp.what()); + } + const auto& out_descriptors = get_output_port_descriptors(); + OPENVINO_ASSERT(result.dims.size() == out_descriptors.size(), "shapeInference call returned invalid number of output shapes"); + for (size_t i = 0; i < out_descriptors.size(); i++) + out_descriptors[i]->set_shape(result.dims[i]); +} + +IOExpression::IOExpression(const std::shared_ptr& par, int64_t index, const std::shared_ptr& factory) + : Expression(par, factory), m_index(index), m_type{io_type::INPUT} {} +IOExpression::IOExpression(const std::shared_ptr& res, int64_t index, const std::shared_ptr& factory) + : Expression(res, factory), m_index(index), m_type{io_type::OUTPUT} {} }// namespace lowered }// namespace snippets diff --git a/src/common/snippets/src/lowered/expression_factory.cpp b/src/common/snippets/src/lowered/expression_factory.cpp index 224dcb86830..34651fd6dbb 100644 --- a/src/common/snippets/src/lowered/expression_factory.cpp +++ b/src/common/snippets/src/lowered/expression_factory.cpp @@ -57,7 +57,7 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& model) { // Note: ctor of shared_ptr isn't friend class for Expression -> we cannot use directly make_shared(args) OPENVINO_ASSERT(model != nullptr, "To create IOExpression from Parameter there must be inited model!"); - auto expr = std::shared_ptr(new IOExpression(par, model->get_parameter_index(par))); + auto expr = std::shared_ptr(new IOExpression(par, model->get_parameter_index(par), linear_ir.m_shape_infer_factory)); create_expression_outputs(expr); expr->validate(); return expr; @@ -67,7 +67,7 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& model) { // Note: ctor of shared_ptr isn't friend class for Expression -> we cannot use directly make_shared(args) OPENVINO_ASSERT(model != nullptr, "To create IOExpression from Result there must be inited model!"); - auto expr = std::shared_ptr(new IOExpression(res, model->get_result_index(res))); + auto expr = std::shared_ptr(new IOExpression(res, model->get_result_index(res), linear_ir.m_shape_infer_factory)); create_expression_inputs(linear_ir, expr); // The Result node don't need output port (because of sense of the node). But each node in ngraph must have one output at least. // The port descriptors are automatically created in constructor. We manually clean output ports. @@ -80,24 +80,28 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& model) { OPENVINO_ASSERT(!ov::is_type(n), "Default expression builder doesn't support LoopBegin and LoopEnd"); // Note: ctor of shared_ptr isn't friend class for Expression - auto expr = std::shared_ptr(new Expression(n)); + auto expr = std::shared_ptr(new Expression(n, linear_ir.m_shape_infer_factory)); create_expression_inputs(linear_ir, expr); create_expression_outputs(expr); expr->validate(); return expr; } -ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& n, const std::vector& inputs) { +ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& n, + const std::vector& inputs, + const LinearIR& linear_ir) { OPENVINO_ASSERT(inputs.empty(), "LoopBegin cannot have inputs"); - auto expr = std::make_shared(Expression(n)); + auto expr = std::make_shared(Expression(n, linear_ir.m_shape_infer_factory)); init_expression_inputs(expr, inputs); create_expression_outputs(expr); expr->validate(); return expr; } -ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& n, const std::vector& inputs) { - auto expr = std::shared_ptr(new Expression(n)); +ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& n, + const std::vector& inputs, + const LinearIR& linear_ir) { + auto expr = std::shared_ptr(new Expression(n, linear_ir.m_shape_infer_factory)); expr->m_input_port_descriptors.resize(inputs.size(), nullptr); for (size_t i = 0; i < inputs.size() - 1; ++i) { expr->m_input_port_descriptors[i] = std::make_shared(); @@ -113,14 +117,20 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& n, const std::vector& inputs) { +ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr& n, + const std::vector& inputs, + const LinearIR& linear_ir) { OPENVINO_ASSERT(!ov::is_type(n) && !ov::is_type(n), "Expression builder with inputs doesn't support Result and Parameter"); - auto expr = std::shared_ptr(new Expression(n)); + auto expr = std::shared_ptr(new Expression(n, linear_ir.m_shape_infer_factory)); init_expression_inputs(expr, inputs); create_expression_outputs(expr); expr->validate(); + // todo: here we blindly synchronize input shapes from parent and child. Remove this when shapes will be stored in + // port connector itself + if (linear_ir.m_shape_infer_factory) + expr->updateShapes(); return expr; } }// namespace lowered diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 42ac45cac33..4f039d837fb 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -18,8 +18,8 @@ namespace ov { namespace snippets { namespace lowered { -LinearIR::LinearIR(const std::shared_ptr& model, Config config) - : m_io_expressions{}, m_config{std::move(config)}, m_loop_manager(std::make_shared()) { +LinearIR::LinearIR(const std::shared_ptr& model, const std::shared_ptr& factory, Config config) + : m_io_expressions{}, m_config{config}, m_loop_manager(std::make_shared()), m_shape_infer_factory(factory) { constExprIt last_param = m_expressions.end(); for (const auto& n : get_ordered_ops(model)) { constExprIt insertion_pos = m_expressions.end(); @@ -48,7 +48,7 @@ ExpressionPtr LinearIR::create_expression(const std::shared_ptr& n, const } ExpressionPtr LinearIR::create_expression(const std::shared_ptr& n, const std::vector& inputs) { - return ExpressionFactory::build(n, inputs); + return ExpressionFactory::build(n, inputs, *this); } ov::NodeVector LinearIR::get_ordered_ops(const std::shared_ptr& m) { @@ -66,7 +66,7 @@ ov::NodeVector LinearIR::get_ordered_ops(const std::shared_ptr& m) { return ov::topological_sort(nodes); } -void LinearIR::serialize(const std::string& xml, const std::string& bin) { +void LinearIR::serialize(const std::string& xml, const std::string& bin) const { auto first_node = std::make_shared(element::f32, Shape{}); first_node->set_friendly_name("Start"); first_node->get_rt_info()["execTimeMcs"] = 0; diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index 90c003b33d8..2bef20bb54e 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -182,7 +182,7 @@ void LinearIR::LoopManager::mark_loop(LinearIR::constExprIt loop_begin_pos, std::vector loop_tensor(loop_depth, 1); for (const auto& exit_point : loop_exit_points) { const auto& desc = exit_point.get_descriptor_ptr(); - const auto shape = utils::get_reordered_planar_shape(ov::PartialShape(desc->get_shape()), desc->get_layout()).get_shape(); + const auto shape = utils::get_planar_vdims(desc); auto subtensor = desc->get_subtensor(); if (subtensor.empty()) { subtensor.resize(loop_depth, 1); diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index 409e956c315..91cbe55ef98 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -37,13 +37,13 @@ ov::Shape compute_allocation_shape(const LinearIR::LoopManagerPtr& loop_manager, const std::vector& parent_loop_ids, const ov::Output& parent_output, const int allocation_rank) { - const auto port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(parent_output); - const auto planar_shape = utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout()); + const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(parent_output); + const auto planar_shape = utils::get_planar_vdims(port); const size_t rank = allocation_rank >= 0 ? std::min(static_cast(allocation_rank), planar_shape.size()) : planar_shape.size(); ov::Shape allocation_shape(rank); for (size_t i = 0; i < rank; ++i) { - *(allocation_shape.rbegin() + i) = (planar_shape.rbegin() + i)->get_length(); + *(allocation_shape.rbegin() + i) = *(planar_shape.rbegin() + i); } if (buffer_loop_ids.empty() || parent_loop_ids.empty()) { diff --git a/src/common/snippets/src/lowered/port_descriptor.cpp b/src/common/snippets/src/lowered/port_descriptor.cpp index 719f77e7a56..96e8c718cc9 100644 --- a/src/common/snippets/src/lowered/port_descriptor.cpp +++ b/src/common/snippets/src/lowered/port_descriptor.cpp @@ -10,17 +10,17 @@ namespace lowered { size_t PortDescriptor::ServiceDimensions::FULL_DIM = SIZE_MAX; -PortDescriptor::PortDescriptor(const ov::Input& in, std::vector subtensor_shape, std::vector layout) +PortDescriptor::PortDescriptor(const ov::Input& in, VectorDims subtensor_shape, std::vector layout) : PortDescriptor(ov::Input(in.get_node(), in.get_index()), std::move(subtensor_shape), std::move(layout)) {} -PortDescriptor::PortDescriptor(const ov::Input& in, std::vector subtensor_shape, std::vector layout) +PortDescriptor::PortDescriptor(const ov::Input& in, VectorDims subtensor_shape, std::vector layout) : PortDescriptor(in.get_shape(), std::move(subtensor_shape), std::move(layout)) {} -PortDescriptor::PortDescriptor(const ov::Output& out, std::vector subtensor_shape, std::vector layout) +PortDescriptor::PortDescriptor(const ov::Output& out, VectorDims subtensor_shape, std::vector layout) : PortDescriptor(ov::Output(out.get_node(), out.get_index()), std::move(subtensor_shape), std::move(layout)) {} -PortDescriptor::PortDescriptor(const ov::Output& out, std::vector subtensor_shape, std::vector layout) +PortDescriptor::PortDescriptor(const ov::Output& out, VectorDims subtensor_shape, std::vector layout) : PortDescriptor(out.get_shape(), std::move(subtensor_shape), std::move(layout)) {} -PortDescriptor::PortDescriptor(std::vector shape, std::vector subtensor_shape, std::vector layout) +PortDescriptor::PortDescriptor(VectorDims shape, VectorDims subtensor_shape, std::vector layout) : m_tensor_shape(std::move(shape)), m_layout(std::move(layout)), m_subtensor_shape(std::move(subtensor_shape)) { validate_arguments(); } diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 6f40256f9bd..1f415a4f64b 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -13,6 +13,23 @@ namespace ov { namespace snippets { namespace op { +namespace { +std::vector get_output_layout(const std::shared_ptr& n) { + const auto& key = lowered::PortDescriptorVectorAttribute::get_type_info_static(); + auto& rt_info = n->get_rt_info(); + const auto& found = rt_info.find(key); + if (found != rt_info.end()) { + const auto& out_descs = found->second.as().outputs; + if (out_descs.size() != n->get_output_size()) + OPENVINO_THROW("Get output port descriptor is failed: incorrect count"); + const auto& port_desc = out_descs[0]; + return port_desc->get_layout(); + } + return {}; +} + +} // namespace + Brgemm::Brgemm(const Output& A, const Output& B, const size_t offset_a, const size_t offset_b, const size_t offset_c, std::vector layout_a, std::vector layout_b, std::vector layout_c) @@ -39,10 +56,10 @@ void Brgemm::custom_constructor_validate_and_infer_types(std::vector lay // During ctor call, Brgemm doesn't know his port descriptors. // So we use explicit layouts from parameters const auto planar_input_shapes = - std::vector{ ov::snippets::utils::get_reordered_planar_shape(get_input_partial_shape(0), layout_a), - ov::snippets::utils::get_reordered_planar_shape(get_input_partial_shape(1), layout_b) }; + std::vector{ ov::snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_a), + ov::snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) }; auto output_shape = get_output_partial_shape(planar_input_shapes); - set_output_type(0, get_output_type(), ov::snippets::utils::get_reordered_planar_shape(output_shape, layout_c)); + set_output_type(0, get_output_type(), ov::snippets::utils::get_planar_pshape(output_shape, layout_c)); } void Brgemm::validate_inputs() const { @@ -97,21 +114,15 @@ ov::element::Type Brgemm::get_output_type() const { std::vector Brgemm::get_planar_input_shapes(const std::vector>& inputs) const { OPENVINO_ASSERT(inputs.size() == 2, "Brgemm::get_planar_input_shapes() expects 2 inputs"); - return { utils::get_port_planar_shape(inputs[0]), utils::get_port_planar_shape(inputs[1]) }; + return {utils::get_planar_pshape(inputs[0]), utils::get_planar_pshape(inputs[1]) }; } ov::PartialShape Brgemm::get_planar_output_shape(const ov::PartialShape& output_shape) const { // This method can be safely called from validate_and_infer_types() before output creation - const auto& key = lowered::PortDescriptorVectorAttribute::get_type_info_static(); - auto& rt_info = get_rt_info(); - const auto& found = rt_info.find(key); - if (found != rt_info.end()) { - const auto& out_descs = found->second.as().outputs; - if (out_descs.size() != get_output_size()) - OPENVINO_THROW("Get output port descriptor is failed: incorrect count"); - const auto& port_desc = out_descs[0]; - return utils::get_reordered_planar_shape(output_shape, port_desc->get_layout()); - } + const auto& out_layout = get_output_layout(shared_from_this()); + if (!out_layout.empty()) + return utils::get_planar_pshape(output_shape, out_layout); + return output_shape; } @@ -178,6 +189,76 @@ ov::PartialShape Brgemm::get_output_partial_shape(const std::vector& n) { + for (const auto& in : n->inputs()) { + const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in); + m_io_layouts.push_back(port->get_layout()); + } + m_io_layouts.push_back(get_output_layout(n)); +} + +IShapeInferSnippets::Result Brgemm::ShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference"); + + // Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it? + const auto& arg0_shape = snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]); + const auto& arg1_shape = snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]); + + size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size(); + + // temporary shapes to calculate output shape + VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape); + + // one-dimensional tensors unsqueezing is applied to each input independently. + if (arg0_rank == 1) { + // If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector) + // by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape. + // For example {S} will be reshaped to {1, S}. + arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1); + arg0_rank = arg0_shape_tmp.size(); + } + if (arg1_rank == 1) { + // If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector) + // by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape. + // For example {S} will be reshaped to {S, 1}. + arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1); + arg1_rank = arg1_shape_tmp.size(); + } + + // add 1 to begin to align shape ranks if needed + if (arg0_rank < arg1_rank) + arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1); + else if (arg0_rank > arg1_rank) + arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1); + + size_t max_rank = arg0_shape_tmp.size(); + VectorDims output_shape(max_rank); + for (size_t i = 0; i < max_rank - 2; ++i) { + if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) { + output_shape[i] = arg0_shape_tmp[i]; + } else { + if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION) + output_shape[i] = arg1_shape_tmp[i]; + else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION) + output_shape[i] = arg0_shape_tmp[i]; + else + OPENVINO_THROW("Incompatible Brgemm batch dimension"); + } + } + output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M + output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N + + // removing the temporary axes from originally 1D tensors. + if (arg0_shape.size() == 1) { + output_shape.erase(output_shape.begin() + output_shape.size() - 2); + } + if (arg1_shape.size() == 1) { + output_shape.erase(output_shape.begin() + output_shape.size() - 1); + } + output_shape = snippets::utils::get_planar_vdims(output_shape, m_io_layouts[2]); + return {{output_shape}, snippets::ShapeInferStatus::success}; +} + } // namespace op } // namespace snippets } // namespace ov diff --git a/src/common/snippets/src/op/load.cpp b/src/common/snippets/src/op/load.cpp index 84dfb000e1c..d1a7d0f2cb5 100644 --- a/src/common/snippets/src/op/load.cpp +++ b/src/common/snippets/src/op/load.cpp @@ -5,6 +5,7 @@ #include "snippets/itt.hpp" #include "snippets/op/load.hpp" +#include "snippets/utils.hpp" namespace ov { @@ -69,6 +70,15 @@ std::shared_ptr LoadReshape::clone_with_new_inputs(const OutputVector& new check_new_args_count(this, new_args); return std::make_shared(new_args.at(0), get_count(), get_offset(), m_order); } +LoadReshape::ShapeInfer::ShapeInfer(const std::shared_ptr& n) { + const auto& loadReshape = ov::as_type_ptr(n); + OPENVINO_ASSERT(loadReshape, "Got invalid node in LoadReshape::ShapeInfer"); + m_order = loadReshape->m_order; +} +IShapeInferSnippets::Result LoadReshape::ShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes"); + return {{utils::get_planar_vdims(input_shapes[0], m_order)}, ShapeInferStatus::success}; +} }// namespace op }// namespace snippets diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 760b4252ae8..5e67d6cad9e 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -44,7 +44,7 @@ #include "transformations/utils/utils.hpp" #include "snippets/pass_manager.hpp" -#include "ngraph/pass/constant_folding.hpp" +#include "openvino/pass/constant_folding.hpp" #include "ov_ops/type_relaxed.hpp" #include @@ -104,7 +104,7 @@ auto Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t { for (const auto& op : ops) { if (const auto transpose = ov::as_type_ptr(op)) { - // At the moment Transposes are supported only on Results and Parameters but + // At the moment Transposes are supported only on Results and Parameters, but // then we should have the different Buffers for Transpose as well (Transpose isn't inplace) const auto consumers = transpose->get_output_target_inputs(0); // If after Transpose there is Result it means that there won't be Buffer after Transpose. @@ -119,7 +119,7 @@ auto Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t { } } else if (ov::is_type(op) || ov::is_type(op)) { // Softmax always uses 2 FP32 Buffers after decomposition. - // They are inplace and the same so we can push precision size only once + // They are inplace and the same, so we can push precision size only once push_prc_size(ov::element::f32.size()); } else if (const auto matmul = ov::as_type_ptr(op)) { // Since all buffers around Matmul must be unique, we explicitely add values to the vector without any checks @@ -195,7 +195,7 @@ void Subgraph::validate_and_infer_types() { INTERNAL_OP_SCOPE(Subgraph); OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::validate_and_infer_types") ov::ParameterVector old_parameters; - for (auto op : body_ptr()->get_parameters()) { + for (const auto& op : body_ptr()->get_parameters()) { old_parameters.push_back(op); } @@ -257,7 +257,7 @@ auto Subgraph::wrap_node_as_subgraph(const std::shared_ptr& node) -> s } ov::ResultVector body_results; - for (auto output : node->outputs()) { + for (const auto& output : node->outputs()) { body_results.push_back(std::make_shared(body_node->output(output.get_index()))); } @@ -469,6 +469,75 @@ bool Subgraph::check_broadcast(const std::shared_ptr& node) noex (elementwise->get_autob().m_type != ov::op::AutoBroadcastType::PDPD); } +IShapeInferSnippets::Result Subgraph::shape_infer(const std::vector& input_shapes) { + if (!m_shape_infer && !m_linear_ir) { + OPENVINO_ASSERT(body_ptr(), "Can't create shape infer for Subgraph with an empty body"); + m_shape_infer = std::make_shared(body_ptr()); + } else if (!std::dynamic_pointer_cast(m_shape_infer) && m_linear_ir) { + m_shape_infer = std::make_shared(m_linear_ir); + } + return m_shape_infer->infer(input_shapes); +} + +Subgraph::NgraphShapeInfer::NgraphShapeInfer(const std::shared_ptr& body) : + m_ngraph_body(body), m_parameters(body->get_parameters()), m_results(body->get_results()) { +} + +IShapeInferSnippets::Result Subgraph::NgraphShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(m_parameters.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body"); + for (size_t i = 0; i < m_parameters.size(); ++i) + m_parameters[i]->set_partial_shape(utils::vdims_to_pshape(input_shapes[i].get())); + m_ngraph_body->validate_nodes_and_infer_types(); + std::vector outputDims; + for (const auto& res : m_results) + outputDims.emplace_back(utils::pshape_to_vdims(res->get_input_partial_shape(0))); + m_last_result = {outputDims, ShapeInferStatus::success}; + return m_last_result; +} + +Subgraph::LIRShapeInfer::LIRShapeInfer(const std::shared_ptr& body) : + m_lir_body(body) { + for (const auto& io_expr : m_lir_body->get_IO_ops()) { + switch (io_expr->get_type()) { + case IOExpression::io_type::INPUT : m_param_exprs.push_back(io_expr); break; + case IOExpression::io_type::OUTPUT : m_result_exprs.push_back(io_expr); break; + default : OPENVINO_THROW("Undefined io expression type"); + } + } +} + +IShapeInferSnippets::Result +Subgraph::LIRShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(m_param_exprs.size() == input_shapes.size(), "Got invalid number of input shapes in LIR ShapeInfer"); + // todo: check that order of param_exprs is always the same as that of input_shapes + // if not use io_expr index to sort in constructor + + for (size_t i = 0; i < m_param_exprs.size(); ++i) { + m_param_exprs[i]->get_output_port_descriptor(0)->set_shape(input_shapes[i]); + } + for (const auto& expr : *m_lir_body) { + if (expr->needShapeInfer()) + expr->updateShapes(); + } + std::vector outputDims; + outputDims.reserve(m_result_exprs.size()); + for (const auto& r : m_result_exprs) { + outputDims.push_back(r->get_input_port_descriptor(0)->get_shape()); + } + m_last_result = {outputDims, ShapeInferStatus::success}; + return m_last_result; +} + +std::shared_ptr +Subgraph::convert_body_to_linear_ir(const std::shared_ptr& shape_infer_factory) const { + lowered::Config lowering_config; + lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops; + lowering_config.m_need_fill_tail_register = config.m_has_domain_sensitive_ops; + lowering_config.m_loop_depth = tileRank; + + return std::make_shared(body_ptr(), shape_infer_factory, lowering_config); +} + void Subgraph::align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes) { // We should insert Convert before Results to set original output element type if needed @@ -632,18 +701,21 @@ snippets::Schedule Subgraph::generate(const BlockedShapeVector& output_shapes, const std::vector& data_flow_passes, const lowered::pass::PassPipeline& control_flow_passes_pre_common, const lowered::pass::PassPipeline& control_flow_passes_post_common, + const std::shared_ptr& shape_infer_factory, const void* compile_params) { canonicalize(output_shapes, input_shapes); - return generate(data_flow_passes, control_flow_passes_pre_common, control_flow_passes_post_common, compile_params); + return generate(data_flow_passes, control_flow_passes_pre_common, control_flow_passes_post_common, + shape_infer_factory, compile_params); } snippets::Schedule Subgraph::generate(const void* compile_params) { - return generate({}, {}, {}, compile_params); + return generate({}, {}, {}, nullptr, compile_params); } snippets::Schedule Subgraph::generate(const std::vector& data_flow_passes, const lowered::pass::PassPipeline& control_flow_passes_pre_common, const lowered::pass::PassPipeline& control_flow_passes_post_common, + const std::shared_ptr& shape_infer_factory, const void* compile_params) { INTERNAL_OP_SCOPE(Subgraph); OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::generate") @@ -651,16 +723,11 @@ snippets::Schedule Subgraph::generate(const std::vectorgenerate(linear_ir, lowering_config, compile_params); + const auto& lowering_result = m_generator->generate(linear_ir, linear_ir.get_config(), compile_params); const auto ptr = lowering_result.binary_code; return {master_shape, false /*canBeLinearized*/, ptr}; @@ -691,66 +758,6 @@ void Subgraph::print() const { } } -void Subgraph::print_statistics(bool verbose) { - INTERNAL_OP_SCOPE(Subgraph); - auto getNodeInventory = [](std::shared_ptr n) -> size_t { - size_t total = 0; - - for (auto input : n->inputs()) { - total += input.get_tensor().size(); - } - - for (auto output : n->outputs()) { - total += output.get_tensor().size(); - } - - if (auto subgraph = ov::as_type_ptr(n)) { - for (auto op : subgraph->body_ptr()->get_ordered_ops()) { - if (ov::as_type_ptr(op)) { - total += op->output(0).get_tensor().size(); - } - } - } - - return total; - }; - - auto getModelInventory = [getNodeInventory](const ov::Model& f) -> size_t { - size_t total = 0; - for (auto op : f.get_ordered_ops()) { - // Results and parameters are artificially introduced, - // while Constants are already considered if they are inputs of other operation - // this should lead to 1:1 inventory for single node operations - if (!ov::as_type_ptr(op) - && !ov::as_type_ptr(op) - && !ov::as_type_ptr(op)) { - total += getNodeInventory(op); - } - } - return total; - }; - - auto countConstants = [](const ov::Model& f) -> size_t { - size_t count = 0; - for (auto op : f.get_ordered_ops()) { - count += !!ov::as_type_ptr(op) ? 1 : 0; - } - return count; - }; - - std::cout << get_friendly_name() - << ";" << this - << ";" << body_ptr()->get_ops().size() - << ";" << body_ptr()->get_parameters().size() - << ";" << body_ptr()->get_results().size() - << ";" << countConstants(body()) - << ";" << getModelInventory(body()) - << ";" << getNodeInventory(shared_from_this()) << std::endl; - - if (verbose) { - this->print(); - } -} void Subgraph::serialize() const { std::stringstream xmlFile, binFile; diff --git a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp new file mode 100644 index 00000000000..b254adbdc64 --- /dev/null +++ b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp @@ -0,0 +1,166 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "snippets/shape_inference/shape_infer_instances.hpp" +#include "snippets/snippets_isa.hpp" +#include "openvino/op/select.hpp" +namespace ov { +namespace snippets { +using Result = IShapeInferSnippets::Result; +namespace { +/* + * Merge SRC to DST with broadcasting rules defined by the Autobroadcast specifier + */ +bool broadcast_merge_into(VectorDims& dst, const VectorDims& src, const ov::op::AutoBroadcastSpec& autob) { + auto broadcast_merge_dim = [](size_t& dst, const size_t& d1, const size_t& d2) { + if (d1 == d2 || d1 == 1 || d1 == IShapeInferSnippets::DYNAMIC_DIMENSION) { + dst = d2; + } else if (d2 == 1 || d2 == IShapeInferSnippets::DYNAMIC_DIMENSION) { + dst = d1; + } else { + return false; + } + return true; + }; + // Ranks are both static. + const auto dst_rank = static_cast(dst.size()); + const auto src_rank = static_cast(src.size()); + switch (autob.m_type) { + case ov::op::AutoBroadcastType::NONE: + return true; + case ov::op::AutoBroadcastType::NUMPY: { + const auto new_rank = std::max(dst_rank, src_rank); + VectorDims dims(new_rank); + bool success = true; + for (int64_t i = 0; i < new_rank; i++) { + auto dsti = i < (new_rank - dst_rank) ? 1 : dst[i - (new_rank - dst_rank)]; + auto srci = i < (new_rank - src_rank) ? 1 : src[i - (new_rank - src_rank)]; + success &= broadcast_merge_dim(dims[i], dsti, srci); + } + dst = std::move(dims); + return success; + } + case ov::op::AutoBroadcastType::PDPD: { + int64_t axis = autob.m_axis; + if (src_rank > dst_rank || axis < -1) + return false; + + axis = (axis == -1) ? (dst_rank - src_rank) : axis; + if (src_rank + axis > dst_rank) + return false; + + bool success = true; + for (int64_t i = 0; i < src_rank; ++i) { + if (dst[axis + i] != IShapeInferSnippets::DYNAMIC_DIMENSION && + src[i] != IShapeInferSnippets::DYNAMIC_DIMENSION) { + if (src[i] > dst[axis + i]) + return false; + } + success &= broadcast_merge_dim(dst[axis + i], dst[axis + i], src[i]); + } + return success; + } + default: + OPENVINO_THROW("Unsupported auto broadcast type: ", autob.m_type); + } + return false; +} +/* + * Merge SRC to DST, no broadcasting is allowed + */ +bool merge_into(VectorDims& dst, const VectorDims& src) { + auto merge_dim = [](size_t& dst, const size_t& d1, const size_t& d2) { + if (d1 == d2 || d1 == IShapeInferSnippets::DYNAMIC_DIMENSION) { + dst = d2; + } else if (d2 == IShapeInferSnippets::DYNAMIC_DIMENSION) { + dst = d1; + } else { + return false; + } + return true; + }; + if (dst.size() != src.size()) + return false; + + bool success = true; + for (size_t i = 0; i < dst.size(); i++) + success &= merge_dim(dst[i], dst[i], src[i]); + return success; +} +} // namespace + +Result NumpyBroadcastShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(!input_shapes.empty(), "No input shapes were provided for NumpyBroadcastShapeInfer"); + auto output_shape = input_shapes[0].get(); + for (size_t i = 1; i < input_shapes.size(); i++) { + OPENVINO_ASSERT(broadcast_merge_into(output_shape, input_shapes[i], ov::op::AutoBroadcastType::NUMPY), + "Failed to broadcast-merge input shapes in NumpyBroadcastShapeInfer"); + } + return {{std::move(output_shape)}, ShapeInferStatus::success}; +} + +template +BroadcastShapeInfer::BroadcastShapeInfer(const std::shared_ptr& n) { + static_assert(std::is_base_of() || + std::is_base_of(), + "This ShapeInfer class could be used only for BroadcastMove and BroadcastLoad operations."); + const auto& broadcast = as_type_ptr(n); + OPENVINO_ASSERT(broadcast, "Invalid node passed to BroadcastShapeInfer.", + "Expected ", typeid(BroadcastOP).name(), "got ", n->get_type_name()); + const auto last_dim = *broadcast->get_output_shape().rbegin(); + m_broadcasted_dim = last_dim.is_dynamic() ? IShapeInferSnippets::DYNAMIC_DIMENSION : last_dim.get_length(); +} +template +Result BroadcastShapeInfer::infer(const std::vector& input_shapes) { + auto out_shape = input_shapes[0].get(); + out_shape.back() = m_broadcasted_dim; + return {{out_shape}, ShapeInferStatus::success}; +} + +//// Note: we need to manually create template instances here, so they can be reused in Broadcast* headers. +template class BroadcastShapeInfer; +template class BroadcastShapeInfer; + +SelectShapeInfer::SelectShapeInfer(const std::shared_ptr& n) { + const auto& select = as_type_ptr(n); + OPENVINO_ASSERT(select, "Invalid node passed to SelectShapeInfer."); + m_broadcast_spec = select->get_auto_broadcast(); +} + +Result SelectShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(input_shapes.size() == 3, "Invalid number of shapes passed SelectShapeInfer"); + VectorDims result_shape; + if (m_broadcast_spec == ov::op::AutoBroadcastType::PDPD) { + result_shape = input_shapes[1]; // 'then' tensor + // in PDPD type, Broadcast-merging 'else' into 'then' one way not each other. + OPENVINO_ASSERT(broadcast_merge_into(result_shape, input_shapes[2], m_broadcast_spec), + "'Else' tensor shape is not broadcastable."); + OPENVINO_ASSERT(broadcast_merge_into(result_shape, input_shapes[0], m_broadcast_spec), + "'Cond' tensor shape is not broadcastable."); + } else { + result_shape = input_shapes[2]; + for (int input_port = 1; input_port >= 0; input_port--) { + if (m_broadcast_spec.m_type == ov::op::AutoBroadcastType::NONE) { + OPENVINO_ASSERT(merge_into(result_shape, input_shapes[input_port]), + "Argument shapes are inconsistent."); + } else if (m_broadcast_spec.m_type == ov::op::AutoBroadcastType::NUMPY) { + OPENVINO_ASSERT(broadcast_merge_into(result_shape, input_shapes[input_port], m_broadcast_spec), + "Argument shapes are inconsistent."); + } else { + OPENVINO_THROW("Unsupported auto broadcast specification"); + } + } + } + return {{result_shape}, ShapeInferStatus::success}; +} + +Result HorizonOpShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(input_shapes.size() == 1, "Got invalid number of input shapes in HorizonShapeInfer"); + auto output_shapes = input_shapes[0].get(); + if (!output_shapes.empty()) + output_shapes.back() = 1; + return {{output_shapes}, ShapeInferStatus::success}; +} + +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/src/shape_inference/shape_inference.cpp b/src/common/snippets/src/shape_inference/shape_inference.cpp new file mode 100644 index 00000000000..bc9534f7b08 --- /dev/null +++ b/src/common/snippets/src/shape_inference/shape_inference.cpp @@ -0,0 +1,90 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "snippets/shape_inference/shape_inference.hpp" +#include "snippets/shape_inference/shape_infer_instances.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace snippets { +using ShapeInferPtr = IShapeInferSnippetsFactory::ShapeInferPtr; + +ShapeInferPtr IShapeInferSnippetsFactory::make(const ov::DiscreteTypeInfo& key, const std::shared_ptr& op) { + const auto& maker_iter = registry.find(key); + if (maker_iter != registry.end()) + return maker_iter->second(op); + return get_specific_op_shape_infer(key, op); +} + +ShapeInferPtr IShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key, + const std::shared_ptr& op) const { + return {}; +} + +#define SHAPE_INFER_PREDEFINED(OP, InferType) \ + { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared();} } +#define SHAPE_INFER_OP_SPECIFIC(OP) \ + { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared(n);} } +#define SHAPE_INFER_OP_SPECIFIC_EXTERNAL(OP, InferType) \ + { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared(n);} } + +const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry { + SHAPE_INFER_PREDEFINED(op::ConvertTruncation, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(op::ConvertSaturation, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(op::Load, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(op::Store, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(op::Buffer, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(op::Fill, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::op::v0::Parameter, PassThroughShapeInfer), + // Note: We should remove Softmax shape infers after the decomposition activity, + // since there won't be any Softmax ops on LIR. Ticket: 112847 + SHAPE_INFER_PREDEFINED(ov::op::v1::Softmax, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::op::v8::Softmax, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::op::v1::LogicalNot, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::op::v0::PRelu, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(op::HorizonMax, HorizonOpShapeInfer), + SHAPE_INFER_PREDEFINED(op::HorizonSum, HorizonOpShapeInfer), + // + SHAPE_INFER_PREDEFINED(op::LoopBegin, SingleElementShapeInfer), + SHAPE_INFER_PREDEFINED(op::Scalar, SingleElementShapeInfer), + SHAPE_INFER_PREDEFINED(op::VectorBuffer, SingleElementShapeInfer), + SHAPE_INFER_PREDEFINED(op::LoopEnd, EmptyShapeInfer), + SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer), + SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer), + // Note that Result has no output PortConnectors, so the shape must be empty + SHAPE_INFER_PREDEFINED(ov::op::v0::Result, EmptyShapeInfer), + // + SHAPE_INFER_OP_SPECIFIC(op::LoadReshape), + SHAPE_INFER_OP_SPECIFIC(op::Brgemm), + SHAPE_INFER_OP_SPECIFIC(op::BroadcastLoad), + SHAPE_INFER_OP_SPECIFIC(op::BroadcastMove), +}; +#undef SHAPE_INFER_OP_SPECIFIC_EXTERNAL +#undef SHAPE_INFER_OP_SPECIFIC +#undef SHAPE_INFER_PREDEFINED + +std::shared_ptr make_shape_inference(const std::shared_ptr& op, + const std::shared_ptr& factory) { + if (!factory) { + return nullptr; + } else if (auto shape_infer = factory->make(op->get_type_info(), op)) { + return shape_infer; + } else if (ov::is_type(op)) { + return std::make_shared(); + } else if (ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op)) { + return std::make_shared(); + } else { + OPENVINO_THROW("Operation type " + std::string(op->get_type_info().name) + " is not supported in Snippets shape inference pipeline"); + } +} + +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/src/utils.cpp b/src/common/snippets/src/utils.cpp index 5b8937e9f0d..621c6c9bf67 100644 --- a/src/common/snippets/src/utils.cpp +++ b/src/common/snippets/src/utils.cpp @@ -70,7 +70,7 @@ auto get_non_scalar_constant_count_for_fq(const std::shared_ptr& layout) { +ov::PartialShape get_planar_pshape(const ov::PartialShape& shape, const std::vector& layout) { if (layout.empty()) return shape; std::vector reordered_shape(layout.size()); @@ -87,14 +87,47 @@ ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const return reordered_shape; } -ov::PartialShape get_port_planar_shape(const Input& in) { - const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in); - return utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout()); +VectorDims pshape_to_vdims(const PartialShape& pshape) { + VectorDims result; + result.reserve(pshape.size()); + for (const auto& d : pshape) + result.push_back(d.is_dynamic() ? IShapeInferSnippets::DYNAMIC_DIMENSION : d.get_length()); + return result; } -ov::PartialShape get_port_planar_shape(const Output& out) { - const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(out); - return utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout()); +ov::PartialShape vdims_to_pshape(const VectorDims& vdims) { + ov::PartialShape result; + result.reserve(vdims.size()); + for (const auto& v : vdims) + result.push_back(v != IShapeInferSnippets::DYNAMIC_DIMENSION ? + Dimension(static_cast(v)) : + Dimension()); + return result; +} + +ov::PartialShape get_planar_pshape(const Input& in) { + const auto& port = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(in); + return utils::get_planar_pshape(ov::Shape{port->get_shape()}, port->get_layout()); +} + +ov::PartialShape get_planar_pshape(const Output& out) { + const auto& port = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(out); + return utils::get_planar_pshape(ov::Shape{port->get_shape()}, port->get_layout()); +} + +VectorDims get_planar_vdims(const VectorDims& shape, const std::vector& layout) { + VectorDims reordered_shape(shape.size()); + for (size_t i = 0; i < layout.size(); i++) + reordered_shape[i] = shape[layout[i]]; + return reordered_shape; +} + +VectorDims get_planar_vdims(const snippets::lowered::PortDescriptorPtr& port_desc) { + return get_planar_vdims(port_desc->get_shape(), port_desc->get_layout()); +} + +VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port) { + return get_planar_vdims(expr_port.get_descriptor_ptr()); } } // namespace utils diff --git a/src/common/snippets/tests/src/pass/lowered/loop.cpp b/src/common/snippets/tests/src/pass/lowered/loop.cpp index a9e1b67c5e3..27b4e3ce95b 100644 --- a/src/common/snippets/tests/src/pass/lowered/loop.cpp +++ b/src/common/snippets/tests/src/pass/lowered/loop.cpp @@ -13,6 +13,7 @@ #include "snippets/lowered/pass/validate_loops.hpp" #include "snippets/lowered/pass/insert_loops.hpp" #include "snippets/lowered/pass/insert_tail_loop.hpp" +#include "snippets/shape_inference/shape_inference.hpp" #include "snippets/op/loop.hpp" @@ -26,7 +27,8 @@ constexpr static size_t vector_size = 16; static void init_linear_ir(const std::vector& in_shapes, LinearIR& linear_ir, size_t block_size) { auto body = ov::test::snippets::AddFunction(in_shapes).getOriginal(); - linear_ir = LinearIR(body); + auto shape_infer_factory = std::make_shared(); + linear_ir = LinearIR(body, shape_infer_factory); auto expr_it = std::find_if(linear_ir.cbegin(), linear_ir.cend(), [](const ExpressionPtr& expr) { return ov::is_type(expr->get_node()); }); ASSERT_TRUE(expr_it != linear_ir.cend()); diff --git a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp index 7eeb701cbfa..a146141c85c 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp @@ -131,8 +131,13 @@ KernelEmitter::KernelEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPt IE_THROW() << "Kernel detected unsupported io_type"; } } - io_shapes.push_back(desc->get_shape()); - io_data_layouts.push_back(desc->get_layout()); + const auto& shape = desc->get_shape(); + const auto& layout = desc->get_layout(); + OPENVINO_ASSERT(shape.size() == layout.size(), "Shape and layout must have the same length"); + const auto max_dim = *std::max_element(layout.begin(), layout.end()); + OPENVINO_ASSERT(max_dim < shape.size(), "Max layout index can't be larger than the shape size"); + io_shapes.push_back(shape); + io_data_layouts.push_back(layout); io_data_sizes.push_back(etype.size()); } diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 15fbcf9dc83..ce0216b4fd5 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -9,13 +9,10 @@ #include #include #include -#include -#include #include #include -#include #include #include #include @@ -31,6 +28,7 @@ #include "transformations/snippets/x64/pass/remove_converts.hpp" #include "transformations/snippets/x64/pass/enforce_precision.hpp" #include "transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.hpp" +#include "transformations/snippets/x64/shape_inference.hpp" #include "transformations/cpu_opset/common/pass/convert_to_swish_cpu.hpp" #include "transformations/defs.hpp" #include "shape_inference/custom/subgraph.hpp" @@ -142,7 +140,7 @@ snippets::op::Subgraph::BlockedShapeVector getBlockedShapes(const std::vector& op, const GraphContext::CPtr context) +Snippet::Snippet(const std::shared_ptr& op, const GraphContext::CPtr& context) : Node(op, context, SnippetShapeInferFactory(op)) { host_isa = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ? dnnl::impl::cpu::x64::avx512_core : dnnl::impl::cpu::x64::avx2; @@ -724,10 +722,12 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp) ov::snippets::lowered::pass::PassPipeline control_flow_pipeline; CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert); - + // Todo: We don't need shape infer factory now, since shape infer will be done through validate_and_infer_types + // pass std::make_shared() instead of nullptr, when shape infer is performed on LIR schedule = snippet_for_generation->generate(backend_passes, control_flow_markup_pipeline, control_flow_pipeline, + nullptr, reinterpret_cast(jcp)); } diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.h b/src/plugins/intel_cpu/src/nodes/subgraph.h index 7d4947cea97..ed706443e68 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.h +++ b/src/plugins/intel_cpu/src/nodes/subgraph.h @@ -24,7 +24,7 @@ namespace node { /// precision: fp32 class Snippet : public Node { public: - Snippet(const std::shared_ptr& op, const GraphContext::CPtr context); + Snippet(const std::shared_ptr& op, const GraphContext::CPtr& context); ~Snippet() override = default; void getSupportedDescriptors() override {}; @@ -33,10 +33,6 @@ public: void initOptimalPrimitiveDescriptor() override; InferenceEngine::Precision getRuntimePrecision() const override; - // to avoid collisions in throughput mode with copy of TypeRelaxed nodes - // we should have common shared mutex between streams - void setSharedMutex(const std::shared_ptr& mutex); - // Here we convert to canonical for & jit everything void prepareParams() override; bool needPrepareParams() const override; diff --git a/src/plugins/intel_cpu/src/shape_inference/custom/subgraph.hpp b/src/plugins/intel_cpu/src/shape_inference/custom/subgraph.hpp index 7c1bad348eb..42c951ce533 100644 --- a/src/plugins/intel_cpu/src/shape_inference/custom/subgraph.hpp +++ b/src/plugins/intel_cpu/src/shape_inference/custom/subgraph.hpp @@ -10,69 +10,19 @@ namespace ov { namespace intel_cpu { namespace node { using Result = IShapeInfer::Result; + class SnippetShapeInfer : public ShapeInferEmptyPads { public: - SnippetShapeInfer(std::shared_ptr body) : m_body(body) {} + explicit SnippetShapeInfer(const std::shared_ptr& s) : m_subgraph(s) { + m_status_map[snippets::ShapeInferStatus::success] = ov::intel_cpu::ShapeInferStatus::success; + m_status_map[snippets::ShapeInferStatus::skip] = ov::intel_cpu::ShapeInferStatus::skip; + } Result infer( - const std::vector>& input_shapes, - const std::unordered_map& data_dependency) override { - auto broadcast_merge = [](VectorDims& dst, const VectorDims& src) { - // Ranks are both static. - auto dst_rank = dst.size(); - auto src_rank = src.size(); - const auto new_rank = std::max(dst_rank, src_rank); - dst.insert(dst.begin(), new_rank - dst_rank, 1); - for (size_t i = 0; i < new_rank; i++) { - auto srci = i < (new_rank - src_rank) ? 1 : src[i - (new_rank - src_rank)]; - if (dst[i] != srci && srci != Shape::UNDEFINED_DIM) { - if (dst[i] == 1 || dst[i] == Shape::UNDEFINED_DIM) { - dst[i] = srci; - } else { - if (srci != 1) { - IE_THROW() << "Got imcompatible input shapes in snippets shape infer"; - } - } - } - } - }; - - const size_t out_size = m_body->get_output_size(); - if (out_size == 1) { - VectorDims masterShape; - for (size_t i = 0; i < input_shapes.size(); i++) { - if (i == 0) - masterShape = input_shapes[i]; - else - broadcast_merge(masterShape, input_shapes[i]); - } - size_t output_rank = m_body->get_output_partial_shape(0).rank().get_length(); - if (output_rank > masterShape.size()) { - masterShape.insert(masterShape.begin(), output_rank - masterShape.size(), 1); - } - return {{masterShape}, ShapeInferStatus::success}; - } else { - std::vector outputDims; - std::vector new_shapes; - for (const auto& s : input_shapes) - new_shapes.emplace_back(s); - auto& params = m_body->get_parameters(); - if (params.size() != input_shapes.size()) { - IE_THROW() << "Got invalid number of input shapes to reshape subgraph body"; - } - for (size_t i = 0; i < params.size(); ++i) { - params[i]->set_partial_shape(new_shapes[i]); - } - m_body->validate_nodes_and_infer_types(); - for (const auto& res : m_body->get_results()) { - auto& pshape = res->get_input_partial_shape(0); - if (!pshape.is_static()) { - IE_THROW() << "Subgraph inferred dynamic output shape during reshape with static inputs"; - } - outputDims.emplace_back(pshape.get_shape()); - } - - return {outputDims, ShapeInferStatus::success}; - } + const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + const auto& snippets_result = m_subgraph->shape_infer(input_shapes); + OPENVINO_ASSERT(m_status_map.count(snippets_result.status) != 0, "Failed to map snippets shapeInfer status to the plugin one"); + return {snippets_result.dims, m_status_map.at(snippets_result.status)}; } port_mask_t get_port_mask() const override { @@ -80,21 +30,22 @@ public: } private: - std::shared_ptr m_body; + std::shared_ptr m_subgraph; + std::map m_status_map; }; class SnippetShapeInferFactory : public ShapeInferFactory { public: - SnippetShapeInferFactory(const std::shared_ptr& op) { - auto subgraph = ov::as_type_ptr(op); - snippet_body = subgraph->body_ptr()->clone(); + explicit SnippetShapeInferFactory(const std::shared_ptr& op) { + m_subgraph = ov::as_type_ptr(op); + OPENVINO_ASSERT(m_subgraph, "Invalid node type detected in SnippetShapeInferFactory"); } ShapeInferPtr makeShapeInfer() const override { - return std::make_shared(snippet_body); + return std::make_shared(m_subgraph); } private: - std::shared_ptr snippet_body = nullptr; + std::shared_ptr m_subgraph = nullptr; }; } // namespace node } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp index a1e46db97ad..e16088a1567 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp @@ -10,13 +10,15 @@ #include "utils/general_utils.h" -using namespace ov; +namespace ov { +namespace intel_cpu { intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type src_type, const Type type, const size_t offset_in, const size_t offset_out0, const size_t offset_out1, std::vector layout_input, const size_t blk_size_k, const size_t blk_size_n) : snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), m_type(type), m_src_type(src_type) { + m_brgemmVNNIFactor = 4 / m_src_type.size(); set_output_size(type == Type::WithCompensations ? 2 : 1); set_input_port_descriptor({0, offset_in}, 0); set_output_port_descriptor({0, offset_out0}, 0); @@ -32,6 +34,7 @@ intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type s std::vector layout_input, const size_t blk_size_k, const size_t blk_size_n) : snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), m_type(type), m_src_type(src_type) { + m_brgemmVNNIFactor = 4 / m_src_type.size(); set_output_size(type == Type::WithCompensations ? 2 : 1); set_input_port_descriptor(desc_in0, 0); set_output_port_descriptor(desc_out0, 0); @@ -42,38 +45,38 @@ intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type s custom_constructor_validate_and_infer_types(std::move(layout_input)); } -bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { +bool BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes); MemoryAccess::visit_attributes(visitor); visitor.on_attribute("src_type", m_src_type); return true; } -void intel_cpu::BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector layout_input) { +void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector layout_input) { INTERNAL_OP_SCOPE(BrgemmRepack_ctor_validate_and_infer_types); // During ctor call, BrgemmCopyB doesn't know his port descriptors. // So we use port descs from source inputs const auto element_type = get_input_element_type(0); - const auto pshape = snippets::utils::get_reordered_planar_shape(get_input_partial_shape(0), layout_input); + const auto pshape = snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_input); validate(pshape, element_type); } -void intel_cpu::BrgemmCopyB::validate_and_infer_types() { +void BrgemmCopyB::validate_and_infer_types() { INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types); - const auto element_type = get_input_element_type(0); - const auto pshape = snippets::utils::get_port_planar_shape(input(0)); + const auto& element_type = get_input_element_type(0); + const auto& pshape = snippets::utils::get_planar_pshape(input(0)); validate(pshape, element_type); } -void intel_cpu::BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::element::Type& element_type) { +void BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::element::Type& element_type) { NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8), - "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); + "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); if (pshape.is_dynamic()) { - set_output_type(0, element_type, ov::PartialShape{ov::Dimension::dynamic()}); + set_output_type(0, element_type, ov::PartialShape {ov::Dimension::dynamic()}); if (is_with_compensations()) { - set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()}); + set_output_type(1, ov::element::f32, ov::PartialShape {ov::Dimension::dynamic()}); } return; } @@ -81,9 +84,8 @@ void intel_cpu::BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov:: const auto shape = pshape.get_shape(); const auto N = *shape.rbegin(); const auto K = *(shape.rbegin() + 1); - const auto brgemmVNNIFactor = 4 / m_src_type.size(); - set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, brgemmVNNIFactor)), + set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, m_brgemmVNNIFactor)), ov::Dimension(rnd_up(N, m_N_blk))}); if (is_with_compensations()) { set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, m_N_blk))}); @@ -91,7 +93,7 @@ void intel_cpu::BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov:: } void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n) { - const auto input_shape = snippets::utils::get_port_planar_shape(input(0)).get_shape(); + const auto& input_shape = snippets::utils::get_planar_pshape(input(0)).get_shape(); m_K_blk = blk_size_k != 0 ? blk_size_k : *(input_shape.rbegin() + 1); m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin(); } @@ -107,8 +109,41 @@ std::shared_ptr intel_cpu::BrgemmCopyB::clone_with_new_inputs(const Output m_K_blk, m_N_blk); } -size_t intel_cpu::BrgemmCopyB::get_offset_compensations() const { +size_t BrgemmCopyB::get_offset_compensations() const { OPENVINO_ASSERT(is_with_compensations() && get_output_size() == 2, "The offset for compensations must be in BrgemmCopyB only with compensations and 2 outputs!"); return get_output_offset(1); } + +BrgemmCopyB::ShapeInfer::ShapeInfer(const std::shared_ptr& n) { + const auto& brg_copyb = ov::as_type_ptr(n); + OPENVINO_ASSERT(brg_copyb, "Got invalid node in BrgemmCopyB::ShapeInfer"); + m_layout = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->input(0))->get_layout(); + m_num_outs = brg_copyb->get_output_size(); + m_N_blk = brg_copyb->get_n_block_size(); + m_brgemmVNNIFactor = brg_copyb->m_brgemmVNNIFactor; +} + +snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes"); + const auto& old_shape = input_shapes[0].get(); + snippets::VectorDims planar_shape; + planar_shape.reserve(old_shape.size()); + for (const auto idx : m_layout) + planar_shape.push_back(old_shape[idx]); + const auto N = *planar_shape.rbegin(); + const auto K = *(planar_shape.rbegin() + 1); + OPENVINO_ASSERT(N != DYNAMIC_DIMENSION && K != DYNAMIC_DIMENSION, + "BrgemmCopyB shape infer got dynamic N or K dimension, which is not supported"); + + std::vector new_shapes(m_num_outs); + new_shapes[0].push_back(rnd_up(K, m_brgemmVNNIFactor)); + new_shapes[0].push_back(rnd_up(N, m_N_blk)); + if (m_num_outs == 2) { + new_shapes[1].push_back(rnd_up(N, m_N_blk)); + } + return {new_shapes, snippets::ShapeInferStatus::success}; +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp index d2e5a06d738..62703049aea 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp @@ -5,6 +5,7 @@ #pragma once #include "snippets/op/memory_access.hpp" +#include namespace ov { namespace intel_cpu { @@ -51,6 +52,16 @@ public: bool has_evaluate() const override { return false; } std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + class ShapeInfer : public snippets::IShapeInferSnippets { + std::vector m_layout{}; + size_t m_num_outs = 1; + size_t m_N_blk = 64; + size_t m_brgemmVNNIFactor = 1; + public: + explicit ShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; + }; + private: void custom_constructor_validate_and_infer_types(std::vector layout_input = {}); void validate(const ov::PartialShape& pshape, const ov::element::Type& element_type); @@ -61,6 +72,7 @@ private: size_t m_K_blk = 0; size_t m_N_blk = 0; + size_t m_brgemmVNNIFactor = 1; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp index 362262ad9a7..03e3325376c 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp @@ -7,6 +7,7 @@ #include "snippets/utils.hpp" #include "snippets/lowered/port_descriptor.hpp" #include "utils/general_utils.h" +#include "snippets/utils.hpp" namespace ov { @@ -78,19 +79,19 @@ void BrgemmCPU::custom_constructor_validate_and_infer_types(std::vector // So we use port descs from source inputs const auto brgemm_copy = is_with_data_repacking() ? get_brgemm_copy() : nullptr; const auto planar_input_shapes = - std::vector{ snippets::utils::get_reordered_planar_shape(get_input_partial_shape(0), layout_a), - brgemm_copy ? snippets::utils::get_port_planar_shape(brgemm_copy->input(0)) - : snippets::utils::get_reordered_planar_shape(get_input_partial_shape(1), layout_b) }; + std::vector{ snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_a), + brgemm_copy ? snippets::utils::get_planar_pshape(brgemm_copy->input(0)) + : snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) }; auto output_shape = get_output_partial_shape(planar_input_shapes); - set_output_type(0, get_output_type(), snippets::utils::get_reordered_planar_shape(output_shape, layout_c)); + set_output_type(0, get_output_type(), snippets::utils::get_planar_pshape(output_shape, layout_c)); // Additional check for 3rd input validate_with_scratchpad(planar_input_shapes[1].get_shape()); } void BrgemmCPU::compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) { - const auto input_shape_0 = snippets::utils::get_port_planar_shape(input(0)).get_shape(); - const auto input_shape_1 = snippets::utils::get_port_planar_shape(input(1)).get_shape(); + const auto input_shape_0 = snippets::utils::get_planar_pshape(input(0)).get_shape(); + const auto input_shape_1 = snippets::utils::get_planar_pshape(input(1)).get_shape(); m_M_blk = blk_size_m != 0 ? blk_size_m : *(input_shape_0.rbegin() + 1); m_K_blk = blk_size_k != 0 ? blk_size_k : *input_shape_0.rbegin(); m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape_1.rbegin(); @@ -180,5 +181,13 @@ size_t BrgemmCPU::get_offset_scratch() const { return get_input_offset(2); } +BrgemmCPU::ShapeInfer::ShapeInfer(const std::shared_ptr& n) : Brgemm::ShapeInfer(n) { + const auto& brg = ov::as_type_ptr(n); + OPENVINO_ASSERT(brg, "Got invalid node in BrgemmCPU::ShapeInfer"); + const auto brgemm_copy = brg->is_with_data_repacking() ? brg->get_brgemm_copy() : nullptr; + if (brgemm_copy) + m_io_layouts[1] = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(brgemm_copy->input(0))->get_layout(); +} + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp index bf07b7a8546..e1957bb66d2 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp @@ -69,6 +69,12 @@ public: constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024; + class ShapeInfer : public Brgemm::ShapeInfer { + public: + explicit ShapeInfer(const std::shared_ptr& n); + }; + + private: void custom_constructor_validate_and_infer_types(std::vector layout_a, std::vector layout_b, std::vector layout_c); void compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp index 6bd7a2c72b5..40dc488a254 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp @@ -62,8 +62,8 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { const auto& brgemm_in1_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->input(1)); const auto& brgemm_out_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->output(0)); - const auto dimsMatMulIn0 = snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape(); - const auto dimsMatMulIn1 = snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape(); + const auto dimsMatMulIn0 = snippets::utils::get_planar_pshape(brgemm->input_value(0)).get_shape(); + const auto dimsMatMulIn1 = snippets::utils::get_planar_pshape(brgemm->input_value(1)).get_shape(); const auto K = *dimsMatMulIn0.rbegin(); const auto N = *dimsMatMulIn1.rbegin(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp index 71abbb7f7b1..db6f34a4e74 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp @@ -47,8 +47,8 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() { return false; } - const auto dimsMatMulIn0 = snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape(); - const auto dimsMatMulIn1 = snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape(); + const auto dimsMatMulIn0 = snippets::utils::get_planar_pshape(brgemm->input_value(0)).get_shape(); + const auto dimsMatMulIn1 = snippets::utils::get_planar_pshape(brgemm->input_value(1)).get_shape(); const auto K = *dimsMatMulIn0.rbegin(); const auto N = *dimsMatMulIn1.rbegin(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp new file mode 100644 index 00000000000..d09f3f218e6 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "shape_inference.hpp" +#include +#include "op/brgemm_copy_b.hpp" +#include "op/brgemm_cpu.hpp" +#include "op/fused_mul_add.hpp" +#include "op/load_convert.hpp" +#include "op/store_convert.hpp" +#include "transformations/cpu_opset/common/op/swish_cpu.hpp" + +namespace ov { +namespace snippets { +using ShapeInferPtr = IShapeInferSnippetsFactory::ShapeInferPtr; + +ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key, + const std::shared_ptr& op) const { + const auto& maker_iter = specific_ops_registry.find(key); + if (maker_iter != specific_ops_registry.end()) + return maker_iter->second(op); + return {}; +} + + +#define SHAPE_INFER_PREDEFINED(OP, InferType) \ + { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared();} } +#define SHAPE_INFER_OP_SPECIFIC(OP) \ + { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared(n);} } + +const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry { + SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::SwishNode, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertSaturation, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer), + // + SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB), + SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCPU), +}; +#undef SHAPE_INFER_OP_SPECIFIC +#undef SHAPE_INFER_PREDEFINED + +} // namespace snippets +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.hpp new file mode 100644 index 00000000000..bc2eb7684a2 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace snippets { + +/** + * \brief Shape infer factory that can create shape-infer instances for cpu-specific operations + */ +class CPUShapeInferSnippetsFactory : public IShapeInferSnippetsFactory{ + /** \brief Factory makers registry which can be specialized for key and value. */ + static const TRegistry specific_ops_registry; + +protected: + /** + * @brief get shape infer instances for operations from backend-specific opset + * @return register ShapeInferPtr + */ + ShapeInferPtr get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key, const std::shared_ptr& op) const override; +}; + +} // namespace snippets +} // namespace ov