[CPU] Support MHA optimization (#12643)
* [CPU] Support MHA optimization * [CPU] Extend pattern supported by MHA node * [CPU] MHA: fixed int8 perf issue Co-authored-by: Gu, Jianan <jianan.gu@intel.com>
This commit is contained in:
@@ -197,6 +197,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
|
||||
{ "Subgraph", Type::Subgraph},
|
||||
{ "PriorBox", Type::PriorBox},
|
||||
{ "PriorBoxClustered", Type::PriorBoxClustered},
|
||||
{ "MHA", Type::MHA},
|
||||
};
|
||||
|
||||
Type TypeFromName(const std::string& type) {
|
||||
@@ -388,6 +389,8 @@ std::string NameFromType(const Type type) {
|
||||
return "Reference";
|
||||
case Type::Subgraph:
|
||||
return "Subgraph";
|
||||
case Type::MHA:
|
||||
return "MHA";
|
||||
default:
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
@@ -108,6 +108,7 @@ enum class Type {
|
||||
Subgraph,
|
||||
PriorBox,
|
||||
PriorBoxClustered,
|
||||
MHA
|
||||
};
|
||||
|
||||
enum class Algorithm {
|
||||
|
||||
@@ -144,9 +144,6 @@ void jit_emitter::emitter_preamble(const std::vector<size_t> &in_idxs, const std
|
||||
for (size_t i = 0; i < preserved_vec_idxs.size(); ++i) {
|
||||
push_vec(h->ptr[h->rsp + i * get_vec_length()], preserved_vec_idxs[i]);
|
||||
}
|
||||
|
||||
if (!entry_map_.empty())
|
||||
load_table_addr();
|
||||
}
|
||||
|
||||
|
||||
@@ -204,6 +201,9 @@ void jit_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vecto
|
||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
|
||||
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
|
||||
|
||||
if (!entry_map_.empty())
|
||||
load_table_addr();
|
||||
|
||||
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, nullptr);
|
||||
|
||||
emitter_postamble();
|
||||
@@ -214,6 +214,9 @@ void jit_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vecto
|
||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
|
||||
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
|
||||
|
||||
if (!entry_map_.empty())
|
||||
load_table_addr();
|
||||
|
||||
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, emit_context.get());
|
||||
|
||||
emitter_postamble();
|
||||
|
||||
@@ -516,6 +516,34 @@ void jit_load_emitter::register_table_entries() {
|
||||
push_arg_entry_of("float_max", 0x7f7fffff, true);
|
||||
}
|
||||
|
||||
void jit_load_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
|
||||
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
|
||||
|
||||
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, nullptr);
|
||||
|
||||
emitter_postamble();
|
||||
}
|
||||
|
||||
void jit_load_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
||||
const std::shared_ptr<const emitter_context> &emit_context,
|
||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
|
||||
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
|
||||
|
||||
const auto* load_emitter_context = dynamic_cast<const ov::intel_cpu::load_emitter_context*>(emit_context.get());
|
||||
if (load_emitter_context == nullptr) {
|
||||
IE_THROW() << "Load emitter in " << name << " does not get load emmiter context.";
|
||||
}
|
||||
|
||||
if (!entry_map_.empty() && load_emitter_context->is_fill_)
|
||||
load_table_addr();
|
||||
|
||||
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, emit_context.get());
|
||||
|
||||
emitter_postamble();
|
||||
}
|
||||
|
||||
|
||||
/// STORE ///
|
||||
jit_store_emitter::jit_store_emitter(jit_generator *host, cpu_isa_t host_isa,
|
||||
Precision exec_prc, emitter_in_out_map in_out_type)
|
||||
|
||||
@@ -47,6 +47,14 @@ class jit_load_emitter : public jit_emitter {
|
||||
public:
|
||||
jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
|
||||
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec);
|
||||
|
||||
void emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
||||
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {}) const override;
|
||||
|
||||
void emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
||||
const std::shared_ptr<const emitter_context> &emit_context,
|
||||
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {}) override;
|
||||
|
||||
/**
|
||||
* load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc.
|
||||
* is_fill: when load_num can not fully fit in vector register, whether fill_value should be filled as default values.
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ngraph_transformations/op/leaky_relu.hpp"
|
||||
#include "ngraph_transformations/op/power_static.hpp"
|
||||
#include "ngraph_transformations/op/swish_cpu.hpp"
|
||||
#include "ngraph_transformations/op/mha.hpp"
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph_ops/type_relaxed.hpp>
|
||||
@@ -40,6 +41,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
|
||||
NGRAPH_OP(LeakyReluNode, ov::intel_cpu)
|
||||
NGRAPH_OP(PowerStaticNode, ov::intel_cpu)
|
||||
NGRAPH_OP(SwishNode, ov::intel_cpu)
|
||||
NGRAPH_OP(MHANode, ov::intel_cpu)
|
||||
#undef NGRAPH_OP
|
||||
|
||||
return opset;
|
||||
|
||||
559
src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.cpp
Normal file
559
src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.cpp
Normal file
@@ -0,0 +1,559 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "mha_fusion.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include "op/mha.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
// TODO: draw pattern
|
||||
ov::intel_cpu::MHAFusion::MHAFusion() {
|
||||
MATCHER_SCOPE(MHAFusion);
|
||||
|
||||
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
|
||||
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
|
||||
auto mul = std::make_shared<ngraph::opset3::Multiply>(transpose1, in2);
|
||||
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, mul);
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(matmul0, in3);
|
||||
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, in6, true);
|
||||
auto softmax = std::make_shared<ngraph::opset1::Softmax>(reshape0);
|
||||
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softmax, in7, true);
|
||||
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
|
||||
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2);
|
||||
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(matmul1, in10);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose0_in = pattern_to_output.at(in0);
|
||||
auto transpose1_in = pattern_to_output.at(in1);
|
||||
auto mul_in1 = pattern_to_output.at(in2);
|
||||
auto add_in1 = pattern_to_output.at(in3);
|
||||
auto transpose2_in = pattern_to_output.at(in8);
|
||||
|
||||
// TODO: check transpose order
|
||||
|
||||
auto const_node = std::dynamic_pointer_cast<ngraph::opset3::Constant>(mul_in1.get_node_shared_ptr());
|
||||
if (!const_node)
|
||||
return false;
|
||||
|
||||
std::vector<float> mul_scales;
|
||||
if (auto mul_node = std::dynamic_pointer_cast<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
|
||||
mul_scales = std::dynamic_pointer_cast<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto matmul0_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
|
||||
if (!matmul0_node)
|
||||
return false;
|
||||
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
|
||||
return false;
|
||||
|
||||
auto reshape0_node = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(pattern_to_output.at(reshape0).get_node_shared_ptr());
|
||||
if (!reshape0_node)
|
||||
return false;
|
||||
|
||||
if (auto reshape0_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(in6).get_node_shared_ptr())) {
|
||||
// TODO: add valid condition based on reshape pattern
|
||||
auto reshape0_pattern_values = reshape0_pattern->cast_vector<int64_t>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reshape1_node = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(pattern_to_output.at(reshape1).get_node_shared_ptr());
|
||||
if (!reshape1_node)
|
||||
return false;
|
||||
|
||||
if (auto reshape1_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(in7).get_node_shared_ptr())) {
|
||||
// TODO: add valid condition based on reshape pattern
|
||||
auto reshape1_pattern_values = reshape1_pattern->cast_vector<int64_t>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (reshape0_node->get_output_partial_shape(0).rank() != 2 || reshape0_node->get_input_partial_shape(0) != reshape1_node->get_output_partial_shape(0))
|
||||
return false;
|
||||
|
||||
auto softmax_node = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
|
||||
if (!softmax_node)
|
||||
return false;
|
||||
if (softmax_node->get_axis() != 1)
|
||||
return false;
|
||||
|
||||
auto matmul1_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||
if (!matmul1_node)
|
||||
return false;
|
||||
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
|
||||
return false;
|
||||
|
||||
bool is_mul_first = true;
|
||||
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
|
||||
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
|
||||
transpose3_node->output(0).get_element_type());
|
||||
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(reshape0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||
pattern_to_output.at(reshape1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||
},
|
||||
mha);
|
||||
|
||||
if (transformation_callback(mha)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::replace_node(m.get_match_root(), mha);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::intel_cpu::MHAFusion2::MHAFusion2() {
|
||||
MATCHER_SCOPE(MHAFusion2);
|
||||
|
||||
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
|
||||
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
|
||||
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1);
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(matmul0, in3);
|
||||
auto softmax = std::make_shared<ngraph::opset1::Softmax>(add);
|
||||
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
|
||||
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, transpose2);
|
||||
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(matmul1, in10);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose0_in = pattern_to_output.at(in0);
|
||||
auto transpose1_in = pattern_to_output.at(in1);
|
||||
auto add_in1 = pattern_to_output.at(in3);
|
||||
auto transpose2_in = pattern_to_output.at(in8);
|
||||
|
||||
// TODO: check transpose order
|
||||
|
||||
auto matmul0_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
|
||||
if (!matmul0_node)
|
||||
return false;
|
||||
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
|
||||
return false;
|
||||
|
||||
auto softmax_node = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
|
||||
if (!softmax_node)
|
||||
return false;
|
||||
if (softmax_node->get_axis() != 3)
|
||||
return false;
|
||||
|
||||
auto matmul1_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||
if (!matmul1_node)
|
||||
return false;
|
||||
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
|
||||
return false;
|
||||
|
||||
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
|
||||
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, std::vector<float>(), false,
|
||||
transpose3_node->output(0).get_element_type());
|
||||
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||
},
|
||||
mha);
|
||||
|
||||
if (transformation_callback(mha)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::replace_node(m.get_match_root(), mha);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
static std::vector<float> simplifyToScale(const std::shared_ptr<ngraph::opset1::FakeQuantize>& fq_node) {
|
||||
auto levels = fq_node->get_levels();
|
||||
auto input_low = std::dynamic_pointer_cast<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(1))->cast_vector<float>();
|
||||
auto input_high = std::dynamic_pointer_cast<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(2))->cast_vector<float>();
|
||||
auto output_low = std::dynamic_pointer_cast<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(3))->cast_vector<float>();
|
||||
auto output_high = std::dynamic_pointer_cast<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(4))->cast_vector<float>();
|
||||
|
||||
std::vector<float> cl, ch, isc, ish, osc, osh;
|
||||
for (int i = 0; i < input_low.size(); i++) {
|
||||
cl.push_back(input_low[i]);
|
||||
}
|
||||
for (int i = 0; i < input_high.size(); i++) {
|
||||
ch.push_back(input_high[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < std::max(input_low.size(), input_high.size()); i++) {
|
||||
float il = input_low[input_low.size() == 1 ? 0 : i];
|
||||
float ih = input_high[input_high.size() == 1 ? 0 : i];
|
||||
|
||||
isc.push_back((levels - 1) / (ih - il));
|
||||
ish.push_back(-il * (levels - 1) / (ih - il));
|
||||
}
|
||||
|
||||
for (int i = 0; i < std::max(output_low.size(), output_high.size()); i++) {
|
||||
float ol = output_low[output_low.size() == 1 ? 0 : i];
|
||||
float oh = output_high[output_high.size() == 1 ? 0 : i];
|
||||
|
||||
osc.push_back((oh - ol) / (levels - 1));
|
||||
osh.push_back(ol);
|
||||
}
|
||||
|
||||
std::vector<float> outScale;
|
||||
|
||||
if (fq_node->get_output_element_type(0) == ngraph::element::u8 &&
|
||||
std::all_of(cl.cbegin(), cl.cend(), [](float val) { return val == 0.0f; }) &&
|
||||
std::all_of(ish.cbegin(), ish.cend(), [](float val) { return val == 0.0f; }) &&
|
||||
std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.0f; }) &&
|
||||
std::all_of(osh.cbegin(), osh.cend(), [](float val) { return val == 0.0f; })) {
|
||||
outScale = isc;
|
||||
}
|
||||
|
||||
if (fq_node->get_output_element_type(0) == ngraph::element::i8 &&
|
||||
std::all_of(ish.cbegin(), ish.cend(), [](float val) { return std::abs(val - 128.f) < 0.0001f; }) &&
|
||||
std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.f; }) &&
|
||||
std::all_of(osh.cbegin(), osh.cend(), [](float val) { return std::abs(val + 128.f) < 0.0001f; })) {
|
||||
bool isCropAligned = true;
|
||||
for (int i = 0; i < std::max(cl.size(), isc.size()); i++) {
|
||||
if (std::abs(cl[cl.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] + 128.f) > 0.0001f) {
|
||||
isCropAligned = false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < std::max(ch.size(), isc.size()); i++) {
|
||||
if (std::abs(ch[ch.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] - 127.f) > 0.0001f) {
|
||||
isCropAligned = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (isCropAligned) {
|
||||
outScale = isc;
|
||||
}
|
||||
}
|
||||
|
||||
return outScale;
|
||||
}
|
||||
|
||||
// TODO: draw pattern
|
||||
ov::intel_cpu::MHAQuantFusion::MHAQuantFusion() {
|
||||
MATCHER_SCOPE(MHAQuantFusion);
|
||||
|
||||
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
|
||||
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
|
||||
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1);
|
||||
auto fakeQuantize0 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul0,
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(fakeQuantize0, in3);
|
||||
auto mul = std::make_shared<ngraph::opset3::Multiply>(add, in2);
|
||||
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(mul, in6, true);
|
||||
auto softmax = std::make_shared<ngraph::opset1::Softmax>(reshape0);
|
||||
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softmax, in7, true);
|
||||
auto fakeQuantize1 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({reshape1,
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
|
||||
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(fakeQuantize1, transpose2);
|
||||
auto fakeQuantize2 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul1,
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize2, in10);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose0_in = pattern_to_output.at(in0);
|
||||
auto transpose1_in = pattern_to_output.at(in1);
|
||||
auto add_in1 = pattern_to_output.at(in3);
|
||||
auto transpose2_in = pattern_to_output.at(in8);
|
||||
|
||||
// todo: check mul shape
|
||||
std::vector<float> mul_scales;
|
||||
if (auto mul_node = std::dynamic_pointer_cast<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
|
||||
mul_scales = std::dynamic_pointer_cast<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: check transpose order
|
||||
auto matmul0_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
|
||||
if (!matmul0_node)
|
||||
return false;
|
||||
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
|
||||
return false;
|
||||
|
||||
std::vector<float> fq0_scale;
|
||||
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr())) {
|
||||
fq0_scale = simplifyToScale(fq_node);
|
||||
if (!fq0_scale.size())
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reshape0_node = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(pattern_to_output.at(reshape0).get_node_shared_ptr());
|
||||
if (!reshape0_node)
|
||||
return false;
|
||||
|
||||
if (auto reshape0_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(in6).get_node_shared_ptr())) {
|
||||
// TODO: add valid condition based on reshape pattern
|
||||
auto reshape0_pattern_values = reshape0_pattern->cast_vector<int64_t>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reshape1_node = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(pattern_to_output.at(reshape1).get_node_shared_ptr());
|
||||
if (!reshape1_node)
|
||||
return false;
|
||||
|
||||
if (auto reshape1_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(in7).get_node_shared_ptr())) {
|
||||
// TODO: add valid condition based on reshape pattern
|
||||
auto reshape1_pattern_values = reshape1_pattern->cast_vector<int64_t>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (reshape0_node->get_output_partial_shape(0).rank() != 2 || reshape0_node->get_input_partial_shape(0) != reshape1_node->get_output_partial_shape(0))
|
||||
return false;
|
||||
|
||||
auto softmax_node = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
|
||||
if (!softmax_node)
|
||||
return false;
|
||||
if (softmax_node->get_axis() != 1)
|
||||
return false;
|
||||
|
||||
std::vector<float> fq1_scale;
|
||||
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
|
||||
fq1_scale = simplifyToScale(fq_node);
|
||||
if (!fq1_scale.size())
|
||||
return false;
|
||||
}
|
||||
|
||||
auto matmul1_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||
if (!matmul1_node)
|
||||
return false;
|
||||
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
|
||||
return false;
|
||||
|
||||
std::vector<float> fq2_scale;
|
||||
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize2).get_node_shared_ptr())) {
|
||||
fq2_scale = simplifyToScale(fq_node);
|
||||
if (!fq2_scale.size())
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_mul_first = false;
|
||||
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
|
||||
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
|
||||
std::vector<float>(), fq0_scale, fq1_scale, fq2_scale,
|
||||
transpose3_node->output(0).get_element_type());
|
||||
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(reshape0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||
pattern_to_output.at(reshape1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(fakeQuantize2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||
},
|
||||
mha);
|
||||
|
||||
if (transformation_callback(mha)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::replace_node(m.get_match_root(), mha);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
// TODO: draw pattern
|
||||
ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
|
||||
MATCHER_SCOPE(MHAQuantFusion2);
|
||||
|
||||
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
|
||||
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
|
||||
auto fakeQuantize0 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({transpose1,
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, fakeQuantize0);
|
||||
auto mul = std::make_shared<ngraph::opset3::Multiply>(matmul0, in2);
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(mul, in3);
|
||||
auto softmax = std::make_shared<ngraph::opset1::Softmax>(add);
|
||||
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
|
||||
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, transpose2);
|
||||
auto fakeQuantize1 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul1,
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize1, in10);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose0_in = pattern_to_output.at(in0);
|
||||
auto transpose1_in = pattern_to_output.at(in1);
|
||||
auto add_in1 = pattern_to_output.at(in3);
|
||||
auto transpose2_in = pattern_to_output.at(in8);
|
||||
|
||||
// todo: check mul shape
|
||||
std::vector<float> mul_scales;
|
||||
if (auto mul_node = std::dynamic_pointer_cast<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
|
||||
mul_scales = std::dynamic_pointer_cast<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: check transpose order
|
||||
auto matmul0_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
|
||||
if (!matmul0_node)
|
||||
return false;
|
||||
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
|
||||
return false;
|
||||
|
||||
std::vector<float> fq0_scale;
|
||||
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr())) {
|
||||
fq0_scale = simplifyToScale(fq_node);
|
||||
if (!fq0_scale.size())
|
||||
return false;
|
||||
}
|
||||
|
||||
auto softmax_node = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
|
||||
if (!softmax_node)
|
||||
return false;
|
||||
if (softmax_node->get_axis() != 3)
|
||||
return false;
|
||||
|
||||
std::vector<float> fq1_scale;
|
||||
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
|
||||
fq1_scale = simplifyToScale(fq_node);
|
||||
if (!fq1_scale.size())
|
||||
return false;
|
||||
}
|
||||
|
||||
auto matmul1_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||
if (!matmul1_node)
|
||||
return false;
|
||||
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
|
||||
return false;
|
||||
|
||||
std::vector<float> fq2_scale;
|
||||
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
|
||||
fq2_scale = simplifyToScale(fq_node);
|
||||
if (!fq2_scale.size())
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_mul_first = true;
|
||||
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
|
||||
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
|
||||
fq0_scale, std::vector<float>(), std::vector<float>(), fq1_scale,
|
||||
transpose3_node->output(0).get_element_type());
|
||||
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||
},
|
||||
mha);
|
||||
|
||||
if (transformation_callback(mha)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::replace_node(m.get_match_root(), mha);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
class MHAFusion: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("MHAFusion", "0");
|
||||
MHAFusion();
|
||||
};
|
||||
|
||||
class MHAFusion2: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("MHAFusion2", "0");
|
||||
MHAFusion2();
|
||||
};
|
||||
|
||||
class MHAQuantFusion: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("MHAQuantFusion", "0");
|
||||
MHAQuantFusion();
|
||||
};
|
||||
|
||||
class MHAQuantFusion2: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("MHAQuantFusion2", "0");
|
||||
MHAQuantFusion2();
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
98
src/plugins/intel_cpu/src/ngraph_transformations/op/mha.cpp
Normal file
98
src/plugins/intel_cpu/src/ngraph_transformations/op/mha.cpp
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "mha.hpp"
|
||||
#include "../itt.hpp"
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <matmul_shape_inference.hpp>
|
||||
|
||||
ov::intel_cpu::MHANode::MHANode(const ngraph::Output<ngraph::Node> &in0,
|
||||
const ngraph::Output<ngraph::Node> &in1,
|
||||
const ngraph::Output<ngraph::Node> &in2,
|
||||
const ngraph::Output<ngraph::Node> &in3,
|
||||
const std::vector<float> &mul_scales,
|
||||
bool is_mul_first,
|
||||
const ngraph::element::Type output_type)
|
||||
: Op({in0, in1, in2, in3}), m_output_type(output_type) {
|
||||
this->mul_scales = mul_scales;
|
||||
this->is_mul_first = is_mul_first;
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
ov::intel_cpu::MHANode::MHANode(const ngraph::Output<ngraph::Node> &in0,
|
||||
const ngraph::Output<ngraph::Node> &in1,
|
||||
const ngraph::Output<ngraph::Node> &in2,
|
||||
const ngraph::Output<ngraph::Node> &in3,
|
||||
const std::vector<float> &mul_scales,
|
||||
bool is_mul_first,
|
||||
const std::vector<float> &fq_scales0,
|
||||
const std::vector<float> &fq_scales1,
|
||||
const std::vector<float> &fq_scales2,
|
||||
const std::vector<float> &fq_scales3,
|
||||
const ngraph::element::Type output_type)
|
||||
: Op({in0, in1, in2, in3}), m_output_type(output_type) {
|
||||
this->mul_scales = mul_scales;
|
||||
this->is_mul_first = is_mul_first;
|
||||
this->fq_scales0 = fq_scales0;
|
||||
this->fq_scales1 = fq_scales1;
|
||||
this->fq_scales2 = fq_scales2;
|
||||
this->fq_scales3 = fq_scales3;
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> ov::intel_cpu::MHANode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(MHANode_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<ov::intel_cpu::MHANode>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
|
||||
mul_scales, is_mul_first, fq_scales0, fq_scales1, fq_scales2, fq_scales3, m_output_type);
|
||||
}
|
||||
|
||||
void ov::intel_cpu::MHANode::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(MHANode_validate_and_infer_types);
|
||||
|
||||
auto transpose = [](const ov::Shape& shape, const std::vector<size_t>& order) -> ov::Shape {
|
||||
std::vector<size_t> new_shape(shape.size());
|
||||
for (int i = 0; i < shape.size(); i++) {
|
||||
new_shape[i] = shape[order[i]];
|
||||
}
|
||||
return new_shape;
|
||||
};
|
||||
|
||||
const auto matmul0_shape0 = transpose(get_input_partial_shape(0).get_shape(), {0, 2, 1, 3});
|
||||
const auto matmul0_shape1 = transpose(get_input_partial_shape(1).get_shape(), {0, 2, 3, 1});
|
||||
|
||||
auto matmul0_in0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul0_shape0);
|
||||
auto matmul0_in1 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul0_shape1);
|
||||
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(matmul0_in0, matmul0_in1);
|
||||
|
||||
std::vector<ov::PartialShape> matmul0_input_shapes = {matmul0_shape0, matmul0_shape1};
|
||||
std::vector<ov::PartialShape> matmul0_output_shapes = {ov::PartialShape{}};
|
||||
|
||||
shape_infer(matmul0.get(), matmul0_input_shapes, matmul0_output_shapes);
|
||||
|
||||
const auto matmul1_shape0 = matmul0_output_shapes[0];
|
||||
const auto matmul1_shape1 = transpose(get_input_partial_shape(3).get_shape(), {0, 2, 1, 3});
|
||||
|
||||
auto matmul1_in0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul1_shape0);
|
||||
auto matmul1_in1 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul1_shape1);
|
||||
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(matmul1_in0, matmul1_in1);
|
||||
|
||||
std::vector<ov::PartialShape> matmul1_input_shapes = {matmul1_shape0, matmul1_shape1};
|
||||
std::vector<ov::PartialShape> matmul1_output_shapes = {ov::PartialShape{}};
|
||||
|
||||
shape_infer(matmul1.get(), matmul1_input_shapes, matmul1_output_shapes);
|
||||
|
||||
const auto output_shape = transpose(matmul1_output_shapes[0].get_shape(), {0, 2, 1, 3});
|
||||
|
||||
set_output_type(
|
||||
0,
|
||||
m_output_type == ngraph::element::undefined || m_output_type == ngraph::element::dynamic ? get_input_element_type(0) : m_output_type,
|
||||
output_shape);
|
||||
}
|
||||
|
||||
bool ov::intel_cpu::MHANode::visit_attributes(ngraph::AttributeVisitor &visitor) {
|
||||
INTERNAL_OP_SCOPE(MHANode_visit_attributes);
|
||||
visitor.on_attribute("out-type", m_output_type);
|
||||
return true;
|
||||
}
|
||||
78
src/plugins/intel_cpu/src/ngraph_transformations/op/mha.hpp
Normal file
78
src/plugins/intel_cpu/src/ngraph_transformations/op/mha.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/op/op.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
class MHANode : public ngraph::op::Op {
|
||||
public:
|
||||
OPENVINO_OP("MHA", "cpu_plugin_opset");
|
||||
|
||||
MHANode() = default;
|
||||
|
||||
MHANode(const ngraph::Output<ngraph::Node> &in0,
|
||||
const ngraph::Output<ngraph::Node> &in1,
|
||||
const ngraph::Output<ngraph::Node> &in2,
|
||||
const ngraph::Output<ngraph::Node> &in3,
|
||||
const std::vector<float> &mul_scales,
|
||||
bool is_mul_first,
|
||||
const ngraph::element::Type output_type);
|
||||
|
||||
MHANode(const ngraph::Output<ngraph::Node> &in0,
|
||||
const ngraph::Output<ngraph::Node> &in1,
|
||||
const ngraph::Output<ngraph::Node> &in2,
|
||||
const ngraph::Output<ngraph::Node> &in3,
|
||||
const std::vector<float> &mul_scales,
|
||||
bool is_mul_first,
|
||||
const std::vector<float> &fq_scales0,
|
||||
const std::vector<float> &fq_scales1,
|
||||
const std::vector<float> &fq_scales2,
|
||||
const std::vector<float> &fq_scales3,
|
||||
const ngraph::element::Type output_type);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(ngraph::AttributeVisitor &visitor) override;
|
||||
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector &new_args) const override;
|
||||
|
||||
ngraph::element::Type get_output_type() const { return m_output_type; }
|
||||
|
||||
const std::vector<float>& get_mul_scales() const {
|
||||
return mul_scales;
|
||||
}
|
||||
|
||||
const std::vector<float>& get_fq_scales0() const {
|
||||
return fq_scales0;
|
||||
}
|
||||
const std::vector<float>& get_fq_scales1() const {
|
||||
return fq_scales1;
|
||||
}
|
||||
const std::vector<float>& get_fq_scales2() const {
|
||||
return fq_scales2;
|
||||
}
|
||||
const std::vector<float>& get_fq_scales3() const {
|
||||
return fq_scales3;
|
||||
}
|
||||
|
||||
bool get_is_mul_first() const {
|
||||
return is_mul_first;
|
||||
}
|
||||
|
||||
private:
|
||||
ngraph::element::Type m_output_type;
|
||||
std::vector<float> mul_scales;
|
||||
bool is_mul_first;
|
||||
std::vector<float> fq_scales0;
|
||||
std::vector<float> fq_scales1;
|
||||
std::vector<float> fq_scales2;
|
||||
std::vector<float> fq_scales3;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
1370
src/plugins/intel_cpu/src/nodes/mha.cpp
Normal file
1370
src/plugins/intel_cpu/src/nodes/mha.cpp
Normal file
File diff suppressed because it is too large
Load Diff
242
src/plugins/intel_cpu/src/nodes/mha.h
Normal file
242
src/plugins/intel_cpu/src/nodes/mha.h
Normal file
@@ -0,0 +1,242 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <node.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cpu/x64/brgemm/brgemm.hpp>
|
||||
#include <cpu/x64/matmul/brgemm_matmul_copy_utils.hpp>
|
||||
#include <cpu/x64/matmul/brgemm_matmul_utils.hpp>
|
||||
#include <cpu/x64/amx_tile_configure.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace node {
|
||||
|
||||
struct jit_mul_add_softmax_compile_params {
|
||||
InferenceEngine::Precision src_prc;
|
||||
InferenceEngine::Precision dst_prc;
|
||||
size_t work_amount;
|
||||
bool with_mul_scales;
|
||||
bool is_mul_first;
|
||||
bool with_scales0;
|
||||
bool broadcast_scales0;
|
||||
bool with_scales1;
|
||||
bool broadcast_scales1;
|
||||
};
|
||||
|
||||
struct jit_mul_add_softmax_call_args {
|
||||
const void *p_in0;
|
||||
const void *p_mul_in1;
|
||||
const void *p_add_in1;
|
||||
void *p_out;
|
||||
void *p_buffer;
|
||||
const void *p_scales0;
|
||||
const void *p_scales1;
|
||||
};
|
||||
|
||||
struct jit_uni_mul_add_softmax_kernel {
|
||||
void (*ker_)(const jit_mul_add_softmax_call_args*);
|
||||
|
||||
void operator()(const jit_mul_add_softmax_call_args* call_args) {
|
||||
assert(ker_);
|
||||
ker_(call_args);
|
||||
}
|
||||
|
||||
explicit jit_uni_mul_add_softmax_kernel(const jit_mul_add_softmax_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
|
||||
virtual ~jit_uni_mul_add_softmax_kernel() {}
|
||||
|
||||
virtual void create_ker() = 0;
|
||||
|
||||
jit_mul_add_softmax_compile_params jcp_;
|
||||
};
|
||||
|
||||
struct jit_convert_reorder_compile_params {
|
||||
InferenceEngine::Precision src_prc;
|
||||
InferenceEngine::Precision dst_prc;
|
||||
size_t inner_work_amount;
|
||||
bool with_scales;
|
||||
bool broadcast_scales;
|
||||
size_t src_stride;
|
||||
size_t dst_stride;
|
||||
};
|
||||
|
||||
struct jit_convert_reorder_call_args {
|
||||
const void *p_in;
|
||||
void *p_out;
|
||||
const void *p_scales;
|
||||
size_t outter_work_amount;
|
||||
};
|
||||
|
||||
struct jit_uni_convert_reorder_kernel {
|
||||
void (*ker_)(const jit_convert_reorder_call_args*);
|
||||
|
||||
void operator()(const jit_convert_reorder_call_args* call_args) {
|
||||
assert(ker_);
|
||||
ker_(call_args);
|
||||
}
|
||||
|
||||
explicit jit_uni_convert_reorder_kernel(const jit_convert_reorder_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
|
||||
virtual ~jit_uni_convert_reorder_kernel() {}
|
||||
|
||||
virtual void create_ker() = 0;
|
||||
|
||||
jit_convert_reorder_compile_params jcp_;
|
||||
};
|
||||
|
||||
struct jit_convert_transpose_compile_params {
|
||||
InferenceEngine::Precision src_prc;
|
||||
InferenceEngine::Precision dst_prc;
|
||||
size_t inner_work_amount;
|
||||
size_t outter_work_amount;
|
||||
bool with_scales;
|
||||
bool broadcast_scales;
|
||||
size_t inner_src_stride;
|
||||
size_t outter_src_stride;
|
||||
size_t outter_dst_stride;
|
||||
};
|
||||
|
||||
struct jit_convert_transpose_call_args {
|
||||
const void *p_in;
|
||||
void *p_out;
|
||||
const void *p_scales;
|
||||
};
|
||||
|
||||
struct jit_uni_convert_transpose_kernel {
|
||||
void (*ker_)(const jit_convert_transpose_call_args*);
|
||||
|
||||
void operator()(const jit_convert_transpose_call_args* call_args) {
|
||||
assert(ker_);
|
||||
ker_(call_args);
|
||||
}
|
||||
|
||||
explicit jit_uni_convert_transpose_kernel(const jit_convert_transpose_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
|
||||
virtual ~jit_uni_convert_transpose_kernel() {}
|
||||
|
||||
virtual void create_ker() = 0;
|
||||
|
||||
jit_convert_transpose_compile_params jcp_;
|
||||
};
|
||||
|
||||
#define MHA_BRGEMM_KERNELS_NUM 8
|
||||
|
||||
class MHA : public Node {
|
||||
public:
|
||||
MHA(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache);
|
||||
|
||||
void getSupportedDescriptors() override {};
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void execute(dnnl::stream strm) override;
|
||||
bool created() const override;
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
protected:
|
||||
void executeDynamicImpl(dnnl::stream strm) override;
|
||||
void prepareParams() override;
|
||||
|
||||
private:
|
||||
struct brgemmCtx {
|
||||
size_t M, N, K, LDA, LDB, LDC;
|
||||
dnnl_data_type_t dt_in0, dt_in1;
|
||||
char palette[64];
|
||||
bool is_with_amx;
|
||||
bool is_with_comp;
|
||||
float beta;
|
||||
};
|
||||
|
||||
template <typename in1_type>
|
||||
void mhaImpl();
|
||||
|
||||
void init_brgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx);
|
||||
void init_brgemm_copy_a(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_a_t>& brgCopyKernel,
|
||||
size_t K, size_t K_blk, size_t K_tail, size_t LDA, dnnl_data_type_t dt_in0);
|
||||
void init_brgemm_copy_b(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t>& brgCopyKernel,
|
||||
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1);
|
||||
|
||||
void callBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel,
|
||||
const void* pin0, const void* pin1, void* pout, void* wsp);
|
||||
|
||||
size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) {
|
||||
return mIdx * 4 + kIdx * 2 + nIdx;
|
||||
}
|
||||
|
||||
std::vector<InferenceEngine::Precision> inputPrecisions;
|
||||
InferenceEngine::Precision accPrecision0;
|
||||
InferenceEngine::Precision accPrecision1;
|
||||
|
||||
VectorDims dimsTranspose0In0;
|
||||
VectorDims dimsTranspose1In0;
|
||||
VectorDims dimsMulIn1;
|
||||
VectorDims dimsAddIn1;
|
||||
VectorDims dimsTranspose2In0;
|
||||
VectorDims dimsOut;
|
||||
|
||||
VectorDims strTranspose0In0;
|
||||
VectorDims strTranspose1In0;
|
||||
VectorDims strMulIn1;
|
||||
VectorDims strAddIn1;
|
||||
VectorDims strTranspose2In0;
|
||||
VectorDims strOut;
|
||||
|
||||
VectorDims dimsMatMul0In0;
|
||||
VectorDims dimsMatMul0In1;
|
||||
VectorDims dimsMatMul0Out;
|
||||
VectorDims dimsMatMul1In1;
|
||||
VectorDims dimsMatMul1Out;
|
||||
|
||||
size_t batch0, batch1;
|
||||
size_t M, M_blk, M_tail;
|
||||
size_t K0, K0_blk, K0_tail, N0, N0_blk, N0_tail;
|
||||
size_t K1, K1_blk, K1_tail, N1, N1_blk, N1_tail;
|
||||
|
||||
size_t bufferMatMul0In0Size;
|
||||
size_t bufferMatMul0In1Size;
|
||||
size_t bufferMatMul0OutSize;
|
||||
size_t bufferMatMul1In1Size;
|
||||
size_t bufferMatMul1OutSize;
|
||||
size_t bufferCompensation0Size;
|
||||
size_t bufferCompensation1Size;
|
||||
size_t wsp_size_per_thread = 4 * 1024;
|
||||
|
||||
std::vector<uint8_t> bufferMatMul0In0;
|
||||
std::vector<uint8_t> bufferMatMul0In1;
|
||||
std::vector<uint8_t> bufferMatMul0Out;
|
||||
std::vector<uint8_t> bufferMatMul1In1;
|
||||
std::vector<uint8_t> bufferMatMul1Out;
|
||||
std::vector<int32_t> bufferCompensation0;
|
||||
std::vector<int32_t> bufferCompensation1;
|
||||
std::vector<size_t> wsp;
|
||||
|
||||
bool isMulFirst;
|
||||
|
||||
std::vector<float> mulScales;
|
||||
std::vector<float> fqScales0;
|
||||
std::vector<float> fqScales1;
|
||||
std::vector<float> fqScales2;
|
||||
std::vector<float> fqScales3;
|
||||
|
||||
size_t brg0VnniFactor;
|
||||
brgemmCtx brgCtxs0[MHA_BRGEMM_KERNELS_NUM];
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels0[MHA_BRGEMM_KERNELS_NUM];
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_a_t> brgCopyAKernel0;
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> brgCopyBKernel0;
|
||||
|
||||
size_t brg1VnniFactor;
|
||||
brgemmCtx brgCtxs1[MHA_BRGEMM_KERNELS_NUM];
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels1[MHA_BRGEMM_KERNELS_NUM];
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> brgCopyBKernel1;
|
||||
|
||||
std::unique_ptr<jit_uni_mul_add_softmax_kernel> mulAddSoftmaxKernel;
|
||||
std::unique_ptr<jit_uni_convert_reorder_kernel> convertReorderKernel;
|
||||
std::unique_ptr<jit_uni_convert_transpose_kernel> convertTransposeKernel;
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
@@ -88,6 +88,7 @@
|
||||
#include "nodes/priorbox.h"
|
||||
#include "nodes/priorbox_clustered.h"
|
||||
#include "nodes/eye.h"
|
||||
#include "nodes/mha.h"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
@@ -188,6 +189,7 @@ Node::NodesFactory::NodesFactory()
|
||||
INTEL_CPU_NODE(PriorBox, Type::PriorBox);
|
||||
INTEL_CPU_NODE(PriorBoxClustered, Type::PriorBoxClustered);
|
||||
INTEL_CPU_NODE(Eye, Type::Eye);
|
||||
INTEL_CPU_NODE(MHA, Type::MHA);
|
||||
}
|
||||
|
||||
#undef INTEL_CPU_NODE
|
||||
|
||||
@@ -86,6 +86,7 @@
|
||||
#include <transformations/op_conversions/convert_roi_align_v9_to_v3.hpp>
|
||||
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
|
||||
#include <transformations/op_conversions/softsign_decomposition.hpp>
|
||||
#include "ngraph_transformations/mha_fusion.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
@@ -117,6 +118,7 @@
|
||||
#include "nodes/mvn.h"
|
||||
#include "nodes/fake_quantize.h"
|
||||
#include "nodes/normalize.h"
|
||||
#include "nodes/mha.h"
|
||||
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"
|
||||
#include "ngraph_transformations/move_eltwise_up_data_movement.hpp"
|
||||
#include "transformations/smart_reshape/smart_reshape.hpp"
|
||||
@@ -248,7 +250,7 @@ Engine::~Engine() {
|
||||
executorManager()->clear("CPUCallbackExecutor");
|
||||
}
|
||||
|
||||
static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT,
|
||||
static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT, const bool _enableBF16,
|
||||
const bool _enableSnippets, const bool isLegacyApi) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.set_per_pass_validation(false);
|
||||
@@ -598,6 +600,34 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
|
||||
});
|
||||
|
||||
postLPTPassManager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
auto isMHASupported = [_enableBF16](const std::shared_ptr<const ov::Node>& n) -> bool {
|
||||
std::string errorMessage;
|
||||
|
||||
if (!node::MHA::isSupportedOperation(n, errorMessage))
|
||||
return true;
|
||||
|
||||
// Implementation calls AMX BF16 brgemm only for tensors with K and N aligned on 2, otherwise fallbacks on vector impl
|
||||
// Vector madd BF16 instruction on SPR has reduced performance on HW level, which results in overall perf degradation
|
||||
size_t bf16Factor = 2;
|
||||
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_bf16) &&
|
||||
(n->get_input_element_type(0) == element::bf16 || (n->get_input_element_type(0) == element::f32 && _enableBF16)) &&
|
||||
(n->get_input_shape(0)[3] % bf16Factor != 0 || n->get_input_shape(1)[1] % bf16Factor != 0 || n->get_input_shape(3)[3] % bf16Factor != 0)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// Snippets may brake MHA patterns so the fusion has to performed before
|
||||
postLPTPassManager.register_pass<MHAFusion>();
|
||||
postLPTPassManager.get_pass_config()->set_callback<MHAFusion>(isMHASupported);
|
||||
postLPTPassManager.register_pass<MHAFusion2>();
|
||||
postLPTPassManager.get_pass_config()->set_callback<MHAFusion2>(isMHASupported);
|
||||
postLPTPassManager.register_pass<MHAQuantFusion>();
|
||||
postLPTPassManager.get_pass_config()->set_callback<MHAQuantFusion>(isMHASupported);
|
||||
postLPTPassManager.register_pass<MHAQuantFusion2>();
|
||||
postLPTPassManager.get_pass_config()->set_callback<MHAQuantFusion2>(isMHASupported);
|
||||
postLPTPassManager.run_passes(nGraphFunc);
|
||||
|
||||
if (!useLpt && _enableSnippets && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
|
||||
@@ -629,9 +659,9 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
|
||||
}
|
||||
}
|
||||
|
||||
static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableSnippets, const bool isLegacyApi) {
|
||||
static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableBF16, const bool _enableSnippets, const bool isLegacyApi) {
|
||||
auto nGraphFunc = clonedNetwork.getFunction();
|
||||
TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableSnippets, isLegacyApi);
|
||||
TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableBF16, _enableSnippets, isLegacyApi);
|
||||
ConvertToCPUSpecificOpset(nGraphFunc);
|
||||
}
|
||||
|
||||
@@ -772,7 +802,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
|
||||
|| engConfig.enableDynamicBatch;
|
||||
const bool enableSnippets = !(enableModelCache || enableDynamicBatch || enableBF16);
|
||||
auto nGraphFunc = clonedNetwork.getFunction();
|
||||
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableSnippets, isLegacyAPI());
|
||||
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableBF16, enableSnippets, isLegacyAPI());
|
||||
|
||||
// need to check that all outputs have static shapes
|
||||
// checking that all inputs have static shapes is performed in the common part
|
||||
@@ -1020,7 +1050,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma
|
||||
|| Config::LPTransformsMode::On == engConfig.lpTransformsMode /* or already enabled */;
|
||||
const bool enableSnippets = !(conf.cache_dir.empty() || conf.enableDynamicBatch || (conf.enforceBF16
|
||||
&& dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)));
|
||||
Transformation(clonedNetwork, enableLPT, enableSnippets, isLegacyAPI());
|
||||
Transformation(clonedNetwork, enableLPT, conf.enforceBF16, enableSnippets, isLegacyAPI());
|
||||
auto ops = clonnedFunction->get_ordered_ops();
|
||||
|
||||
//Mark removed nodes as supported
|
||||
|
||||
Reference in New Issue
Block a user