[Snippets] Delegate domain optimization to a LIR pass (#18991)

This commit is contained in:
Ivan Novoselov 2023-10-10 15:23:28 +01:00 committed by GitHub
parent 4426486e6f
commit c385c13185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 525 additions and 236 deletions

View File

@ -12,6 +12,7 @@
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/shape_types.hpp"
namespace ov {
namespace snippets {
@ -23,17 +24,13 @@ namespace snippets {
*/
class Schedule {
public:
/**
* @brief Default constructor
*/
Schedule() : work_size({}), is_flat(false), ptr(nullptr) {}
Schedule() = default;
/**
* @brief Default to create schedule out of specific parameters
* @param ws work size for kernel execution
* @param f can this kernel be linearided to 1D range
* @param wd work domain for kernel execution
* @param p pointer to generated code
*/
Schedule(const ov::PartialShape& ws, bool f, code p) : work_size(ws), is_flat(f), ptr(p) {}
Schedule(const VectorDims& wd, code p) : parallel_exec_domain(wd), ptr(p) {}
/**
* @brief Returns callable instanse of code pointer
*/
@ -41,8 +38,7 @@ public:
return reinterpret_cast<K>(const_cast<unsigned char*>(ptr));
}
ov::PartialShape work_size {};
bool is_flat {false};
VectorDims parallel_exec_domain {};
code ptr {nullptr};
};

View File

@ -21,6 +21,14 @@ public:
// True if we should check runtime info for nodes to call specific needed transformations
bool m_need_fill_tail_register = false;
size_t m_loop_depth = 1;
// Some Subgraphs doesn't support domain optimization due to operations' semantics
bool m_enable_domain_optimization = false;
// Minimal advised work amount for parallel execution.
// Set by a backend, typically equals to the number of threads available on the machine.
size_t m_min_parallel_work_amount = 8;
// Minimal advised work amount that should be processed during one call of the executable produced by Subgraph::generate
// Set by a backend, should be large enough to compensate for the kernel call overheads
size_t m_min_kernel_work_amount = 256;
};
/* The control flow of Snippets is built on Linear Intermediate Representation (Linear IR).
@ -46,6 +54,7 @@ public:
const container& get_ops() const {return m_expressions; }
const io_container& get_IO_ops() const {return m_io_expressions; }
Config get_config() {return m_config; }
void set_loop_depth(size_t loop_depth) { m_config.m_loop_depth = loop_depth; }
const ExpressionPtr& get_expr_by_node(const std::shared_ptr<Node>& n) const;
@ -103,9 +112,26 @@ public:
using LoopManagerPtr = std::shared_ptr<LoopManager>;
const LoopManagerPtr& get_loop_manager() const { return m_loop_manager; }
const std::shared_ptr<IShapeInferSnippetsFactory>& get_shape_infer_factory() { return m_shape_infer_factory; }
IShapeInferSnippets::Result shape_infer(const std::vector<VectorDimsRef>& input_shapes);
const std::shared_ptr<ShapeInferSnippetsNode>& get_shape_infer_instance() const {return m_shape_infer; }
VectorDims get_master_shape() const;
private:
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;
class LIRShapeInfer : public ShapeInferSnippetsNode {
public:
using IOExpression = lowered::IOExpression;
explicit LIRShapeInfer(container& body_exprs, io_container& io_exprs);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
private:
const std::shared_ptr<container> m_exprs = nullptr;
std::vector<std::shared_ptr<IOExpression>> m_input_exprs {};
std::vector<std::shared_ptr<IOExpression>> m_output_exprs {};
};
static ov::NodeVector get_ordered_ops(const std::shared_ptr<ov::Model>& model);
// Default ctor - can be called only from Linear IR initialization as default way
ExpressionPtr create_expression(const std::shared_ptr<Node>& n, const std::shared_ptr<ov::Model>& model = nullptr);

View File

@ -0,0 +1,68 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/shape_types.hpp"
namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @interface OptimizeDomain
* @brief Collapse input/output dimensions to balance parallel/per-thread load. The pass consists of two steps:
* The pass collapses two last dimensions while none of them is broadcasted and the resulting dim size
* 1. Dimension collapsing: If none of the last two dimensions are broadcasted, the last dimension's size
* is less than min_kernel_work_amount and the remaining dimensions provide work amount larger than
* min_parallel_work_amount (min_kernel_work_amount and min_parallel_work_amount specified in LireanIR config),
* then these two dimensions are collapsed into one and the collapsing attempt is repeated.
* 2. Tile rank increment: Tile rank is the rank of a tensor that processed during one call. If all except
* for the last two dimensions provide work_amount larger than min_parallel_work_amount, then tile_rank
* is incremented. This effectively increases kernel work_amount.
* Examples of graphs before and after this transformations are depicted below.
* @param tile_rank (taken by reference) rank of a tensor that processed during one call. Incremented if dimensions are collapsed.
* @ingroup snippets
*/
// Example:
// min_jit_work_amount = 256
// min_parallel_work_amount = 4
//
// Before OptimizeDomain | After OptimizeDomain
// -------------------------------------------------------------------
// tile_rank = 1 | tile_rank = 2
// |
// in1 in2 | in1 in2
// [14, 15, 16, 17] [14, 15, 16, 17] | [1, 14, 15, 272] [1, 14, 15, 272]
// \ / | \ /
// Add | Add
// [14, 15, 16, 17] | [1, 14, 15, 272]
// | | |
// Result | Result
// [14, 15, 16, 17] | [1, 14, 15, 272]
class OptimizeDomain : public snippets::lowered::pass::Pass {
public:
OPENVINO_RTTI("OptimizeDomain", "Pass")
explicit OptimizeDomain(size_t& tile_rank);
bool run(LinearIR& linear_ir) override;
private:
size_t& m_tile_rank;
static size_t optimize(std::vector<VectorDims>& input_shapes,
VectorDims& master_shape,
size_t total_work_amount,
size_t min_parallel_work_amount,
size_t min_jit_work_amount);
inline static bool can_increase_jit_work_amount(const VectorDims& master_shape,
size_t min_parallel_work_amount,
size_t total_work_amount);
};
} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

View File

@ -38,9 +38,9 @@ public:
PortDescriptor(VectorDims shape, VectorDims subtensor_shape, std::vector<size_t> layout = {});
PortDescriptor() = default;
VectorDims get_shape() const {return m_tensor_shape;}
VectorDims get_subtensor() const {return m_subtensor_shape;}
std::vector<size_t> get_layout() const {return m_layout;}
const VectorDims& get_shape() const {return m_tensor_shape;}
const VectorDims& get_subtensor() const {return m_subtensor_shape;}
const std::vector<size_t>& get_layout() const {return m_layout;}
size_t get_reg() const { return m_reg; }
void set_shape(const VectorDims& tensor) { m_tensor_shape = tensor; }

View File

@ -124,6 +124,8 @@ public:
void set_generator(std::shared_ptr<ov::snippets::Generator> generator);
void set_tile_rank(size_t newRank) {tileRank = newRank;}
void set_virtual_port_count(const size_t count);
void set_min_jit_work_amount(const size_t jit_work_amount);
void set_min_parallel_work_amount(const size_t parallel_work_amount);
void print() const;
@ -178,34 +180,22 @@ private:
// True if body has operations that don't support plugin-side domain optimizations
// (e.g. Transpose, Softmax, MatMul in general doesn't support dimensions collapsing)
bool m_has_domain_sensitive_ops = false;
// Minimal advised work amount for parallel execution.
// Set by a backend, typically equals to the number of threads available on the machine.
size_t m_min_parallel_work_amount = 8;
// Minimal advised work amount every JIT kernel should process during one execution call
// Set by a backend, should be large enough to compensate for the kernel call overheads
size_t m_min_jit_work_amount = 256;
} config;
class ShapeInferSnippetsNode : public IShapeInferSnippets {
public:
const Result& get_last_result() {return m_last_result; }
protected:
Result m_last_result{{}, ShapeInferStatus::success};
};
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;
class NgraphShapeInfer : public ShapeInferSnippetsNode {
std::shared_ptr<ov::Model> m_ngraph_body;
ParameterVector m_parameters;
ResultVector m_results;
public:
explicit NgraphShapeInfer(const std::shared_ptr<ov::Model>& body);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
class LIRShapeInfer : public ShapeInferSnippetsNode {
using IOExpression = lowered::IOExpression;
std::shared_ptr<lowered::LinearIR> m_lir_body;
std::vector<std::shared_ptr<IOExpression>> m_param_exprs;
std::vector<std::shared_ptr<IOExpression>> m_result_exprs;
public:
explicit LIRShapeInfer(const std::shared_ptr<lowered::LinearIR>& body);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
};
static inline auto create_body(const std::string& name, const ov::ResultVector& results, const ov::ParameterVector& parameters) ->

View File

@ -8,6 +8,11 @@
namespace ov {
namespace snippets {
bool broadcast_merge_into(VectorDims& dst, const VectorDims& src, const ov::op::AutoBroadcastSpec& autob = ov::op::AutoBroadcastType::NUMPY);
bool merge_into(VectorDims& dst, const VectorDims& src);
class NumpyBroadcastShapeInfer : public IShapeInferSnippets {
public:
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;

View File

@ -37,6 +37,18 @@ public:
virtual Result infer(const std::vector<VectorDimsRef>& input_shapes) = 0;
};
/**
* Shape inference class for Subgraph node (both nGraph and Linear IRs).
* It stores the result of the last shape inference, so it can be reused in optimization pipeline.
*
*/
class ShapeInferSnippetsNode : public IShapeInferSnippets {
public:
const Result& get_last_result() {return m_last_result; }
protected:
Result m_last_result{{}, ShapeInferStatus::success};
};
class IShapeInferSnippetsFactory {
public:
// Helper type to define specific Makers map values.

View File

@ -112,6 +112,7 @@ ExpressionPort Expression::get_output_port(size_t i) {
}
void Expression::updateShapes() {
OPENVINO_ASSERT(m_shapeInference, "Attempt to UpdateShapes without initialized shapeInference");
IShapeInferSnippets::Result result;
try {
std::vector<VectorDimsRef> input_shapes;
@ -121,11 +122,10 @@ void Expression::updateShapes() {
input_shapes.reserve(in_connectors.size());
for (size_t i = 0; i < in_connectors.size(); i++) {
const auto& src_port = in_connectors[i]->get_source();
const auto i_shape = src_port.get_descriptor_ptr()->get_shape();
// todo: do we really need to store the same shape twice in parent's out_port_desc and this in_port_descs
in_descriptors[i]->set_shape(i_shape);
input_shapes.emplace_back(i_shape);
const auto& src_port_desc = in_connectors[i]->get_source().get_descriptor_ptr();
in_descriptors[i]->set_shape(src_port_desc->get_shape());
// Note that input_shape is a reference, so we should always bind it to an object with a longer lifetime
input_shapes.emplace_back(in_descriptors[i]->get_shape());
}
result = m_shapeInference->infer(input_shapes);
@ -133,6 +133,8 @@ void Expression::updateShapes() {
catch (const std::exception& exp) {
OPENVINO_THROW("Shape inference of " + (get_node()->get_friendly_name()) + " failed: " + exp.what());
}
OPENVINO_ASSERT(result.status == ShapeInferStatus::success,
"Shape inference of " + (get_node()->get_friendly_name()) + " didn't return success status");
const auto& out_descriptors = get_output_port_descriptors();
OPENVINO_ASSERT(result.dims.size() == out_descriptors.size(), "shapeInference call returned invalid number of output shapes");
for (size_t i = 0; i < out_descriptors.size(); i++)

View File

@ -9,10 +9,10 @@
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/lowered/expression_factory.hpp"
#include "snippets/op/serialization_node.hpp"
#include "snippets/utils.hpp"
#include "openvino/core/graph_util.hpp"
#include "openvino/core/type.hpp"
#include "snippets/utils.hpp"
namespace ov {
namespace snippets {
@ -41,6 +41,7 @@ LinearIR::LinearIR(const std::shared_ptr<ov::Model>& model, const std::shared_pt
last_param = it;
}
}
m_shape_infer = std::make_shared<LIRShapeInfer>(m_expressions, m_io_expressions);
}
ExpressionPtr LinearIR::create_expression(const std::shared_ptr<Node>& n, const std::shared_ptr<ov::Model>& model) {
@ -296,6 +297,68 @@ LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt i
return find(it, crend(), target);
}
IShapeInferSnippets::Result LinearIR::shape_infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(m_shape_infer, "Attempt to call shape_infer when the shapeInfer instance was not created");
return m_shape_infer->infer(input_shapes);
}
VectorDims LinearIR::get_master_shape() const {
VectorDims master_shape{};
// Note: inputs and outputs must be broadcastable, so it's enough to broadcast-merge only outputs
std::vector<std::shared_ptr<IOExpression>> out_exprs;
for (const auto& ioe : m_io_expressions) {
if (ioe->get_type() == IOExpression::io_type::OUTPUT)
out_exprs.push_back(ioe);
}
// Note: Snippets would benefit from a more generic master_shape calculation approach.
// It will be implemented in the scope of ROI propagation activity (ticket 120505)
const auto& result_parent = out_exprs[0]->get_input_port_connector(0)->get_source().get_expr();
if (!m_config.m_enable_domain_optimization && out_exprs.size() == 1 &&
ov::is_type<snippets::op::Brgemm>(result_parent->get_node())) {
master_shape = utils::get_planar_vdims(out_exprs[0]->get_input_port_descriptor(0));
} else {
for (const auto& oe : out_exprs) {
const auto& port_desc = oe->get_input_port_descriptor(0);
OPENVINO_ASSERT(ov::snippets::broadcast_merge_into(master_shape, port_desc->get_shape()),
"Failed to merge input shapes in OptimizeDomain pass");
}
}
return master_shape;
}
LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs)
: ShapeInferSnippetsNode(),
m_exprs{std::make_shared<container>(body_exprs)} {
// Note that here we rely on the assumption that io_expressions can't be changed after the LIR was created
for (const auto& expr : io_exprs) {
if (expr->get_type() == IOExpression::io_type::INPUT) {
m_input_exprs.push_back(expr);
} else if (expr->get_type() == IOExpression::io_type::OUTPUT) {
m_output_exprs.emplace_back(expr);
} else {
OPENVINO_THROW("Invalid io expression type detected");
}
}
}
IShapeInferSnippets::Result LinearIR::LIRShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(m_input_exprs.size() == input_shapes.size(), "Got invalid number of input shapes in LIR ShapeInfer");
for (size_t i = 0; i < m_input_exprs.size(); i++)
m_input_exprs[i]->get_output_port_descriptor(0)->set_shape(input_shapes[i]);
for (const auto& expr : *m_exprs) {
if (expr->needShapeInfer())
expr->updateShapes();
}
std::vector<VectorDims> outputDims;
outputDims.reserve(m_output_exprs.size());
for (const auto& expr : m_output_exprs) {
outputDims.push_back(expr->get_input_port_descriptor(0)->get_shape());
}
m_last_result = {outputDims, ShapeInferStatus::success};
return m_last_result;
}
}// namespace lowered
}// namespace snippets

View File

@ -24,7 +24,7 @@ size_t InsertLoadStore::get_count(const PortDescriptorPtr& port_desc) const {
const auto shape = port_desc->get_shape();
// Find last dimension by layout
const auto last_dim_idx = std::find(layout.begin(), layout.end(), layout.size() - 1);
OPENVINO_ASSERT(last_dim_idx != layout.end(), "Load/Store expression have incorrect layout");
OPENVINO_ASSERT(last_dim_idx != layout.end() && *last_dim_idx < shape.size(), "Load/Store expression have incorrect layout");
const auto dim = shape[*last_dim_idx];
return dim == 1 ? 1 : m_vector_size;
}

View File

@ -0,0 +1,124 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/lowered/pass/optimize_domain.hpp"
#include "snippets/itt.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
OptimizeDomain::OptimizeDomain(size_t& tile_rank) : Pass(), m_tile_rank(tile_rank) {
}
size_t OptimizeDomain::optimize(std::vector<VectorDims>& input_shapes,
VectorDims& master_shape,
const size_t total_work_amount,
const size_t min_parallel_work_amount,
const size_t min_jit_work_amount) {
if (master_shape.size() <= 2)
return false;
auto CollapseLastDim = [](VectorDims& dims) {
OPENVINO_ASSERT(dims.size() >= 2, "CollapseLastDim can't process shape with less than two dims");
dims[dims.size() - 1] *= dims[dims.size() - 2];
for (auto i = dims.size() - 2; i > 0; i--)
dims[i] = dims[i - 1];
dims[0] = 1;
};
// Check that neither of the two last dims is broadcasted, so they can be collapsed
auto LastDimsNotBroadcasted = [] (const std::vector<VectorDims>& input_shapes, const VectorDims& master_shape) {
const auto master_last = *master_shape.rbegin();
const auto master_prelast = *++master_shape.rbegin();
return std::all_of(input_shapes.begin(), input_shapes.end(),
[=](const VectorDims& s) {
return *s.rbegin() == master_last &&
*++s.rbegin() == master_prelast;
});
};
size_t jit_work_amount = master_shape.back();
size_t num_dims_collapsed = 0;
while (jit_work_amount < min_jit_work_amount &&
can_increase_jit_work_amount(master_shape, min_parallel_work_amount, total_work_amount) &&
LastDimsNotBroadcasted(input_shapes, master_shape) &&
num_dims_collapsed < master_shape.size() - 1) {
for (auto &s : input_shapes)
CollapseLastDim(s);
CollapseLastDim(master_shape);
num_dims_collapsed++;
jit_work_amount = master_shape.back();
}
return num_dims_collapsed;
}
inline bool OptimizeDomain::can_increase_jit_work_amount(const VectorDims& master_shape,
const size_t min_parallel_work_amount,
const size_t total_work_amount) {
return master_shape.size() > 2 &&
master_shape[master_shape.size() - 1] * master_shape[master_shape.size() - 2] *
min_parallel_work_amount <= total_work_amount;
}
bool OptimizeDomain::run(snippets::lowered::LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::OptimizeDomain")
const auto& config = linear_ir.get_config();
if (linear_ir.empty())
return false;
m_tile_rank = 1;
if (!config.m_enable_domain_optimization) {
// Note: this is a special case: if optimization is not allowed, always assume 2D tile
m_tile_rank = 2;
return false;
}
OPENVINO_ASSERT(config.m_min_parallel_work_amount != 0, "OptimizeDomain: Min parallel work amount can't equal to zero");
std::vector<std::shared_ptr<snippets::lowered::IOExpression>> input_exprs;
std::vector<VectorDims> input_shapes;
VectorDims master_shape = linear_ir.get_master_shape();
for (const auto& expr : linear_ir.get_IO_ops()) {
if (expr->get_type() == snippets::lowered::IOExpression::io_type::INPUT) {
input_exprs.push_back(expr);
const auto& shape = expr->get_output_port_descriptor(0)->get_shape();
OPENVINO_ASSERT(std::none_of(shape.begin(), shape.end(),
[](size_t d) {return d == snippets::IShapeInferSnippets::DYNAMIC_DIMENSION; }),
"OptimizeDomain pass does not support dynamic shapes");
OPENVINO_ASSERT(ov::snippets::broadcast_merge_into(master_shape, shape),
"Failed to merge input shapes in OptimizeDomain pass");
input_shapes.emplace_back(shape);
}
}
const auto total_work_amount = std::accumulate(master_shape.begin(),
master_shape.end(),
(size_t)1,
std::multiplies<size_t>());
const auto num_dims_collapsed = optimize(input_shapes,
master_shape,
total_work_amount,
config.m_min_parallel_work_amount,
config.m_min_kernel_work_amount);
if (num_dims_collapsed > 0) {
std::vector<VectorDimsRef> infer_shapes;
infer_shapes.reserve(input_shapes.size());
for (const auto& is : input_shapes)
infer_shapes.emplace_back(is);
// Need to propagate updated shapes through LIR
linear_ir.shape_infer(infer_shapes);
}
// We can still try to increment tile rank after dimension collapsing
if (can_increase_jit_work_amount(master_shape, config.m_min_parallel_work_amount, total_work_amount) &&
num_dims_collapsed != master_shape.size() - 1)
m_tile_rank++;
return num_dims_collapsed > 0;
}
} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

View File

@ -40,6 +40,7 @@
#include "snippets/lowered/pass/identify_buffers.hpp"
#include "snippets/lowered/pass/validate_loops.hpp"
#include "snippets/lowered/pass/insert_loops.hpp"
#include "snippets/lowered/pass/optimize_domain.hpp"
#include "transformations/utils/utils.hpp"
@ -67,6 +68,14 @@ void Subgraph::set_virtual_port_count(const size_t count) {
m_virtual_port_count = count;
}
void Subgraph::set_min_jit_work_amount(const size_t jit_work_amount) {
config.m_min_jit_work_amount = jit_work_amount;
}
void Subgraph::set_min_parallel_work_amount(const size_t parallel_work_amount) {
config.m_min_parallel_work_amount = parallel_work_amount;
}
auto Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool {
return ov::is_type<ov::op::v1::Transpose>(op) ||
ov::is_type<ov::op::v1::Softmax>(op) ||
@ -151,6 +160,7 @@ Subgraph::Subgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& b
for (size_t i = 0; i < body->get_output_size(); ++i)
m_output_descriptions[0].push_back(std::make_shared<BodyOutputDescription>(i, i));
m_transformations_allowed = false;
m_shape_infer = std::make_shared<NgraphShapeInfer>(body);
}
Subgraph::Subgraph(const NodeVector& args, const std::shared_ptr<ov::Model>& body)
@ -470,70 +480,38 @@ bool Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noex
}
IShapeInferSnippets::Result Subgraph::shape_infer(const std::vector<VectorDimsRef>& input_shapes) {
if (!m_shape_infer && !m_linear_ir) {
OPENVINO_ASSERT(body_ptr(), "Can't create shape infer for Subgraph with an empty body");
m_shape_infer = std::make_shared<NgraphShapeInfer>(body_ptr());
} else if (!std::dynamic_pointer_cast<LIRShapeInfer>(m_shape_infer) && m_linear_ir) {
m_shape_infer = std::make_shared<LIRShapeInfer>(m_linear_ir);
}
OPENVINO_ASSERT(m_shape_infer, "Attempt to call shape_infer when it's not initialized");
return m_shape_infer->infer(input_shapes);
}
Subgraph::NgraphShapeInfer::NgraphShapeInfer(const std::shared_ptr<ov::Model>& body) :
m_ngraph_body(body), m_parameters(body->get_parameters()), m_results(body->get_results()) {
m_ngraph_body(body) {
OPENVINO_ASSERT(m_ngraph_body, "Can't initialize shape infer with empty body");
}
IShapeInferSnippets::Result Subgraph::NgraphShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(m_parameters.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
for (size_t i = 0; i < m_parameters.size(); ++i)
m_parameters[i]->set_partial_shape(utils::vdims_to_pshape(input_shapes[i].get()));
const ParameterVector& parameters = m_ngraph_body->get_parameters();
const ResultVector& results = m_ngraph_body->get_results();
OPENVINO_ASSERT(parameters.size() == input_shapes.size(), "Got invalid number of input shapes to reshape subgraph body");
for (size_t i = 0; i < parameters.size(); ++i)
parameters[i]->set_partial_shape(utils::vdims_to_pshape(input_shapes[i].get()));
m_ngraph_body->validate_nodes_and_infer_types();
std::vector<VectorDims> outputDims;
for (const auto& res : m_results)
for (const auto& res : results)
outputDims.emplace_back(utils::pshape_to_vdims(res->get_input_partial_shape(0)));
m_last_result = {outputDims, ShapeInferStatus::success};
return m_last_result;
}
Subgraph::LIRShapeInfer::LIRShapeInfer(const std::shared_ptr<lowered::LinearIR>& body) :
m_lir_body(body) {
for (const auto& io_expr : m_lir_body->get_IO_ops()) {
switch (io_expr->get_type()) {
case IOExpression::io_type::INPUT : m_param_exprs.push_back(io_expr); break;
case IOExpression::io_type::OUTPUT : m_result_exprs.push_back(io_expr); break;
default : OPENVINO_THROW("Undefined io expression type");
}
}
}
IShapeInferSnippets::Result
Subgraph::LIRShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(m_param_exprs.size() == input_shapes.size(), "Got invalid number of input shapes in LIR ShapeInfer");
// todo: check that order of param_exprs is always the same as that of input_shapes
// if not use io_expr index to sort in constructor
for (size_t i = 0; i < m_param_exprs.size(); ++i) {
m_param_exprs[i]->get_output_port_descriptor(0)->set_shape(input_shapes[i]);
}
for (const auto& expr : *m_lir_body) {
if (expr->needShapeInfer())
expr->updateShapes();
}
std::vector<VectorDims> outputDims;
outputDims.reserve(m_result_exprs.size());
for (const auto& r : m_result_exprs) {
outputDims.push_back(r->get_input_port_descriptor(0)->get_shape());
}
m_last_result = {outputDims, ShapeInferStatus::success};
return m_last_result;
}
std::shared_ptr<lowered::LinearIR>
Subgraph::convert_body_to_linear_ir(const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory) const {
lowered::Config lowering_config;
lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops;
lowering_config.m_need_fill_tail_register = config.m_has_domain_sensitive_ops;
lowering_config.m_loop_depth = tileRank;
lowering_config.m_enable_domain_optimization = !config.m_has_domain_sensitive_ops;
lowering_config.m_min_parallel_work_amount = config.m_min_parallel_work_amount;
lowering_config.m_min_kernel_work_amount = config.m_min_jit_work_amount;
return std::make_shared<lowered::LinearIR>(body_ptr(), shape_infer_factory, lowering_config);
}
@ -650,6 +628,11 @@ void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::control_flow_transformations")
// Domain optimization must be the first pass, because all other transformations may depend on PortDescriptor shapes
size_t loop_depth = 1;
lowered::pass::OptimizeDomain(loop_depth).run(linear_ir);
linear_ir.set_loop_depth(loop_depth);
const size_t vector_size = get_generator()->get_target_machine()->get_lanes();
const int32_t buffer_allocation_rank = static_cast<int32_t>(linear_ir.get_config().m_loop_depth);
@ -730,7 +713,13 @@ snippets::Schedule Subgraph::generate(const std::vector<pass::Manager::Positione
const auto& lowering_result = m_generator->generate(linear_ir, linear_ir.get_config(), compile_params);
const auto ptr = lowering_result.binary_code;
return {master_shape, false /*canBeLinearized*/, ptr};
VectorDims parallel_exec_domain = linear_ir.get_master_shape();
const size_t loop_depth = linear_ir.get_config().m_loop_depth;
for (size_t i = 0; i < loop_depth; i++)
parallel_exec_domain[parallel_exec_domain.size() - 1 - i] = 1;
return {parallel_exec_domain, ptr};
}
void Subgraph::print() const {

View File

@ -175,7 +175,7 @@ auto update_intermediate_supported_ops(std::shared_ptr<ov::Node>& interm_op, ov:
interm_op = interm_op->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
}
return true;
};
}
} // namespace
bool ov::snippets::pass::TokenizeMHASnippets::is_matmul0_supported(const std::shared_ptr<ov::opset1::MatMul>& matmul) {

View File

@ -7,7 +7,6 @@
namespace ov {
namespace snippets {
using Result = IShapeInferSnippets::Result;
namespace {
/*
* Merge SRC to DST with broadcasting rules defined by the Autobroadcast specifier
*/
@ -87,7 +86,6 @@ bool merge_into(VectorDims& dst, const VectorDims& src) {
success &= merge_dim(dst[i], dst[i], src[i]);
return success;
}
} // namespace
Result NumpyBroadcastShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(!input_shapes.empty(), "No input shapes were provided for NumpyBroadcastShapeInfer");

View File

@ -55,6 +55,7 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
SHAPE_INFER_PREDEFINED(op::Scalar, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::VectorBuffer, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::LoopEnd, EmptyShapeInfer),
SHAPE_INFER_PREDEFINED(op::Kernel, EmptyShapeInfer),
SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer),
// Note that Result has no output PortConnectors, so the shape must be empty

View File

@ -117,8 +117,10 @@ ov::PartialShape get_planar_pshape(const Output<Node>& out) {
VectorDims get_planar_vdims(const VectorDims& shape, const std::vector<size_t>& layout) {
VectorDims reordered_shape(shape.size());
for (size_t i = 0; i < layout.size(); i++)
for (size_t i = 0; i < layout.size(); i++) {
OPENVINO_ASSERT(layout[i] < shape.size(), "get_planar_vdims: layout index is greater than the shape size");
reordered_shape[i] = shape[layout[i]];
}
return reordered_shape;
}

View File

@ -0,0 +1,36 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <common_test_utils/ov_test_utils.hpp>
#include "snippets/shape_types.hpp"
namespace ov {
namespace test {
namespace snippets {
struct OptimizeDomainParams {
OptimizeDomainParams() = default;
OptimizeDomainParams(size_t, size_t, std::vector<ov::PartialShape>, ov::snippets::VectorDims, size_t);
size_t min_jit_work_amount = 0;
size_t min_parallel_work_amount = 0;
std::vector<ov::PartialShape> input_shapes;
ov::snippets::VectorDims exp_master_shape; // Expected master_shape
size_t exp_loop_depth = 0; // Expected loop depth (aka tile rank)
};
class OptimizeDomainTest : public testing::TestWithParam<OptimizeDomainParams> {
public:
using VectorDims = ov::snippets::VectorDims;
static std::string getTestCaseName(testing::TestParamInfo<OptimizeDomainParams> obj);
protected:
void SetUp() override;
std::shared_ptr<ov::Model> m_model;
OptimizeDomainParams m_domain_opt_params;
};
} // namespace snippets
} // namespace test
} // namespace ov

View File

@ -47,7 +47,6 @@ public:
void SetUp() override;
void TearDown() override;
protected:
static std::shared_ptr<ov::snippets::op::Subgraph> getSubgraph(const std::shared_ptr<Model>& f);
static std::shared_ptr<ov::snippets::op::Subgraph>
getLoweredSubgraph(const std::shared_ptr<Model>& f,
@ -57,6 +56,8 @@ protected:
const ov::snippets::lowered::pass::PassPipeline& lowered_post_common = {},
const std::shared_ptr<ov::snippets::Generator>& generator = nullptr);
static std::shared_ptr<ov::snippets::op::Subgraph> getTokenizedSubgraph(const std::shared_ptr<Model>& f);
protected:
ov::PartialShape master_shape{};
};

View File

@ -0,0 +1,98 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/common_utils.hpp"
#include "snippets/lowered/pass/optimize_domain.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "lowered/pass/optimize_domain.hpp"
#include "subgraph_simple.hpp"
#include "lowering_utils.hpp"
namespace ov {
namespace test {
namespace snippets {
OptimizeDomainParams::OptimizeDomainParams(size_t min_jit_work_amount,
size_t min_parallel_work_amount,
std::vector<ov::PartialShape> input_shapes,
ov::snippets::VectorDims exp_master_shape,
size_t exp_loop_depth) :
min_jit_work_amount(min_jit_work_amount),
min_parallel_work_amount(min_parallel_work_amount),
input_shapes(std::move(input_shapes)),
exp_master_shape(std::move(exp_master_shape)),
exp_loop_depth(exp_loop_depth) {
}
std::string OptimizeDomainTest::getTestCaseName(testing::TestParamInfo<OptimizeDomainParams> obj) {
OptimizeDomainParams domain_opt_params = obj.param;
std::ostringstream result;
result << "MinJitWork=" << domain_opt_params.min_jit_work_amount << "_";
result << "MinParWork=" << domain_opt_params.min_parallel_work_amount << "_";
for (size_t i = 0; i < domain_opt_params.input_shapes.size(); i++)
result << "IS[" << i << "]=" << ov::test::utils::partialShape2str({domain_opt_params.input_shapes[i]}) << "_";
result << "ExpMS=" << ov::test::utils::vec2str(domain_opt_params.exp_master_shape) << "_";
result << "ExpLD=" << domain_opt_params.exp_loop_depth << "_";
return result.str();
}
void OptimizeDomainTest::SetUp() {
m_domain_opt_params = this->GetParam();
m_model = std::make_shared<EltwiseFunction>(m_domain_opt_params.input_shapes)->getOriginal();
}
TEST_P(OptimizeDomainTest, DomainOptimization) {
auto subgraph = LoweringTests::getTokenizedSubgraph(m_model);
subgraph->set_min_jit_work_amount(m_domain_opt_params.min_jit_work_amount);
subgraph->set_min_parallel_work_amount(m_domain_opt_params.min_parallel_work_amount);
auto linear_ir = *subgraph->convert_body_to_linear_ir();
size_t loop_depth = 1;
ov::snippets::lowered::pass::OptimizeDomain(loop_depth).run(linear_ir);
const auto& master_shape = linear_ir.get_master_shape();
EXPECT_EQ(loop_depth, m_domain_opt_params.exp_loop_depth) << "Inconsistent loop depth detected";
EXPECT_THAT(master_shape, testing::ContainerEq(m_domain_opt_params.exp_master_shape)) << "Inconsistent master_shape detected";
}
namespace OptimizeDomainTestsInstantiation {
std::vector<OptimizeDomainParams> dopt_params = {
// No broadcasting => dimensions collapsed
{256, 4, {{14, 15, 1, 17}, {14, 15, 1, 17}}, {1, 1, 14, 255}, 1},
{256, 4, {{14, 15, 16, 1}, {14, 15, 16, 1}}, {1, 1, 14, 240}, 1},
// Same dimensions, but larger num threads => collapsing omitted
{256, 18, {{14, 15, 1, 17}, {14, 15, 1, 17}}, {1, 14, 15, 17}, 1},
{256, 18, {{14, 15, 16, 1}, {14, 15, 16, 1}}, {1, 14, 15, 16}, 1},
// No broadcasting => collapsing and loop_depth increment
{256, 4, {{14, 15, 16, 17}, {14, 15, 16, 17}}, {1, 14, 15, 272}, 2},
// Same dimensions, but smaller jit work amount => collapsing omitted
{16, 4, {{14, 15, 16, 17}, {14, 15, 16, 17}}, {14, 15, 16, 17}, 2},
// Same dimensions, but higher parallel work amount => collapsing but no loop_depth increment
{256, 18, {{14, 15, 16, 17}, {14, 15, 16, 17}}, {1, 14, 15, 272}, 1},
// Broadcasting breaks dimension collapsing => loop depth incremented
{256, 4, {{14, 15, 16, 1}, {14, 15, 1, 17}}, {14, 15, 16, 17}, 2},
{256, 4, {{14, 15, 1, 17}, {14, 15, 16, 17}}, {14, 15, 16, 17}, 2},
// Collapse even if not enough work to cover min_jit_work_amount
{256, 18, {{4, 5, 6, 7}, {4, 5, 6, 7}}, {1, 4, 5, 42}, 1},
// Same dims, but higher parallel work amount => do not collapse to load all the threads
{256, 32, {{4, 5, 6, 7}, {4, 5, 6, 7}}, {4, 5, 6, 7}, 1},
// 2D and 1D shapes are too small, so no collapsing should be done in such cases
{256, 32, {{4, 5}, {4, 5}}, {4, 5}, 1},
{256, 32, {{5}, {5}}, {5}, 1},
// min_parallel_work_amount = 1 is a special case that would cause all dimensions to collapse (up to min_jit_work_amount of course)
{256, 1, {{4, 1, 6, 7}, {4, 1, 6, 7}}, {1, 1, 1, 168}, 1},
};
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DomainOptimization, OptimizeDomainTest,
::testing::ValuesIn(dopt_params),
OptimizeDomainTest::getTestCaseName);
} // namespace OptimizeDomainTestsInstantiation
} // namespace snippets
} // namespace test
} // namespace ov

View File

@ -67,7 +67,6 @@ namespace CanonicalizationTestsInstantiation {
using ov::snippets::op::Subgraph;
std::vector<Shape> input_shapes;
Shape expected_output_shape;
Subgraph::BlockedShapeVector input_blocked_shapes;
using ov::Shape;
ov::element::Type_t prec = ov::element::f32;

View File

@ -109,6 +109,10 @@ KernelEmitter::KernelEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPt
IE_THROW() << "KernelEmitter invoked with op::Kernel that contains no compile_params";
body = kernel->region;
jcp = *reinterpret_cast<const jit_snippets_compile_args*>(kernel->compile_params);
master_shape = body.get_master_shape();
// Note: plugin can prepend master shape with 1 to facilitate parallel execution (usually up to 6D tensor)
// so we have to reproduce this behavior here
master_shape.insert(master_shape.begin(), jcp.parallel_executor_ndims - master_shape.size(), 1);
const auto& io_exprs = body.get_IO_ops();
num_inputs = 0;
num_outputs = 0;
@ -217,7 +221,7 @@ void KernelEmitter::init_data_pointers(const Xbyak::Reg64& reg_indexes, const Xb
const std::vector<Xbyak::Reg64>& data_ptr_regs) const {
const auto num_params = num_inputs + num_outputs;
// Note that we don't need offset for the last dim, since it's handled directly by Tile emitter
const size_t offset_rank = jcp.master_shape.size() - 1;
const size_t offset_rank = master_shape.size() - 1;
std::vector<std::vector<size_t>> data_offsets(num_params, std::vector<size_t>{});
auto offset_calculation = [=](const std::vector<size_t>& shape, const std::vector<size_t>& layout, const size_t data_size) {
// Strides represent distance between consecutive elements of corresponding dimension.
@ -243,11 +247,8 @@ void KernelEmitter::init_data_pointers(const Xbyak::Reg64& reg_indexes, const Xb
strides = std::move(reordered_strides);
}
// the last stride is ignored, since the entire last dim is processed by kernel
// and no parallel_for data_ptr offsets can be applied in this case (cover tile_rank == 1)
// and no parallel_for data_ptr offsets can be applied in this case
strides.pop_back();
// if tile_rank > 1, then zero corresponding strides since no external offset can be applied
// for (auto j = 0; j < tile_rank - 1; j++)
// strides[strides.size() - 1 - j] = 0;
// actual offset size might be larger that the shape size due to 6D scheduling
strides.insert(strides.begin(), offset_rank - strides.size(), 0);
@ -260,7 +261,7 @@ void KernelEmitter::init_data_pointers(const Xbyak::Reg64& reg_indexes, const Xb
std::function<void(Reg64, const std::vector<size_t>&, Reg64)> init_ptr_with_offset;
init_ptr_with_offset = [&](Reg64 pointer, const std::vector<size_t>& offsets, Reg64 reg_tmp) {
for (size_t j = 0; j < offset_rank; j++) {
if (jcp.master_shape[j] != 1 && offsets[j] != 0) {
if (master_shape[j] != 1 && offsets[j] != 0) {
h->mov(reg_tmp, offsets[j]);
h->imul(reg_tmp, h->ptr[reg_indexes + j * sizeof(size_t)]);
h->add(pointer, reg_tmp);

View File

@ -35,8 +35,7 @@ struct jit_snippets_call_args {
};
struct jit_snippets_compile_args {
std::vector<size_t> master_shape{};
size_t tile_rank = 0;
size_t parallel_executor_ndims = 1;
};
///
/// \brief jit_container_emitter designed to wrap Emitters that contain other Emitters (for example, KernelEmitter)
@ -94,6 +93,7 @@ private:
jit_snippets_compile_args jcp;
std::vector<size_t> gp_regs_pool;
std::vector<size_t> master_shape;
size_t num_inputs;
size_t num_outputs;
size_t num_unique_buffers;

View File

@ -18,6 +18,7 @@
#include <ie_ngraph_utils.hpp>
#include <snippets/op/subgraph.hpp>
#include <snippets/lowered/pass/optimize_domain.hpp>
#include "snippets/pass/matmul_to_brgemm.hpp"
#include "utils/cpu_utils.hpp"
#include "emitters/x64/cpu_generator.hpp"
@ -461,7 +462,7 @@ void Snippet::SnippetJitExecutor::update_ptrs(jit_snippets_call_args& call_args,
}
void Snippet::SnippetJitExecutor::schedule_6d(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) {
const auto& dom = exec_domain;
const auto& dom = parallel_exec_domain;
// < N, C, H, W > < 1, 1, N, C*H*W>
parallel_for5d(dom[0], dom[1], dom[2], dom[3], dom[4],
[&](int64_t d0, int64_t d1, int64_t d2, int64_t d3, int64_t d4) {
@ -474,7 +475,7 @@ void Snippet::SnippetJitExecutor::schedule_6d(const std::vector<MemoryPtr>& inMe
}
void Snippet::SnippetJitExecutor::schedule_nt(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) {
const auto& work_size = exec_domain;
const auto& work_size = parallel_exec_domain;
parallel_nt(0, [&](const int ithr, const int nthr) {
jit_snippets_call_args call_args;
update_ptrs(call_args, inMemPtrs, outMemPtrs);
@ -552,65 +553,20 @@ Snippet::SnippetJitExecutor::SnippetJitExecutor(const SnippetAttrs& attrs, bool
if (canonicalShape.is_dynamic())
IE_THROW() << "Snippets: Canonicalization returned dynamic shape in static pipeline";
masterShape = canonicalShape.get_shape();
const auto &body = snippet_for_generation->body_ptr();
normInputShapes.clear();
for (const auto& p : body->get_parameters())
normInputShapes.emplace_back(p->get_output_shape(0));
normOutputShapes.clear();
for (const auto& r : body->get_results())
normOutputShapes.emplace_back(r->get_input_shape(0));
// prepare
masterShape = getNormalizedDimsBySize(masterShape, tensorRank);
std::vector<size_t> original_input_shape_ranks;
for (auto& pshape : normInputShapes) {
original_input_shape_ranks.push_back(pshape.size());
pshape = getNormalizedDimsBySize(pshape, tensorRank);
}
for (auto& pshape : normOutputShapes)
pshape = getNormalizedDimsBySize(pshape, tensorRank);
tileRank = 1;
bool dims_collapsed = false;
fullWorkAmount = std::accumulate(masterShape.begin(), masterShape.end(), 1, std::multiplies<size_t>());
if (snippet_for_generation->has_domain_sensitive_ops()) {
tileRank = 2;
} else {
dims_collapsed = optimizeExecDomain(normInputShapes, normOutputShapes, masterShape, tileRank);
}
exec_domain = masterShape;
std::vector<size_t> scheduler_work_amounts;
// rename schedulerWorkAmount to harnessWorkAmount?
harnessWorkAmount = fullWorkAmount;
const auto rank = exec_domain.size();
for (auto i = rank - tileRank; i < rank; i++) {
auto& dim = exec_domain[i];
harnessWorkAmount /= dim;
scheduler_work_amounts.push_back(dim);
dim = 1;
}
if (dims_collapsed) {
std::vector<ov::Shape> new_shapes;
for (size_t i = 0; i < normInputShapes.size(); i++) {
const auto norm_shape = normInputShapes[i];
size_t ndims_to_skip = norm_shape.size() - original_input_shape_ranks[i];
new_shapes.emplace_back(norm_shape.begin() + ndims_to_skip, norm_shape.end());
}
snippet_for_generation->reshape_body(new_shapes);
}
snippet_for_generation->set_master_shape(ov::PartialShape(masterShape));
snippet_for_generation->set_tile_rank(tileRank);
snippet_for_generation->set_min_parallel_work_amount(static_cast<size_t>(parallel_get_max_threads()));
// Note: minimal JIT work amount is a predefined value that describes the number of kernel iterations (work amount)
// needed to cover kernel call overhead. It is used for balancing between parallel and JIT work amounts in domain optimization.
snippet_for_generation->set_min_jit_work_amount(256);
// generate
jit_snippets_compile_args jcp;
jcp.master_shape = masterShape;
jcp.tile_rank = tileRank;
jcp.parallel_executor_ndims = tensorRank;
generate(&jcp);
buffer_scratchpad_size = snippet_for_generation->get_buffer_scratchpad_size();
buffer_scratchpad.resize(buffer_scratchpad_size * parallel_get_max_threads(), 0);
parallel_exec_domain = schedule.parallel_exec_domain;
harnessWorkAmount = std::accumulate(parallel_exec_domain.begin(), parallel_exec_domain.end(), 1, std::multiplies<size_t>());
parallel_exec_domain = getNormalizedDimsBySize(parallel_exec_domain, tensorRank);
}
ov::PartialShape Snippet::SnippetJitExecutor::canonicalizeBody(bool reshape) {
@ -628,74 +584,6 @@ ov::PartialShape Snippet::SnippetJitExecutor::canonicalizeBody(bool reshape) {
}
}
bool Snippet::SnippetJitExecutor::optimizeExecDomain(std::vector<VectorDims>& inputShapes, std::vector<VectorDims>& outputShapes,
VectorDims &domain, size_t& TileRank) const {
const size_t minimalConcurrency = parallel_get_max_threads();
const size_t minimalJitWorkAmount = 256;
const size_t ds = domain.size();
if ( ds <= 2 || // not enough dimensions to collapse
domain[ds-1] >= minimalJitWorkAmount || // There is enough work for 1D Tiles, no need to collapse
domain[ds-1] * domain[ds-2] >= fullWorkAmount / minimalConcurrency) // There won't be enough work for every thread (even one iter) if we collapse
return false;
auto findDimsToCollapse = [&]() {
auto collapseLastDims = [](VectorDims& dims, size_t dimsToCollapse) {
if (dimsToCollapse >= dims.size() - 1)
IE_THROW() << "Got invalid number of dims to collapse. Expected < " << dims.size() - 1 << " got " << dimsToCollapse;
for (int i = dims.size() - 2; i > static_cast<int>(dims.size() - dimsToCollapse - 2); i--) {
dims[dims.size() - 1] *= dims[i];
}
for (int i = dims.size() - 2; i >= static_cast<int>(dimsToCollapse); i--) {
dims[i] = dims[i - dimsToCollapse];
}
for (int i = dimsToCollapse - 1; i >= 0; i--) {
dims[i] = 1;
}
};
int collapsedDims = 0;
size_t currentJitWorkAmount = domain[domain.size() - 1];
while (currentJitWorkAmount < minimalJitWorkAmount && currentJitWorkAmount < fullWorkAmount) {
if (static_cast<int>(domain.size()) - collapsedDims - 2 < 0)
break;
bool canCollapse = true;
for (size_t i = 0; i < inputShapes.size(); i++) {
const size_t last = inputShapes[i].size() - 1;
if ((inputShapes[i][last - 1] != 1 && inputShapes[i][last] == 1) ||
(inputShapes[i][last - 1] == 1 && inputShapes[i][last] != 1)) {
canCollapse = false;
break;
}
}
size_t nextJitWorkAmount = currentJitWorkAmount * domain[domain.size() - 2];
if (fullWorkAmount / nextJitWorkAmount >= minimalConcurrency) {
currentJitWorkAmount = nextJitWorkAmount;
// if we cannot use dim collapsing we should use tile2D
if (!canCollapse) {
if (TileRank < maxTileRank) {
TileRank++;
continue;
}
break;
}
collapsedDims++;
for (auto &d : inputShapes)
collapseLastDims(d, 1);
for (auto &d : outputShapes)
collapseLastDims(d, 1);
collapseLastDims(domain, 1);
} else {
break;
}
}
return collapsedDims > 0;
};
return findDimsToCollapse();
}
void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp) {
using Manager = snippets::pass::Manager;
std::vector<Manager::PositionedPass> backend_passes;
@ -726,16 +614,16 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp)
#undef SNIPPETS_REGISTER_PASS
ov::snippets::lowered::pass::PassPipeline control_flow_markup_pipeline;
CPU_REGISTER_PASS_X64(control_flow_markup_pipeline, ov::intel_cpu::pass::BrgemmBlocking);
CPU_REGISTER_PASS_X64(control_flow_markup_pipeline, ov::intel_cpu::pass::BrgemmBlocking)
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert);
// Todo: We don't need shape infer factory now, since shape infer will be done through validate_and_infer_types
// pass std::make_shared<snippets::CPUShapeInferSnippetsFactory>() instead of nullptr, when shape infer is performed on LIR
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert)
// Note: we need to pass valid shapeInfer factory to generate, so it can be used in OptimizeDomain pass
// in all other cases nGraph shape inference will be used until ticket # 113209 (PR 18563) is merged
schedule = snippet_for_generation->generate(backend_passes,
control_flow_markup_pipeline,
control_flow_pipeline,
nullptr,
std::make_shared<snippets::CPUShapeInferSnippetsFactory>(),
reinterpret_cast<const void*>(jcp));
}

View File

@ -114,8 +114,6 @@ private:
size_t numOutput = 0;
ov::PartialShape canonicalizeBody(bool reshape);
// returns true if exec domain was modified
bool optimizeExecDomain(std::vector<VectorDims>&, std::vector<VectorDims>&, VectorDims&, size_t&) const;
void generate(const jit_snippets_compile_args*);
inline void update_ptrs(jit_snippets_call_args&, const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs);
@ -130,22 +128,14 @@ private:
// Holds index of output used as in execution domain
// it should be compatible with a schedule's work size
std::vector<size_t> exec_domain = {};
std::vector<size_t> parallel_exec_domain = {};
/// scheduling info
size_t tensorRank = 0;
size_t tileRank = 1;
size_t fullWorkAmount = 0;
size_t harnessWorkAmount = 0;
const size_t maxTileRank = 2;
std::vector<size_t> dataSize = {};
// master shape is mutable since we need to modify it inside const shapeInfer method
mutable VectorDims masterShape = {};
mutable std::vector<VectorDims> normInputShapes = {};
mutable std::vector<VectorDims> normOutputShapes = {};
std::vector<ptrdiff_t> start_offset_in = {};
std::vector<ptrdiff_t> start_offset_out = {};