[Snippets] Added support of BF16/I8/U8 for MatMul (#15063)
This commit is contained in:
parent
253e4eb366
commit
38c924a3ae
@ -26,7 +26,7 @@ ie_faster_build(${TARGET_NAME}
|
||||
)
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime
|
||||
PRIVATE ngraph_reference ov_shape_inference openvino::runtime::dev)
|
||||
PRIVATE ngraph_reference openvino::runtime::dev)
|
||||
|
||||
target_include_directories(${TARGET_NAME} PUBLIC $<BUILD_INTERFACE:${PUBLIC_HEADERS_DIR}>
|
||||
PRIVATE $<BUILD_INTERFACE:${SHAPE_INFER_INCLUDE_DIR}>)
|
||||
|
@ -43,7 +43,6 @@ public:
|
||||
*/
|
||||
virtual size_t get_lanes() const = 0;
|
||||
|
||||
|
||||
/**
|
||||
* @brief called by generator to all the emitter for a target machine
|
||||
* @return a map by node's type info with callbacks to create an instance of emitter for corresponding operation type
|
||||
@ -155,7 +154,29 @@ public:
|
||||
*/
|
||||
std::shared_ptr<const TargetMachine> get_target_machine() const;
|
||||
|
||||
/**
|
||||
* @interface opRegType
|
||||
* @brief Register type of operations
|
||||
* Note that currently there are 4 types of ops:
|
||||
* gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc)
|
||||
* gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc.
|
||||
* vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc.
|
||||
*/
|
||||
enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec};
|
||||
/**
|
||||
* @brief gets register type by op type
|
||||
* TODO: Should be static attribute of emitters
|
||||
* @return register type
|
||||
*/
|
||||
opRegType get_op_reg_type(const std::shared_ptr<Node>& op) const;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief gets register type by specific plugin op type
|
||||
* @return register type
|
||||
*/
|
||||
virtual opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const;
|
||||
|
||||
std::shared_ptr<TargetMachine> target;
|
||||
// todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then).
|
||||
// This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method.
|
||||
|
@ -5,7 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/matmul.hpp"
|
||||
#include "memory_access.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace snippets {
|
||||
@ -16,30 +16,25 @@ namespace op {
|
||||
* @brief Brgemm is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class Brgemm : public ngraph::op::v0::MatMul {
|
||||
class Brgemm : public MemoryAccess {
|
||||
public:
|
||||
OPENVINO_OP("Brgemm", "SnippetsOpset", ngraph::op::v0::MatMul);
|
||||
Brgemm(const Output<Node>& A, const Output<Node>& B, const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu);
|
||||
OPENVINO_OP("Brgemm", "SnippetsOpset", MemoryAccess);
|
||||
Brgemm(const Output<Node>& A, const Output<Node>& B,
|
||||
const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu);
|
||||
Brgemm() = default;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
size_t get_offset_a() const { return get_input_offset(0); }
|
||||
size_t get_offset_b() const { return get_input_offset(1); }
|
||||
size_t get_offset_c() const { return get_output_offset(0); }
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool has_evaluate() const override { return false; }
|
||||
|
||||
size_t get_offset_a() const { return m_offset_a; }
|
||||
size_t get_offset_b() const { return m_offset_b; }
|
||||
size_t get_offset_c() const { return m_offset_c; }
|
||||
|
||||
void set_offset_a(const size_t offset) { m_offset_a = offset; }
|
||||
void set_offset_b(const size_t offset) { m_offset_b = offset; }
|
||||
void set_offset_c(const size_t offset) { m_offset_c = offset; }
|
||||
|
||||
private:
|
||||
size_t m_offset_a = 0lu; // offset for first input
|
||||
size_t m_offset_b = 0lu; // offset for second input
|
||||
size_t m_offset_c = 0lu; // offset for output
|
||||
protected:
|
||||
ov::element::Type get_output_type() const;
|
||||
ov::PartialShape get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <snippets/op/broadcastmove.hpp>
|
||||
#include <snippets/op/memory_access.hpp>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
@ -17,22 +17,21 @@ namespace op {
|
||||
* @brief Is generated for broadcasting by least varying dimension for non-blocked cases and the second varying dimension for blocked
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class BroadcastLoad : public BroadcastMove {
|
||||
class BroadcastLoad : public MemoryAccess {
|
||||
public:
|
||||
OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::BroadcastMove);
|
||||
OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::MemoryAccess);
|
||||
|
||||
BroadcastLoad(const Output<Node>& x, ov::PartialShape output_shape, size_t offset = 0lu);
|
||||
BroadcastLoad() = default;
|
||||
|
||||
size_t get_offset() const { return m_offset; }
|
||||
void set_offset(const size_t offset) { m_offset = offset; }
|
||||
size_t get_offset() const { return get_input_offset(0); }
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
private:
|
||||
size_t m_offset = 0lu;
|
||||
ov::PartialShape output_shape;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -12,10 +12,9 @@ namespace op {
|
||||
|
||||
/**
|
||||
* @interface Buffer
|
||||
* @brief The operation is for intermediate data storage
|
||||
* - m_allocation_rank - rank of shape for memory allocation: shape[shape_rank - normalize(m_allocation_rank) : shape_rank].
|
||||
* It's needed to allocate needed memory size that depends on Tile rank, for example.
|
||||
* Default value is -1 (full shape)
|
||||
* @brief This is a base class for memory storage.
|
||||
* If Buffer has a parent, the operation is for intermediate data storage - IntermediateMemory type.
|
||||
* Otherwise, the operation is for allocation of new empty memory with shape `m_shape` - NewMemory type
|
||||
* Notes:
|
||||
* - All buffers in a graph have the same memory pointer. So if we have a few buffers,
|
||||
* each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer
|
||||
@ -25,21 +24,30 @@ namespace op {
|
||||
class Buffer : public ngraph::op::Op {
|
||||
public:
|
||||
OPENVINO_OP("Buffer", "SnippetsOpset");
|
||||
|
||||
Buffer(const Output<Node>& x, const int32_t allocation_rank = -1);
|
||||
Buffer() = default;
|
||||
|
||||
int32_t get_allocation_rank() const { return m_allocation_rank; }
|
||||
void set_allocation_rank(int32_t rank) { m_allocation_rank = rank; }
|
||||
|
||||
size_t get_byte_size() const;
|
||||
Buffer(const ov::Shape& shape);
|
||||
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape);
|
||||
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
enum Type {
|
||||
NewMemory,
|
||||
IntermediateMemory
|
||||
};
|
||||
|
||||
Type get_type() const { return m_type; }
|
||||
ov::Shape get_allocation_shape() const { return m_shape; }
|
||||
size_t get_byte_size() const;
|
||||
|
||||
bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
|
||||
bool is_new_memory() const { return m_type == Type::NewMemory; }
|
||||
|
||||
private:
|
||||
int32_t m_allocation_rank = -1;
|
||||
Type m_type = Type::IntermediateMemory;
|
||||
ov::Shape m_shape = {};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -20,11 +20,18 @@ namespace op {
|
||||
*/
|
||||
class Load : public MemoryAccess {
|
||||
public:
|
||||
OPENVINO_OP("Load", "SnippetsOpset");
|
||||
OPENVINO_OP("Load", "SnippetsOpset", MemoryAccess);
|
||||
|
||||
Load(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
|
||||
Load() = default;
|
||||
|
||||
size_t get_offset() const { return get_input_offset(0); }
|
||||
size_t get_count() const { return get_input_count(0); }
|
||||
|
||||
void set_offset(size_t offset) { set_input_offset(offset, 0); }
|
||||
void set_count(size_t count) { set_input_count(count, 0); }
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
|
||||
@ -41,6 +48,9 @@ public:
|
||||
LoadReshape(const Output<Node>& x, size_t count = 1lu, const size_t offset = 0lu, std::vector<size_t> order = {});
|
||||
LoadReshape() = default;
|
||||
|
||||
void set_offset(size_t offset) { set_output_offset(offset, 0); }
|
||||
void set_count(size_t count) { set_output_count(count, 0); }
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -13,9 +13,9 @@ namespace op {
|
||||
/**
|
||||
* @interface MemoryAccess
|
||||
* @brief This is a base class for memory access operations (like Load and Store).
|
||||
* It provides universal set/get interface to manipulate the number
|
||||
* of elements accessed during one operation call ("count").
|
||||
* Default "count" value is "1" - it means to load/store one element
|
||||
* It provides universal interface to manipulate with memory: load/store.
|
||||
* @param m_input_ports - vector of input descriptors: variables of PortDescriptor class
|
||||
* @param m_output_ports - vector of output descriptors: variables of PortDescriptor class
|
||||
* @ingroup snippets
|
||||
*/
|
||||
|
||||
@ -23,18 +23,54 @@ class MemoryAccess : public ngraph::op::Op {
|
||||
public:
|
||||
OPENVINO_OP("MemoryAccess", "SnippetsOpset");
|
||||
|
||||
size_t get_count() const;
|
||||
size_t get_offset() const;
|
||||
void set_count(const size_t count);
|
||||
void set_offset(const size_t offset);
|
||||
/**
|
||||
* @interface PortDescriptor
|
||||
* @brief This class describes port of MemoryAccess operation
|
||||
* @param m_count - count of elements to load/store
|
||||
* @param m_offset - starting index of elements to load/store
|
||||
* @param m_index - port index
|
||||
* @ingroup snippets
|
||||
*/
|
||||
struct PortDescriptor {
|
||||
PortDescriptor(size_t count, size_t offset) : count(count), offset(offset) {}
|
||||
PortDescriptor() = default;
|
||||
|
||||
size_t count = 0lu;
|
||||
size_t offset = 0lu;
|
||||
size_t index = 0lu;
|
||||
|
||||
private:
|
||||
PortDescriptor(size_t count, size_t offset, size_t index) : count(count), offset(offset), index(index) {}
|
||||
|
||||
friend class MemoryAccess;
|
||||
};
|
||||
|
||||
void set_input_count(size_t count, size_t idx = 0);
|
||||
void set_output_count(size_t count, size_t idx = 0);
|
||||
void set_input_offset(size_t offset, size_t idx = 0);
|
||||
void set_output_offset(size_t offset, size_t idx = 0);
|
||||
|
||||
size_t get_input_count(size_t idx = 0) const;
|
||||
size_t get_output_count(size_t idx = 0) const;
|
||||
size_t get_input_offset(size_t idx = 0) const;
|
||||
size_t get_output_offset(size_t idx = 0) const;
|
||||
|
||||
size_t get_input_port_count() const { return m_input_ports.size(); }
|
||||
size_t get_output_port_count() const { return m_output_ports.size(); }
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
protected:
|
||||
explicit MemoryAccess(const Output<Node>& x, size_t count = 1lu, size_t offset = 0lu);
|
||||
explicit MemoryAccess(const OutputVector& arguments, size_t input_count = 0, size_t output_count = 0);
|
||||
MemoryAccess() = default;
|
||||
size_t m_count = 0lu;
|
||||
size_t m_offset = 0lu;
|
||||
|
||||
void set_input_port_descriptor(const PortDescriptor& desc, const size_t i);
|
||||
void set_output_port_descriptor(const PortDescriptor& desc, const size_t i);
|
||||
const PortDescriptor& get_input_port_descriptor(const size_t i) const;
|
||||
const PortDescriptor& get_output_port_descriptor(const size_t i) const;
|
||||
|
||||
std::vector<PortDescriptor> m_input_ports;
|
||||
std::vector<PortDescriptor> m_output_ports;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -20,11 +20,18 @@ namespace op {
|
||||
*/
|
||||
class Store : public MemoryAccess {
|
||||
public:
|
||||
OPENVINO_OP("Store", "SnippetsOpset");
|
||||
OPENVINO_OP("Store", "SnippetsOpset", MemoryAccess);
|
||||
|
||||
Store(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
|
||||
Store() = default;
|
||||
|
||||
size_t get_offset() const { return get_output_offset(0); }
|
||||
size_t get_count() const { return get_output_count(0); }
|
||||
|
||||
void set_offset(size_t offset) { set_output_offset(offset, 0); }
|
||||
void set_count(size_t count) { set_output_count(count, 0); }
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
|
||||
|
@ -6,6 +6,8 @@
|
||||
|
||||
#include <ngraph/pass/pass.hpp>
|
||||
|
||||
#include "snippets/generator.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace snippets {
|
||||
namespace pass {
|
||||
@ -18,10 +20,13 @@ namespace pass {
|
||||
*/
|
||||
class AssignRegisters : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
explicit AssignRegisters() {
|
||||
explicit AssignRegisters(const std::function<Generator::opRegType(const std::shared_ptr<Node>& op)>& mapper) : m_reg_type_mapper(mapper) {
|
||||
set_property(ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE, true);
|
||||
}
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
|
||||
private:
|
||||
std::function<Generator::opRegType(const std::shared_ptr<Node>& op)> m_reg_type_mapper;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
@ -29,10 +29,31 @@ ov::PartialShape get_port_planar_shape(const Output<Node>& out);
|
||||
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout);
|
||||
std::vector<size_t> get_node_output_layout(const std::shared_ptr<Node>& node);
|
||||
std::vector<size_t> get_node_output_layout(const Node* node);
|
||||
void set_transpose_output_layout(const ov::Output<Node>& port, const std::shared_ptr<opset1::Transpose>& node);
|
||||
void set_output_layout(const ov::Output<Node>& port, const std::vector<size_t>& layout);
|
||||
|
||||
inline ov::Dimension get_inner_dim(const ov::PartialShape &shape) { return *(shape.rbegin()); }
|
||||
inline ov::Dimension get_outer_dim(const ov::PartialShape &shape) { return *(shape.rbegin() + 1); }
|
||||
|
||||
inline auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t {
|
||||
return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) + 1 : allocation_rank;
|
||||
}
|
||||
|
||||
template <typename T, typename P>
|
||||
constexpr bool one_of(T val, P item) { return val == item; }
|
||||
|
||||
template <typename T, typename P, typename... Args>
|
||||
constexpr bool one_of(T val, P item, Args... item_others) {
|
||||
return val == item || one_of(val, item_others...);
|
||||
}
|
||||
|
||||
template <typename T, typename P>
|
||||
constexpr bool everyone_is(T val, P item) { return val == item; }
|
||||
|
||||
template <typename T, typename P, typename... Args>
|
||||
constexpr bool everyone_is(T val, P item, Args... item_others) {
|
||||
return val == item && everyone_is(val, item_others...);
|
||||
}
|
||||
} // namespace utils
|
||||
} // namespace snippets
|
||||
} // namespace ngraph
|
||||
|
@ -77,8 +77,15 @@ auto tail_transformations(NodeVector& tail, const size_t tail_size, const ngraph
|
||||
}
|
||||
}
|
||||
} else if (const auto memory_access = std::dynamic_pointer_cast<ngraph::snippets::op::MemoryAccess>(op)) {
|
||||
if (memory_access->get_count() != 1) {
|
||||
memory_access->set_count(tail_size);
|
||||
for (size_t i = 0; i < memory_access->get_input_port_count(); ++i) {
|
||||
if (memory_access->get_input_count(i) > 1) {
|
||||
memory_access->set_input_count(tail_size, i);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < memory_access->get_output_port_count(); ++i) {
|
||||
if (memory_access->get_output_count(i) > 1) {
|
||||
memory_access->set_output_count(tail_size, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
updated_tile.push_back(op);
|
||||
@ -220,5 +227,41 @@ std::shared_ptr<const TargetMachine> Generator::get_target_machine() const {
|
||||
return target;
|
||||
}
|
||||
|
||||
Generator::opRegType Generator::get_op_reg_type(const std::shared_ptr<Node>& op) const {
|
||||
if (std::dynamic_pointer_cast<opset1::Parameter>(op) ||
|
||||
std::dynamic_pointer_cast<opset1::Result>(op) ||
|
||||
std::dynamic_pointer_cast<op::LoopBegin>(op) ||
|
||||
std::dynamic_pointer_cast<op::LoopEnd>(op) ||
|
||||
std::dynamic_pointer_cast<op::Brgemm>(op) ||
|
||||
std::dynamic_pointer_cast<op::Buffer>(op))
|
||||
return gpr2gpr;
|
||||
else if (std::dynamic_pointer_cast<snippets::op::Load>(op) ||
|
||||
std::dynamic_pointer_cast<snippets::op::BroadcastLoad>(op))
|
||||
return gpr2vec;
|
||||
else if (std::dynamic_pointer_cast<snippets::op::Store>(op))
|
||||
return vec2gpr;
|
||||
else if (ov::op::util::is_unary_elementwise_arithmetic(op) ||
|
||||
ov::op::util::is_binary_elementwise_arithmetic(op) ||
|
||||
ov::op::util::is_binary_elementwise_comparison(op) ||
|
||||
ov::op::util::is_binary_elementwise_logical(op) ||
|
||||
std::dynamic_pointer_cast<opset1::LogicalNot>(op) ||
|
||||
std::dynamic_pointer_cast<opset1::PRelu>(op) ||
|
||||
std::dynamic_pointer_cast<opset1::Convert>(op) ||
|
||||
std::dynamic_pointer_cast<opset1::Select>(op) ||
|
||||
std::dynamic_pointer_cast<op::VectorBuffer>(op) ||
|
||||
std::dynamic_pointer_cast<op::BroadcastMove>(op) ||
|
||||
std::dynamic_pointer_cast<op::Scalar>(op) ||
|
||||
std::dynamic_pointer_cast<op::HorizonMax>(op) ||
|
||||
std::dynamic_pointer_cast<op::HorizonSum>(op))
|
||||
return vec2vec;
|
||||
else
|
||||
return get_specific_op_reg_type(op);
|
||||
}
|
||||
|
||||
Generator::opRegType Generator::get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const {
|
||||
throw ov::Exception("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!");
|
||||
}
|
||||
|
||||
|
||||
}// namespace snippets
|
||||
}// namespace ngraph
|
||||
|
@ -7,56 +7,123 @@
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "matmul_shape_inference.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace snippets {
|
||||
namespace op {
|
||||
|
||||
Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B, const size_t offset_a, const size_t offset_b, const size_t offset_c)
|
||||
: MatMul(), m_offset_a(offset_a), m_offset_b(offset_b), m_offset_c(offset_c) {
|
||||
set_arguments({A, B});
|
||||
Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
|
||||
const size_t offset_a, const size_t offset_b, const size_t offset_c) : MemoryAccess({A, B}, 2, 1) {
|
||||
set_output_size(1);
|
||||
set_input_offset(offset_a, 0);
|
||||
set_input_offset(offset_b, 1);
|
||||
set_output_offset(offset_a, 0);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool Brgemm::visit_attributes(AttributeVisitor& visitor) {
|
||||
MatMul::visit_attributes(visitor);
|
||||
visitor.on_attribute("offset_a", m_offset_a);
|
||||
visitor.on_attribute("offset_b", m_offset_b);
|
||||
visitor.on_attribute("offset_c", m_offset_c);
|
||||
return true;
|
||||
}
|
||||
|
||||
void Brgemm::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types);
|
||||
element::Type result_et;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
|
||||
"Arguments do not have the same element type (arg0 element type: ",
|
||||
get_input_element_type(0),
|
||||
", arg1 element type: ",
|
||||
get_input_element_type(1),
|
||||
").");
|
||||
// If no leading dimensions are provided, assume dense row-major inputs-outputs
|
||||
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(),
|
||||
"Brgemm currently supports only static shapes.");
|
||||
|
||||
std::vector<ov::PartialShape> planar_input_shapes;
|
||||
for (const auto& in : input_values())
|
||||
planar_input_shapes.emplace_back(utils::get_port_planar_shape(in));
|
||||
std::vector<ov::PartialShape> planar_input_shapes = {
|
||||
utils::get_port_planar_shape(input_value(0)),
|
||||
utils::get_port_planar_shape(input_value(1))
|
||||
};
|
||||
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||
ov::op::v0::shape_infer(this, planar_input_shapes, output_shapes);
|
||||
auto output_shape = get_output_partial_shape(planar_input_shapes);
|
||||
const auto& output_layout = utils::get_node_output_layout(this);
|
||||
output_shapes[0] = utils::get_reordered_planar_shape(output_shapes[0], output_layout);
|
||||
set_output_type(0, result_et, output_shapes[0]);
|
||||
set_output_type(0,
|
||||
get_output_type(),
|
||||
utils::get_reordered_planar_shape(output_shape, output_layout));
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(Brgemm_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<Brgemm>(new_args.at(0), new_args.at(1), m_offset_a, m_offset_b, m_offset_c);
|
||||
return std::make_shared<Brgemm>(new_args.at(0), new_args.at(1), get_offset_a(), get_offset_b(), get_offset_c());
|
||||
}
|
||||
|
||||
ov::element::Type Brgemm::get_output_type() const {
|
||||
const auto element_type_a = get_input_element_type(0);
|
||||
const auto element_type_b = get_input_element_type(1);
|
||||
const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b);
|
||||
const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8;
|
||||
const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b);
|
||||
if (is_f32 || is_bf16) {
|
||||
return element::f32;
|
||||
} else if (is_int8) {
|
||||
return element::i32;
|
||||
} else {
|
||||
throw ngraph_error("BrgemmCPU node has incompatible input element types: " +
|
||||
element_type_a.get_type_name() +
|
||||
" and " +
|
||||
element_type_b.get_type_name());
|
||||
}
|
||||
}
|
||||
|
||||
ov::PartialShape Brgemm::get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const {
|
||||
NGRAPH_CHECK(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference");
|
||||
|
||||
// Note: All majors checks are missed because Brgemm is transformed from MatMul with whole shape infer support
|
||||
|
||||
const auto arg0_shape = input_shapes[0];
|
||||
const auto arg1_shape = input_shapes[1];
|
||||
|
||||
size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size();
|
||||
|
||||
// temporary shapes to calculate output shape
|
||||
ov::PartialShape arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape);
|
||||
|
||||
// one-dimensional tensors unsqueezing is applied to each input independently.
|
||||
if (arg0_rank == 1) {
|
||||
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
|
||||
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
|
||||
// For example {S} will be reshaped to {1, S}.
|
||||
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
|
||||
arg0_rank = arg0_shape_tmp.size();
|
||||
}
|
||||
if (arg1_rank == 1) {
|
||||
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
|
||||
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
|
||||
// For example {S} will be reshaped to {S, 1}.
|
||||
arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
|
||||
arg1_rank = arg1_shape_tmp.size();
|
||||
}
|
||||
// Check matrices dimensions compatibility,
|
||||
using DimType = typename std::iterator_traits<typename ov::PartialShape::iterator>::value_type;
|
||||
auto merged_dimension = DimType();
|
||||
auto arg0_col_dim = arg0_shape_tmp[arg0_rank - 1];
|
||||
auto arg1_row_dim = arg1_shape_tmp[arg1_rank - 2];
|
||||
OPENVINO_ASSERT(DimType::merge(merged_dimension, arg0_col_dim, arg1_row_dim) || arg0_col_dim.is_dynamic() || arg1_row_dim.is_dynamic(),
|
||||
"Incompatible Brgemm matrix dimension");
|
||||
|
||||
// add 1 to begin to align shape ranks if needed
|
||||
if (arg0_rank < arg1_rank)
|
||||
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
|
||||
else if (arg0_rank > arg1_rank)
|
||||
arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);
|
||||
|
||||
size_t max_rank = arg0_shape_tmp.size();
|
||||
std::vector<DimType> output_shape(max_rank);
|
||||
for (size_t i = 0; i < max_rank - 2; ++i) {
|
||||
OPENVINO_ASSERT(DimType::broadcast_merge(output_shape[i], arg0_shape_tmp[i], arg1_shape_tmp[i]) ||
|
||||
arg0_shape_tmp[i].is_dynamic() ||
|
||||
arg1_shape_tmp[i].is_dynamic(),
|
||||
"Incompatible Brgemm batch dimension");
|
||||
}
|
||||
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
|
||||
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N
|
||||
|
||||
// removing the temporary axes from originally 1D tensors.
|
||||
if (arg0_shape.rank().get_length() == 1) {
|
||||
output_shape.erase(output_shape.begin() + output_shape.size() - 2);
|
||||
}
|
||||
if (arg1_shape.rank().get_length() == 1) {
|
||||
output_shape.erase(output_shape.begin() + output_shape.size() - 1);
|
||||
}
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
@ -12,20 +12,20 @@ using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
snippets::op::BroadcastLoad::BroadcastLoad(const Output<Node>& x, ov::PartialShape shape, size_t offset)
|
||||
: BroadcastMove(x, std::move(shape)), m_offset(offset) {
|
||||
: MemoryAccess({x}, 1, 0), output_shape(std::move(shape)) {
|
||||
set_input_port_descriptor({1, offset}, 0);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool snippets::op::BroadcastLoad::visit_attributes(AttributeVisitor& visitor) {
|
||||
BroadcastMove::visit_attributes(visitor);
|
||||
visitor.on_attribute("offset", m_offset);
|
||||
MemoryAccess::visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> snippets::op::BroadcastLoad::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(BroadcastLoad);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<BroadcastLoad>(new_args.at(0), output_shape, m_offset);
|
||||
return std::make_shared<BroadcastLoad>(new_args.at(0), output_shape, get_offset());
|
||||
}
|
||||
|
||||
void snippets::op::BroadcastLoad::validate_and_infer_types() {
|
||||
|
@ -6,8 +6,8 @@
|
||||
|
||||
#include "snippets/op/buffer.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
|
||||
#include <ngraph/runtime/host_tensor.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -16,38 +16,64 @@ auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t
|
||||
return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) : allocation_rank;
|
||||
}
|
||||
|
||||
snippets::op::Buffer::Buffer(const Output<Node>& x, const int32_t allocation_rank) : Op({x}), m_allocation_rank(allocation_rank) {
|
||||
snippets::op::Buffer::Buffer(const ov::Shape& shape)
|
||||
: Op(), m_type(Type::NewMemory), m_shape(shape) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
snippets::op::Buffer::Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape)
|
||||
: Op({arg}), m_type(Type::IntermediateMemory), m_shape(shape) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
snippets::op::Buffer::Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank)
|
||||
: Op({arg}), m_type(Type::IntermediateMemory) {
|
||||
const auto pshape = arg.get_partial_shape();
|
||||
OPENVINO_ASSERT(pshape.is_static(), "Buffer supports only static input shape");
|
||||
const auto shape = pshape.get_shape();
|
||||
const auto normalize_rank = utils::normalize_rank(static_cast<int32_t>(allocation_rank), shape.size());
|
||||
const auto offset = static_cast<int32_t>(shape.size()) - normalize_rank;
|
||||
m_shape = {shape.begin() + offset, shape.end()};
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool snippets::op::Buffer::visit_attributes(AttributeVisitor& visitor) {
|
||||
INTERNAL_OP_SCOPE(Buffer_visit_attributes);
|
||||
visitor.on_attribute("allocation_rank", m_allocation_rank);
|
||||
visitor.on_attribute("allocation_shape", m_shape);
|
||||
return true;
|
||||
}
|
||||
|
||||
void snippets::op::Buffer::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
|
||||
ov::element::Type output_type;
|
||||
ov::Shape output_shape;
|
||||
if (m_type == Type::NewMemory) {
|
||||
OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!");
|
||||
output_shape = m_shape;
|
||||
output_type = ov::element::u8; // 1Byte
|
||||
} else if (m_type == Type::IntermediateMemory) {
|
||||
const auto input_shape = get_input_partial_shape(0);
|
||||
OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape");
|
||||
output_type = get_input_element_type(0);
|
||||
output_shape = input_shape.get_shape();
|
||||
} else {
|
||||
throw ov::Exception("Buffer supports only the following types: NewMemory and IntermediateMemory");
|
||||
}
|
||||
set_output_type(0, output_type, output_shape);
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> snippets::op::Buffer::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(Buffer_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
auto new_buffer = std::make_shared<Buffer>(new_args.at(0), m_allocation_rank);
|
||||
return new_buffer;
|
||||
}
|
||||
|
||||
void snippets::op::Buffer::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
|
||||
const auto shape_rank = get_input_partial_shape(0).rank();
|
||||
if (shape_rank.is_static()) {
|
||||
const auto normalized_rank = normalize_rank(m_allocation_rank, shape_rank.get_length());
|
||||
NGRAPH_CHECK(normalized_rank >= 0 && normalized_rank <= shape_rank.get_length(),
|
||||
"Buffer has incorrect allocation rank: " + std::to_string(m_allocation_rank));
|
||||
if (m_type == Type::NewMemory) {
|
||||
return std::make_shared<Buffer>(m_shape);
|
||||
} else if (m_type == Type::IntermediateMemory) {
|
||||
return std::make_shared<Buffer>(new_args.at(0), m_shape);
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
throw ov::Exception("Buffer supports only the following types: NewMemory and IntermediateMemory");
|
||||
}
|
||||
|
||||
size_t ngraph::snippets::op::Buffer::get_byte_size() const {
|
||||
const auto pshape = get_input_partial_shape(0);
|
||||
NGRAPH_CHECK(pshape.is_static(), "Buffer should have static shapes for memory allocation");
|
||||
const auto shape = pshape.get_shape();
|
||||
const auto normalized_rank = normalize_rank(m_allocation_rank, shape.size());
|
||||
return ngraph::shape_size(shape.rbegin(), shape.rbegin() + normalized_rank) * get_element_type().size();
|
||||
const auto shape = get_allocation_shape();
|
||||
return ngraph::shape_size(shape) * get_element_type().size();
|
||||
}
|
||||
|
@ -12,17 +12,24 @@ namespace ngraph {
|
||||
namespace snippets {
|
||||
namespace op {
|
||||
|
||||
Load::Load(const Output<Node>& x, const size_t count, const size_t offset) : MemoryAccess({x}, count, offset) {
|
||||
Load::Load(const Output<Node>& x, const size_t count, const size_t offset) : MemoryAccess({x}, 1, 0) {
|
||||
set_input_port_descriptor({count, offset}, 0);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void snippets::op::Load::validate_and_infer_types() {
|
||||
// Load has memory access port only on output
|
||||
OPENVINO_ASSERT(get_input_port_count() == 1, "Load node must have memory access input port");
|
||||
OPENVINO_ASSERT(get_output_port_count() == 0, "Load node mustn't have memory access output port");
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> Load::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(Load);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<Load>(new_args.at(0), m_count, m_offset);
|
||||
return std::make_shared<Load>(new_args.at(0), get_count(), get_offset());
|
||||
}
|
||||
|
||||
|
||||
LoadReshape::LoadReshape(const Output<ov::Node>& x, const size_t count, const size_t offset, std::vector<size_t> order)
|
||||
: Load(x, count, offset), m_order(std::move(order)) {
|
||||
const auto& in_shape = x.get_partial_shape();
|
||||
@ -33,6 +40,8 @@ LoadReshape::LoadReshape(const Output<ov::Node>& x, const size_t count, const si
|
||||
*std::min_element(m_order.begin(), m_order.end()) == 0, "LoadReshape detected invalid values in new_order");
|
||||
const std::set<size_t> unique_dims(order.begin(), order.end());
|
||||
NGRAPH_CHECK(unique_dims.size() == order.size(), "LoadReshape order must not contain repeated elements");
|
||||
m_input_ports.resize(get_input_size());
|
||||
set_input_port_descriptor({count, offset}, 0);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -53,7 +62,7 @@ bool snippets::op::LoadReshape::visit_attributes(AttributeVisitor& visitor) {
|
||||
std::shared_ptr<Node> snippets::op::LoadReshape::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(LoadReshape);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<LoadReshape>(new_args.at(0), m_count, m_offset, m_order);
|
||||
return std::make_shared<LoadReshape>(new_args.at(0), get_count(), get_offset(), m_order);
|
||||
}
|
||||
|
||||
}// namespace op
|
||||
|
@ -3,43 +3,80 @@
|
||||
//
|
||||
|
||||
#include <snippets/itt.hpp>
|
||||
|
||||
#include "snippets/op/memory_access.hpp"
|
||||
|
||||
#include <ngraph/runtime/host_tensor.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace snippets {
|
||||
namespace op {
|
||||
|
||||
MemoryAccess::MemoryAccess(const Output<Node>& x, const size_t count, const size_t offset) : Op({x}), m_count(count), m_offset(offset) {}
|
||||
MemoryAccess::MemoryAccess(const OutputVector& arguments, size_t input_count, size_t output_count) : Op(arguments) {
|
||||
while (m_input_ports.size() < input_count) {
|
||||
m_input_ports.push_back({0, 0, m_input_ports.size()});
|
||||
}
|
||||
while (m_output_ports.size() < output_count) {
|
||||
m_output_ports.push_back({0, 0, m_output_ports.size()});
|
||||
}
|
||||
}
|
||||
|
||||
bool MemoryAccess::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("count", m_count);
|
||||
visitor.on_attribute("offset", m_offset);
|
||||
for (size_t i = 0; i < m_input_ports.size(); ++i) {
|
||||
auto port = m_input_ports[i];
|
||||
visitor.on_attribute("count_in_" + std::to_string(i), port.count);
|
||||
visitor.on_attribute("offset_in_" + std::to_string(i), port.offset);
|
||||
}
|
||||
for (size_t i = 0; i < m_output_ports.size(); ++i) {
|
||||
auto port = m_output_ports[i];
|
||||
visitor.on_attribute("count_out_" + std::to_string(i), port.count);
|
||||
visitor.on_attribute("offset_out_" + std::to_string(i), port.offset);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t MemoryAccess::get_count() const {
|
||||
return m_count;
|
||||
void MemoryAccess::set_input_port_descriptor(const PortDescriptor& desc, const size_t i) {
|
||||
NGRAPH_CHECK(i < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports");
|
||||
m_input_ports[i] = { desc.count, desc.offset, i};
|
||||
}
|
||||
|
||||
size_t MemoryAccess::get_offset() const {
|
||||
return m_offset;
|
||||
void MemoryAccess::set_output_port_descriptor(const PortDescriptor& desc, const size_t i) {
|
||||
NGRAPH_CHECK(i < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports");
|
||||
m_output_ports[i] = { desc.count, desc.offset, i};
|
||||
}
|
||||
|
||||
void MemoryAccess::set_count(const size_t count) {
|
||||
m_count = count;
|
||||
const MemoryAccess::PortDescriptor& MemoryAccess::get_input_port_descriptor(const size_t i) const {
|
||||
NGRAPH_CHECK(i < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports");
|
||||
return m_input_ports[i];
|
||||
}
|
||||
|
||||
void MemoryAccess::set_offset(const size_t offset) {
|
||||
m_offset = offset;
|
||||
const MemoryAccess::PortDescriptor& MemoryAccess::get_output_port_descriptor(const size_t i) const {
|
||||
NGRAPH_CHECK(i < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports");
|
||||
return m_output_ports[i];
|
||||
}
|
||||
|
||||
void MemoryAccess::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
void MemoryAccess::set_input_count(size_t count, size_t idx) {
|
||||
set_input_port_descriptor({count, get_input_port_descriptor(idx).offset, idx}, idx);
|
||||
}
|
||||
void MemoryAccess::set_output_count(size_t count, size_t idx) {
|
||||
set_output_port_descriptor({count, get_output_port_descriptor(idx).offset, idx}, idx);
|
||||
}
|
||||
void MemoryAccess::set_input_offset(size_t offset, size_t idx) {
|
||||
set_input_port_descriptor({get_input_port_descriptor(idx).count, offset, idx}, idx);
|
||||
}
|
||||
void MemoryAccess::set_output_offset(size_t offset, size_t idx) {
|
||||
set_output_port_descriptor({get_output_port_descriptor(idx).count, offset, idx}, idx);
|
||||
}
|
||||
size_t MemoryAccess::get_input_count(size_t idx) const {
|
||||
return get_input_port_descriptor(idx).count;
|
||||
}
|
||||
size_t MemoryAccess::get_output_count(size_t idx) const {
|
||||
return get_output_port_descriptor(idx).count;
|
||||
}
|
||||
size_t MemoryAccess::get_input_offset(size_t idx) const {
|
||||
return get_input_port_descriptor(idx).offset;
|
||||
}
|
||||
size_t MemoryAccess::get_output_offset(size_t idx) const {
|
||||
return get_output_port_descriptor(idx).offset;
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace snippets
|
||||
} // namespace ngraph
|
||||
} // namespace ngraph
|
||||
|
@ -12,13 +12,22 @@ namespace ngraph {
|
||||
namespace snippets {
|
||||
namespace op {
|
||||
|
||||
snippets::op::Store::Store(const Output<Node>& x, const size_t count, const size_t offset) : MemoryAccess({x}, count, offset) {
|
||||
snippets::op::Store::Store(const Output<Node>& x, const size_t count, const size_t offset) : MemoryAccess({x}, 0, 1) {
|
||||
set_output_port_descriptor({count, offset}, 0);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void snippets::op::Store::validate_and_infer_types() {
|
||||
// Store has memory access port only on output
|
||||
OPENVINO_ASSERT(get_input_port_count() == 0, "Store node mustn't have memory access input port");
|
||||
OPENVINO_ASSERT(get_output_port_count() == 1, "Store node must have memory access output port");
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> snippets::op::Store::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(Store_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<Store>(new_args.at(0), m_count, m_offset);
|
||||
return std::make_shared<Store>(new_args.at(0), get_count(), get_offset());
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
@ -434,22 +434,21 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() {
|
||||
|
||||
// Propagate to up: in Store. Buffer can have only one Store
|
||||
{
|
||||
auto parent = buffer->get_input_node_shared_ptr(0);
|
||||
auto idx = buffer->input(0).get_source_output().get_index();
|
||||
// There may be graph with several LoopBegin and LoopEnd between Store/Brgemm and Buffer,
|
||||
// so we should iterate through LoopBase
|
||||
while (ov::is_type<snippets::op::LoopBase>(parent)) {
|
||||
const auto source_output = parent->input_value(idx);
|
||||
parent = source_output.get_node_shared_ptr();
|
||||
idx = source_output.get_index();
|
||||
}
|
||||
if (auto store = ov::as_type_ptr<snippets::op::Store>(parent)) {
|
||||
store->set_offset(offset);
|
||||
} else if (const auto brgemm = ov::as_type_ptr<snippets::op::Brgemm>(parent)) {
|
||||
// Brgemm encapsulates work with loading and storing of data
|
||||
brgemm->set_offset_c(offset);
|
||||
} else {
|
||||
throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding Store op for offset propagation");
|
||||
if (buffer->is_intermediate_memory()) {
|
||||
OPENVINO_ASSERT(buffer->get_input_size() == 1, "Buffer with intermediate memory must have one parent");
|
||||
auto parent = buffer->get_input_node_shared_ptr(0);
|
||||
auto idx = buffer->input(0).get_source_output().get_index();
|
||||
while (ov::is_type<snippets::op::LoopBase>(parent)) {
|
||||
const auto source_output = parent->input_value(idx);
|
||||
parent = source_output.get_node_shared_ptr();
|
||||
idx = source_output.get_index();
|
||||
}
|
||||
if (auto memory_access = ov::as_type_ptr<ngraph::snippets::op::MemoryAccess>(parent)) {
|
||||
memory_access->set_output_offset(offset, idx);
|
||||
} else {
|
||||
throw ngraph_error(
|
||||
"Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -466,17 +465,10 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() {
|
||||
for (const auto loop_target_output : child->output(index).get_target_inputs()) {
|
||||
propagate_down(loop_target_output);
|
||||
}
|
||||
} else if (const auto load = ov::as_type_ptr<snippets::op::Load>(child)) {
|
||||
load->set_offset(offset);
|
||||
} else if (const auto brgemm = ov::as_type_ptr<snippets::op::Brgemm>(child)) {
|
||||
// Brgemm encapsulates work with loading and storing of data
|
||||
if (target_input.get_index() == 0) {
|
||||
brgemm->set_offset_a(offset);
|
||||
} else if (target_input.get_index() == 1) {
|
||||
brgemm->set_offset_b(offset);
|
||||
}
|
||||
} else if (auto memory_access = ov::as_type_ptr<ngraph::snippets::op::MemoryAccess>(child)) {
|
||||
memory_access->set_input_offset(offset, target_input.get_index());
|
||||
} else {
|
||||
throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding Load op for offset propagation");
|
||||
throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
|
||||
}
|
||||
};
|
||||
|
||||
@ -497,26 +489,25 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Transpose and MatMul ops should have different memories on inputs and outputs to avoid data corruption,
|
||||
// so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace.
|
||||
const auto parent = buffer->get_input_node_shared_ptr(0);
|
||||
if (ov::is_type<op::Brgemm>(parent) || is_transpose_loop(parent)) {
|
||||
if (buffer->is_intermediate_memory()) {
|
||||
// Transpose, MatMul and other non-decomposed ops should have different memories on inputs and outputs to avoid data corruption,
|
||||
// so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace inside Loop.
|
||||
OPENVINO_ASSERT(buffer->get_input_size() == 1, "Buffer with intermediate memory must have one parent");
|
||||
const auto parent = buffer->get_input_node_shared_ptr(0);
|
||||
if (!ov::is_type<LoopEnd>(parent) || is_transpose_loop(parent)) {
|
||||
offset = m_buffer_scratchpad;
|
||||
propagate_offset(buffer, offset);
|
||||
m_buffer_scratchpad += buffer_size;
|
||||
continue;
|
||||
}
|
||||
|
||||
propagate_offset(buffer, offset);
|
||||
} else {
|
||||
// Single Buffer without input should allocate new memory
|
||||
offset = m_buffer_scratchpad;
|
||||
propagate_offset(buffer, offset);
|
||||
m_buffer_scratchpad += buffer_size;
|
||||
continue;
|
||||
}
|
||||
|
||||
// If Buffer op requires memory size more that has been already allocated,
|
||||
// we increase current memory size to the needed size
|
||||
// For example, it's possible when we have a sequence of Eltwise ops with broadcasting
|
||||
const auto current_allocated_memory_size = m_buffer_scratchpad - offset;
|
||||
if (buffer_size > current_allocated_memory_size) {
|
||||
m_buffer_scratchpad += (buffer_size - current_allocated_memory_size);
|
||||
// Note: we don't update offset because we just add memory to needed size
|
||||
}
|
||||
|
||||
propagate_offset(buffer, offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -644,7 +635,10 @@ snippets::Schedule snippets::op::Subgraph::generate(
|
||||
if (config.m_has_domain_sensitive_ops)
|
||||
initialize_buffer_scratchpad_size();
|
||||
|
||||
snippets::pass::AssignRegisters().run_on_model(body_ptr());
|
||||
std::function<Generator::opRegType(const std::shared_ptr<Node>& op)> reg_type_mapper = [=](const std::shared_ptr<Node>& op) -> Generator::opRegType {
|
||||
return m_generator->get_op_reg_type(op);
|
||||
};
|
||||
snippets::pass::AssignRegisters(reg_type_mapper).run_on_model(body_ptr());
|
||||
|
||||
const auto ops = body_ptr()->get_ops();
|
||||
ngraph::snippets::Generator::GeneratorConfig generatorConfig;
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
namespace {
|
||||
constexpr size_t reg_count = 16lu;
|
||||
using opRegType = ngraph::snippets::Generator::opRegType;
|
||||
} // namespace
|
||||
|
||||
bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
@ -22,31 +23,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr
|
||||
using Reg = size_t;
|
||||
using tensor = std::shared_ptr<descriptor::Tensor>;
|
||||
auto ops = f->get_ordered_ops();
|
||||
// Note that currently there are 3 types of ops:
|
||||
// * gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd) will also be Buffer?
|
||||
// * gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc.
|
||||
// * vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc.
|
||||
enum op_reg_type {gpr2gpr, gpr2vec, vec2gpr, vec2vec};
|
||||
|
||||
auto get_op_reg_type = [](const std::shared_ptr<Node>& op) {
|
||||
if (std::dynamic_pointer_cast<opset1::Parameter>(op) ||
|
||||
std::dynamic_pointer_cast<opset1::Result>(op) ||
|
||||
std::dynamic_pointer_cast<op::LoopBegin>(op) ||
|
||||
std::dynamic_pointer_cast<op::LoopEnd>(op) ||
|
||||
std::dynamic_pointer_cast<op::Brgemm>(op) ||
|
||||
std::dynamic_pointer_cast<op::Buffer>(op))
|
||||
return gpr2gpr;
|
||||
else if (std::dynamic_pointer_cast<snippets::op::Load>(op) ||
|
||||
std::dynamic_pointer_cast<snippets::op::BroadcastLoad>(op))
|
||||
return gpr2vec;
|
||||
else if (std::dynamic_pointer_cast<snippets::op::Store>(op))
|
||||
return vec2gpr;
|
||||
else
|
||||
return vec2vec;
|
||||
};
|
||||
std::vector<std::pair<op_reg_type, std::shared_ptr<Node>>> typed_ops;
|
||||
for (const auto& op : ops)
|
||||
typed_ops.emplace_back(std::make_pair(get_op_reg_type(op), op));
|
||||
std::vector<std::pair<opRegType, std::shared_ptr<Node>>> typed_ops;
|
||||
for (const auto& op : ops) {
|
||||
typed_ops.emplace_back(std::make_pair(m_reg_type_mapper(op), op));
|
||||
}
|
||||
|
||||
size_t counter_vec = 0;
|
||||
size_t counter_gpr = 0;
|
||||
std::map<tensor, Reg> regs_vec, regs_gpr;
|
||||
@ -64,10 +46,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr
|
||||
// here we use the fact that Result input & output tensors are identical by construction
|
||||
manually_assigned_gprs[op->output(0).get_tensor_ptr()] =
|
||||
static_cast<Reg>(f->get_result_index(result) + num_parameters);
|
||||
} else if (const auto& buffer = ov::as_type_ptr<op::Buffer>(op)) {
|
||||
} else if (const auto buffer = ov::as_type_ptr<op::Buffer>(op)) {
|
||||
// All buffers have one common data pointer
|
||||
manually_assigned_gprs[op->input(0).get_tensor_ptr()] =
|
||||
static_cast<Reg>(num_results + num_parameters);
|
||||
if (buffer->is_intermediate_memory()) {
|
||||
manually_assigned_gprs[op->input(0).get_tensor_ptr()] =
|
||||
static_cast<Reg>(num_results + num_parameters);
|
||||
}
|
||||
manually_assigned_gprs[op->output(0).get_tensor_ptr()] =
|
||||
static_cast<Reg>(num_results + num_parameters);
|
||||
} else if (ov::is_type<op::HorizonMax>(op) || ov::is_type<op::HorizonSum>(op)) {
|
||||
@ -114,12 +98,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr
|
||||
};
|
||||
for (const auto& t_op : typed_ops) {
|
||||
switch (t_op.first) {
|
||||
case vec2vec:
|
||||
case gpr2vec:
|
||||
case opRegType::vec2vec:
|
||||
case opRegType::gpr2vec:
|
||||
enumerate_out_tensors(t_op.second, regs_vec, manually_assigned_vecs, counter_vec);
|
||||
break;
|
||||
case gpr2gpr:
|
||||
case vec2gpr:
|
||||
case opRegType::gpr2gpr:
|
||||
case opRegType::vec2gpr:
|
||||
enumerate_out_tensors(t_op.second, regs_gpr, manually_assigned_gprs, counter_gpr);
|
||||
break;
|
||||
}
|
||||
@ -144,24 +128,25 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr
|
||||
for (size_t i = 0; i < typed_ops.size(); i++) {
|
||||
const auto& t_op = typed_ops[i];
|
||||
std::vector<tensor> used_tensors, defined_tensors;
|
||||
for (const auto& in : t_op.second->inputs())
|
||||
for (const auto& in : t_op.second->inputs()) {
|
||||
used_tensors.push_back(in.get_tensor_ptr());
|
||||
}
|
||||
for (const auto& out : t_op.second->outputs())
|
||||
defined_tensors.push_back(out.get_tensor_ptr());
|
||||
switch (t_op.first) {
|
||||
case vec2vec:
|
||||
case opRegType::vec2vec:
|
||||
used_vec[i] = tensor2reg(used_tensors, regs_vec);
|
||||
defined_vec[i] = tensor2reg(defined_tensors, regs_vec);
|
||||
break;
|
||||
case gpr2gpr:
|
||||
case opRegType::gpr2gpr:
|
||||
used_gpr[i] = tensor2reg(used_tensors, regs_gpr);
|
||||
defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr);
|
||||
break;
|
||||
case gpr2vec:
|
||||
case opRegType::gpr2vec:
|
||||
used_gpr[i] = tensor2reg(used_tensors, regs_gpr);
|
||||
defined_vec[i] = tensor2reg(defined_tensors, regs_vec);
|
||||
break;
|
||||
case vec2gpr:
|
||||
case opRegType::vec2gpr:
|
||||
used_vec[i] = tensor2reg(used_tensors, regs_vec);
|
||||
defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr);
|
||||
break;
|
||||
@ -196,12 +181,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr
|
||||
if (k == ops.size())
|
||||
throw ngraph_error("assign registers can't find target op in the body");
|
||||
switch (typed_ops[k].first) {
|
||||
case vec2vec:
|
||||
case vec2gpr:
|
||||
case opRegType::vec2vec:
|
||||
case opRegType::vec2gpr:
|
||||
life_out_vec[n].insert(life_in_vec[k].begin(), life_in_vec[k].end());
|
||||
break;
|
||||
case gpr2gpr:
|
||||
case gpr2vec:
|
||||
case opRegType::gpr2gpr:
|
||||
case opRegType::gpr2vec:
|
||||
life_out_gpr[n].insert(life_in_gpr[k].begin(), life_in_gpr[k].end());
|
||||
break;
|
||||
}
|
||||
|
@ -49,9 +49,16 @@ auto outputs_are_not_broadcastable(const std::shared_ptr<const Node>& node) -> b
|
||||
auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
|
||||
OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::is_supported_op")
|
||||
auto is_supported_matmul = [](const std::shared_ptr<const Node>& n) -> bool {
|
||||
const auto& matmul = is_type<const opset1::MatMul>(n);
|
||||
const auto& matmul = ov::as_type_ptr<const opset1::MatMul>(n);
|
||||
const auto& out_shape = n->get_output_partial_shape(0);
|
||||
return matmul && out_shape.is_static() && out_shape.size() == 4;
|
||||
if (!matmul || out_shape.is_dynamic() || out_shape.size() != 4)
|
||||
return false;
|
||||
const auto intype_0 = matmul->get_input_element_type(0);
|
||||
const auto intype_1 = matmul->get_input_element_type(1);
|
||||
const bool is_f32 = intype_0 == element::f32 && intype_1 == element::f32;
|
||||
const bool is_int8 = (intype_0 == element::i8 || intype_0 == element::u8) && (intype_1 == element::i8);
|
||||
const bool is_bf16 = intype_0 == element::bf16 && intype_1 == element::bf16;
|
||||
return is_f32 || is_bf16 || is_int8;
|
||||
};
|
||||
auto is_supported_transpose = [](const std::shared_ptr<const Node>& n) -> bool {
|
||||
const auto& transpose = as_type_ptr<const opset1::Transpose>(n);
|
||||
|
@ -49,13 +49,8 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() {
|
||||
|
||||
auto callback = [=](pattern::Matcher& m) {
|
||||
OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "FuseTransposeBrgemm")
|
||||
auto set_layout_from_order = [](const std::shared_ptr<opset1::Transpose>& node, const ov::Output<Node>& port) {
|
||||
const auto& const_order = as_type_ptr<opset1::Constant>(node->get_input_node_shared_ptr(1));
|
||||
std::vector<size_t> layout = const_order->cast_vector<size_t>();
|
||||
auto& rt_info = port.get_node_shared_ptr()->get_rt_info();
|
||||
rt_info["Layout"] = layout;
|
||||
};
|
||||
auto brgemm = as_type_ptr<op::Brgemm>(m.get_match_root());
|
||||
|
||||
// Transpose on the Brgemm's output
|
||||
if (!brgemm) {
|
||||
brgemm = as_type_ptr<op::Brgemm>(m.get_match_root()->get_input_node_shared_ptr(0));
|
||||
@ -63,13 +58,13 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() {
|
||||
const auto& transpose_out = m.get_match_value();
|
||||
for (const auto& in : transpose_out.get_target_inputs())
|
||||
in.replace_source_output(brgemm->output(0));
|
||||
set_layout_from_order(as_type_ptr<opset1::Transpose>(transpose_out.get_node_shared_ptr()), brgemm_out);
|
||||
utils::set_transpose_output_layout(brgemm_out, as_type_ptr<opset1::Transpose>(transpose_out.get_node_shared_ptr()));
|
||||
}
|
||||
for (size_t i = 0; i < brgemm->get_input_size(); i++) {
|
||||
const auto& in_value = brgemm->input_value(i);
|
||||
if (transpose_matcher->match(in_value)) {
|
||||
const auto& transpose = as_type_ptr<opset1::Transpose>(in_value.get_node_shared_ptr());
|
||||
set_layout_from_order(transpose, transpose->input_value(0));
|
||||
utils::set_transpose_output_layout(transpose->input_value(0), transpose);
|
||||
brgemm->set_argument(i, transpose->input_value(0));
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ ngraph::snippets::pass::InsertBuffer::InsertBuffer(const int32_t allocation_rank
|
||||
if (!ov::is_type<ngraph::snippets::op::Buffer>(input_node) &&
|
||||
!ov::is_type<ngraph::op::v0::Parameter>(input_node) &&
|
||||
!ov::is_type<ngraph::op::v0::Constant>(input_node)) {
|
||||
const auto buffer = std::make_shared<ngraph::snippets::op::Buffer>(input_node, allocation_rank);
|
||||
const auto buffer = std::make_shared<op::Buffer>(input_node, allocation_rank);
|
||||
root->set_argument(input.get_index(), buffer);
|
||||
rewritten |= true;
|
||||
}
|
||||
@ -68,7 +68,7 @@ ngraph::snippets::pass::InsertBuffer::InsertBuffer(const int32_t allocation_rank
|
||||
}
|
||||
}
|
||||
|
||||
const auto buffer = std::make_shared<ngraph::snippets::op::Buffer>(output, allocation_rank);
|
||||
const auto buffer = std::make_shared<op::Buffer>(output, allocation_rank);
|
||||
for (const auto& consumer : output.get_target_inputs()) {
|
||||
const auto output_node = consumer.get_node()->shared_from_this();
|
||||
if (output_node != buffer &&
|
||||
|
@ -30,7 +30,7 @@ ngraph::snippets::pass::InsertLoad::InsertLoad(const size_t count) {
|
||||
const auto& consumer_node = consumer.get_node();
|
||||
if (ov::is_type<ngraph::snippets::op::Load>(consumer_node) ||
|
||||
ov::is_type<ngraph::snippets::op::LoopBegin>(consumer_node) ||
|
||||
ov::is_type<ngraph::op::v0::MatMul>(consumer_node) ||
|
||||
ov::is_type<ngraph::snippets::op::Brgemm>(consumer_node) ||
|
||||
ov::is_type<ngraph::op::v1::Transpose>(consumer_node)) {
|
||||
return false;
|
||||
}
|
||||
@ -67,7 +67,7 @@ ngraph::snippets::pass::InsertStore::InsertStore(const size_t count) {
|
||||
const auto& parent_node = input.get_source_output().get_node();
|
||||
if (ov::is_type<ngraph::snippets::op::Store>(parent_node) ||
|
||||
ov::is_type<ngraph::snippets::op::LoopEnd>(parent_node) ||
|
||||
ov::is_type<ngraph::op::v0::MatMul>(parent_node) ||
|
||||
ov::is_type<ngraph::snippets::op::Brgemm>(parent_node) ||
|
||||
ov::is_type<ngraph::op::v1::Transpose>(parent_node)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -24,20 +24,20 @@ ngraph::snippets::pass::LoadMoveBroadcastToBroadcastLoad::LoadMoveBroadcastToBro
|
||||
auto root = m.get_match_root();
|
||||
|
||||
const auto &pm = m.get_pattern_value_map();
|
||||
const auto input = pm.at(load_pattern).get_node_shared_ptr();
|
||||
const auto load = ov::as_type_ptr<snippets::op::Load>(pm.at(load_pattern).get_node_shared_ptr());
|
||||
const auto param = pm.at(param_pattern).get_node_shared_ptr();
|
||||
|
||||
// Cannot rewrite Broadcast + Load if load has more than 1 user
|
||||
// or more than one input, or if Broadcast has several inputs
|
||||
if (input->output(0).get_target_inputs().size() != 1 ||
|
||||
root->inputs().size() != 1 || input->inputs().size() != 1) {
|
||||
if (load->output(0).get_target_inputs().size() != 1 ||
|
||||
root->inputs().size() != 1 || load->inputs().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto inshape = root->input(0).get_partial_shape();
|
||||
auto outshape = root->output(0).get_partial_shape();
|
||||
|
||||
auto broadcastload = std::make_shared<snippets::op::BroadcastLoad>(param, outshape, ov::as_type_ptr<snippets::op::Load>(input)->get_offset());
|
||||
auto broadcastload = std::make_shared<snippets::op::BroadcastLoad>(param, outshape, load->get_offset());
|
||||
ngraph::copy_runtime_info(root, broadcastload);
|
||||
ngraph::replace_node(root, broadcastload);
|
||||
|
||||
|
@ -73,7 +73,6 @@ auto get_buffer_and_loop_end(const std::shared_ptr<ngraph::snippets::op::LoopBeg
|
||||
buffer = ov::as_type_ptr<ngraph::snippets::op::Buffer>(parent_shared);
|
||||
if (buffer) {
|
||||
if (buffer->output(0).get_target_inputs().size() == 0 ||
|
||||
buffer->get_input_size() != 1 ||
|
||||
buffer->get_input_source_output(0).get_target_inputs().size() != 1)
|
||||
return false;
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
#include "snippets/pass/matmul_to_brgemm.hpp"
|
||||
|
||||
#include "snippets/op/brgemm.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/rt_info.hpp"
|
||||
@ -30,9 +30,13 @@ MatMulToBrgemm::MatMulToBrgemm() {
|
||||
return false;
|
||||
|
||||
auto brgemm = std::make_shared<op::Brgemm>(matmul->get_input_source_output(0), matmul->get_input_source_output(1));
|
||||
ov::NodeVector nodes = { brgemm };
|
||||
if (brgemm->get_output_element_type(0) != matmul->get_output_element_type(0)) {
|
||||
nodes.emplace_back(std::make_shared<op::ConvertSaturation>(brgemm, matmul->get_output_element_type(0)));
|
||||
}
|
||||
brgemm->set_friendly_name(matmul->get_friendly_name());
|
||||
ngraph::copy_runtime_info(matmul, brgemm);
|
||||
ngraph::replace_node(matmul, brgemm);
|
||||
ngraph::copy_runtime_info(matmul, nodes);
|
||||
ngraph::replace_node(matmul, nodes.back());
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -79,10 +79,9 @@ ngraph::snippets::pass::ResetBufferState::ResetBufferState() {
|
||||
|
||||
// If after Loop there is immediately Buffer, we should reset the Buffer ptr for the next calculations
|
||||
for (size_t i = 0; i < o_size; ++i) {
|
||||
const auto result_shape = body_shapes[i_size + i].get_shape();
|
||||
// check for first target input is enough for Buffer searching because operations can have only single Buffer per each output port as op
|
||||
const auto consumer = loop_end->output(i).get_target_inputs().begin()->get_node();
|
||||
if (ov::is_type<ngraph::snippets::op::Buffer>(consumer)) {
|
||||
if (const auto buffer = ov::as_type_ptr<ngraph::snippets::op::Buffer>(consumer->shared_from_this())) {
|
||||
// To calculate finalization offset we should know index of nesting Loop
|
||||
auto loop_index = 0lu;
|
||||
auto loop = loop_end->input_value(i).get_node_shared_ptr();
|
||||
@ -93,7 +92,8 @@ ngraph::snippets::pass::ResetBufferState::ResetBufferState() {
|
||||
port_idx = source_output.get_index();
|
||||
loop_index++;
|
||||
}
|
||||
|
||||
const auto result_shape = buffer->get_allocation_shape();
|
||||
NGRAPH_CHECK(loop_index < result_shape.size(), "Buffer has invalid Loop index and allocation shape rank");
|
||||
const auto work_amount = std::accumulate(result_shape.rbegin(), result_shape.rbegin() + loop_index + 1, size_t(1), std::multiplies<size_t>());
|
||||
finalization_offsets[i_size + i] =
|
||||
calculate_required_finalization_offsets(work_amount, *(result_shape.rbegin() + loop_index));
|
||||
|
@ -126,7 +126,7 @@ ngraph::snippets::pass::SoftmaxDecomposition::SoftmaxDecomposition(const size_t
|
||||
apply_increments_sum, finalization_offsets_sum);
|
||||
|
||||
const auto horizon_sum = std::make_shared<ngraph::snippets::op::HorizonSum>(sum);
|
||||
const auto buffer_exp = std::make_shared<ngraph::snippets::op::Buffer>(loop_sum_end->output(0), buffer_allocation_rank);
|
||||
const auto buffer_exp = std::make_shared<op::Buffer>(loop_sum_end->output(0), buffer_allocation_rank);
|
||||
|
||||
/* =========================================== */
|
||||
|
||||
|
@ -24,7 +24,7 @@ ngraph::snippets::pass::SetScalarCountForLoad::SetScalarCountForLoad() {
|
||||
if (!load)
|
||||
return false;
|
||||
|
||||
load->set_count(1lu);
|
||||
load->set_input_count(1lu, 0);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
@ -43,7 +43,7 @@ ngraph::snippets::pass::SetScalarCountForStore::SetScalarCountForStore() {
|
||||
if (!store)
|
||||
return false;
|
||||
|
||||
store->set_count(1lu);
|
||||
store->set_output_count(1lu, 0);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
@ -115,6 +115,17 @@ ov::PartialShape get_port_planar_shape(const Output<Node>& out) {
|
||||
return get_reordered_planar_shape(tensor_shape, layout);
|
||||
}
|
||||
|
||||
void set_transpose_output_layout(const ov::Output<Node>& port, const std::shared_ptr<opset1::Transpose>& node) {
|
||||
const auto& const_order = as_type_ptr<opset1::Constant>(node->get_input_node_shared_ptr(1));
|
||||
OPENVINO_ASSERT(const_order != nullptr, "Transpose order must be Constant to set layout!");
|
||||
set_output_layout(port, const_order->cast_vector<size_t>());
|
||||
}
|
||||
|
||||
void set_output_layout(const ov::Output<Node>& port, const std::vector<size_t>& layout) {
|
||||
auto& rt_info = port.get_node_shared_ptr()->get_rt_info();
|
||||
rt_info["Layout"] = layout;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace snippets
|
||||
} // namespace ngraph
|
||||
|
@ -36,6 +36,9 @@ class DummyGenerator : public ngraph::snippets::Generator {
|
||||
public:
|
||||
DummyGenerator() : ngraph::snippets::Generator(std::make_shared<DummyTargetMachine>()) {}
|
||||
DummyGenerator(const std::shared_ptr<ngraph::snippets::TargetMachine>& t) : ngraph::snippets::Generator(t) {}
|
||||
|
||||
protected:
|
||||
opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const override { return vec2vec; };
|
||||
};
|
||||
|
||||
class LoweringTests : public TransformationTestsF {
|
||||
|
@ -19,18 +19,20 @@ using namespace ngraph;
|
||||
|
||||
// todo: Rewrite this test using Snippets test infrastructure. See ./include/canonicalization.hpp for example
|
||||
|
||||
template<typename T>
|
||||
size_t get_count(const std::shared_ptr<Function>& f, const std::string& name) {
|
||||
size_t load_count = std::numeric_limits<size_t>::max();
|
||||
size_t get_count(const std::shared_ptr<Function>& f, const std::string& name, bool is_load = true) {
|
||||
size_t count = std::numeric_limits<size_t>::max();
|
||||
for (auto op : f->get_ops()) {
|
||||
if (op->get_friendly_name() == name) {
|
||||
load_count = ov::as_type_ptr<T>(op)->get_count();
|
||||
if (const auto memory_access = std::dynamic_pointer_cast<snippets::op::MemoryAccess>(op)) {
|
||||
count = is_load ? memory_access->get_input_offset(0)
|
||||
: memory_access->get_output_offset(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
return load_count;
|
||||
return count;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SetScalarCountForLoad) {
|
||||
TEST(TransformationTests, SetScalarCountForLoadStore) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
const auto count = 16;
|
||||
{
|
||||
@ -39,11 +41,13 @@ TEST(TransformationTests, SetScalarCountForLoad) {
|
||||
load->set_friendly_name("load");
|
||||
auto neg = std::make_shared<opset1::Negative>(load);
|
||||
auto store = std::make_shared<snippets::isa::Store>(neg, count);
|
||||
store->set_friendly_name("store");
|
||||
f = std::make_shared<Function>(NodeVector{store}, ParameterVector{data});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<snippets::pass::SetScalarCountForLoad>();
|
||||
m.register_pass<snippets::pass::SetScalarCountForStore>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
@ -52,39 +56,6 @@ TEST(TransformationTests, SetScalarCountForLoad) {
|
||||
auto load = std::make_shared<snippets::isa::Load>(data, 1lu);
|
||||
load->set_friendly_name("load_ref");
|
||||
auto neg = std::make_shared<opset1::Negative>(load);
|
||||
auto store = std::make_shared<snippets::isa::Store>(neg, count);
|
||||
f_ref = std::make_shared<Function>(NodeVector{store}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
auto load_count = get_count<ngraph::snippets::op::Load>(f, "load");
|
||||
auto load_count_ref = get_count<ngraph::snippets::op::Load>(f_ref, "load_ref");
|
||||
ASSERT_EQ(load_count, load_count_ref);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SetScalarCountForStore) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
const auto count = 16;
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
|
||||
auto load = std::make_shared<snippets::isa::Load>(data, count);
|
||||
auto neg = std::make_shared<opset1::Negative>(load);
|
||||
auto store = std::make_shared<snippets::isa::Store>(neg, count);
|
||||
store->set_friendly_name("store");
|
||||
f = std::make_shared<Function>(NodeVector{store}, ParameterVector{data});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<snippets::pass::SetScalarCountForStore>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
|
||||
auto load = std::make_shared<snippets::isa::Load>(data, count);
|
||||
auto neg = std::make_shared<opset1::Negative>(load);
|
||||
auto store = std::make_shared<snippets::isa::Store>(neg, 1lu);
|
||||
store->set_friendly_name("store_ref");
|
||||
f_ref = std::make_shared<Function>(NodeVector{store}, ParameterVector{data});
|
||||
@ -93,7 +64,11 @@ TEST(TransformationTests, SetScalarCountForStore) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
int64_t store_count = get_count<ngraph::snippets::op::Store>(f, "store");
|
||||
int64_t store_count_ref = get_count<ngraph::snippets::op::Store>(f_ref, "store_ref");
|
||||
auto load_count = get_count(f, "load");
|
||||
auto load_count_ref = get_count(f_ref, "load_ref");
|
||||
ASSERT_EQ(load_count, load_count_ref);
|
||||
|
||||
auto store_count = get_count(f, "store", false);
|
||||
auto store_count_ref = get_count(f_ref, "store_ref", false);
|
||||
ASSERT_EQ(store_count, store_count_ref);
|
||||
}
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "lowering_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
@ -20,6 +21,7 @@ using namespace ngraph;
|
||||
// todo: Rewrite this test using Snippets test infrastructure. See ./include/canonicalization.hpp for example
|
||||
|
||||
TEST(TransformationTests, AssignRegisters) {
|
||||
const auto generator = std::make_shared<ov::test::snippets::DummyGenerator>();
|
||||
std::shared_ptr<Function> f(nullptr);
|
||||
{
|
||||
auto p0 = std::make_shared<opset1::Parameter>(element::f32, Shape(1));
|
||||
@ -37,7 +39,12 @@ TEST(TransformationTests, AssignRegisters) {
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<snippets::pass::AssignRegisters>();
|
||||
std::function<snippets::Generator::opRegType(const std::shared_ptr<Node>& op)> reg_type_mapper =
|
||||
[=](const std::shared_ptr<Node>& op) -> snippets::Generator::opRegType {
|
||||
return generator->get_op_reg_type(op);
|
||||
};
|
||||
m.register_pass<snippets::pass::AssignRegisters>(reg_type_mapper);
|
||||
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
@ -73,6 +80,7 @@ TEST(TransformationTests, AssignRegisters) {
|
||||
}
|
||||
|
||||
TEST(TransformationTests, AssignRegisters2) {
|
||||
const auto generator = std::make_shared<ov::test::snippets::DummyGenerator>();
|
||||
std::shared_ptr<Function> f(nullptr);
|
||||
{
|
||||
auto p0 = std::make_shared<opset1::Parameter>(ngraph::element::f32, Shape());
|
||||
@ -126,7 +134,11 @@ TEST(TransformationTests, AssignRegisters2) {
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<snippets::pass::AssignRegisters>();
|
||||
std::function<snippets::Generator::opRegType(const std::shared_ptr<Node>& op)> reg_type_mapper =
|
||||
[=](const std::shared_ptr<Node>& op) -> snippets::Generator::opRegType {
|
||||
return generator->get_op_reg_type(op);
|
||||
};
|
||||
m.register_pass<snippets::pass::AssignRegisters>(reg_type_mapper);
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
@ -18,7 +18,8 @@
|
||||
#include "snippets_transformations/op/load_convert.hpp"
|
||||
#include "snippets_transformations/op/store_convert.hpp"
|
||||
#include "snippets_transformations/op/fused_mul_add.hpp"
|
||||
#include "snippets/op/brgemm.hpp"
|
||||
#include "snippets_transformations/op/brgemm_copy_b.hpp"
|
||||
#include "snippets_transformations/op/brgemm_cpu.hpp"
|
||||
#include "ngraph_transformations/op/swish_cpu.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
@ -144,7 +145,8 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_
|
||||
jitters[ngraph::snippets::op::Kernel::get_type_info_static()] = CREATE_EMITTER(KernelEmitter);
|
||||
jitters[ngraph::snippets::op::LoopBegin::get_type_info_static()] = CREATE_EMITTER(LoopBeginEmitter);
|
||||
jitters[ngraph::snippets::op::LoopEnd::get_type_info_static()] = CREATE_EMITTER(LoopEndEmitter);
|
||||
jitters[ngraph::snippets::op::Brgemm::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter);
|
||||
jitters[ov::intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter);
|
||||
jitters[ov::intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_EMITTER(BrgemmCopyBEmitter);
|
||||
}
|
||||
|
||||
size_t ov::intel_cpu::CPUTargetMachine::get_lanes() const {
|
||||
@ -169,3 +171,15 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const {
|
||||
|
||||
ov::intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared<CPUTargetMachine>(isa_)) {
|
||||
}
|
||||
|
||||
ngraph::snippets::Generator::opRegType ov::intel_cpu::CPUGenerator::get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const {
|
||||
if (std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCPU>(op) ||
|
||||
std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCopyB>(op))
|
||||
return gpr2gpr;
|
||||
else if (
|
||||
std::dynamic_pointer_cast<ov::intel_cpu::FusedMulAdd>(op) ||
|
||||
std::dynamic_pointer_cast<ov::intel_cpu::SwishNode>(op))
|
||||
return vec2vec;
|
||||
else
|
||||
throw ov::Exception("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!");
|
||||
}
|
||||
|
@ -28,6 +28,9 @@ private:
|
||||
class CPUGenerator : public ngraph::snippets::Generator {
|
||||
public:
|
||||
CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa);
|
||||
|
||||
protected:
|
||||
opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const override;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -6,9 +6,10 @@
|
||||
#include <cpu/x64/jit_generator.hpp>
|
||||
|
||||
#include "jit_snippets_emitters.hpp"
|
||||
#include "snippets/op/brgemm.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "snippets_transformations/op/brgemm_copy_b.hpp"
|
||||
#include "snippets_transformations/op/brgemm_cpu.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using ngraph::snippets::op::Subgraph;
|
||||
@ -20,6 +21,10 @@ using namespace dnnl::impl::cpu::x64;
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
namespace {
|
||||
constexpr size_t gpr_size = 8;
|
||||
} // namespace
|
||||
|
||||
inline static void transform_idxs_to_regs(const std::vector<size_t>& idxs, std::vector<Reg64>& regs) {
|
||||
regs.resize(idxs.size());
|
||||
std::transform(idxs.begin(), idxs.end(), regs.begin(), [](size_t idx){return Reg64(static_cast<int>(idx));});
|
||||
@ -68,7 +73,8 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool,
|
||||
// where all utility emitters align with conventional Op emitters
|
||||
if (std::dynamic_pointer_cast<LoopBeginEmitter>(emitter) ||
|
||||
std::dynamic_pointer_cast<LoopEndEmitter>(emitter) ||
|
||||
std::dynamic_pointer_cast<BrgemmEmitter>(emitter))
|
||||
std::dynamic_pointer_cast<BrgemmEmitter>(emitter) ||
|
||||
std::dynamic_pointer_cast<BrgemmCopyBEmitter>(emitter))
|
||||
in_physical_regs = map_regs(in_abstract_regs, gpr_map_pool);
|
||||
else
|
||||
in_physical_regs = std::move(in_abstract_regs);
|
||||
@ -182,7 +188,8 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
||||
// todo: how this will be handled if Brgemm in & out are op::Buffer
|
||||
// Brgemm is a special case since it incorporates input and output (we use onednn kernel)
|
||||
// Just like Load & Store it requires offsets calculation
|
||||
const auto is_brgemm = std::dynamic_pointer_cast<BrgemmEmitter>(emitter) != nullptr;
|
||||
const auto is_brgemm = std::dynamic_pointer_cast<BrgemmEmitter>(emitter) ||
|
||||
std::dynamic_pointer_cast<BrgemmCopyBEmitter>(emitter);
|
||||
return emitter_type == gpr_to_vec || emitter_type == vec_to_gpr || is_brgemm;
|
||||
});
|
||||
// Note that we can't use reg_indexes_idx or reg_const_params_idx to store data pointers because these two
|
||||
@ -567,9 +574,6 @@ LoadEmitter::LoadEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu
|
||||
IE_THROW() << "LoadEmitter supports only equal input and output types but gets: " << src_prc.name() << " and " << dst_prc.name();
|
||||
|
||||
const auto load = std::dynamic_pointer_cast<ngraph::snippets::op::Load>(n);
|
||||
if (!load)
|
||||
IE_THROW() << "LoadEmitter expects Load snippets op";
|
||||
|
||||
count = load->get_count();
|
||||
byte_offset = load->get_offset();
|
||||
in_out_type_ = emitter_in_out_map::gpr_to_vec;
|
||||
@ -606,9 +610,6 @@ BroadcastLoadEmitter::BroadcastLoadEmitter(dnnl::impl::cpu::x64::jit_generator*
|
||||
IE_THROW() << "BroadcastEmitters support only equal input and output types but gets: " << src_prc.name() << " and " << dst_prc.name();
|
||||
|
||||
const auto broadcast_load = std::dynamic_pointer_cast<ngraph::snippets::op::BroadcastLoad>(n);
|
||||
if (!broadcast_load)
|
||||
IE_THROW() << "BroadcastLoadEmitter expects BroadcastLoad snippets op";
|
||||
|
||||
byte_offset = broadcast_load->get_offset();
|
||||
in_out_type_ = emitter_in_out_map::gpr_to_vec;
|
||||
}
|
||||
@ -717,12 +718,15 @@ size_t BrgemmEmitter::getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const {
|
||||
return mIdx * 4 + kIdx * 2 + nIdx;
|
||||
}
|
||||
BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
|
||||
const std::shared_ptr<ov::Node>& node) : jit_emitter(h, isa, node) {
|
||||
const std::shared_ptr<ov::Node>& node) : jit_emitter(h, isa, node) {
|
||||
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
|
||||
const auto& brgemm_node = as_type_ptr<ngraph::snippets::op::Brgemm>(node);
|
||||
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
|
||||
if (brgemm_node->is_dynamic())
|
||||
IE_THROW() << "Snippets don't support code generation for dynamic Brgemm";
|
||||
const OutputVector io_values {brgemm_node->input_value(0), brgemm_node->input_value(1), brgemm_node->output(0)};
|
||||
const auto brgemm_copy = brgemm_node->is_with_data_repacking() ? brgemm_node->get_brgemm_copy() : nullptr;
|
||||
const OutputVector io_values {brgemm_node->input_value(0),
|
||||
brgemm_copy ? brgemm_copy->input_value(0) : brgemm_node->input_value(1),
|
||||
brgemm_node->output(0)};
|
||||
std::vector<size_t> leading_dimensions;
|
||||
std::vector<std::vector<size_t>> io_layouts;
|
||||
for (const auto& val : io_values) {
|
||||
@ -747,51 +751,61 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
||||
io_layouts.push_back(layout);
|
||||
}
|
||||
}
|
||||
// todo: leave AMX and VNNI related code for now, it'll help to enable int8 and bf16 support
|
||||
bool isAMXSupported = mayiuse(avx512_core_amx);
|
||||
|
||||
const auto& A_shape = io_values[0].get_shape();
|
||||
const auto& A_layout = io_layouts[0];
|
||||
const auto& C_shape = io_values[2].get_shape();
|
||||
const auto& C_layout = io_layouts[2];
|
||||
|
||||
M = C_shape[C_layout[2]];
|
||||
K = A_shape[A_layout[3]];
|
||||
M_blk = matmulOptimalM;
|
||||
M_tail = M % M_blk;
|
||||
// We need find original M,N,K having layouts and ordered shapes
|
||||
// Layout: 0, 1, 2, 3 => New layout: 0, 2, 1, 3
|
||||
// Shape: 1, 3, 5, 9 => New Shape: 1, 5, 3, 9
|
||||
// To find original 2nd dimension, we should find index of position value `2` in new layout
|
||||
// and get dimension from new shape by this index
|
||||
auto get_ordered_idx = [](const std::vector<size_t>& layout, size_t idx) {
|
||||
return std::distance(layout.begin(), std::find(layout.begin(), layout.end(), idx));
|
||||
};
|
||||
|
||||
m_M = C_shape[get_ordered_idx(C_layout, C_layout.size() - 2)];
|
||||
m_K = A_shape[get_ordered_idx(A_layout, A_layout.size() - 1)];
|
||||
m_M_blk = matmulOptimalM;
|
||||
m_M_tail = m_M % m_M_blk;
|
||||
// B_shape[B_layout[3]]
|
||||
N = C_shape[C_layout[3]];
|
||||
m_N = C_shape[get_ordered_idx(C_layout, C_layout.size() - 1)];
|
||||
|
||||
auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0));
|
||||
auto brg1Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(1));
|
||||
io_data_size = {brg0Prc.size(), brg1Prc.size(), brgemm_node->get_output_element_type(0).size()};
|
||||
brg0VnniFactor = 4 / brg0Prc.size();
|
||||
bool brg0WithAMX = isAMXSupported && brg0Prc != Precision::FP32 && (K % brg0VnniFactor == 0) && (N % brg0VnniFactor == 0);
|
||||
m_brg0VnniFactor = 4 / brg0Prc.size();
|
||||
bool brgWithAMX = brgemm_node->is_amx();
|
||||
|
||||
N_blk = brg0Prc == Precision::FP32 ? N :
|
||||
brg0Prc == Precision::BF16 ? 32 : 64;
|
||||
N_tail = N % N_blk;
|
||||
K_blk = brg0WithAMX ? brg0Prc == Precision::BF16 ? 32 : 64
|
||||
: K;
|
||||
K_tail = K % K_blk;
|
||||
m_with_comp = brgemm_node->is_with_compensations();
|
||||
m_with_scratch = brgemm_node->is_with_scratchpad();
|
||||
|
||||
m_N_blk = brg1Prc == Precision::FP32 ? m_N :
|
||||
brg1Prc == Precision::BF16 ? 32 : 64;
|
||||
m_N_tail = m_N % m_N_blk;
|
||||
m_K_blk = brgWithAMX ? brg0Prc == Precision::BF16 ? 32 : 64
|
||||
: m_K;
|
||||
m_K_tail = m_K % m_K_blk;
|
||||
|
||||
size_t brg0BaseIdx = -1;
|
||||
for (size_t m = 0; m < 2; m++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
auto& brgemmCtx = brgCtxs0[getBrgIdx(m, k, n)];
|
||||
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(m, k, n)];
|
||||
|
||||
auto M_ = m ? M_tail
|
||||
: M < M_blk ? 0 : M_blk;
|
||||
auto N_ = n ? N_tail : N - N_tail;
|
||||
auto K_ = k ? K_tail : K - K_tail;
|
||||
auto beta = k && brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f;
|
||||
auto M_ = m ? m_M_tail
|
||||
: m_M < m_M_blk ? 0 : m_M_blk;
|
||||
auto N_ = n ? m_N_tail : m_N - m_N_tail;
|
||||
auto K_ = k ? m_K_tail : m_K - m_K_tail;
|
||||
auto beta = k && m_brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f;
|
||||
|
||||
brgemmCtx.M = M_;
|
||||
brgemmCtx.N = N_;
|
||||
brgemmCtx.K = K_;
|
||||
brgemmCtx.LDA = leading_dimensions[0];
|
||||
brgemmCtx.LDB = leading_dimensions[1];
|
||||
brgemmCtx.LDB = brgemm_node->is_with_data_repacking() ? rnd_up(m_N, m_N_blk) : leading_dimensions[1];
|
||||
brgemmCtx.LDC = leading_dimensions[2];
|
||||
brgemmCtx.dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc));
|
||||
brgemmCtx.dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc));
|
||||
@ -801,22 +815,46 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
||||
if (M_ != 0 && K_ != 0 && N_ != 0) {
|
||||
if (brg0BaseIdx == -1)
|
||||
brg0BaseIdx = getBrgIdx(m, k, n);
|
||||
initBrgemm(brgemmCtx, brgKernels0[getBrgIdx(m, k, n)], brg0WithAMX);
|
||||
initBrgemm(brgemmCtx, m_brgKernels0[getBrgIdx(m, k, n)], brgWithAMX);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
load_offset_a = brgemm_node->get_offset_a();
|
||||
load_offset_b = brgemm_node->get_offset_b();
|
||||
store_offset_c = brgemm_node->get_offset_c();
|
||||
m_load_offset_a = brgemm_node->get_offset_a();
|
||||
m_load_offset_b = brgemm_node->get_offset_b();
|
||||
m_store_offset_c = brgemm_node->get_offset_c();
|
||||
if (m_with_scratch)
|
||||
m_load_offset_scratch = brgemm_node->get_offset_scratch();
|
||||
}
|
||||
|
||||
std::set<std::vector<element::Type>> BrgemmEmitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
|
||||
const auto brgemm = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
|
||||
OPENVINO_ASSERT(brgemm, "BrgemmEmitter::get_supported_precisions() expects BrgemmCPU node");
|
||||
switch (brgemm->get_type()) {
|
||||
case BrgemmCPU::Type::Floating:
|
||||
return {{element::f32, element::f32}};
|
||||
case BrgemmCPU::Type::WithDataRepacking:
|
||||
return {{element::u8, element::i8},
|
||||
{element::bf16, element::bf16}};
|
||||
case BrgemmCPU::Type::WithCompensations:
|
||||
return {{element::i8, element::i8, element::f32}};
|
||||
case BrgemmCPU::Type::AMX:
|
||||
return {{element::i8, element::i8, element::u8},
|
||||
{element::u8, element::i8, element::u8},
|
||||
{element::bf16, element::bf16, element::u8}};
|
||||
default:
|
||||
throw ov::Exception("BrgemmEmitter got BrgemmCPU node with unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr<brgemm_kernel_t>& brgKernel, bool use_amx) const {
|
||||
brgemm_t brgDesc;
|
||||
brgemm_strides_t strides {static_cast<dnnl_dim_t>(ctx.M * ctx.K), static_cast<dnnl_dim_t>(ctx.K * ctx.N)};
|
||||
// When implementing int8 support, note that isa logics is more complicated in the MHA node
|
||||
auto status = brgemm_desc_init(&brgDesc, host_isa_, brgemm_strd, ctx.dt_in0, ctx.dt_in1,
|
||||
const bool is_int8 = utils::one_of(ctx.dt_in0, data_type::u8, data_type::s8) && utils::one_of(ctx.dt_in1, data_type::u8, data_type::s8);
|
||||
auto isa = use_amx ? isa_undef
|
||||
: ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : (is_int8 ? avx512_core_vnni : avx512_core);
|
||||
auto status = brgemm_desc_init(&brgDesc, isa, brgemm_strd, ctx.dt_in0, ctx.dt_in1,
|
||||
false, false, brgemm_row_major, 1.f, ctx.beta, ctx.LDA, ctx.LDB, ctx.LDC, ctx.M, ctx.N, ctx.K, &strides);
|
||||
if (status != dnnl_success)
|
||||
IE_THROW() << "BrgemmEmitter cannot initialize brgemm descriptor due to invalid params";
|
||||
@ -837,23 +875,91 @@ void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr<brgemm_kernel_t>&
|
||||
|
||||
void BrgemmEmitter::emit_impl(const std::vector<size_t>& in,
|
||||
const std::vector<size_t>& out) const {
|
||||
if (host_isa_ == cpu::x64::sse41 || host_isa_ == cpu::x64::avx2) {
|
||||
IE_THROW() << "BrgemmEmitter requires at least avx512_core instruction set";
|
||||
} else if (host_isa_ == cpu::x64::avx512_core) {
|
||||
emit_isa<cpu::x64::avx512_core>(in, out);
|
||||
if (host_isa_ == cpu::x64::avx512_core) {
|
||||
Xbyak::Reg64 input_0(static_cast<int>(in[0]));
|
||||
Xbyak::Reg64 input_1(static_cast<int>(in[1]));
|
||||
Xbyak::Reg64 input_2(static_cast<int>(0)); // scratch. Default reg index is 0 if there isn't scratch
|
||||
if (m_with_scratch) {
|
||||
if (in.size() != 3) {
|
||||
IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp";
|
||||
}
|
||||
input_2 = Xbyak::Reg64(static_cast<int>(in[2]));
|
||||
}
|
||||
Xbyak::Reg64 output_0(static_cast<int>(out[0]));
|
||||
|
||||
for (size_t mb = 0; mb < div_up(m_M, m_M_blk); mb++) {
|
||||
const bool is_M_tail = (m_M - mb * m_M_blk < m_M_blk);
|
||||
|
||||
size_t brgIdx0 = getBrgIdx(0, 0, 0);
|
||||
size_t K0_step0 = m_brgCtxs0[brgIdx0].K;
|
||||
size_t K0_step1 = m_brgCtxs0[brgIdx0].K * m_brgCtxs0[brgIdx0].LDB;
|
||||
size_t N0_step0 = m_brgCtxs0[brgIdx0].N * m_brg0VnniFactor;
|
||||
size_t N0_step1 = m_brgCtxs0[brgIdx0].N;
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
size_t mIdx = is_M_tail ? 1 : 0;
|
||||
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(mIdx, k, n)];
|
||||
|
||||
if (brgemmCtx.K != 0 && brgemmCtx.N != 0) {
|
||||
const size_t in0_offset = m_load_offset_a + (k * K0_step0 + mb * m_M_blk * brgemmCtx.LDA) * io_data_size[0];
|
||||
const size_t in1_offset = m_load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1];
|
||||
const size_t in2_offset = m_load_offset_scratch + (m_with_comp ? n * N0_step1 * sizeof(int32_t) : 0);
|
||||
const size_t out0_offset = m_store_offset_c + (n * N0_step1 + mb * m_M_blk * brgemmCtx.LDC) * io_data_size[2];
|
||||
|
||||
emit_brgemm_kernel_call(m_brgKernels0[getBrgIdx(mIdx, k, n)].get(),
|
||||
brgemmCtx,
|
||||
input_0,
|
||||
input_1,
|
||||
input_2,
|
||||
output_0,
|
||||
in0_offset,
|
||||
in1_offset,
|
||||
in2_offset,
|
||||
out0_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert(!"unsupported isa");
|
||||
IE_THROW() << "BrgemmEmitter requires at least avx512_core instruction set";
|
||||
}
|
||||
}
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, int bs,
|
||||
Reg64 addr_A, Reg64 addr_B,
|
||||
const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch,
|
||||
const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const {
|
||||
using Vmm = typename dnnl::impl::utils::conditional3<isa == cpu::x64::sse41, Xmm, isa == cpu::x64::avx2, Ymm, Zmm>::type;
|
||||
size_t gpr_size = 8;
|
||||
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax,
|
||||
h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx};
|
||||
|
||||
void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, const brgemmCtx& ctx,
|
||||
Reg64 addr_A, Reg64 addr_B, Reg64 scratch, Reg64 addr_C,
|
||||
const size_t in0_kernel_offset, const size_t in1_kernel_offset,
|
||||
const size_t in2_kernel_offset, const size_t out0_kernel_offset) const {
|
||||
if (ctx.is_with_amx) {
|
||||
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax,
|
||||
h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx};
|
||||
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);
|
||||
|
||||
h->sub(h->rsp, n_gprs_to_save * gpr_size);
|
||||
for (size_t i = 0; i < n_gprs_to_save; ++i)
|
||||
h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]);
|
||||
|
||||
// save function address in gpr to pass in call instruction
|
||||
const auto& overload = static_cast<status_t(*)(const char*)>(amx_tile_configure);
|
||||
h->mov(h->rbp, reinterpret_cast<uintptr_t>(overload));
|
||||
h->mov(abi_param1, reinterpret_cast<uintptr_t>(ctx.palette));
|
||||
|
||||
// align stack on 16-byte as ABI requires
|
||||
// note that RBX must not be changed by the callee
|
||||
h->mov(h->rbx, h->rsp);
|
||||
h->and_(h->rbx, 0xf);
|
||||
h->sub(h->rsp, h->rbx);
|
||||
|
||||
h->call(h->rbp);
|
||||
|
||||
h->add(h->rsp, h->rbx);
|
||||
// restore gpr registers
|
||||
for (int i = n_gprs_to_save - 1; i >= 0; --i)
|
||||
h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]);
|
||||
h->add(h->rsp, n_gprs_to_save * gpr_size);
|
||||
}
|
||||
|
||||
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
|
||||
h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx};
|
||||
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);
|
||||
|
||||
h->sub(h->rsp, n_gprs_to_save * gpr_size);
|
||||
@ -862,14 +968,12 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in
|
||||
|
||||
// caller obligation to save k-regs as callee may use them
|
||||
size_t n_k_regs_to_save = 8;
|
||||
if (isa == cpu::x64::avx512_core) {
|
||||
h->sub(h->rsp, n_k_regs_to_save * k_mask_size);
|
||||
for (size_t i = 0; i < n_k_regs_to_save; ++i) {
|
||||
if (mayiuse(avx512_core))
|
||||
h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast<int>(i)));
|
||||
else
|
||||
h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast<int>(i)));
|
||||
}
|
||||
h->sub(h->rsp, n_k_regs_to_save * k_mask_size);
|
||||
for (size_t i = 0; i < n_k_regs_to_save; ++i) {
|
||||
if (mayiuse(avx512_core))
|
||||
h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast<int>(i)));
|
||||
else
|
||||
h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast<int>(i)));
|
||||
}
|
||||
|
||||
// 1. Caller obligation to save vector registers as callee may use them.
|
||||
@ -879,13 +983,16 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in
|
||||
// `host_isa::vecs_count`.
|
||||
h->sub(h->rsp, get_max_vecs_count() * get_vec_length());
|
||||
for (size_t i = 0; i < get_max_vecs_count(); ++i)
|
||||
h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Vmm(i));
|
||||
h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Zmm(i));
|
||||
|
||||
size_t num_args_passed_on_stack = 0;
|
||||
// save function address in gpr to pass in call instruction
|
||||
const auto& brgemm_kernel_overload = static_cast<void (*)(const brgemm_kernel_t*,
|
||||
const void*,
|
||||
const void*,
|
||||
void*)>(kernel_execute);
|
||||
void*,
|
||||
void*,
|
||||
int)>(kernel_execute);
|
||||
h->mov(h->rbp, reinterpret_cast<uintptr_t>(brgemm_kernel_overload));
|
||||
// todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted
|
||||
// if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption.
|
||||
@ -893,16 +1000,44 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in
|
||||
h->uni_vmovq(Xmm(0), addr_A);
|
||||
h->uni_vmovq(Xmm(1), addr_B);
|
||||
h->uni_vmovq(Xmm(2), addr_C);
|
||||
|
||||
if (m_with_scratch)
|
||||
h->uni_vmovq(Xmm(3), scratch);
|
||||
// todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align.
|
||||
const auto data_ptr_reg = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) {
|
||||
h->uni_vmovq(reg, xmm);
|
||||
if (bytes_offset) h->add(reg, bytes_offset);
|
||||
};
|
||||
h->mov(abi_param1, reinterpret_cast<uintptr_t>(brgKernel));
|
||||
h->mov(abi_param1, reinterpret_cast<uintptr_t>(brg_kernel));
|
||||
data_ptr_reg(Xmm(0), abi_param2, in0_kernel_offset);
|
||||
data_ptr_reg(Xmm(1), abi_param3, in1_kernel_offset);
|
||||
data_ptr_reg(Xmm(2), abi_param4, out0_kernel_offset);
|
||||
|
||||
#ifdef _WIN32
|
||||
// Before function call we should allocate stack area for
|
||||
// - register parameters - ABI parameters (shadow space)
|
||||
// - stack parameters - remaining parameters
|
||||
num_args_passed_on_stack = 6; // count of function brgemm_kernel_overload() parameters
|
||||
size_t abi_param_count = sizeof(abi_param_regs) / sizeof(abi_param_regs[0]);
|
||||
h->sub(h->rsp, num_args_passed_on_stack * gpr_size);
|
||||
|
||||
// Push the remaining parameters on the stack
|
||||
if (m_with_scratch) {
|
||||
h->uni_vmovq(h->qword[h->rsp + (abi_param_count + 0) * gpr_size], Xmm(3));
|
||||
if (in2_kernel_offset) h->add(h->qword[h->rsp + (abi_param_count + 0) * gpr_size], in2_kernel_offset);
|
||||
} else {
|
||||
h->mov(h->qword[h->rsp + (abi_param_count + 0) * gpr_size], reinterpret_cast<uintptr_t>(nullptr));
|
||||
}
|
||||
h->mov(abi_not_param1, static_cast<int>(m_with_comp));
|
||||
h->mov(h->qword[h->rsp + (abi_param_count + 1) * gpr_size], abi_not_param1);
|
||||
#else
|
||||
if (m_with_scratch) {
|
||||
data_ptr_reg(Xmm(3), abi_param5, in2_kernel_offset);
|
||||
} else {
|
||||
h->mov(abi_param5, reinterpret_cast<uintptr_t>(nullptr));
|
||||
}
|
||||
h->mov(abi_param6, static_cast<int>(m_with_comp));
|
||||
#endif
|
||||
|
||||
// align stack on 16-byte as ABI requires
|
||||
// note that RBX must not be changed by the callee
|
||||
h->mov(h->rbx, h->rsp);
|
||||
@ -912,22 +1047,22 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in
|
||||
h->call(h->rbp);
|
||||
|
||||
h->add(h->rsp, h->rbx);
|
||||
if (num_args_passed_on_stack > 0)
|
||||
h->add(h->rsp, num_args_passed_on_stack * gpr_size);
|
||||
// restore vector registers
|
||||
for (int i = static_cast<int>(get_max_vecs_count()) - 1; i >= 0; --i) {
|
||||
h->uni_vmovups(Vmm(i), h->ptr[h->rsp + i * get_vec_length()]);
|
||||
h->uni_vmovups(Zmm(i), h->ptr[h->rsp + i * get_vec_length()]);
|
||||
}
|
||||
h->add(h->rsp, (get_max_vecs_count()) * get_vec_length());
|
||||
|
||||
// restore k registers
|
||||
if (isa == cpu::x64::avx512_core) {
|
||||
for (int i = n_k_regs_to_save - 1; i >= 0; --i) {
|
||||
if (mayiuse(avx512_core))
|
||||
h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
|
||||
else
|
||||
h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
|
||||
}
|
||||
h->add(h->rsp, n_k_regs_to_save * k_mask_size);
|
||||
for (int i = n_k_regs_to_save - 1; i >= 0; --i) {
|
||||
if (mayiuse(avx512_core))
|
||||
h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
|
||||
else
|
||||
h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
|
||||
}
|
||||
h->add(h->rsp, n_k_regs_to_save * k_mask_size);
|
||||
|
||||
// restore gpr registers
|
||||
for (int i = n_gprs_to_save - 1; i >= 0; --i)
|
||||
@ -935,9 +1070,8 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in
|
||||
h->add(h->rsp, n_gprs_to_save * gpr_size);
|
||||
}
|
||||
|
||||
void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C) {
|
||||
// TODO: There are 4 available abi_params on Windows so we have the copy of brgemm_kernel_execute() function
|
||||
// with 4 runtime parameters (kernel and I/O) and 4 default parameter values (batch, bs and scratch)
|
||||
void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel,
|
||||
const void *A, const void *B, void *C, void *scratch, int with_comp) {
|
||||
brgemm_kernel_params_t brgemm_p;
|
||||
|
||||
brgemm_p.batch = nullptr; // default value
|
||||
@ -945,54 +1079,266 @@ void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, const void
|
||||
brgemm_p.ptr_B = B;
|
||||
brgemm_p.ptr_C = C;
|
||||
brgemm_p.ptr_D = C;
|
||||
brgemm_p.ptr_buf = nullptr; // default value
|
||||
brgemm_p.ptr_buf = scratch;
|
||||
brgemm_p.ptr_bias = nullptr;
|
||||
brgemm_p.do_post_ops = 0;
|
||||
brgemm_p.do_apply_comp = 0;
|
||||
brgemm_p.do_post_ops = static_cast<size_t>(with_comp);
|
||||
brgemm_p.do_apply_comp = static_cast<size_t>(with_comp);
|
||||
brgemm_p.skip_accm = 0;
|
||||
brgemm_p.BS = 1; // default value
|
||||
assert(brg_kernel);
|
||||
(*brg_kernel)(&brgemm_p);
|
||||
}
|
||||
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
void BrgemmEmitter::emit_isa(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
|
||||
Reg64 input_0(static_cast<int>(in[0]));
|
||||
Reg64 input_1(static_cast<int>(in[1]));
|
||||
Reg64 output_0(static_cast<int>(out[0]));
|
||||
BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n)
|
||||
: jit_emitter(h, isa, n) {
|
||||
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
|
||||
const auto brgemm_repack = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyB>(n);
|
||||
if (!brgemm_repack)
|
||||
IE_THROW() << "BrgemmCopyBEmitters expects BrgemmCopyB node";
|
||||
|
||||
for (size_t mb = 0; mb < div_up(M, M_blk); mb++) {
|
||||
const bool is_M_tail = (M - mb * M_blk < M_blk);
|
||||
m_brgemm_prc_in0 = brgemm_repack->get_src_element_type();
|
||||
m_brgemm_prc_in1 = brgemm_repack->get_input_element_type(0);
|
||||
m_brgemmVNNIFactor = 4 / m_brgemm_prc_in0.size();
|
||||
m_with_comp = brgemm_repack->is_with_compensations();
|
||||
m_in_offset = brgemm_repack->get_offset_in();
|
||||
m_out_offset = brgemm_repack->get_offset_out();
|
||||
if (m_with_comp)
|
||||
m_comp_offset = brgemm_repack->get_offset_compensations();
|
||||
|
||||
size_t brgIdx0 = getBrgIdx(0, 0, 0);
|
||||
size_t K0_step0 = brgCtxs0[brgIdx0].K;
|
||||
size_t K0_step1 = brgCtxs0[brgIdx0].K * brgCtxs0[brgIdx0].LDB;
|
||||
size_t N0_step0 = brgCtxs0[brgIdx0].N * brg0VnniFactor;
|
||||
size_t N0_step1 = brgCtxs0[brgIdx0].N;
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
size_t mIdx = is_M_tail ? 1 : 0;
|
||||
auto& brgemmCtx = brgCtxs0[getBrgIdx(mIdx, k, n)];
|
||||
|
||||
if (brgemmCtx.K != 0 && brgemmCtx.N != 0) {
|
||||
const size_t in0_offset = load_offset_a + (k * K0_step0 + mb * M_blk * brgemmCtx.LDA) * io_data_size[0];
|
||||
const size_t in1_offset = load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1];
|
||||
const size_t out0_offset = store_offset_c + (n * N0_step1 + mb * M_blk * brgemmCtx.LDC) * io_data_size[2];
|
||||
|
||||
emit_brgemm_kernel_call<isa>(brgKernels0[getBrgIdx(mIdx, k, n)].get(),
|
||||
1,
|
||||
input_0,
|
||||
input_1,
|
||||
nullptr,
|
||||
output_0,
|
||||
nullptr,
|
||||
in0_offset,
|
||||
in1_offset,
|
||||
out0_offset);
|
||||
}
|
||||
}
|
||||
auto layout = ngraph::snippets::utils::get_node_output_layout(brgemm_repack->get_input_node_shared_ptr(0));
|
||||
const auto& original_shape = brgemm_repack->get_input_shape(0);
|
||||
auto transposed_shape = original_shape;
|
||||
size_t leading_dimension = *(original_shape.rbegin());
|
||||
if (!layout.empty()) {
|
||||
transposed_shape.resize(layout.size(), 1);
|
||||
for (size_t i = 0; i < layout.size(); ++i) {
|
||||
transposed_shape[i] = original_shape[layout[i]];
|
||||
}
|
||||
// The idea here is to find "2" (for 4D shapes) in the layout and multiply dimensions that are to the right
|
||||
// This implies that "3" is the last layout value, otherwise this layout is not supported.
|
||||
// counting from the end since shape could be prepended with ones
|
||||
const int64_t num_last_dims = layout.end() - std::find(layout.begin(), layout.end(), layout.size() - 2) - 1;
|
||||
if (layout.back() != layout.size() - 1 || num_last_dims < 1)
|
||||
IE_THROW() << "BrgemmRepackEmitter detected invalid layout values: " <<
|
||||
"check that this shape + layout combination is schedulable";
|
||||
leading_dimension = std::accumulate(original_shape.end() - num_last_dims, original_shape.end(), 1, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
m_N = *(transposed_shape.rbegin());
|
||||
m_K = *(transposed_shape.rbegin() + 1);
|
||||
|
||||
const bool isAMXSupported = mayiuse(avx512_core_amx);
|
||||
const auto use_amx = isAMXSupported && m_brgemm_prc_in0 != ov::element::f32 && (m_K % m_brgemmVNNIFactor == 0) && (m_N % m_brgemmVNNIFactor == 0);
|
||||
|
||||
m_N_blk = m_brgemm_prc_in1 == ov::element::f32 ? m_N :
|
||||
m_brgemm_prc_in1 == ov::element::bf16 ? 32 : 64;
|
||||
m_K_blk = use_amx ? m_brgemm_prc_in0 == ov::element::bf16 ? 32 : 64
|
||||
: m_K;
|
||||
m_N_tail = m_N % m_N_blk;
|
||||
m_K_tail = m_K % m_K_blk;
|
||||
m_LDB = m_brgemm_prc_in1 == ov::element::f32 ? leading_dimension : rnd_up(m_N, m_N_blk);
|
||||
|
||||
const auto dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(m_brgemm_prc_in0)));
|
||||
const auto dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(m_brgemm_prc_in1)));
|
||||
init_brgemm_copy(m_kernel, leading_dimension, m_N_blk, m_N_tail, m_LDB, m_K - m_K_tail, use_amx, dt_in0, dt_in1);
|
||||
}
|
||||
|
||||
void BrgemmCopyBEmitter::init_brgemm_copy(std::unique_ptr<matmul::jit_brgemm_matmul_copy_b_t>& kernel,
|
||||
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K,
|
||||
bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const {
|
||||
matmul::brgemm_matmul_conf_t brgCopyKernelConf;
|
||||
brgCopyKernelConf.src_dt = dt_in0;
|
||||
brgCopyKernelConf.wei_dt = dt_in1;
|
||||
brgCopyKernelConf.wei_n_blk = static_cast<int>(N_blk);
|
||||
brgCopyKernelConf.wei_tag = dnnl_abcd; // What's about other ranks?
|
||||
brgCopyKernelConf.copy_B_wei_stride = 0;
|
||||
brgCopyKernelConf.LDB = static_cast<dim_t>(LDB);
|
||||
brgCopyKernelConf.N = static_cast<dim_t>(N);
|
||||
brgCopyKernelConf.N_tail = static_cast<dim_t>(N_tail);
|
||||
brgCopyKernelConf.N_blk = static_cast<dim_t>(N_blk);
|
||||
brgCopyKernelConf.K = static_cast<dim_t>(K);
|
||||
brgCopyKernelConf.K_blk = static_cast<dim_t>(K);
|
||||
brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk;
|
||||
brgCopyKernelConf.b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast<dnnl::memory::data_type>(brgCopyKernelConf.src_dt));
|
||||
brgCopyKernelConf.tr_b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast<dnnl::memory::data_type>(brgCopyKernelConf.src_dt));
|
||||
brgCopyKernelConf.req_wei_vnni_downconvert = false;
|
||||
|
||||
if (is_with_amx) {
|
||||
brgCopyKernelConf.isa = avx512_core_amx;
|
||||
brgCopyKernelConf.s8s8_compensation_required = false;
|
||||
} else {
|
||||
brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : avx512_core_vnni;
|
||||
brgCopyKernelConf.s8s8_compensation_required = dt_in0 == dnnl_data_type_t::dnnl_s8;
|
||||
}
|
||||
|
||||
brgCopyKernelConf.has_zero_point_a = false;
|
||||
brgCopyKernelConf.has_zero_point_b = false;
|
||||
brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none;
|
||||
|
||||
auto status = matmul::create_brgemm_matmul_copy_b(kernel, &brgCopyKernelConf);
|
||||
if (status != dnnl_success)
|
||||
IE_THROW() << "BrgemmRepackEmitter cannot create kernel due to invalid params";
|
||||
}
|
||||
|
||||
void BrgemmCopyBEmitter::emit_impl(const std::vector<size_t>& in,
|
||||
const std::vector<size_t>& out) const {
|
||||
if (host_isa_ == cpu::x64::avx512_core) {
|
||||
Xbyak::Reg64 src(static_cast<int>(in[0]));
|
||||
Xbyak::Reg64 dst(static_cast<int>(out[0]));
|
||||
Xbyak::Reg64 comp(static_cast<int>(0)); // Compensations. Default reg idx is 0 if there aren't the compensations
|
||||
if (m_with_comp) {
|
||||
if (out.size() != 2) {
|
||||
IE_THROW() << "BrgemmCopyBEmitter with compensations requires separate register for them";
|
||||
}
|
||||
comp = Xbyak::Reg64(static_cast<int>(out[1]));
|
||||
}
|
||||
|
||||
const size_t data_size = m_brgemm_prc_in1.size();
|
||||
for (size_t nb = 0; nb < div_up(m_N, m_N_blk); nb++) {
|
||||
const size_t offset_in = m_in_offset + nb * m_N_blk * data_size;
|
||||
const size_t offset_out = m_out_offset + nb * m_N_blk * m_brgemmVNNIFactor * data_size;
|
||||
const size_t offset_comp = m_with_comp ? m_comp_offset + nb * m_N_blk * sizeof(int32_t) : 0;
|
||||
|
||||
const bool is_N_tail = (m_N - nb * m_N_blk < m_N_blk);
|
||||
const auto current_N_blk = is_N_tail ? m_N_tail : m_N_blk;
|
||||
|
||||
emit_kernel_call(m_kernel.get(), src, dst, comp, current_N_blk, m_K, offset_in, offset_out, offset_comp);
|
||||
}
|
||||
} else {
|
||||
IE_THROW() << "BrgemmCopyBEmitter requires at least avx512_core instruction set";
|
||||
}
|
||||
}
|
||||
|
||||
void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp,
|
||||
size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const {
|
||||
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
|
||||
h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx};
|
||||
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);
|
||||
|
||||
h->sub(h->rsp, n_gprs_to_save * gpr_size);
|
||||
for (size_t i = 0; i < n_gprs_to_save; ++i)
|
||||
h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]);
|
||||
|
||||
// caller obligation to save k-regs as callee may use them
|
||||
size_t n_k_regs_to_save = 8;
|
||||
h->sub(h->rsp, n_k_regs_to_save * k_mask_size);
|
||||
for (size_t i = 0; i < n_k_regs_to_save; ++i) {
|
||||
if (mayiuse(avx512_core))
|
||||
h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast<int>(i)));
|
||||
else
|
||||
h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast<int>(i)));
|
||||
}
|
||||
|
||||
// 1. Caller obligation to save vector registers as callee may use them.
|
||||
// 2. There is an implicit assumption that the host code uses the same
|
||||
// `isa` as the injector. Once the assumption is wrong, `vecs_count` and
|
||||
// `vlen` should be replaced with `host_isa::vlen` and
|
||||
// `host_isa::vecs_count`.
|
||||
h->sub(h->rsp, get_max_vecs_count() * get_vec_length());
|
||||
for (size_t i = 0; i < get_max_vecs_count(); ++i)
|
||||
h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Zmm(i));
|
||||
|
||||
const auto data_ptr = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) {
|
||||
h->uni_vmovq(reg, xmm);
|
||||
if (bytes_offset) h->add(reg, bytes_offset);
|
||||
};
|
||||
#ifdef _WIN32
|
||||
const auto push_value = [&](size_t value, size_t index) {
|
||||
// Firstly we need to move integer to GPR. Then we can move value from GPR to stack
|
||||
h->mov(abi_not_param1, value);
|
||||
h->mov(h->qword[h->rsp + index * gpr_size], abi_not_param1);
|
||||
};
|
||||
#endif
|
||||
|
||||
size_t num_args_passed_on_stack = 0;
|
||||
// save function address in gpr to pass in call instruction
|
||||
const auto &kernel_overload = static_cast<void (*)(matmul::jit_brgemm_matmul_copy_b_t*,
|
||||
const void*,
|
||||
const void*,
|
||||
const void*,
|
||||
size_t,
|
||||
size_t)>(execute);
|
||||
h->mov(h->rbp, reinterpret_cast<uintptr_t>(kernel_overload));
|
||||
// todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted
|
||||
// if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption.
|
||||
// It's likely that a more efficient solution exists.
|
||||
h->uni_vmovq(Xmm(0), src);
|
||||
h->uni_vmovq(Xmm(1), dst);
|
||||
if (m_with_comp)
|
||||
h->uni_vmovq(Xmm(2), comp);
|
||||
// todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align.
|
||||
h->mov(abi_param1, reinterpret_cast<uintptr_t>(kernel));
|
||||
|
||||
data_ptr(Xmm(0), abi_param2, offset_in);
|
||||
data_ptr(Xmm(1), abi_param3, offset_out);
|
||||
if (m_with_comp) {
|
||||
data_ptr(Xmm(2), abi_param4, offset_comp);
|
||||
} else {
|
||||
h->mov(abi_param4, reinterpret_cast<uintptr_t>(nullptr));
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
// Before function call we should allocate stack area for
|
||||
// - register parameters - ABI parameters (shadow space)
|
||||
// - stack parameters - remaining parameters
|
||||
num_args_passed_on_stack = 6; // count of function kernel_overload() parameters
|
||||
size_t abi_param_count = sizeof(abi_param_regs) / sizeof(abi_param_regs[0]);
|
||||
|
||||
h->sub(h->rsp, num_args_passed_on_stack * gpr_size);
|
||||
push_value(N, abi_param_count + 0);
|
||||
push_value(K, abi_param_count + 1);
|
||||
#else
|
||||
h->mov(abi_param5, N);
|
||||
h->mov(abi_param6, K);
|
||||
#endif
|
||||
// align stack on 16-byte as ABI requires
|
||||
// note that RBX must not be changed by the callee
|
||||
h->mov(h->rbx, h->rsp);
|
||||
h->and_(h->rbx, 0xf);
|
||||
h->sub(h->rsp, h->rbx);
|
||||
|
||||
h->call(h->rbp);
|
||||
|
||||
h->add(h->rsp, h->rbx);
|
||||
if (num_args_passed_on_stack > 0)
|
||||
h->add(h->rsp, gpr_size * num_args_passed_on_stack);
|
||||
// restore vector registers
|
||||
for (int i = static_cast<int>(get_max_vecs_count()) - 1; i >= 0; --i) {
|
||||
h->uni_vmovups(Zmm(i), h->ptr[h->rsp + i * get_vec_length()]);
|
||||
}
|
||||
h->add(h->rsp, (get_max_vecs_count()) * get_vec_length());
|
||||
|
||||
// restore k registers
|
||||
for (int i = n_k_regs_to_save - 1; i >= 0; --i) {
|
||||
if (mayiuse(avx512_core))
|
||||
h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
|
||||
else
|
||||
h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
|
||||
}
|
||||
h->add(h->rsp, n_k_regs_to_save * k_mask_size);
|
||||
|
||||
// restore gpr registers
|
||||
for (int i = n_gprs_to_save - 1; i >= 0; --i)
|
||||
h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]);
|
||||
h->add(h->rsp, n_gprs_to_save * gpr_size);
|
||||
}
|
||||
|
||||
void BrgemmCopyBEmitter::execute(matmul::jit_brgemm_matmul_copy_b_t *kernel, const void *src,
|
||||
const void *dst, const void *comp, size_t N, size_t K) {
|
||||
if (!kernel)
|
||||
IE_THROW() << "Kernel for `brgemm_copy_b` hasn't been created";
|
||||
|
||||
auto ctx = dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t::ctx_t();
|
||||
ctx.current_N_blk = N;
|
||||
ctx.src = src;
|
||||
ctx.tr_src = dst;
|
||||
ctx.compensation_ptr = comp;
|
||||
ctx.zp_a_compensation_ptr = nullptr;
|
||||
ctx.zp_a_neg_value_ptr = nullptr;
|
||||
ctx.current_K_start = 0;
|
||||
ctx.current_K_iters = K;
|
||||
|
||||
(*kernel)(&ctx);
|
||||
}
|
||||
|
||||
HorizonMaxEmitter::HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n) :
|
||||
|
@ -321,17 +321,13 @@ class BrgemmEmitter : public jit_emitter {
|
||||
public:
|
||||
BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);
|
||||
|
||||
size_t get_inputs_num() const override {return 2;}
|
||||
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr) {
|
||||
return {{element::f32, element::f32}};
|
||||
}
|
||||
size_t get_inputs_num() const override { return m_with_scratch ? 3 : 2; }
|
||||
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);
|
||||
|
||||
private:
|
||||
void emit_impl(const std::vector<size_t>& in,
|
||||
const std::vector<size_t>& out) const override;
|
||||
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
void emit_isa(const std::vector<size_t> &in, const std::vector<size_t> &out) const;
|
||||
std::vector<size_t> io_data_size {};
|
||||
struct brgemmCtx {
|
||||
size_t M, N, K, LDA, LDB, LDC;
|
||||
@ -342,29 +338,68 @@ private:
|
||||
float beta;
|
||||
};
|
||||
void initBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx) const;
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
void callBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel,
|
||||
const void* pin0, const void* pin1, void* pout, void* wsp) const;
|
||||
size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const;
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
void emit_brgemm_kernel_call(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, int bs,
|
||||
Xbyak::Reg64 addr_A, Xbyak::Reg64 addr_B,
|
||||
const dnnl::impl::cpu::x64::brgemm_batch_element_t *batch, Xbyak::Reg64 addr_C, void *scratch,
|
||||
const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const;
|
||||
static void kernel_execute(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C);
|
||||
|
||||
void emit_brgemm_kernel_call(const dnnl::impl::cpu::x64::brgemm_kernel_t* brg_kernel, const brgemmCtx& ctx,
|
||||
Xbyak::Reg64 addr_A, Xbyak::Reg64 addr_B, Xbyak::Reg64 scratch, Xbyak::Reg64 addr_C,
|
||||
const size_t in0_kernel_offset, const size_t in1_kernel_offset,
|
||||
const size_t in2_kernel_offset, const size_t out0_kernel_offset) const;
|
||||
static void kernel_execute(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp);
|
||||
|
||||
static constexpr size_t BRGEMM_KERNELS_NUM = 8;
|
||||
static constexpr size_t matmulOptimalM = 32;
|
||||
brgemmCtx brgCtxs0[BRGEMM_KERNELS_NUM];
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels0[BRGEMM_KERNELS_NUM];
|
||||
brgemmCtx m_brgCtxs0[BRGEMM_KERNELS_NUM];
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> m_brgKernels0[BRGEMM_KERNELS_NUM];
|
||||
|
||||
size_t M, M_blk, M_tail;
|
||||
size_t K, K_blk, K_tail;
|
||||
size_t N, N_blk, N_tail;
|
||||
size_t brg0VnniFactor;
|
||||
size_t m_M, m_M_blk, m_M_tail;
|
||||
size_t m_K, m_K_blk, m_K_tail;
|
||||
size_t m_N, m_N_blk, m_N_tail;
|
||||
size_t m_brg0VnniFactor;
|
||||
|
||||
size_t load_offset_a = 0lu;
|
||||
size_t load_offset_b = 0lu;
|
||||
size_t store_offset_c = 0lu;
|
||||
bool m_with_scratch = false;
|
||||
bool m_with_comp = false;
|
||||
|
||||
size_t m_load_offset_a = 0lu;
|
||||
size_t m_load_offset_b = 0lu;
|
||||
size_t m_load_offset_scratch = 0lu;
|
||||
size_t m_store_offset_c = 0lu;
|
||||
};
|
||||
|
||||
class BrgemmCopyBEmitter : public jit_emitter {
|
||||
public:
|
||||
BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);
|
||||
|
||||
size_t get_inputs_num() const override {return 1;}
|
||||
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr) {
|
||||
return {{element::i8}, {element::bf16}};
|
||||
}
|
||||
|
||||
private:
|
||||
void emit_impl(const std::vector<size_t>& in,
|
||||
const std::vector<size_t>& out) const override;
|
||||
|
||||
void init_brgemm_copy(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t>& kernel,
|
||||
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K,
|
||||
bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const;
|
||||
void emit_kernel_call(const dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel,
|
||||
Xbyak::Reg64 src, Xbyak::Reg64 dst, Xbyak::Reg64 comp, size_t N, size_t K,
|
||||
size_t offset_in, size_t offset_out, size_t offset_comp) const;
|
||||
|
||||
static void execute(dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel,
|
||||
const void* src, const void* dst, const void* comp, size_t N, size_t K);
|
||||
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> m_kernel;
|
||||
|
||||
ov::element::Type m_brgemm_prc_in0, m_brgemm_prc_in1;
|
||||
size_t m_N, m_N_blk, m_N_tail;
|
||||
size_t m_K, m_K_blk, m_K_tail;
|
||||
size_t m_LDB;
|
||||
size_t m_brgemmVNNIFactor;
|
||||
bool m_with_comp = false;
|
||||
|
||||
size_t m_in_offset = 0lu;
|
||||
size_t m_out_offset = 0lu;
|
||||
size_t m_comp_offset = 0lu;
|
||||
};
|
||||
|
||||
class HorizonMaxEmitter : public jit_emitter {
|
||||
|
@ -11,6 +11,8 @@
|
||||
#include "ngraph_transformations/op/mha.hpp"
|
||||
#include "snippets_transformations/op/load_convert.hpp"
|
||||
#include "snippets_transformations/op/store_convert.hpp"
|
||||
#include "snippets_transformations/op/brgemm_cpu.hpp"
|
||||
#include "snippets_transformations/op/brgemm_copy_b.hpp"
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ov_ops/augru_cell.hpp>
|
||||
@ -54,6 +56,8 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
|
||||
NGRAPH_OP(LoadConvertTruncation, ov::intel_cpu)
|
||||
NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu)
|
||||
NGRAPH_OP(StoreConvertTruncation, ov::intel_cpu)
|
||||
NGRAPH_OP(BrgemmCPU, ov::intel_cpu)
|
||||
NGRAPH_OP(BrgemmCopyB, ov::intel_cpu)
|
||||
#undef NGRAPH_OP
|
||||
|
||||
return opset;
|
||||
@ -132,9 +136,9 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
|
||||
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
NGRAPH_OP(Brgemm, ngraph::snippets::op)
|
||||
NGRAPH_OP(Buffer, ngraph::snippets::op)
|
||||
NGRAPH_OP(BroadcastLoad, ngraph::snippets::op)
|
||||
NGRAPH_OP(BroadcastMove, ngraph::snippets::op)
|
||||
NGRAPH_OP(Buffer, ngraph::snippets::op)
|
||||
NGRAPH_OP(ConvertSaturation, ngraph::snippets::op)
|
||||
NGRAPH_OP(ConvertTruncation, ngraph::snippets::op)
|
||||
NGRAPH_OP(Fill, ngraph::snippets::op)
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include "utils/cpu_utils.hpp"
|
||||
#include "snippets_transformations/fuse_load_store_and_convert.hpp"
|
||||
#include "snippets_transformations/mul_add_to_fma.hpp"
|
||||
#include "snippets_transformations/brgemm_to_brgemm_cpu.hpp"
|
||||
#include "snippets_transformations/remove_converts.hpp"
|
||||
#include "ngraph_transformations/convert_to_swish_cpu.hpp"
|
||||
|
||||
@ -536,6 +537,7 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) {
|
||||
pre_dialect.register_pass<ConvertToSwishCPU>();
|
||||
|
||||
ov::pass::Manager post_dialect;
|
||||
post_dialect.register_pass<ov::intel_cpu::pass::BrgemmToBrgemmCPU>();
|
||||
|
||||
ov::pass::Manager post_precision;
|
||||
post_precision.register_pass<ov::intel_cpu::pass::RemoveConverts>();
|
||||
|
@ -0,0 +1,96 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
#include "brgemm_to_brgemm_cpu.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "op/brgemm_copy_b.hpp"
|
||||
#include "op/brgemm_cpu.hpp"
|
||||
|
||||
#include "ngraph/rt_info.hpp"
|
||||
#include "ngraph/pattern/op/wrap_type.hpp"
|
||||
|
||||
#include <cpu/x64/cpu_isa_traits.hpp>
|
||||
|
||||
#include "cpu_shape.h"
|
||||
#include "utils/general_utils.h"
|
||||
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
|
||||
MATCHER_SCOPE(BrgemmToBrgemmCPU);
|
||||
|
||||
auto m_brgemm = ngraph::pattern::wrap_type<ngraph::snippets::op::Brgemm>();
|
||||
|
||||
auto callback = [=](ngraph::pattern::Matcher& m) {
|
||||
OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::BrgemmToBrgemmCPU")
|
||||
const auto node = m.get_match_root();
|
||||
const auto brgemm = ov::as_type_ptr<ngraph::snippets::op::Brgemm>(node);
|
||||
const auto brgemm_plugin = ov::as_type_ptr<BrgemmCPU>(node);
|
||||
if (!brgemm || brgemm_plugin)
|
||||
throw ov::Exception("BrgemmCPU cannot be in body before BrgemmToBrgemmCPU pass");
|
||||
|
||||
if (brgemm->is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto dimsMatMulIn0 = ngraph::snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape();
|
||||
const auto dimsMatMulIn1 = ngraph::snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape();
|
||||
|
||||
const auto K = *dimsMatMulIn0.rbegin();
|
||||
const auto N = *dimsMatMulIn1.rbegin();
|
||||
|
||||
const auto element_type_a = brgemm->get_input_element_type(0);
|
||||
const auto brgemmVNNIFactor = 4 / element_type_a.size();
|
||||
const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx);
|
||||
const bool with_amx = isAMXSupported && element_type_a != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0);
|
||||
const bool with_comp = element_type_a == ov::element::i8 && !with_amx;
|
||||
|
||||
const auto offset_a = brgemm->get_offset_a();
|
||||
const auto offset_b = brgemm->get_offset_b();
|
||||
const auto offset_c = brgemm->get_offset_c();
|
||||
|
||||
std::shared_ptr<ov::Node> brgemm_cpu = nullptr;
|
||||
if (element_type_a == ov::element::f32) {
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), brgemm->input_value(1), BrgemmCPU::Type::Floating,
|
||||
offset_a, offset_b, offset_c);
|
||||
} else {
|
||||
const auto layoutIn1 = ngraph::snippets::utils::get_node_output_layout(brgemm->input_value(1).get_node_shared_ptr());
|
||||
const auto copy_b_type = with_comp ? BrgemmCopyB::WithCompensations : BrgemmCopyB::OnlyRepacking;
|
||||
const auto brgemmRepackIn1 = std::make_shared<BrgemmCopyB>(brgemm->input_value(1), element_type_a, copy_b_type, offset_b);
|
||||
const auto buffer = std::make_shared<ngraph::snippets::op::Buffer>(brgemmRepackIn1->output(0));
|
||||
|
||||
if (with_amx) {
|
||||
const auto scratch = std::make_shared<ngraph::snippets::op::Buffer>(ov::Shape{BrgemmCPU::SCRATCH_BYTE_SIZE});
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), buffer, scratch, BrgemmCPU::Type::AMX,
|
||||
offset_a, offset_b, offset_c);
|
||||
} else if (with_comp) {
|
||||
const auto scratch = std::make_shared<ngraph::snippets::op::Buffer>(brgemmRepackIn1->output(1));
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), buffer, scratch, BrgemmCPU::Type::WithCompensations,
|
||||
offset_a, offset_b, offset_c);
|
||||
} else if (one_of(element_type_a, ov::element::u8, ov::element::bf16)) {
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), buffer, BrgemmCPU::Type::WithDataRepacking,
|
||||
offset_a, offset_b, offset_c);
|
||||
} else {
|
||||
IE_THROW() << "Invalid configuration for BRGEMM CPU";
|
||||
}
|
||||
}
|
||||
|
||||
brgemm_cpu->set_friendly_name(brgemm->get_friendly_name());
|
||||
ngraph::snippets::utils::set_output_layout(brgemm_cpu->output(0), ngraph::snippets::utils::get_node_output_layout(brgemm));
|
||||
ngraph::copy_runtime_info(brgemm, brgemm_cpu);
|
||||
ngraph::replace_node(brgemm, brgemm_cpu);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(m_brgemm, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @interface BrgemmToBrgemmCPU
|
||||
* @brief The pass decompose Snippets Brgemm to specific subgraph that depends on ISA and input precisions:
|
||||
* - f32|f32:
|
||||
* BrgemmCPU
|
||||
* - u8|i8 or bf16|bf16 (non-AMX system):
|
||||
* \ BrgemmCopyB (the operation for data repacking)
|
||||
* \ Buffer
|
||||
* BrgemmCPU
|
||||
* - i8|i8 (non-AMX system) - needs compensations:
|
||||
* \ BrgemmCopyB
|
||||
* \ / \
|
||||
* \ Buffer (with repacked data) Buffer (with compensations)
|
||||
* \ | /
|
||||
* BrgemmCPU
|
||||
* - u8|i8, i8|i8 or bf16|bf16 on AMX system:
|
||||
* \ BrgemmCopyB
|
||||
* \ Buffer (with repacked data) Buffer (with new memory)
|
||||
* \ | /
|
||||
* BrgemmCPU
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class BrgemmToBrgemmCPU: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("BrgemmToBrgemmCPU", "0");
|
||||
BrgemmToBrgemmCPU();
|
||||
};
|
||||
|
||||
|
||||
} // namespace pass
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -10,20 +10,17 @@
|
||||
#include "snippets_transformations/op/load_convert.hpp"
|
||||
#include "snippets_transformations/op/store_convert.hpp"
|
||||
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/rt_info.hpp"
|
||||
#include "ngraph/pattern/op/wrap_type.hpp"
|
||||
|
||||
ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() {
|
||||
MATCHER_SCOPE(FuseLoadConvert);
|
||||
auto param_pattern = ngraph::pattern::wrap_type<ngraph::opset1::Parameter>();
|
||||
auto load_pattern = ngraph::pattern::wrap_type<ngraph::snippets::op::Load>({param_pattern});
|
||||
auto load_pattern = ngraph::pattern::wrap_type<ngraph::snippets::op::Load>();
|
||||
auto convert_pattern = ngraph::pattern::wrap_type<ngraph::opset1::Convert>({load_pattern});
|
||||
|
||||
auto callback = [=](ngraph::pattern::Matcher& m) {
|
||||
OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::FuseLoadConvert")
|
||||
auto& pm = m.get_pattern_value_map();
|
||||
const auto param = pm.at(param_pattern).get_node_shared_ptr();
|
||||
const auto load_shared = pm.at(load_pattern).get_node_shared_ptr();
|
||||
if (!load_shared || load_shared->output(0).get_target_inputs().size() != 1) {
|
||||
return false;
|
||||
@ -40,12 +37,12 @@ ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() {
|
||||
std::shared_ptr<ngraph::Node> load_convert = nullptr;
|
||||
if (const auto convert_saturation =
|
||||
std::dynamic_pointer_cast<ngraph::snippets::op::ConvertSaturation>(convert)) {
|
||||
load_convert = std::make_shared<ov::intel_cpu::LoadConvertSaturation>(param,
|
||||
load_convert = std::make_shared<ov::intel_cpu::LoadConvertSaturation>(load->input_value(0),
|
||||
convert_saturation->get_destination_type(),
|
||||
load->get_count(), load->get_offset());
|
||||
} else if (const auto convert_truncation =
|
||||
std::dynamic_pointer_cast<ngraph::snippets::op::ConvertTruncation>(convert)) {
|
||||
load_convert = std::make_shared<ov::intel_cpu::LoadConvertTruncation>(param,
|
||||
load_convert = std::make_shared<ov::intel_cpu::LoadConvertTruncation>(load->input_value(0),
|
||||
convert_truncation->get_destination_type(),
|
||||
load->get_count(), load->get_offset());
|
||||
} else {
|
||||
@ -102,7 +99,6 @@ ov::intel_cpu::pass::FuseStoreConvert::FuseStoreConvert() {
|
||||
"Type of Convert op is undefined. Supports only fusing Store and ConvertTruncation or ConvertSaturation ops");
|
||||
}
|
||||
|
||||
|
||||
if (!store_convert)
|
||||
return false;
|
||||
|
||||
|
@ -0,0 +1,78 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
|
||||
#include "brgemm_copy_b.hpp"
|
||||
|
||||
#include "utils/general_utils.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
|
||||
intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type,
|
||||
const size_t offset_in, const size_t offset_out0, const size_t offset_out1)
|
||||
: ngraph::snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), m_type(type), m_src_type(src_type) {
|
||||
set_output_size(get_output_port_count());
|
||||
m_input_ports.resize(get_input_size());
|
||||
m_output_ports.resize(get_output_size());
|
||||
set_input_port_descriptor({0, offset_in}, 0);
|
||||
set_output_port_descriptor({0, offset_out0}, 0);
|
||||
if (is_with_compensations()) {
|
||||
set_output_port_descriptor({0, offset_out1}, 1);
|
||||
}
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) {
|
||||
INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes);
|
||||
MemoryAccess::visit_attributes(visitor);
|
||||
visitor.on_attribute("src_type", m_src_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
void intel_cpu::BrgemmCopyB::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types);
|
||||
|
||||
const auto element_type = get_input_element_type(0);
|
||||
NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8),
|
||||
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
|
||||
|
||||
const auto pshape = ngraph::snippets::utils::get_port_planar_shape(input_value(0));
|
||||
if (pshape.is_dynamic()) {
|
||||
set_output_type(0, element_type, ov::PartialShape{ov::Dimension::dynamic()});
|
||||
if (is_with_compensations()) {
|
||||
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const auto shape = pshape.get_shape();
|
||||
const auto N = *shape.rbegin();
|
||||
const auto K = *(shape.rbegin() + 1);
|
||||
const auto N_blk = element_type == element::bf16 ? 32 : 64;
|
||||
const auto brgemmVNNIFactor = 4 / m_src_type.size();
|
||||
|
||||
set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, brgemmVNNIFactor)),
|
||||
ov::Dimension(rnd_up(N, N_blk))});
|
||||
if (is_with_compensations()) {
|
||||
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, N_blk))});
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<BrgemmCopyB>(new_args.at(0), m_src_type, m_type,
|
||||
get_offset_in(),
|
||||
get_offset_out(),
|
||||
is_with_compensations() ? get_offset_compensations() : 0);
|
||||
}
|
||||
|
||||
size_t intel_cpu::BrgemmCopyB::get_offset_compensations() const {
|
||||
OPENVINO_ASSERT(is_with_compensations() && get_output_size() == 2,
|
||||
"The offset for compensations must be in BrgemmCopyB only with compensations and 2 outputs!");
|
||||
return get_output_offset(1);
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "snippets/op/memory_access.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
/**
|
||||
* @interface BrgemmCopyB
|
||||
* @brief The operation for data repacking of Brgemm with input non-fp32 precisions.
|
||||
The CPU Generator uses oneDNN primitives for generation code of Brgemm.
|
||||
OneDNN requiers data repacking for second input of Brgemm with input non-fp32 precisions.
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class BrgemmCopyB : public ngraph::snippets::op::MemoryAccess {
|
||||
public:
|
||||
OPENVINO_OP("BrgemmCopyB", "SnippetsOpset", MemoryAccess);
|
||||
|
||||
enum Type {
|
||||
OnlyRepacking, // Just data repacking - one output
|
||||
WithCompensations, // Repack data and caclulate compensations - 2 outputs (is needed for BrgemmCPU with compensations)
|
||||
};
|
||||
|
||||
BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type = Type::OnlyRepacking,
|
||||
const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu);
|
||||
BrgemmCopyB() = default;
|
||||
|
||||
size_t get_offset_in() const { return get_input_offset(0); }
|
||||
size_t get_offset_out() const { return get_output_offset(0); }
|
||||
size_t get_offset_compensations() const;
|
||||
|
||||
Type get_type() const { return m_type; }
|
||||
element::Type get_src_element_type() const { return m_src_type; }
|
||||
bool is_with_compensations() const { return m_type == Type::WithCompensations; }
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
void validate_and_infer_types() override;
|
||||
bool has_evaluate() const override { return false; }
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
private:
|
||||
Type m_type = Type::OnlyRepacking;
|
||||
element::Type m_src_type = ov::element::undefined; // src element type of the corresponding BRGEMM
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -0,0 +1,117 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "brgemm_cpu.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "utils/general_utils.h"
|
||||
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
|
||||
const size_t offset_a, const size_t offset_b, const size_t offset_c)
|
||||
: Brgemm(), m_type(type) {
|
||||
// We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call
|
||||
set_arguments({A, B});
|
||||
set_output_size(1);
|
||||
m_input_ports.resize(get_input_size());
|
||||
m_output_ports.resize(get_output_size());
|
||||
set_input_port_descriptor({0, offset_a}, 0);
|
||||
set_input_port_descriptor({0, offset_b}, 1);
|
||||
set_output_port_descriptor({0, offset_c}, 0);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
|
||||
const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c)
|
||||
: Brgemm(), m_type(type) {
|
||||
set_arguments({A, B, scratch});
|
||||
set_output_size(1);
|
||||
m_input_ports.resize(get_input_size());
|
||||
m_output_ports.resize(get_output_size());
|
||||
set_input_port_descriptor({0, offset_a}, 0);
|
||||
set_input_port_descriptor({0, offset_b}, 1);
|
||||
set_output_port_descriptor({0, offset_c}, 0);
|
||||
set_input_port_descriptor({0, offset_scratch}, 2);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void BrgemmCPU::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(BrgemmCPU_validate_and_infer_types);
|
||||
// If no leading dimensions are provided, assume dense row-major inputs-outputs
|
||||
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(),
|
||||
"BrgemmCPU currently supports only static shapes.");
|
||||
|
||||
OPENVINO_ASSERT(implication(one_of(m_type, Type::Floating, Type::WithDataRepacking), get_input_size() == 2),
|
||||
"BrgemmCPU expects 2 inputs in cases, when input precisions are f32|f32, u8|i8 or bf16|bf16 (non-AMX system)");
|
||||
OPENVINO_ASSERT(implication(one_of(m_type, Type::WithCompensations, Type::AMX), get_input_size() == 3),
|
||||
"BrgemmCPU expects 3 inputs with input precisions i8|i8 and bf16|bf16 on AMX system");
|
||||
|
||||
const auto brgemm_copy = is_with_data_repacking() ? get_brgemm_copy() : nullptr;
|
||||
std::vector<ov::PartialShape> planar_input_shapes = {
|
||||
ngraph::snippets::utils::get_port_planar_shape(input_value(0)),
|
||||
ngraph::snippets::utils::get_port_planar_shape(brgemm_copy ? brgemm_copy->input_value(0) : input_value(1))
|
||||
};
|
||||
|
||||
auto output_shape = get_output_partial_shape(planar_input_shapes);
|
||||
const auto& output_layout = ngraph::snippets::utils::get_node_output_layout(this);
|
||||
set_output_type(0,
|
||||
get_output_type(),
|
||||
ngraph::snippets::utils::get_reordered_planar_shape(output_shape, output_layout));
|
||||
|
||||
//Additional check for 3rd input
|
||||
if (one_of(m_type, Type::WithCompensations, Type::AMX)) {
|
||||
const auto shape = get_input_partial_shape(2);
|
||||
NGRAPH_CHECK(shape.is_static(), "BRGEMM Scratch must have static shape");
|
||||
const auto type = get_input_element_type(2);
|
||||
if (is_with_compensations()) {
|
||||
const auto element_type_b = get_input_element_type(0);
|
||||
const auto shape_b = planar_input_shapes[1].get_shape();
|
||||
const auto N = *shape_b.rbegin();
|
||||
const auto N_blk = element_type_b == element::f32 ? N :
|
||||
element_type_b == element::bf16 ? 32 : 64;
|
||||
const auto expected_shape = ov::Shape{rnd_up(N, N_blk)};
|
||||
const auto expected_type = ov::element::f32;
|
||||
NGRAPH_CHECK(expected_shape == shape.get_shape() && expected_type == type,
|
||||
"BRGEMM Scratch with compensations must have shape {rnd_up(N, N_blk)} and FP32 element type");
|
||||
} else {
|
||||
NGRAPH_CHECK(ngraph::shape_size(shape.get_shape()) == SCRATCH_BYTE_SIZE && type == ov::element::u8,
|
||||
"BRGEMM Scratch for space workplace must be static, have U8 element type and size is equal to " + std::to_string(SCRATCH_BYTE_SIZE));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> BrgemmCPU::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(BrgemmCPU_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
std::shared_ptr<BrgemmCPU> new_node = nullptr;
|
||||
if (!is_with_scratchpad()) {
|
||||
new_node = std::make_shared<BrgemmCPU>(new_args.at(0), new_args.at(1), m_type,
|
||||
get_offset_a(), get_offset_b(), get_offset_c());
|
||||
} else {
|
||||
new_node = std::make_shared<BrgemmCPU>(new_args.at(0), new_args.at(1), new_args.at(2), m_type,
|
||||
get_offset_a(), get_offset_b(), get_offset_scratch(), get_offset_c());
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
|
||||
std::shared_ptr<BrgemmCopyB> BrgemmCPU::get_brgemm_copy() const {
|
||||
OPENVINO_ASSERT(one_of(m_type, Type::WithDataRepacking, Type::WithCompensations, Type::AMX), "Brgemm doesn't need BrgemmCopyB");
|
||||
if (const auto buffer = ov::as_type_ptr<ngraph::snippets::op::Buffer>(get_input_node_shared_ptr(1))) {
|
||||
return ov::as_type_ptr<BrgemmCopyB>(buffer->get_input_node_shared_ptr(0));
|
||||
}
|
||||
throw ov::Exception("BrgemmCopyB hasn't been found!");
|
||||
}
|
||||
|
||||
size_t BrgemmCPU::get_offset_scratch() const {
|
||||
OPENVINO_ASSERT(is_with_scratchpad() && get_input_size() == 3, "Offset of scratchpad must be only in Brgemm with scratchpad on 3rd input");
|
||||
return get_input_offset(2);
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -0,0 +1,55 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "snippets/op/brgemm.hpp"
|
||||
#include "brgemm_copy_b.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
/**
|
||||
* @interface BrgemmCPU
|
||||
* @brief BrgemmCPU is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows
|
||||
* with support of several precisions on plugin level
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class BrgemmCPU : public ngraph::snippets::op::Brgemm {
|
||||
public:
|
||||
OPENVINO_OP("BrgemmCPU", "SnippetsOpset", ngraph::snippets::op::Brgemm);
|
||||
|
||||
enum Type {
|
||||
Floating, // f32|f32
|
||||
WithDataRepacking, // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking
|
||||
WithCompensations, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations
|
||||
AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad
|
||||
};
|
||||
|
||||
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
|
||||
const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_c = 0);
|
||||
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
|
||||
const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_scratch = 0, const size_t offset_c = 0);
|
||||
BrgemmCPU() = default;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
Type get_type() const { return m_type; }
|
||||
bool is_with_compensations() const { return m_type == Type::WithCompensations; }
|
||||
bool is_with_data_repacking() const { return m_type != Type::Floating; }
|
||||
bool is_amx() const { return m_type == Type::AMX; }
|
||||
bool is_with_scratchpad() const { return is_with_compensations() || is_amx(); }
|
||||
|
||||
size_t get_offset_scratch() const;
|
||||
std::shared_ptr<BrgemmCopyB> get_brgemm_copy() const;
|
||||
|
||||
constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024;
|
||||
|
||||
private:
|
||||
Type m_type = Type::Floating;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -19,6 +19,7 @@ intel_cpu::LoadConvertSaturation::LoadConvertSaturation(const Output<Node>& x, c
|
||||
|
||||
bool intel_cpu::LoadConvertSaturation::visit_attributes(AttributeVisitor& visitor) {
|
||||
INTERNAL_OP_SCOPE(LoadConvert_visit_attributes);
|
||||
MemoryAccess::visit_attributes(visitor);
|
||||
visitor.on_attribute("destination_type", m_destination_type);
|
||||
return true;
|
||||
}
|
||||
@ -31,7 +32,8 @@ void intel_cpu::LoadConvertSaturation::validate_and_infer_types() {
|
||||
std::shared_ptr<Node> intel_cpu::LoadConvertSaturation::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(LoadConvert_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<LoadConvertSaturation>(new_args.at(0), m_destination_type, m_count, m_offset);
|
||||
return std::make_shared<LoadConvertSaturation>(
|
||||
new_args.at(0), m_destination_type, get_count(), get_offset());
|
||||
}
|
||||
|
||||
intel_cpu::LoadConvertTruncation::LoadConvertTruncation(const Output<Node>& x, const ov::element::Type& destination_type,
|
||||
@ -42,6 +44,7 @@ intel_cpu::LoadConvertTruncation::LoadConvertTruncation(const Output<Node>& x, c
|
||||
|
||||
bool intel_cpu::LoadConvertTruncation::visit_attributes(AttributeVisitor& visitor) {
|
||||
INTERNAL_OP_SCOPE(LoadConvert_visit_attributes);
|
||||
MemoryAccess::visit_attributes(visitor);
|
||||
visitor.on_attribute("destination_type", m_destination_type);
|
||||
return true;
|
||||
}
|
||||
@ -54,5 +57,6 @@ void intel_cpu::LoadConvertTruncation::validate_and_infer_types() {
|
||||
std::shared_ptr<Node> intel_cpu::LoadConvertTruncation::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(LoadConvert_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<LoadConvertTruncation>(new_args.at(0), m_destination_type, m_count, m_offset);
|
||||
return std::make_shared<LoadConvertTruncation>(
|
||||
new_args.at(0), m_destination_type, get_count(), get_offset());
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ intel_cpu::StoreConvertSaturation::StoreConvertSaturation(const Output<Node>& x,
|
||||
|
||||
bool intel_cpu::StoreConvertSaturation::visit_attributes(AttributeVisitor& visitor) {
|
||||
INTERNAL_OP_SCOPE(StoreConvert_visit_attributes);
|
||||
MemoryAccess::visit_attributes(visitor);
|
||||
visitor.on_attribute("destination_type", m_destination_type);
|
||||
return true;
|
||||
}
|
||||
@ -31,7 +32,8 @@ void intel_cpu::StoreConvertSaturation::validate_and_infer_types() {
|
||||
std::shared_ptr<Node> intel_cpu::StoreConvertSaturation::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(StoreConvert_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<StoreConvertSaturation>(new_args.at(0), m_destination_type, m_count, m_offset);
|
||||
return std::make_shared<StoreConvertSaturation>(
|
||||
new_args.at(0), m_destination_type, get_count(), get_offset());
|
||||
}
|
||||
|
||||
intel_cpu::StoreConvertTruncation::StoreConvertTruncation(const Output<Node>& x, const ov::element::Type& destination_type,
|
||||
@ -42,6 +44,7 @@ intel_cpu::StoreConvertTruncation::StoreConvertTruncation(const Output<Node>& x,
|
||||
|
||||
bool intel_cpu::StoreConvertTruncation::visit_attributes(AttributeVisitor& visitor) {
|
||||
INTERNAL_OP_SCOPE(StoreConvert_visit_attributes);
|
||||
MemoryAccess::visit_attributes(visitor);
|
||||
visitor.on_attribute("destination_type", m_destination_type);
|
||||
return true;
|
||||
}
|
||||
@ -54,5 +57,6 @@ void intel_cpu::StoreConvertTruncation::validate_and_infer_types() {
|
||||
std::shared_ptr<Node> intel_cpu::StoreConvertTruncation::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(StoreConvert_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<StoreConvertTruncation>(new_args.at(0), m_destination_type, m_count, m_offset);
|
||||
return std::make_shared<StoreConvertTruncation>(
|
||||
new_args.at(0), m_destination_type, get_count(), get_offset());
|
||||
}
|
||||
|
@ -556,7 +556,7 @@ void Transformations::PostLpt() {
|
||||
|
||||
void Transformations::MainSnippets(void) {
|
||||
if (snippetsMode == Config::SnippetsMode::Disable ||
|
||||
!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemeted only for relevant platforms (avx2+ extentions)
|
||||
!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions)
|
||||
return;
|
||||
|
||||
ngraph::pass::Manager snippetsManager;
|
||||
|
@ -207,6 +207,10 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
retVector.emplace_back(R"(.*Snippets.*MHA.*)");
|
||||
retVector.emplace_back(R"(.*Snippets.*(MatMul|Matmul).*)");
|
||||
}
|
||||
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
|
||||
// MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions
|
||||
retVector.emplace_back(R"(.*Snippets.*MatMulFQ.*)");
|
||||
}
|
||||
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8())
|
||||
//TODO: Issue 92895
|
||||
// on platforms which do not support AMX, we are disabling I8 input tests
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "snippets/matmul.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
#include "ie_system_conf.h"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@ -16,49 +17,47 @@ std::vector<std::vector<ov::PartialShape>> input_shapes{
|
||||
{{3, 1, 32, 14}, {1, 2, 14, 32}},
|
||||
{{1, 2, 37, 23}, {2, 1, 23, 37}},
|
||||
{{1, 1, 37, 23}, {1, 2, 23, 33}},
|
||||
{{2, 1, 69, 43}, {1, 1, 43, 49}}
|
||||
{{1, 16, 384, 64}, {1, 16, 64, 384}}
|
||||
};
|
||||
std::vector<element::Type> precisions{element::f32};
|
||||
static inline std::vector<std::vector<element::Type>> precisions(bool only_fp32 = true) {
|
||||
std::vector<std::vector<element::Type>> prc = {
|
||||
{element::f32, element::f32},
|
||||
};
|
||||
if (!only_fp32) {
|
||||
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::i8, element::i8});
|
||||
prc.emplace_back(std::vector<element::Type>{element::u8, element::i8});
|
||||
}
|
||||
// In Snippets MatMul BF16 is supported only on bf16/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::bf16, element::bf16});
|
||||
}
|
||||
}
|
||||
return prc;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes),
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(1), // MatMu;
|
||||
::testing::ValuesIn(precisions(false)),
|
||||
::testing::Values(1), // MatMul
|
||||
::testing::Values(1), // Tokenized MatMul
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
MatMul::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulFQ, MatMulFQ,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes),
|
||||
::testing::ValuesIn(precisions()),
|
||||
::testing::Values(1), // MatMul;
|
||||
::testing::Values(1), // Tokenized MatMul
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
MatMul::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<ov::PartialShape>{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 1, 69, 49}}),
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(1), // Subgraph;
|
||||
::testing::Values(1), // Tokenized MatMul+Bias
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
MatMul::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ExplicitTransposeMatMul, ExplicitTransposeMatMul,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<ov::PartialShape>{{1, 2, 69, 43}, {2, 49, 2, 43}}),
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(1), // Subgraph;
|
||||
::testing::Values(1), // Tokenized MatMul+Bias
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
ExplicitTransposeMatMul::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulBias, ExplicitTransposeMatMulBias,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<ov::PartialShape>{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 1, 69, 49}}),
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(1), // Subgraph;
|
||||
::testing::Values(1), // Tokenized MatMul+Bias
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
MatMul::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMulMatMulBias, ExplicitTransposeMulMatMulBias,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<ov::PartialShape>{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 2, 1, 1}, {1, 1, 69, 49}}),
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::ValuesIn(precisions(false)),
|
||||
::testing::Values(1), // Subgraph;
|
||||
::testing::Values(1), // Tokenized MatMul+Bias
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "snippets/transpose_matmul.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
#include "ie_system_conf.h"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@ -11,7 +12,23 @@ namespace snippets {
|
||||
|
||||
|
||||
namespace {
|
||||
std::vector<element::Type> precisions{element::f32};
|
||||
static inline std::vector<std::vector<element::Type>> precisions(bool only_fp32 = true) {
|
||||
std::vector<std::vector<element::Type>> prc = {
|
||||
{element::f32, element::f32},
|
||||
};
|
||||
if (!only_fp32) {
|
||||
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::i8, element::i8});
|
||||
prc.emplace_back(std::vector<element::Type>{element::u8, element::i8});
|
||||
}
|
||||
// In Snippets MatMul BF16 is supported only on bf16/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::bf16, element::bf16});
|
||||
}
|
||||
}
|
||||
return prc;
|
||||
}
|
||||
namespace transpose_zero_input {
|
||||
std::vector<std::vector<ov::PartialShape>> transpose_input_shapes{
|
||||
{{1, 49, 2, 23}, {2, 2, 23, 39}}
|
||||
@ -20,11 +37,23 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(transpose_input_shapes),
|
||||
::testing::Values(0), // Transpose on 0th Matmul input
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(1), // MatMul;
|
||||
::testing::ValuesIn(precisions(false)),
|
||||
::testing::Values(1), // MatMul
|
||||
::testing::Values(1), // Tokenized MatMul + FusedTranspose
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
TransposeMatMul::getTestCaseName);
|
||||
|
||||
// TODO: FuseTransposeToBrgemm supports fusing only if Transpose is before Parameter in cases when Transpose is on input at the moment
|
||||
// When we support the branch Parameter->FQ->Transpose->MatMul[0th input], uncomment this test case please
|
||||
// INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ,
|
||||
// ::testing::Combine(
|
||||
// ::testing::ValuesIn(transpose_input_shapes),
|
||||
// ::testing::Values(0), // Transpose on 0th Matmul input
|
||||
// ::testing::Values(ov::element::i8),
|
||||
// ::testing::Values(1), // MatMul
|
||||
// ::testing::Values(1), // Tokenized MatMul + FusedTranspose
|
||||
// ::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
// TransposeMatMulFQ::getTestCaseName);
|
||||
} // namespace transpose_zero_input
|
||||
|
||||
namespace transpose_first_input {
|
||||
@ -35,11 +64,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(transpose_input_shapes),
|
||||
::testing::Values(1), // Transpose on 1st Matmul input
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(1), // MatMu;
|
||||
::testing::ValuesIn(precisions(false)),
|
||||
::testing::Values(1), // MatMul
|
||||
::testing::Values(1), // Tokenized MatMul + FusedTranspose
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
TransposeMatMul::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(transpose_input_shapes),
|
||||
::testing::Values(1), // Transpose on 1st Matmul input
|
||||
::testing::ValuesIn(precisions()),
|
||||
::testing::Values(1), // MatMul
|
||||
::testing::Values(1), // Tokenized MatMul + FusedTranspose
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
TransposeMatMulFQ::getTestCaseName);
|
||||
} // namespace transpose_first_input
|
||||
|
||||
namespace transpose_output {
|
||||
@ -50,13 +89,64 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(transpose_input_shapes),
|
||||
::testing::Values(2), // Transpose on Matmul output
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(1), // MatMu;
|
||||
::testing::ValuesIn(precisions()),
|
||||
::testing::Values(1), // MatMul
|
||||
::testing::Values(1), // Tokenized MatMul + FusedTranspose
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
TransposeMatMul::getTestCaseName);
|
||||
|
||||
// TODO: At the moment we doesn't support the branch MatMul[output]->Transpose->FQ.
|
||||
// When we add support, uncomment this test case please
|
||||
// INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ,
|
||||
// ::testing::Combine(
|
||||
// ::testing::ValuesIn(transpose_input_shapes),
|
||||
// ::testing::Values(2), // Transpose on Matmul output
|
||||
// ::testing::Values(ov::element::i8),
|
||||
// ::testing::Values(1), // MatMul
|
||||
// ::testing::Values(1), // Tokenized MatMul + FusedTranspose
|
||||
// ::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
// TransposeMatMulFQ::getTestCaseName);
|
||||
} // namespace transpose_output
|
||||
|
||||
namespace explicit_transpose {
|
||||
static inline std::vector<std::vector<element::Type>> precisions(bool only_fp32 = true) {
|
||||
std::vector<std::vector<element::Type>> prc = {
|
||||
{element::f32, element::f32},
|
||||
};
|
||||
if (!only_fp32) {
|
||||
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::i8, element::i8});
|
||||
prc.emplace_back(std::vector<element::Type>{element::u8, element::i8});
|
||||
}
|
||||
// In Snippets MatMul BF16 is supported only on bf16/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::bf16, element::bf16});
|
||||
}
|
||||
}
|
||||
return prc;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ExplicitTransposeMatMul, ExplicitTransposeMatMul,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<ov::PartialShape>{{1, 2, 69, 43}, {2, 49, 2, 43}}),
|
||||
::testing::Values(1), // Transpose on second input
|
||||
::testing::ValuesIn(precisions()),
|
||||
::testing::Values(1), // Subgraph;
|
||||
::testing::Values(1), // Tokenized MatMul+Bias
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
ExplicitTransposeMatMul::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulBias, ExplicitTransposeMatMulBias,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<ov::PartialShape>{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 1, 69, 49}}),
|
||||
::testing::Values(1), // Transpose on second input
|
||||
::testing::ValuesIn(precisions()),
|
||||
::testing::Values(1), // Subgraph;
|
||||
::testing::Values(1), // Tokenized MatMul+Bias
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
ExplicitTransposeMatMulBias::getTestCaseName);
|
||||
} // namespace explicit_transpose
|
||||
|
||||
} // namespace
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
|
@ -12,21 +12,12 @@ namespace snippets {
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<ov::PartialShape>, // Input Shapes
|
||||
ov::element::Type, // Element type
|
||||
std::vector<ov::element::Type>,// Input Element types
|
||||
size_t, // Expected num nodes
|
||||
size_t, // Expected num subgraphs
|
||||
std::string // Target Device
|
||||
> MatMulParams;
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<ov::PartialShape>, // Input Shapes
|
||||
size_t , // Transpose position
|
||||
ov::element::Type, // Element type
|
||||
size_t, // Expected num nodes
|
||||
size_t, // Expected num subgraphs
|
||||
std::string // Target Device
|
||||
> TransposeMatMulParams;
|
||||
|
||||
class MatMul : public testing::WithParamInterface<ov::test::snippets::MatMulParams>,
|
||||
virtual public ov::test::SnippetsTestsCommon {
|
||||
public:
|
||||
@ -36,26 +27,16 @@ protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class MatMulFQ : public MatMul {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class MatMulBias : public MatMul {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class ExplicitTransposeMatMul : public MatMul {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class ExplicitTransposeMatMulBias : public MatMul {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class ExplicitTransposeMulMatMulBias : public MatMul {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -13,7 +13,7 @@ namespace snippets {
|
||||
typedef std::tuple<
|
||||
std::vector<ov::PartialShape>, // Input Shapes
|
||||
size_t , // Transpose position
|
||||
ov::element::Type, // Element type
|
||||
std::vector<ov::element::Type>,// Input Element types
|
||||
size_t, // Expected num nodes
|
||||
size_t, // Expected num subgraphs
|
||||
std::string // Target Device
|
||||
@ -28,6 +28,21 @@ protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class TransposeMatMulFQ : public TransposeMatMul {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class ExplicitTransposeMatMul : public TransposeMatMul {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class ExplicitTransposeMatMulBias : public TransposeMatMul {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -14,14 +14,15 @@ namespace snippets {
|
||||
|
||||
std::string MatMul::getTestCaseName(testing::TestParamInfo<ov::test::snippets::MatMulParams> obj) {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
ov::element::Type elem_type;
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::string targetDevice;
|
||||
size_t num_nodes, num_subgraphs;
|
||||
std::tie(input_shapes, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
std::tie(input_shapes, elem_types, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < input_shapes.size(); i++)
|
||||
result << "IS[" << i <<"]=" << CommonTestUtils::partialShape2str({input_shapes[i]}) << "_";
|
||||
result << "T=" << elem_type << "_";
|
||||
for (size_t i = 0; i < elem_types.size(); i++)
|
||||
result << "T[" << i <<"]=" << elem_types[i] << "_";
|
||||
result << "#N=" << num_nodes << "_";
|
||||
result << "#S=" << num_subgraphs << "_";
|
||||
result << "targetDevice=" << targetDevice;
|
||||
@ -30,11 +31,25 @@ std::string MatMul::getTestCaseName(testing::TestParamInfo<ov::test::snippets::M
|
||||
|
||||
void MatMul::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
ov::element::Type elem_type;
|
||||
std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::MatMulFunction(input_shapes);
|
||||
auto f = ov::test::snippets::MatMulFunction(input_shapes, elem_types);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
}
|
||||
|
||||
void MatMulFQ::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::FQMatMulFunction(input_shapes);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
@ -44,53 +59,11 @@ void MatMul::SetUp() {
|
||||
|
||||
void MatMulBias::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
ov::element::Type elem_type;
|
||||
std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::MatMulBiasFunction(input_shapes);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
}
|
||||
|
||||
void ExplicitTransposeMatMul::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
ov::element::Type elem_type;
|
||||
std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::TransposeMatMulFunction(input_shapes);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
}
|
||||
|
||||
void ExplicitTransposeMatMulBias::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
ov::element::Type elem_type;
|
||||
std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::TransposeMatMulBiasFunction(input_shapes);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
}
|
||||
|
||||
void ExplicitTransposeMulMatMulBias::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
ov::element::Type elem_type;
|
||||
std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::TransposeMulMatMulBiasFunction(input_shapes);
|
||||
auto f = ov::test::snippets::MatMulBiasFunction(input_shapes, elem_types);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
@ -99,26 +72,19 @@ void ExplicitTransposeMulMatMulBias::SetUp() {
|
||||
}
|
||||
|
||||
TEST_P(MatMul, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MatMulFQ, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MatMulBias, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(ExplicitTransposeMulMatMulBias, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
@ -15,17 +15,17 @@ namespace snippets {
|
||||
std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo<ov::test::snippets::TransposeMatMulParams> obj) {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
size_t transpose_position;
|
||||
ov::element::Type elem_type;
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::string targetDevice;
|
||||
size_t num_nodes, num_subgraphs;
|
||||
std::tie(input_shapes, transpose_position, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
if (input_shapes.size() != 2)
|
||||
IE_THROW() << "Invalid input shapes vector size";
|
||||
std::tie(input_shapes, transpose_position, elem_types, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "IS[0]=" << CommonTestUtils::partialShape2str({input_shapes[0]}) << "_";
|
||||
result << "IS[1]=" << CommonTestUtils::partialShape2str({input_shapes[1]}) << "_";
|
||||
for (size_t i = 0; i < input_shapes.size(); ++i) {
|
||||
result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({input_shapes[i]}) << "_";
|
||||
}
|
||||
result << "Pos=" << transpose_position << "_";
|
||||
result << "T=" << elem_type << "_";
|
||||
for (size_t i = 0; i < elem_types.size(); i++)
|
||||
result << "T[" << i <<"]=" << elem_types[i] << "_";
|
||||
result << "#N=" << num_nodes << "_";
|
||||
result << "#S=" << num_subgraphs << "_";
|
||||
result << "targetDevice=" << targetDevice;
|
||||
@ -35,11 +35,56 @@ std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo<ov::test::sn
|
||||
void TransposeMatMul::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
size_t transpose_position;
|
||||
ov::element::Type elem_type;
|
||||
std::tie(input_shapes, transpose_position, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::Transpose0213MatMulFunction(input_shapes, transpose_position);
|
||||
auto f = ov::test::snippets::Transpose0213MatMulFunction(input_shapes, elem_types, transpose_position);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
}
|
||||
|
||||
void TransposeMatMulFQ::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
size_t transpose_position;
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::FQMatMulFunction(input_shapes, transpose_position);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
}
|
||||
|
||||
void ExplicitTransposeMatMul::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
size_t transpose_position;
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::TransposeMatMulFunction(input_shapes);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
}
|
||||
|
||||
void ExplicitTransposeMatMulBias::SetUp() {
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
size_t transpose_position;
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(input_shapes));
|
||||
|
||||
auto f = ov::test::snippets::TransposeMatMulBiasFunction(input_shapes);
|
||||
function = f.getOriginal();
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
@ -48,6 +93,25 @@ void TransposeMatMul::SetUp() {
|
||||
}
|
||||
|
||||
TEST_P(TransposeMatMul, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(TransposeMatMulFQ, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ private:
|
||||
class Transpose0213MatMulLoweredFunction : public Transpose0213MatMulFunction {
|
||||
public:
|
||||
explicit Transpose0213MatMulLoweredFunction(const std::vector<PartialShape>& inputShapes, size_t position = 0) :
|
||||
Transpose0213MatMulFunction(inputShapes, position) {
|
||||
Transpose0213MatMulFunction(inputShapes, std::vector<ov::element::Type>{ov::element::f32, ov::element::f32}, position) {
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initLowered() const override;
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "./snippets_helpers.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
|
||||
/* This file contains definitions of relatively simple functions (models) that will be used
|
||||
* to test snippets-specific behavior. All the functions are expected to be direct descendants of
|
||||
@ -20,48 +21,77 @@ namespace snippets {
|
||||
// in1 in2
|
||||
// Matmul
|
||||
// Result
|
||||
// todo: remove once "no subgraph after input" limitation is relaxed
|
||||
class MatMulFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MatMulFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
explicit MatMulFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes");
|
||||
verify_precisions(precisions);
|
||||
}
|
||||
static void verify_precisions(const std::vector<ov::element::Type>& precisions) {
|
||||
NGRAPH_CHECK(precisions.size() == 2, "Got invalid number of input element types");
|
||||
const bool is_f32 = ngraph::snippets::utils::everyone_is(element::f32, precisions[0], precisions[1]);
|
||||
const bool is_int8 = ngraph::snippets::utils::one_of(precisions[0], element::i8, element::u8) && precisions[1] == element::i8;
|
||||
const bool is_bf16 = ngraph::snippets::utils::everyone_is(element::bf16, precisions[0], precisions[1]);
|
||||
NGRAPH_CHECK(is_f32 || is_bf16 || is_int8, "Invalid precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
std::shared_ptr<ov::Model> initReference() const override;
|
||||
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
class FQMatMulFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit FQMatMulFunction(const std::vector<PartialShape>& inputShapes, int pos = -1) : SnippetsFunctionBase({inputShapes[0]}), pos(pos) {
|
||||
NGRAPH_CHECK(inputShapes.size() == 2, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(pos >=-1 && pos <= 2, "Got invalid transpose position");
|
||||
const_shape = inputShapes[1];
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
|
||||
ov::PartialShape const_shape;
|
||||
int pos = -1;
|
||||
};
|
||||
|
||||
// As same as MatMulFunction but with biases
|
||||
class MatMulBiasFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MatMulBiasFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
explicit MatMulBiasFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
MatMulFunction::verify_precisions(precisions);
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
/// Minimal graph to test MatMul+Transpose combinations. Transpose location is specified via the position argument:
|
||||
/// 0 - before the first MatMul input; 1 - before the second MatMul input; 2 - after the MatMul output.
|
||||
/// Tokenized simply by starting subgraph,
|
||||
// in1 in2
|
||||
// Transpose /
|
||||
// Transpose /
|
||||
// Matmul
|
||||
// Result
|
||||
class Transpose0213MatMulFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit Transpose0213MatMulFunction(const std::vector<PartialShape>& inputShapes, size_t position = 0)
|
||||
: SnippetsFunctionBase(inputShapes), transpose_position(position) {
|
||||
explicit Transpose0213MatMulFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions,
|
||||
size_t position = 0)
|
||||
: SnippetsFunctionBase(inputShapes), transpose_position(position), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(input_shapes[0].rank().get_length() == 4 && input_shapes[1].rank().get_length() == 4,
|
||||
"Only rank 4 input shapes are supported by this test");
|
||||
NGRAPH_CHECK(transpose_position >=0 && transpose_position <= 2, "Got invalid transpose position");
|
||||
MatMulFunction::verify_precisions(precisions);
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
size_t transpose_position;
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
class TransposeMatMulFunction : public SnippetsFunctionBase {
|
||||
|
@ -107,8 +107,8 @@ std::shared_ptr<ov::Model> EltwiseThreeInputsLoweredFunction::initLowered() cons
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> Transpose0213MatMulLoweredFunction::initLowered() const {
|
||||
ParameterVector data{std::make_shared<op::v0::Parameter>(precision, input_shapes[0]),
|
||||
std::make_shared<op::v0::Parameter>(precision, input_shapes[1])};
|
||||
ParameterVector data{std::make_shared<op::v0::Parameter>(precisions[0], input_shapes[0]),
|
||||
std::make_shared<op::v0::Parameter>(precisions[1], input_shapes[1])};
|
||||
std::vector<size_t> layout{0, 2, 1, 3};
|
||||
// Note: validity of transpose_position values is checked in Transpose0213MatMulSinhFunction constructor
|
||||
if (transpose_position <= 1) {
|
||||
@ -194,6 +194,7 @@ std::shared_ptr<ov::Model> SoftmaxLoweredFunction::initLowered() const {
|
||||
const auto horizon_sum = std::make_shared<ngraph::snippets::op::HorizonSum>(sum);
|
||||
horizon_sum->add_control_dependency(loop_sum_end);
|
||||
|
||||
const auto size_exp = std::make_shared<ngraph::opset1::Constant>(ov::element::i32, ov::Shape{2});
|
||||
const auto buffer_exp = std::make_shared<ngraph::snippets::op::Buffer>(loop_sum_end->output(0));
|
||||
|
||||
loop_sum_begin->add_control_dependency(vector_buffer_sum);
|
||||
@ -303,6 +304,7 @@ std::shared_ptr<ov::Model> AddSoftmaxLoweredFunction::initLowered() const {
|
||||
|
||||
/* =========================================== */
|
||||
|
||||
const auto size_add = std::make_shared<ngraph::opset1::Constant>(ov::element::i32, ov::Shape{2});
|
||||
const auto buffer_add = std::make_shared<ngraph::snippets::op::Buffer>(loop_max_end->output(0));
|
||||
|
||||
/* === Sub + Exp + ReduceSum decomposition === */
|
||||
@ -331,6 +333,7 @@ std::shared_ptr<ov::Model> AddSoftmaxLoweredFunction::initLowered() const {
|
||||
const auto horizon_sum = std::make_shared<ngraph::snippets::op::HorizonSum>(sum);
|
||||
horizon_sum->add_control_dependency(loop_sum_end);
|
||||
|
||||
const auto size_exp = std::make_shared<ngraph::opset1::Constant>(ov::element::i32, ov::Shape{2});
|
||||
const auto buffer_exp = std::make_shared<ngraph::snippets::op::Buffer>(loop_sum_end->output(0));
|
||||
|
||||
loop_sum_begin->add_control_dependency(vector_buffer_sum);
|
||||
|
@ -5,50 +5,133 @@
|
||||
#include "subgraph_matmul.hpp"
|
||||
#include "common_test_utils/data_utils.hpp"
|
||||
#include <snippets/op/subgraph.hpp>
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ov_ops/type_relaxed.hpp"
|
||||
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
namespace snippets {
|
||||
std::shared_ptr<ov::Model> MatMulFunction::initOriginal() const {
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precision, input_shapes[0]);
|
||||
auto data1 = std::make_shared<op::v0::Parameter>(precision, input_shapes[1]);
|
||||
auto matmul = std::make_shared<op::v0::MatMul>(data0, data1);
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto data1 = std::make_shared<op::v0::Parameter>(precisions[1], input_shapes[1]);
|
||||
std::shared_ptr<Node> matmul;
|
||||
if (precisions[1] == ov::element::i8) {
|
||||
matmul = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(data0, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(data1, element::f32).get());
|
||||
} else {
|
||||
matmul = std::make_shared<op::v0::MatMul>(data0, data1);
|
||||
}
|
||||
return std::make_shared<ov::Model>(NodeVector{matmul}, ParameterVector{data0, data1});
|
||||
}
|
||||
std::shared_ptr<ov::Model> MatMulFunction::initReference() const {
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto data1 = std::make_shared<op::v0::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto indata0 = std::make_shared<op::v0::Parameter>(precisions[0], data0->get_output_partial_shape(0));
|
||||
auto indata1 = std::make_shared<op::v0::Parameter>(precisions[1], data1->get_output_partial_shape(0));
|
||||
std::shared_ptr<Node> matmul;
|
||||
if (precisions[1] == ov::element::i8) {
|
||||
matmul = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(indata0, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(indata1, element::f32).get());
|
||||
} else {
|
||||
matmul = std::make_shared<op::v0::MatMul>(indata0, indata1);
|
||||
}
|
||||
const auto subgraph = std::make_shared<ngraph::snippets::op::Subgraph>(NodeVector{data0, data1},
|
||||
std::make_shared<ov::Model>(NodeVector{matmul},
|
||||
ParameterVector{indata0, indata1}));
|
||||
return std::make_shared<ov::Model>(NodeVector{subgraph}, ParameterVector{data0, data1});
|
||||
}
|
||||
std::shared_ptr<ov::Model> FQMatMulFunction::initOriginal() const {
|
||||
auto const_order = std::make_shared<op::v0::Constant>(ov::element::i32, Shape {4}, std::vector<int>{0, 2, 1, 3});
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precision, input_shapes[0]);
|
||||
auto data1 = std::make_shared<op::v0::Parameter>(precision, input_shapes[1]);
|
||||
auto indata0 = std::make_shared<op::v0::Parameter>(precision, data0->get_output_partial_shape(0));
|
||||
auto indata1 = std::make_shared<op::v0::Parameter>(precision, data1->get_output_partial_shape(0));
|
||||
auto matmul = std::make_shared<ngraph::snippets::op::Subgraph>(NodeVector{data0, data1},
|
||||
std::make_shared<ov::Model>(NodeVector{std::make_shared<op::v0::MatMul>(indata0, indata1)},
|
||||
ParameterVector{indata0, indata1}));
|
||||
return std::make_shared<ov::Model>(NodeVector{matmul}, ParameterVector{data0, data1});
|
||||
auto ih = std::make_shared<op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{34.7436294});
|
||||
auto il = std::make_shared<op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{-35.0172004});
|
||||
auto oh = std::make_shared<op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{34.7436294});
|
||||
auto ol = std::make_shared<op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{-35.0172004});
|
||||
auto fq = std::make_shared<op::v0::FakeQuantize>(data0, il, ih, ol, oh, 256);
|
||||
std::shared_ptr<ov::Node> in0 = fq;
|
||||
if (pos == 0) {
|
||||
in0 = std::make_shared<op::v1::Transpose>(in0, const_order);
|
||||
}
|
||||
auto constant = ngraph::builder::makeConstant(ov::element::i8, const_shape.get_shape(), std::vector<int8_t>{}, true);
|
||||
auto convert = std::make_shared<op::v0::Convert>(constant, ov::element::f32);
|
||||
auto deq_mul = std::make_shared<op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{0.00499185826});
|
||||
auto mul = std::make_shared<op::v1::Multiply>(convert, deq_mul);
|
||||
std::shared_ptr<ov::Node> in1 = mul;
|
||||
if (pos == 1) {
|
||||
in1 = std::make_shared<op::v1::Transpose>(in1, const_order);
|
||||
}
|
||||
auto matmul = std::make_shared<op::v0::MatMul>(in0, in1);
|
||||
std::shared_ptr<ov::Node> out = matmul;
|
||||
if (pos == 2) {
|
||||
out = std::make_shared<op::v1::Transpose>(out, const_order);
|
||||
}
|
||||
return std::make_shared<ov::Model>(NodeVector{out}, ParameterVector{data0});
|
||||
}
|
||||
std::shared_ptr<ov::Model> MatMulBiasFunction::initOriginal() const {
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precision, input_shapes[0]);
|
||||
auto data1 = std::make_shared<op::v0::Parameter>(precision, input_shapes[1]);
|
||||
auto matmul = std::make_shared<op::v0::MatMul>(data0, data1);
|
||||
auto data2 = std::make_shared<op::v0::Parameter>(precision, input_shapes[2]);
|
||||
std::shared_ptr<Node> matmul;
|
||||
if (precisions[1] == ov::element::i8) {
|
||||
matmul = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(data0, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(data1, element::f32).get());
|
||||
} else {
|
||||
matmul = std::make_shared<op::v0::MatMul>(data0, data1);
|
||||
}
|
||||
auto bias = std::make_shared<op::v1::Add>(matmul, data2);
|
||||
return std::make_shared<ov::Model>(NodeVector{bias}, ParameterVector{data0, data1, data2});
|
||||
}
|
||||
std::shared_ptr<ov::Model> Transpose0213MatMulFunction::initOriginal() const {
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precision, input_shapes[0]);
|
||||
auto data1 = std::make_shared<op::v0::Parameter>(precision, input_shapes[1]);
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto data1 = std::make_shared<op::v0::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto const_order = std::make_shared<op::v0::Constant>(ov::element::i32, Shape {4}, std::vector<int>{0, 2, 1, 3});
|
||||
std::shared_ptr<Node> result;
|
||||
switch (transpose_position) {
|
||||
case 0: {
|
||||
auto transpose = std::make_shared<op::v1::Transpose>(data0, const_order);
|
||||
result = std::make_shared<op::v0::MatMul>(transpose, data1);
|
||||
if (precisions[1] == ov::element::i8) {
|
||||
result = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(transpose, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(data1, element::f32).get());
|
||||
} else {
|
||||
result = std::make_shared<op::v0::MatMul>(transpose, data1);
|
||||
}
|
||||
break;
|
||||
} case 1: {
|
||||
auto transpose = std::make_shared<op::v1::Transpose>(data1, const_order);
|
||||
result = std::make_shared<op::v0::MatMul>(data0, transpose);
|
||||
if (precisions[1] == ov::element::i8) {
|
||||
result = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(data0, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(transpose, element::f32).get());
|
||||
} else {
|
||||
result = std::make_shared<op::v0::MatMul>(data0, transpose);
|
||||
}
|
||||
break;
|
||||
} case 2: {
|
||||
auto matmul = std::make_shared<op::v0::MatMul>(data0, data1);
|
||||
std::shared_ptr<ov::Node> matmul;
|
||||
if (precisions[1] == ov::element::i8) {
|
||||
matmul = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(data0, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(data1, element::f32).get());
|
||||
} else {
|
||||
matmul = std::make_shared<op::v0::MatMul>(data0, data1);
|
||||
}
|
||||
result = std::make_shared<op::v1::Transpose>(matmul, const_order);
|
||||
break;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user