This reverts commit a61a9be495
.
This commit is contained in:
parent
2ee8b29f64
commit
a5e33f10ff
@ -214,7 +214,6 @@ static const TypeToNameMap& get_type_to_name_tbl() {
|
||||
{ "Unique", Type::Unique},
|
||||
{ "Ngram", Type::Ngram},
|
||||
{ "ScaledDotProductAttention", Type::ScaledDotProductAttention},
|
||||
{ "RoPE", Type::RoPE},
|
||||
};
|
||||
return type_to_name_tbl;
|
||||
}
|
||||
@ -329,7 +328,6 @@ std::string NameFromType(const Type type) {
|
||||
CASE(Unique);
|
||||
CASE(Ngram);
|
||||
CASE(ScaledDotProductAttention);
|
||||
CASE(RoPE);
|
||||
CASE(Unknown);
|
||||
}
|
||||
#undef CASE
|
||||
|
@ -114,7 +114,6 @@ enum class Type {
|
||||
Unique,
|
||||
Ngram,
|
||||
ScaledDotProductAttention,
|
||||
RoPE,
|
||||
};
|
||||
|
||||
enum class Algorithm {
|
||||
|
@ -1,201 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "rope.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cpu/x64/cpu_isa_traits.hpp>
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
#include <shape_inference/shape_inference_internal_dyn.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/bfloat16.hpp"
|
||||
#include "common/cpu_memcpy.h"
|
||||
#include "utils/plain_tensor.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace node {
|
||||
|
||||
RoPE::RoPE(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
|
||||
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
|
||||
std::string errorMessage;
|
||||
if (!isSupportedOperation(op, errorMessage)) {
|
||||
OPENVINO_THROW("CPU: " + errorMessage);
|
||||
}
|
||||
|
||||
const auto node = std::dynamic_pointer_cast<const RoPENode>(op);
|
||||
m_config = node->get_config();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
|
||||
void execute(dnnl::stream strm,
|
||||
const RoPENode::Config& config,
|
||||
const std::vector<MemoryPtr>& inputs,
|
||||
const std::vector<MemoryPtr>& outputs) override {
|
||||
ov::intel_cpu::PlainTensor<T> t_src(inputs[0]);
|
||||
ov::intel_cpu::PlainTensor<float> t_cos(inputs[1]);
|
||||
ov::intel_cpu::PlainTensor<float> t_sin(inputs[2]);
|
||||
ov::intel_cpu::PlainTensor<T> t_dst(outputs[0]);
|
||||
ov::intel_cpu::PlainTensor<int32_t> gather;
|
||||
|
||||
if (config.slice_stop - config.slice_start > 0) {
|
||||
t_src = t_src.slice(3, config.slice_start, config.slice_stop);
|
||||
}
|
||||
if (config.input_trans0213) {
|
||||
t_src = t_src.permute({0, 2, 1, 3});
|
||||
}
|
||||
if (config.gather_position_arg_id > 0) {
|
||||
gather.reset(inputs[config.gather_position_arg_id]);
|
||||
}
|
||||
|
||||
auto batch_size = t_src.size(0);
|
||||
auto head_cnt = t_src.size(1);
|
||||
auto seq_len = t_src.size(2);
|
||||
auto feature_size = t_src.size(3);
|
||||
|
||||
auto rotary_dims = config.rotary_ndims;
|
||||
auto half_rotary_dims = rotary_dims / 2;
|
||||
|
||||
parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
|
||||
auto cos_pos = p;
|
||||
if (gather) {
|
||||
if (gather.m_rank == 4)
|
||||
cos_pos = gather.at({b, h, p, 0}, true);
|
||||
else
|
||||
cos_pos = gather.at({b, p}, true);
|
||||
}
|
||||
auto* src = &t_src.at({b, h, p, 0});
|
||||
auto* cos = &t_cos.at({b, h, cos_pos, 0}, true);
|
||||
auto* sin = &t_sin.at({b, h, cos_pos, 0}, true);
|
||||
auto* dst = &t_dst.at({b, h, p, 0});
|
||||
|
||||
size_t i = 0;
|
||||
for (; i < half_rotary_dims; i++) {
|
||||
dst[i] = cos[i] * src[i] + sin[i] * (-src[i + half_rotary_dims]);
|
||||
}
|
||||
for (; i < rotary_dims; i++) {
|
||||
dst[i] = cos[i] * src[i] + sin[i] * (src[i - half_rotary_dims]);
|
||||
}
|
||||
for (; i < feature_size; i++) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor {
|
||||
void execute(dnnl::stream strm,
|
||||
const RoPENode::Config& config,
|
||||
const std::vector<MemoryPtr>& inputs,
|
||||
const std::vector<MemoryPtr>& outputs) override {
|
||||
ov::intel_cpu::PlainTensor<T> t_src(inputs[0]);
|
||||
ov::intel_cpu::PlainTensor<float> t_sin_cos(inputs[1]);
|
||||
ov::intel_cpu::PlainTensor<T> t_dst(outputs[0]);
|
||||
|
||||
auto batch_size = t_src.size(0);
|
||||
auto seq_len = t_src.size(1);
|
||||
auto head_cnt = t_src.size(2);
|
||||
auto head_dims = t_src.size(3);
|
||||
|
||||
auto rotary_dims = config.rotary_ndims;
|
||||
auto half_rotary_dims = rotary_dims / 2;
|
||||
parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) {
|
||||
auto* x = &t_src.at({b, p, h, 0});
|
||||
float* sin = &t_sin_cos.at({b, p, 0}, true);
|
||||
float* cos = &t_sin_cos.at({b, p, half_rotary_dims}, true);
|
||||
auto* dst = &t_dst.at({b, h, p, 0});
|
||||
|
||||
size_t i = 0;
|
||||
for (size_t j = 0; i < rotary_dims; i += 2, j++) {
|
||||
dst[i] = cos[j] * x[i] - sin[j] * x[i + 1];
|
||||
dst[i + 1] = cos[j] * x[i + 1] + sin[j] * x[i];
|
||||
}
|
||||
for (; i < head_dims; i++) {
|
||||
dst[i] = x[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
void RoPE::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
auto srcPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
|
||||
auto rtPrecision = srcPrecision;
|
||||
auto CosSinPrecision = ov::element::f32;
|
||||
|
||||
if (m_config.is_interleaved) {
|
||||
OPENVINO_ASSERT(m_config.input_trans0213 == false);
|
||||
OPENVINO_ASSERT(m_config.slice_start == 0);
|
||||
OPENVINO_ASSERT(m_config.slice_stop == 0);
|
||||
OPENVINO_ASSERT(m_config.gather_position_arg_id == 0);
|
||||
if (rtPrecision == ov::element::bf16) {
|
||||
m_executor = std::make_shared<RoPEExecutorInterleaved<ov::bfloat16>>();
|
||||
} else {
|
||||
m_executor = std::make_shared<RoPEExecutorInterleaved<float>>();
|
||||
rtPrecision = ov::element::f32;
|
||||
}
|
||||
} else {
|
||||
if (rtPrecision == ov::element::bf16) {
|
||||
m_executor = std::make_shared<RoPEExecutorRotateHalf<ov::bfloat16>>();
|
||||
} else {
|
||||
m_executor = std::make_shared<RoPEExecutorRotateHalf<float>>();
|
||||
rtPrecision = ov::element::f32;
|
||||
}
|
||||
}
|
||||
|
||||
// initialize input ports
|
||||
std::vector<PortConfigurator> inPortConfigs;
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(0), false, -1);
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, CosSinPrecision, getInputShapeAtPort(1), false, -1);
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, CosSinPrecision, getInputShapeAtPort(2), false, -1);
|
||||
if (m_config.gather_position_arg_id > 0) {
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp,
|
||||
ov::element::i32,
|
||||
getInputShapeAtPort(m_config.gather_position_arg_id),
|
||||
false,
|
||||
-1);
|
||||
}
|
||||
|
||||
// initialize output port
|
||||
std::vector<PortConfigurator> outPortConfigs;
|
||||
outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1);
|
||||
|
||||
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
|
||||
}
|
||||
|
||||
void RoPE::execute(dnnl::stream strm) {
|
||||
std::vector<MemoryPtr> inputs(getParentEdges().size()), outputs(getChildEdges().size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
inputs[i] = getParentEdgeAt(i)->getMemoryPtr();
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
outputs[i] = getChildEdgeAt(i)->getMemoryPtr();
|
||||
}
|
||||
m_executor->execute(strm, m_config, inputs, outputs);
|
||||
}
|
||||
|
||||
bool RoPE::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
const auto node = std::dynamic_pointer_cast<const RoPENode>(op);
|
||||
if (!node) {
|
||||
errorMessage = "Only RoPENode operation is supported";
|
||||
return false;
|
||||
}
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -1,54 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#include <ie_common.h>
|
||||
#include <node.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "transformations/cpu_opset/common/op/rope.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace node {
|
||||
|
||||
class RoPE : public Node {
|
||||
public:
|
||||
RoPE(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
|
||||
|
||||
void getSupportedDescriptors() override {}
|
||||
bool created() const override {
|
||||
return getType() == Type::RoPE;
|
||||
}
|
||||
bool needPrepareParams() const override {
|
||||
return false;
|
||||
};
|
||||
void executeDynamicImpl(dnnl::stream strm) override {
|
||||
execute(strm);
|
||||
}
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void execute(dnnl::stream strm) override;
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
struct Executor {
|
||||
virtual void execute(dnnl::stream strm,
|
||||
const RoPENode::Config& config,
|
||||
const std::vector<MemoryPtr>& inputs,
|
||||
const std::vector<MemoryPtr>& outputs) = 0;
|
||||
};
|
||||
template <typename T>
|
||||
struct RoPEExecutorRotateHalf;
|
||||
template <typename T>
|
||||
struct RoPEExecutorInterleaved;
|
||||
RoPENode::Config m_config;
|
||||
std::shared_ptr<Executor> m_executor;
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -94,7 +94,6 @@
|
||||
#include "nodes/unique.hpp"
|
||||
#include "nodes/ngram.h"
|
||||
#include "nodes/scaled_attn.h"
|
||||
#include "nodes/rope.h"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
@ -182,7 +181,6 @@ Node::NodesFactory::NodesFactory()
|
||||
INTEL_CPU_NODE(Eye, Type::Eye);
|
||||
INTEL_CPU_NODE(Unique, Type::Unique);
|
||||
INTEL_CPU_NODE(Ngram, Type::Ngram);
|
||||
INTEL_CPU_NODE(RoPE, Type::RoPE);
|
||||
INTEL_CPU_NODE(Interpolate, Type::Interpolate);
|
||||
INTEL_CPU_NODE(RandomUniform, Type::RandomUniform);
|
||||
INTEL_CPU_NODE(Reduce, Type::Reduce);
|
||||
|
@ -1,50 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "rope.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "transformations/itt.hpp"
|
||||
|
||||
ov::intel_cpu::RoPENode::RoPENode(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> ov::intel_cpu::RoPENode::clone_with_new_inputs(
|
||||
const ngraph::OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(RoPENode_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<ov::intel_cpu::RoPENode>(new_args, m_config);
|
||||
}
|
||||
|
||||
void ov::intel_cpu::RoPENode::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(RoPENode_validate_and_infer_types);
|
||||
auto input_pshape = get_input_partial_shape(0);
|
||||
auto input_slice_size = m_config.slice_stop - m_config.slice_start;
|
||||
if (input_slice_size > 0) {
|
||||
input_pshape[3] = input_slice_size;
|
||||
}
|
||||
if (m_config.input_trans0213) {
|
||||
// transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens before RoPE
|
||||
std::swap(input_pshape[2], input_pshape[1]);
|
||||
} else if (m_config.is_interleaved) {
|
||||
// transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens after RoPE
|
||||
std::swap(input_pshape[2], input_pshape[1]);
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), input_pshape);
|
||||
}
|
||||
|
||||
bool ov::intel_cpu::RoPENode::visit_attributes(ngraph::AttributeVisitor& visitor) {
|
||||
INTERNAL_OP_SCOPE(RoPENode_visit_attributes);
|
||||
visitor.start_structure("config");
|
||||
visitor.on_attribute("slice_start", m_config.slice_start);
|
||||
visitor.on_attribute("slice_stop", m_config.slice_stop);
|
||||
visitor.on_attribute("input_trans0213", m_config.input_trans0213);
|
||||
visitor.on_attribute("is_interleaved", m_config.is_interleaved);
|
||||
visitor.on_attribute("rotary_ndims", m_config.rotary_ndims);
|
||||
visitor.on_attribute("gather_position_arg_id", m_config.gather_position_arg_id);
|
||||
visitor.finish_structure();
|
||||
return true;
|
||||
}
|
@ -1,98 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
#include <ngraph/op/op.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
/**
|
||||
* The operation performs rotary positional embedding operation described in:
|
||||
* ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING by Jianlin Su
|
||||
*
|
||||
* the core computation is application of 2x2 rotation matrix on basis
|
||||
* of pair of input states x[i0] & x[i1] to get the rotary embedded pair of output
|
||||
* states y[i0] and y[i1]:
|
||||
*
|
||||
* suppose dimension of hidden states (of each attention head) is N and d of which
|
||||
* are to be embedded (d <= N), non-embedded parts are copied into output.
|
||||
*
|
||||
* for i in 0...(d/2)
|
||||
* if (is_interleaved) {
|
||||
* // interleaving style of indexing
|
||||
* i0 = i*2
|
||||
* i1 = i*2 + 1
|
||||
* } else {
|
||||
* // rotate-half style of indexing
|
||||
* i0 = i
|
||||
* i1 = i + (d/2)
|
||||
* }
|
||||
* y[i0] = x[i0]*cos(m * xita[i]) - x[i1]*sin(m * xita[i])
|
||||
* y[i1] = x[i1]*cos(m * xita[i]) + x[i0]*sin(m * xita[i])
|
||||
* Note: m is token position of current input
|
||||
*
|
||||
* based on configuration, additional preprocessing steps maybe performed as well:
|
||||
* - slicing last dimension of input tensor
|
||||
* (when q/k/v are merged and only q or k part is to be extracted & embedded)
|
||||
* - transpose input tensor
|
||||
* (when q/k comes from fullyconnect has layout [batch, seq_len, head_cnt, head_dim]
|
||||
* but output of RoPE is required to be of layout [batch, head_cnt, seq_length, head_dims])
|
||||
* - gather sin/cos from input tensor 2&3 using position index tensor passed through input 4
|
||||
*
|
||||
* Inputs:
|
||||
* 1. Input hidden states tensor of type T1 - shape:
|
||||
* [batch, seq_length, head_cnt, head_dims] when input_trans0213 == false OR
|
||||
* [batch, head_cnt, seq_length, head_dims] when input_trans0213 == true
|
||||
* 2. pre-calculated cos(m*xita[n]) tensor of type T2 - shape [1, 1, max_position_embeddings, d].
|
||||
* 3. pre-calculated sin(m*xita[n]) tensor of type T2 - shape [1, 1, max_position_embeddings, d].
|
||||
* input 3 is combined with 2 when is_interleaved is true.
|
||||
* 4. postion index tensor of type T3 - shape [batch, 1, seq_length, 1 or d] OR [batch, seq_length] optional
|
||||
* Outputs:
|
||||
* 1. New embedding tensor of type T1 and of shape [batch, head_cnt, seq_length, head_dims]
|
||||
* Types:
|
||||
* T1 - FP32 or BF16
|
||||
* T2 - FP32
|
||||
* T3 - I32
|
||||
*/
|
||||
class RoPENode : public ngraph::op::Op {
|
||||
public:
|
||||
OPENVINO_OP("RoPE", "cpu_plugin_opset");
|
||||
|
||||
RoPENode() = default;
|
||||
|
||||
struct Config {
|
||||
size_t slice_start = 0; // slice inner-most dimensions of input
|
||||
size_t slice_stop = 0;
|
||||
bool input_trans0213 = false; // transpose input dim 1&2
|
||||
bool is_interleaved = false; // interleaved mode, implies trans0213 happens after RoPE
|
||||
size_t rotary_ndims = 0; // dimensions to be embedded (d in the description)
|
||||
int gather_position_arg_id =
|
||||
0; // arg id of position tensor, ==3 when gather from sin/cos inputs according to position is required
|
||||
};
|
||||
|
||||
RoPENode(const OutputVector& args, const Config& cfg);
|
||||
|
||||
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
|
||||
|
||||
const Config& get_config() const {
|
||||
return m_config;
|
||||
}
|
||||
|
||||
Config& get_config() {
|
||||
return m_config;
|
||||
}
|
||||
|
||||
private:
|
||||
Config m_config;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -1,435 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "rope_fusion.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <openvino/core/rt_info.hpp>
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include <openvino/opsets/opset6.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ov_ops/type_relaxed.hpp"
|
||||
#include "transformations/cpu_opset/common/op/rope.hpp"
|
||||
#include "utils/gen_pattern.hpp"
|
||||
|
||||
using namespace ov::gen_pattern;
|
||||
|
||||
ov::intel_cpu::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() {
|
||||
MATCHER_SCOPE(RoPEFusionGPTNEOX);
|
||||
|
||||
// rope pattern matching triggers a little design flaw:
|
||||
// y1 = mul(x, cos)
|
||||
// y2 = mul(x, sin)
|
||||
// y = add(y1, y2)
|
||||
// there is a chance that in 'y1' branch, pattern x is mapped to actual value of cos (mul is commutable)
|
||||
// this leads to the matching failure of 'y2' branch, because cos didn't appear in that
|
||||
// branch.
|
||||
// so here we use a WA, only match the path of rotate_hal(x)*sin and check the x*cos path
|
||||
// in the callback
|
||||
auto x = makePattern(ov::Rank(4));
|
||||
auto x_or_cos1 = makePattern(ov::Rank(4));
|
||||
auto x_or_cos2 = makePattern(ov::Rank(4));
|
||||
auto t_sin = makePattern(ov::Rank(4));
|
||||
|
||||
x->set_friendly_name("x");
|
||||
|
||||
auto half_ndims = Symbol("half_ndims");
|
||||
auto int32_max = std::numeric_limits<std::int32_t>::max();
|
||||
|
||||
// rotate half : [-x2, x1]
|
||||
auto x2 = GenSlice(x, half_ndims, int32_max, 1, 3);
|
||||
auto x2neg = makePattern<opset1::Multiply>({x2, -1.0f}, {{"auto_broadcast", "numpy"}});
|
||||
auto x1 = GenSlice(x, 0, half_ndims, 1, 3);
|
||||
auto x_rotate_half = makePattern<opset1::Concat>({x2neg, x1}, {{"axis", -1}});
|
||||
|
||||
auto mul_cos = makePattern<opset1::Multiply>({x_or_cos1, x_or_cos2}, {{"auto_broadcast", "numpy"}});
|
||||
auto mul_sin = makePattern<opset1::Multiply>({x_rotate_half, t_sin}, {{"auto_broadcast", "numpy"}});
|
||||
|
||||
// [x1, x2]*cos + [-x2, x1]*sin
|
||||
auto result = makePattern<opset1::Add>({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}});
|
||||
|
||||
matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
PatternValidator validator(m);
|
||||
if (!validator) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto root = m.get_match_root();
|
||||
|
||||
// check mul(x, cos) exists
|
||||
Output<Node> v_cos;
|
||||
if (pattern_map.at(x_or_cos1) == pattern_map.at(x)) {
|
||||
v_cos = pattern_map.at(x_or_cos2);
|
||||
} else if (pattern_map.at(x_or_cos2) == pattern_map.at(x)) {
|
||||
v_cos = pattern_map.at(x_or_cos1);
|
||||
} else {
|
||||
// not a RoPE
|
||||
return false;
|
||||
}
|
||||
|
||||
RoPENode::Config config;
|
||||
OutputVector new_args;
|
||||
config.rotary_ndims = 2 * validator["half_ndims"];
|
||||
|
||||
new_args.push_back(pattern_map.at(x));
|
||||
new_args.push_back(v_cos);
|
||||
new_args.push_back(pattern_map.at(t_sin));
|
||||
|
||||
auto old_node = root;
|
||||
auto new_node = std::make_shared<RoPENode>(new_args, config);
|
||||
new_node->set_friendly_name(old_node->get_friendly_name());
|
||||
ov::replace_node(old_node, new_node);
|
||||
|
||||
// this new node may match following additional matchers
|
||||
register_new_node(new_node);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(result, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::intel_cpu::RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() {
|
||||
MATCHER_SCOPE(RoPEFusionCosSinPreprocess);
|
||||
|
||||
auto cos_const = makePattern<opset1::Constant>({}); // "f32[1,1,2048,24]"
|
||||
auto sin_const = makePattern<opset1::Constant>({}); // "f32[1,1,2048,24]"
|
||||
|
||||
auto node_batch_size = makePattern("i32[1]");
|
||||
auto tile_batch = makePattern("i32[1]");
|
||||
auto gather_positions = makePattern("i32[?,?,?,?]");
|
||||
|
||||
auto prepare_cos_sin_gptneox = [&](std::shared_ptr<Node> const_tab) {
|
||||
auto slice1 = makePattern<opset1::StridedSlice>({const_tab, {0}, node_batch_size, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
return makePattern<opset6::GatherElements>({slice1, gather_positions}, {{"axis", 2}});
|
||||
};
|
||||
|
||||
auto seq_len = makePattern("i32[1]");
|
||||
auto gather_positions_2d = makePattern("i32[?,?]");
|
||||
|
||||
auto head_dims = Symbol("head_dims");
|
||||
auto prepare_cos_sin_llama = [&](std::shared_ptr<Node> const_tab) {
|
||||
auto ScatterUpdate = makePattern<opset3::ScatterUpdate>({{0, 0, 0}, 2, seq_len, 0});
|
||||
auto slice_Slice = makePattern<opset1::StridedSlice>({const_tab, {0, 0, 0}, ScatterUpdate, {1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 0}},
|
||||
{"end_mask", {1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto squeeze = makePattern<opset1::Reshape>({slice_Slice, {-1, head_dims}});
|
||||
auto index_Gather = makePattern<opset8::Gather>({squeeze, gather_positions_2d, 0}, {{"batch_dims", 0}});
|
||||
auto unsqueeze = makePattern<opset1::Reshape>({index_Gather, {1, 1, -1, head_dims}});
|
||||
return unsqueeze;
|
||||
};
|
||||
|
||||
auto cos_tab = prepare_cos_sin_gptneox(cos_const) | prepare_cos_sin_llama(cos_const);
|
||||
auto sin_tab = prepare_cos_sin_gptneox(sin_const) | prepare_cos_sin_llama(sin_const);
|
||||
|
||||
auto x = makePattern(ov::Rank(4));
|
||||
auto rope = makePattern<RoPENode>({x, cos_tab, sin_tab});
|
||||
|
||||
matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
PatternValidator validator(m);
|
||||
if (!validator) {
|
||||
return false;
|
||||
}
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto root = m.get_match_root();
|
||||
auto rope_node = as_type_ptr<RoPENode>(pattern_map.at(rope).get_node_shared_ptr());
|
||||
if (!rope_node)
|
||||
return false;
|
||||
|
||||
if (pattern_map.count(cos_const)) {
|
||||
rope_node->set_argument(1, pattern_map.at(cos_const));
|
||||
}
|
||||
if (pattern_map.count(sin_const)) {
|
||||
rope_node->set_argument(2, pattern_map.at(sin_const));
|
||||
}
|
||||
|
||||
auto& config = rope_node->get_config();
|
||||
if (pattern_map.count(gather_positions)) {
|
||||
auto arg_id = rope_node->get_input_size();
|
||||
rope_node->set_argument(arg_id, pattern_map.at(gather_positions));
|
||||
config.gather_position_arg_id = arg_id;
|
||||
} else if (pattern_map.count(gather_positions_2d)) {
|
||||
auto arg_id = rope_node->get_input_size();
|
||||
rope_node->set_argument(arg_id, pattern_map.at(gather_positions_2d));
|
||||
config.gather_position_arg_id = arg_id;
|
||||
}
|
||||
rope_node->validate_and_infer_types();
|
||||
register_new_node(rope_node);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(rope, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
// only a fraction of head_size is rotary-embedded
|
||||
ov::intel_cpu::RoPEFusionIOSlicing::RoPEFusionIOSlicing() {
|
||||
MATCHER_SCOPE(RoPEFusionIOSlicing);
|
||||
auto int32_max = std::numeric_limits<std::int32_t>::max();
|
||||
auto data = makePattern(ov::Rank(4));
|
||||
|
||||
auto ndims = Symbol("ndims");
|
||||
auto x = GenSlice(data, 0, ndims, 1, 3);
|
||||
auto y = GenSlice(data, ndims, int32_max, 1, 3);
|
||||
auto x_emb = makePattern<RoPENode>({x, {}, {}}) | makePattern<RoPENode>({x, {}, {}, {}});
|
||||
auto result = makePattern<opset1::Concat>({x_emb, y}, {{"axis", -1}});
|
||||
|
||||
matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto root = m.get_match_root();
|
||||
|
||||
auto rope_node = as_type_ptr<RoPENode>(root->input_value(0).get_node_shared_ptr());
|
||||
if (!rope_node)
|
||||
return false;
|
||||
|
||||
PatternValidator validator(m);
|
||||
if (!validator) {
|
||||
return false;
|
||||
}
|
||||
auto ndims = validator["ndims"];
|
||||
|
||||
auto& config = rope_node->get_config();
|
||||
if (config.rotary_ndims != ndims)
|
||||
return false;
|
||||
|
||||
// remove slice & concat
|
||||
rope_node->set_argument(0, pattern_map.at(data));
|
||||
rope_node->set_friendly_name(root->get_friendly_name());
|
||||
ov::replace_node(root, rope_node);
|
||||
|
||||
rope_node->validate_and_infer_types();
|
||||
register_new_node(rope_node);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(result, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::intel_cpu::RoPEFusionPreprocess::RoPEFusionPreprocess() {
|
||||
MATCHER_SCOPE(RoPEFusionPreprocess);
|
||||
|
||||
// gptneox-preprocess of input data
|
||||
auto input_to_slice = makePattern(ov::Rank(4));
|
||||
auto input_to_trans = makePattern(ov::Rank(4)); // no need to slice from 3S
|
||||
|
||||
// in some model qkv prejection is combined and
|
||||
// needs to be sliced before RoPE
|
||||
auto slice_start = Symbol("slice_start");
|
||||
auto slice_stop = Symbol("slice_stop");
|
||||
auto input_slice = GenSlice(input_to_slice, slice_start, slice_stop, 1, 3);
|
||||
|
||||
// some model will transpose from [B,L,H,S] to [B,H,L,S] before RoPE
|
||||
auto x = makePattern<opset1::Transpose>({input_slice | input_to_trans, {0, 2, 1, 3}});
|
||||
auto result = makePattern<RoPENode>({x, {}, {}}) | makePattern<RoPENode>({x, {}, {}, {}});
|
||||
|
||||
matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
PatternValidator validator(m);
|
||||
if (!validator) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto root = m.get_match_root();
|
||||
auto rope_node = as_type_ptr<RoPENode>(root);
|
||||
if (!rope_node)
|
||||
return false;
|
||||
|
||||
auto& config = rope_node->get_config();
|
||||
|
||||
if (pattern_map.count(input_to_slice)) {
|
||||
config.slice_start = validator["slice_start"];
|
||||
config.slice_stop = validator["slice_stop"];
|
||||
config.input_trans0213 = true;
|
||||
rope_node->set_argument(0, pattern_map.at(input_to_slice));
|
||||
} else if (pattern_map.count(input_to_trans)) {
|
||||
config.input_trans0213 = true;
|
||||
rope_node->set_argument(0, pattern_map.at(input_to_trans));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
rope_node->validate_and_infer_types();
|
||||
register_new_node(rope_node);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(result, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
// remove stridedslice from 0 to int32_max with stride 1
|
||||
ov::intel_cpu::EliminateStridedSlice::EliminateStridedSlice() {
|
||||
MATCHER_SCOPE(EliminateStridedSlice);
|
||||
auto data = ov::pass::pattern::any_input(ngraph::pattern::has_static_rank());
|
||||
auto begin = ov::pass::pattern::wrap_type<opset1::Constant>(ngraph::pattern::type_matches(ov::element::i32));
|
||||
auto end = ov::pass::pattern::wrap_type<opset1::Constant>(ngraph::pattern::type_matches(ov::element::i32));
|
||||
auto stride = ov::pass::pattern::wrap_type<opset1::Constant>(ngraph::pattern::type_matches(ov::element::i32));
|
||||
|
||||
auto strided_slice =
|
||||
ov::pass::pattern::wrap_type<opset1::StridedSlice>({data, begin, end, stride}, [](const Output<Node>& value) {
|
||||
auto s1 = as_type_ptr<opset1::StridedSlice>(value.get_node_shared_ptr());
|
||||
if (!s1->get_new_axis_mask().empty() || !s1->get_shrink_axis_mask().empty() ||
|
||||
!s1->get_ellipsis_mask().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto inputs = s1->input_values();
|
||||
|
||||
auto begin = as_type_ptr<opset1::Constant>(inputs[1].get_node_shared_ptr());
|
||||
auto end = as_type_ptr<opset1::Constant>(inputs[2].get_node_shared_ptr());
|
||||
// stride is all 1
|
||||
auto stride = as_type_ptr<opset1::Constant>(inputs[3].get_node_shared_ptr());
|
||||
|
||||
if (!begin)
|
||||
return false;
|
||||
if (!end)
|
||||
return false;
|
||||
if (!stride)
|
||||
return false;
|
||||
|
||||
auto v_stride = stride->cast_vector<int32_t>();
|
||||
for (auto& v : v_stride) {
|
||||
if (v != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
auto v_begin = begin->cast_vector<int32_t>();
|
||||
auto v_end = end->cast_vector<int32_t>();
|
||||
|
||||
auto& begin_mask = s1->get_begin_mask();
|
||||
auto& end_mask = s1->get_end_mask();
|
||||
auto mask_size = begin_mask.size();
|
||||
if (begin_mask.size() != end_mask.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto int32_max = std::numeric_limits<std::int32_t>::max();
|
||||
for (size_t i = 0; i < mask_size; i++) {
|
||||
if (begin_mask[i] != end_mask[i])
|
||||
return false;
|
||||
// all valid [begin, end] are [0, int32_max]
|
||||
if (begin_mask[i] == 0 && end_mask[i] == 0) {
|
||||
if (v_begin[i] != 0 || v_end[i] != int32_max)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto root = m.get_match_root();
|
||||
return replace_output_update_name(root->output(0), root->input_value(0));
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(strided_slice, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
|
||||
MATCHER_SCOPE(RoPEFusionGPTJ);
|
||||
|
||||
auto int32_max = std::numeric_limits<std::int32_t>::max();
|
||||
auto ndims = Symbol("ndims");
|
||||
|
||||
auto view_Reshape = makePattern(ov::Rank(4));
|
||||
|
||||
// view_Reshape : B,L,H,S
|
||||
auto slice_Slice_965 = GenSlice(view_Reshape, 0, ndims, 1, 3);
|
||||
|
||||
auto gather_sin_cos = makePattern("f32");
|
||||
|
||||
auto varsplit = makePattern<opset1::VariadicSplit>({gather_sin_cos, -1, {ndims / 2, -1}});
|
||||
varsplit->set_output_size(2);
|
||||
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}});
|
||||
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}});
|
||||
// repeate cos/sin table
|
||||
auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) {
|
||||
const auto& vec = node.get_vector<int32_t>();
|
||||
int32_t v = 0;
|
||||
for (size_t i = 0; i < vec.size(); i += 2, v++) {
|
||||
if (vec[i] != v || vec[i + 1] != v)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
auto repeat_interleave_sin = makePattern<opset8::Gather>({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}});
|
||||
auto repeat_interleave_cos = makePattern<opset8::Gather>({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}});
|
||||
|
||||
auto t_cos = makePattern(ov::Rank(4));
|
||||
auto t_sin = makePattern(ov::Rank(4));
|
||||
|
||||
// x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2])
|
||||
auto slice_Slice_1174 = GenSlice(slice_Slice_965, 1, int32_max, 2, 3);
|
||||
|
||||
auto neg_Multiply_1177 = makePattern<opset1::Multiply>({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}});
|
||||
auto Unsqueeze_65524 = makePattern<opset1::Unsqueeze>({neg_Multiply_1177, -1});
|
||||
|
||||
auto slice_Slice_1168 = GenSlice(slice_Slice_965, 0, int32_max, 2, 3);
|
||||
auto Unsqueeze_65525 = makePattern<opset1::Unsqueeze>({slice_Slice_1168, -1});
|
||||
auto stack_1182 = makePattern<opset1::Concat>({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}});
|
||||
|
||||
auto ShapeOf_169068 = makePattern<opset1::ShapeOf>({stack_1182});
|
||||
auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0);
|
||||
auto flatten_Concat_1197 = makePattern<opset1::Concat>({flatten_Slice_1194, {-1}}, {{"axis", 0}});
|
||||
auto flatten_Reshape_1198 = makePattern<opset1::Reshape>({stack_1182, flatten_Concat_1197});
|
||||
|
||||
// x*cos [B,L,H,ndims]
|
||||
auto mul_cos =
|
||||
makePattern<opset1::Multiply>({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}});
|
||||
auto mul_sin =
|
||||
makePattern<opset1::Multiply>({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});
|
||||
|
||||
// *cos + *sin
|
||||
auto rotary_emb = makePattern<opset1::Add>({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}});
|
||||
|
||||
auto slice_Slice_971 = GenSlice(view_Reshape, ndims, int32_max, 1, 3);
|
||||
auto cat_Concat_1211 = makePattern<opset1::Concat>({rotary_emb, slice_Slice_971}, {{"axis", -1}});
|
||||
auto permute_Transpose_1213 = makePattern<opset1::Transpose>({cat_Concat_1211, {0, 2, 1, 3}});
|
||||
|
||||
auto result = permute_Transpose_1213;
|
||||
|
||||
matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto root = m.get_match_root();
|
||||
PatternValidator validator(m);
|
||||
if (!validator) {
|
||||
return false;
|
||||
}
|
||||
|
||||
RoPENode::Config config;
|
||||
OutputVector new_args;
|
||||
config.rotary_ndims = validator["ndims"];
|
||||
|
||||
config.is_interleaved = true;
|
||||
|
||||
// input is [B,L,H,S]
|
||||
new_args.push_back(pattern_map.at(view_Reshape));
|
||||
// sin_cos table (gathered with positions) [1, L, 64]
|
||||
new_args.push_back(pattern_map.at(gather_sin_cos));
|
||||
new_args.push_back(pattern_map.at(gather_sin_cos));
|
||||
|
||||
auto old_node = root;
|
||||
|
||||
auto new_node = std::make_shared<RoPENode>(new_args, config);
|
||||
new_node->set_friendly_name(old_node->get_friendly_name());
|
||||
ov::replace_node(old_node, new_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(result, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
class RoPEFusionGPTNEOX : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("RoPEFusionGPTNEOX", "0");
|
||||
RoPEFusionGPTNEOX();
|
||||
};
|
||||
|
||||
class RoPEFusionGPTJ : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("RoPEFusionGPTJ", "0");
|
||||
RoPEFusionGPTJ();
|
||||
};
|
||||
|
||||
class RoPEFusionIOSlicing : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("RoPEFusionIOSlicing", "0");
|
||||
RoPEFusionIOSlicing();
|
||||
};
|
||||
|
||||
class RoPEFusionPreprocess : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("RoPEFusionPreprocess", "0");
|
||||
RoPEFusionPreprocess();
|
||||
};
|
||||
|
||||
class RoPEFusionCosSinPreprocess : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("RoPEFusionCosSinPreprocess", "0");
|
||||
RoPEFusionCosSinPreprocess();
|
||||
};
|
||||
|
||||
class EliminateStridedSlice : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("EliminateStridedSlice", "0");
|
||||
EliminateStridedSlice();
|
||||
};
|
||||
|
||||
class RoPEFusion : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("RoPEFusion", "0");
|
||||
RoPEFusion() {
|
||||
add_matcher<RoPEFusionGPTNEOX>();
|
||||
add_matcher<RoPEFusionGPTJ>();
|
||||
// optional heads & tails are fused in separate matcher pass,
|
||||
// after RoPENode has been created.
|
||||
add_matcher<RoPEFusionCosSinPreprocess>();
|
||||
add_matcher<RoPEFusionIOSlicing>();
|
||||
add_matcher<RoPEFusionPreprocess>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -83,7 +83,6 @@
|
||||
#include "transformations/smart_reshape/matmul_sr.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "utils/ngraph_transformation.hpp"
|
||||
#include "utils/print_model.hpp"
|
||||
|
||||
// LPT transformations
|
||||
#include "low_precision/add.hpp"
|
||||
@ -111,7 +110,6 @@
|
||||
#include "transformations/cpu_opset/common/pass/insert_convert_after_extension.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/rope_fusion.hpp"
|
||||
|
||||
// Snippets
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
@ -656,10 +654,6 @@ void Transformations::PostLpt() {
|
||||
|
||||
// Execute before snippets. Otherwise FQ will be converted to Subgraph
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn);
|
||||
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, EliminateStridedSlice);
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion);
|
||||
|
||||
postLPTPassManager.run_passes(model);
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,415 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/core/attribute_visitor.hpp"
|
||||
#include "openvino/core/model.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// to_code convert value into literal/constexpr/initializer_list/factory_calls in C++ source code
|
||||
inline std::string to_code(bool value) {
|
||||
return value ? "true" : "false";
|
||||
}
|
||||
inline std::string to_code(const std::string& value) {
|
||||
return std::string("\"") + value + "\"";
|
||||
}
|
||||
inline std::string to_code(const element::Type& value) {
|
||||
return std::string("element::") + value.to_string();
|
||||
}
|
||||
inline std::string to_code(const ov::Shape& value) {
|
||||
std::stringstream ss;
|
||||
ss << "ov::Shape({";
|
||||
for (auto& d : value)
|
||||
ss << d << ",";
|
||||
ss << "})";
|
||||
return ss.str();
|
||||
}
|
||||
inline std::string to_code(int value) {
|
||||
if (INT_MAX == value) {
|
||||
return "INT_MAX";
|
||||
}
|
||||
if (INT_MIN == value) {
|
||||
return "INT_MIN";
|
||||
}
|
||||
return std::to_string(value);
|
||||
}
|
||||
inline std::string to_code(int64_t value) {
|
||||
if (LLONG_MAX == value) {
|
||||
return "LLONG_MAX";
|
||||
}
|
||||
if (LLONG_MIN == value) {
|
||||
return "LLONG_MIN";
|
||||
}
|
||||
const char* suffix = "LL";
|
||||
if (value == static_cast<int64_t>(static_cast<int>(value))) {
|
||||
// save suffix since most values can be expressed as int
|
||||
// this produces more readable code
|
||||
suffix = "";
|
||||
}
|
||||
return std::to_string(value) + suffix;
|
||||
}
|
||||
inline std::string to_code(uint64_t value) {
|
||||
if (ULLONG_MAX == value) {
|
||||
return "ULLONG_MAX";
|
||||
}
|
||||
const char* suffix = "uLL";
|
||||
if (value == static_cast<uint64_t>(static_cast<int>(value))) {
|
||||
// save suffix since most values can be expressed as int
|
||||
// this produces more readable code
|
||||
suffix = "";
|
||||
}
|
||||
return std::to_string(value) + suffix;
|
||||
}
|
||||
inline std::string to_code(int8_t value) {
|
||||
return std::to_string(static_cast<int>(value));
|
||||
}
|
||||
inline std::string to_code(uint8_t value) {
|
||||
return std::to_string(static_cast<int>(value));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string to_code_float(T value) {
|
||||
if (std::isnan(value)) {
|
||||
return "NAN";
|
||||
} else if (std::isinf(value)) {
|
||||
return (value > 0 ? "INFINITY" : "-INFINITY");
|
||||
} else if (value == FLT_MIN) {
|
||||
return "FLT_MIN";
|
||||
} else if (value == -FLT_MIN) {
|
||||
return "-FLT_MIN";
|
||||
} else if (value == FLT_MAX) {
|
||||
return "FLT_MAX";
|
||||
} else if (value == -FLT_MAX) {
|
||||
return "-FLT_MAX";
|
||||
}
|
||||
auto strv = std::to_string(value);
|
||||
if (strv.find(".") == std::string::npos && strv.find("e") == std::string::npos)
|
||||
strv += ".0";
|
||||
if (std::is_same<T, float>::value)
|
||||
strv += "f";
|
||||
return strv;
|
||||
}
|
||||
|
||||
inline std::string to_code(float value) {
|
||||
return to_code_float(value);
|
||||
}
|
||||
inline std::string to_code(double value) {
|
||||
return to_code_float(value);
|
||||
}
|
||||
template <typename T>
|
||||
std::string to_code(const std::vector<T>& values, bool no_braces = false, int maxsize = 80) {
|
||||
std::stringstream ss;
|
||||
if (!no_braces)
|
||||
ss << "{";
|
||||
const char* sep = "";
|
||||
for (auto& v : values) {
|
||||
if (ss.tellp() > maxsize) {
|
||||
ss << "... (" << values.size() << " in total)";
|
||||
break;
|
||||
}
|
||||
ss << sep << to_code(v);
|
||||
sep = ",";
|
||||
}
|
||||
if (!no_braces)
|
||||
ss << "}";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename T = void>
|
||||
std::string to_code(std::shared_ptr<ov::op::v0::Constant> constop) {
|
||||
bool no_braces = (constop->get_shape().size() == 0);
|
||||
auto ele_type = constop->get_element_type();
|
||||
if (ele_type == element::Type_t::f32) {
|
||||
return to_code(constop->get_vector<float>(), no_braces);
|
||||
} else if (ele_type == element::Type_t::i8) {
|
||||
return to_code(constop->get_vector<int8_t>(), no_braces);
|
||||
} else if (ele_type == element::Type_t::u8) {
|
||||
return to_code(constop->get_vector<uint8_t>(), no_braces);
|
||||
} else if (ele_type == element::Type_t::i32) {
|
||||
return to_code(constop->get_vector<int32_t>(), no_braces);
|
||||
} else if (ele_type == element::Type_t::i64) {
|
||||
return to_code(constop->get_vector<int64_t>(), no_braces);
|
||||
}
|
||||
|
||||
// general case
|
||||
std::stringstream ss;
|
||||
if (!no_braces)
|
||||
ss << "{";
|
||||
auto ele_size = shape_size(constop->get_shape());
|
||||
if (ele_size < 9) {
|
||||
const char* sep = "";
|
||||
for (auto v : constop->get_value_strings()) {
|
||||
ss << sep << v;
|
||||
sep = ", ";
|
||||
}
|
||||
} else {
|
||||
ss << "...";
|
||||
}
|
||||
if (!no_braces)
|
||||
ss << "}";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
class OstreamAttributeVisitor : public ngraph::AttributeVisitor {
|
||||
std::ostream& os;
|
||||
const char* sep = "";
|
||||
|
||||
public:
|
||||
OstreamAttributeVisitor(std::ostream& os) : os(os) {}
|
||||
|
||||
void append_attribute(const std::string& name, const std::string& value) {
|
||||
os << sep << "{\"" << name << "\", " << value << "}";
|
||||
sep = ", ";
|
||||
}
|
||||
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
|
||||
if (auto a = ov::as_type<ov::AttributeAdapter<std::set<std::string>>>(&adapter)) {
|
||||
const auto& strset = a->get();
|
||||
std::vector<std::string> values(strset.begin(), strset.end());
|
||||
append_attribute(name, to_code(values));
|
||||
} else if (auto a = ov::as_type<ov::AttributeAdapter<std::vector<ov::element::Type>>>(&adapter)) {
|
||||
append_attribute(name, to_code(a->get()));
|
||||
} else if (auto a = ov::as_type<ov::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
||||
const auto& value = a->get();
|
||||
append_attribute(name, value.to_string());
|
||||
} else if (auto a = ov::as_type<ov::AttributeAdapter<std::shared_ptr<ov::op::util::Variable>>>(&adapter)) {
|
||||
const auto& vinfo = a->get()->get_info();
|
||||
std::stringstream ss;
|
||||
ss << vinfo.variable_id << vinfo.data_shape << vinfo.data_type;
|
||||
append_attribute(name, ss.str());
|
||||
} else {
|
||||
append_attribute(name, "?");
|
||||
}
|
||||
}
|
||||
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<bool>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::string>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int64_t>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int32_t>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<float>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<int>>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<float>>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter) override {
|
||||
append_attribute(name, to_code(adapter.get()));
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::shared_ptr<ov::Model>>& adapter) override {
|
||||
append_attribute(name, "Model");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UNUSED_T = void>
|
||||
void dump_cpp_style(std::ostream& os, const std::shared_ptr<ov::Model>& model) {
|
||||
const ov::Model& f = *model;
|
||||
std::string prefix = "";
|
||||
std::string tag = "";
|
||||
std::string sep = "";
|
||||
os << prefix;
|
||||
for (auto op : f.get_results()) {
|
||||
os << sep << op->get_name();
|
||||
sep = ",";
|
||||
}
|
||||
os << " " << f.get_friendly_name() << "(\n" << prefix;
|
||||
for (auto op : f.get_parameters()) {
|
||||
os << " " << tag << op->get_friendly_name() << ",\n" << prefix;
|
||||
}
|
||||
os << ") {\n";
|
||||
|
||||
// collect all scalar & short 1D vectors for literal-style display
|
||||
std::map<std::shared_ptr<ov::Node>, std::string> literal_consts;
|
||||
for (auto op : f.get_ordered_ops()) {
|
||||
if (auto constop = std::dynamic_pointer_cast<op::v0::Constant>(op)) {
|
||||
// only i32/f32 type const literal can be parsed by C++ compiler
|
||||
if (constop->get_output_element_type(0) != ov::element::i32 &&
|
||||
constop->get_output_element_type(0) != ov::element::i64 &&
|
||||
constop->get_output_element_type(0) != ov::element::f32)
|
||||
continue;
|
||||
auto shape = constop->get_shape();
|
||||
if (shape.size() > 1)
|
||||
continue;
|
||||
if (shape_size(constop->get_shape()) > 64)
|
||||
continue;
|
||||
literal_consts[op] = to_code(constop);
|
||||
}
|
||||
}
|
||||
|
||||
auto get_output_values_info = [](std::shared_ptr<ov::Node>& op) {
|
||||
std::stringstream ss;
|
||||
const char* sep = "";
|
||||
for (size_t i = 0; i < op->get_output_size(); i++) {
|
||||
ss << sep << op->get_output_element_type(i) << op->get_output_partial_shape(i);
|
||||
sep = " ";
|
||||
}
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
// change name convension
|
||||
std::map<ov::Node*, std::string> opname;
|
||||
std::map<std::string, int> opname_count;
|
||||
for (auto op : f.get_ordered_ops()) {
|
||||
auto name = op->get_friendly_name();
|
||||
std::replace(name.begin(), name.end(), '\\', '_');
|
||||
std::replace(name.begin(), name.end(), '/', '_');
|
||||
std::replace(name.begin(), name.end(), '.', '_');
|
||||
std::replace(name.begin(), name.end(), '[', '_');
|
||||
std::replace(name.begin(), name.end(), ']', '_');
|
||||
std::replace(name.begin(), name.end(), '-', 'n');
|
||||
if (name[0] >= '0' && name[0] <= '9') {
|
||||
const auto& type_info = op->get_type_info();
|
||||
name.insert(0, type_info.name);
|
||||
}
|
||||
int idx = 0;
|
||||
if (opname_count.count(name)) {
|
||||
idx = opname_count[name] + 1;
|
||||
}
|
||||
opname_count[name] = idx;
|
||||
|
||||
if (idx)
|
||||
name += std::to_string(idx);
|
||||
|
||||
opname[op.get()] = name;
|
||||
}
|
||||
|
||||
for (auto op : f.get_ordered_ops()) {
|
||||
if (literal_consts.count(op))
|
||||
continue;
|
||||
|
||||
const auto& type_info = op->get_type_info();
|
||||
auto version_info = std::string(type_info.get_version());
|
||||
auto type = version_info + "::" + type_info.name;
|
||||
auto name = opname[op.get()];
|
||||
os << prefix << " ";
|
||||
|
||||
if (auto constop = std::dynamic_pointer_cast<op::v0::Constant>(op)) {
|
||||
os << "auto " << name << " = makeConst(" << to_code(op->get_output_element_type(0)) << ", "
|
||||
<< to_code(op->get_output_shape(0)) << ", " << to_code(constop) << ");" << std::endl;
|
||||
} else {
|
||||
os << "auto " << name << " = makeOP<" << type << ">({";
|
||||
// input args
|
||||
sep = "";
|
||||
for (size_t i = 0; i < op->get_input_size(); i++) {
|
||||
auto vout = op->get_input_source_output(i);
|
||||
auto iop = vout.get_node_shared_ptr();
|
||||
if (iop->get_output_size() > 1) {
|
||||
auto out_port = vout.get_index();
|
||||
os << sep << tag << opname[iop.get()] << "->output(" << out_port << ")";
|
||||
} else {
|
||||
if (literal_consts.count(iop))
|
||||
os << sep << tag << literal_consts[iop];
|
||||
else
|
||||
os << sep << tag << opname[iop.get()];
|
||||
}
|
||||
sep = ", ";
|
||||
}
|
||||
os << "}";
|
||||
|
||||
// attributes as AnyMap
|
||||
std::stringstream ss2;
|
||||
OstreamAttributeVisitor osvis(ss2);
|
||||
op->visit_attributes(osvis);
|
||||
auto str_attr = ss2.str();
|
||||
if (str_attr.size())
|
||||
os << ", {" << str_attr << "}";
|
||||
os << "); // tensor_array<" << get_output_values_info(op) << "> " << op->get_friendly_name();
|
||||
|
||||
os << "(";
|
||||
sep = "";
|
||||
for (size_t i = 0; i < op->get_input_size(); i++) {
|
||||
auto vout = op->get_input_source_output(i);
|
||||
auto iop = vout.get_node_shared_ptr();
|
||||
os << sep << tag << iop->get_friendly_name();
|
||||
if (iop->get_output_size() > 1) {
|
||||
auto out_port = vout.get_index();
|
||||
os << "[" << out_port << "]";
|
||||
}
|
||||
sep = ", ";
|
||||
}
|
||||
os << ")" << std::endl;
|
||||
}
|
||||
|
||||
// recursively output subgraphs
|
||||
if (auto msubgraph = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(op)) {
|
||||
auto cnt = msubgraph->get_internal_subgraphs_size();
|
||||
for (size_t i = 0; i < cnt; i++) {
|
||||
os << " MultiSubGraphOp " << tag << msubgraph->get_friendly_name() << "[" << i << "]" << std::endl;
|
||||
dump_cpp_style(os, msubgraph->get_function(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
os << prefix << "}\n";
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class OPENVINO_API PrintModel : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::PrintModel");
|
||||
|
||||
PrintModel(std::string file_name) {
|
||||
static int dump_index = 0;
|
||||
m_file_name = std::string("modelprint_") + std::to_string(dump_index) + "_" + file_name;
|
||||
dump_index++;
|
||||
}
|
||||
~PrintModel() {}
|
||||
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& model) override {
|
||||
if (m_file_name.empty())
|
||||
return false;
|
||||
|
||||
std::ofstream ofs(m_file_name);
|
||||
if (!ofs) {
|
||||
// OPENVINO_WARN << "Error opening file " << m_file_name << " for output" << std::endl;
|
||||
return false;
|
||||
}
|
||||
detail::dump_cpp_style(ofs, model);
|
||||
ofs.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::string m_file_name;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
@ -1,184 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <debug.h>
|
||||
|
||||
#include <common_test_utils/ov_tensor_utils.hpp>
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
#include <ov_models/builders.hpp>
|
||||
#include <shared_test_classes/base/ov_subgraph.hpp>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
#include "ie_precision.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "test_utils/fusing_test_utils.hpp"
|
||||
#include "utils/gen_pattern.hpp"
|
||||
|
||||
using namespace CPUTestUtils;
|
||||
using namespace ov::gen_pattern;
|
||||
using namespace ov::test;
|
||||
using namespace ov;
|
||||
|
||||
static ov::OutputVector makeCosSinCache(int max_position_embeddings, int rotary_ndims) {
|
||||
std::vector<float> lut_sin(max_position_embeddings * rotary_ndims, 0.0f);
|
||||
std::vector<float> lut_cos(max_position_embeddings * rotary_ndims, 0.0f);
|
||||
|
||||
// rotate_half style cos/sin table:
|
||||
// y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2
|
||||
// y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1
|
||||
//
|
||||
for (int i = 0, k = 0; i < rotary_ndims; i += 2, k++) {
|
||||
auto xita_i = 1.0 / std::pow(10000.0, static_cast<double>(i) / rotary_ndims);
|
||||
float* psin = lut_sin.data();
|
||||
float* pcos = lut_cos.data();
|
||||
for (int m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) {
|
||||
auto vsin = std::sin(xita_i * m);
|
||||
auto vcos = std::cos(xita_i * m);
|
||||
pcos[k] = pcos[k + rotary_ndims / 2] = vcos;
|
||||
psin[k] = psin[k + rotary_ndims / 2] = vsin;
|
||||
}
|
||||
}
|
||||
auto shape = ov::Shape({1, 1, static_cast<size_t>(max_position_embeddings), static_cast<size_t>(rotary_ndims)});
|
||||
auto Cos = makeConst(ov::element::f32, shape, lut_cos);
|
||||
auto Sin = makeConst(ov::element::f32, shape, lut_sin);
|
||||
return {Cos, Sin};
|
||||
}
|
||||
|
||||
static std::shared_ptr<ov::Model> buildROPE_Llama2(const int batch,
|
||||
const int seq_length,
|
||||
const int max_position_embeddings,
|
||||
const int num_head,
|
||||
const int ndims) {
|
||||
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{batch, -1, num_head, ndims});
|
||||
auto pos_id_end = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{});
|
||||
auto pos_ids = std::make_shared<ov::opset1::Parameter>(ov::element::i32, PartialShape{1, -1});
|
||||
|
||||
auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims);
|
||||
auto Constant582 = cos_sin_cache[0];
|
||||
auto Constant585 = cos_sin_cache[1];
|
||||
|
||||
// concat KV length
|
||||
auto transpose_Transpose = makeOP<opset1::Transpose>({input, {0, 2, 1, 3}});
|
||||
auto slice_Unsqueeze_426 = makeOP<opset1::Unsqueeze>({pos_id_end, 0});
|
||||
auto ScatterUpdate_152236 = makeOP<opset3::ScatterUpdate>({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}});
|
||||
auto slice_Slice = makeOP<opset1::StridedSlice>({Constant582, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 0}},
|
||||
{"end_mask", {1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto squeeze_Squeeze = makeOP<opset1::Squeeze>({slice_Slice, 1});
|
||||
auto squeeze_Squeeze_435 = makeOP<opset1::Squeeze>({squeeze_Squeeze, 0});
|
||||
auto index_441_Gather = makeOP<opset8::Gather>({squeeze_Squeeze_435, pos_ids, 0}, {{"batch_dims", 0}});
|
||||
auto unsqueeze_Unsqueeze = makeOP<opset1::Unsqueeze>({index_441_Gather, 1});
|
||||
auto mul_Multiply =
|
||||
makeOP<opset1::Multiply>({transpose_Transpose, unsqueeze_Unsqueeze}, {{"auto_broadcast", "numpy"}});
|
||||
auto size_ShapeOf_448 = makeOP<opset3::ShapeOf>({transpose_Transpose}, {{"output_type", "i32"}});
|
||||
auto size_Gather_450 = makeOP<opset8::Gather>({size_ShapeOf_448, 3, 0}, {{"batch_dims", 0}});
|
||||
auto floor_divide_Divide =
|
||||
makeOP<opset1::Divide>({size_Gather_450, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}});
|
||||
auto floor_divide_Floor = makeOP<opset1::Floor>({floor_divide_Divide});
|
||||
auto slice_Unsqueeze_452 = makeOP<opset1::Unsqueeze>({floor_divide_Floor, 0});
|
||||
auto ScatterUpdate_152312 = makeOP<opset3::ScatterUpdate>({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}});
|
||||
auto slice_Slice_459 = makeOP<opset1::StridedSlice>(
|
||||
{transpose_Transpose, ScatterUpdate_152312, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto Constant_182988 = makeConst(element::f32,
|
||||
ov::Shape({
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
}),
|
||||
{-1.000000f});
|
||||
auto neg_Multiply = makeOP<opset1::Multiply>({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}});
|
||||
auto ScatterUpdate_152368 = makeOP<opset3::ScatterUpdate>({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}});
|
||||
auto slice_Slice2 =
|
||||
makeOP<opset1::StridedSlice>({transpose_Transpose, {0, 0, 0, 0}, ScatterUpdate_152368, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto cat_Concat = makeOP<opset1::Concat>({neg_Multiply, slice_Slice2}, {{"axis", -1}});
|
||||
auto ScatterUpdate_152421 = makeOP<opset3::ScatterUpdate>({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}});
|
||||
auto slice_Slice_433 = makeOP<opset1::StridedSlice>({Constant585, {0, 0, 0}, ScatterUpdate_152421, {1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 0}},
|
||||
{"end_mask", {1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto squeeze_Squeeze_436 = makeOP<opset1::Squeeze>({slice_Slice_433, 1});
|
||||
auto squeeze_Squeeze_437 = makeOP<opset1::Squeeze>({squeeze_Squeeze_436, 0});
|
||||
auto index_446_Gather = makeOP<opset8::Gather>({squeeze_Squeeze_437, pos_ids, 0}, {{"batch_dims", 0}});
|
||||
auto unsqueeze_Unsqueeze_447 = makeOP<opset1::Unsqueeze>({index_446_Gather, 1});
|
||||
auto mul_Multiply_463 =
|
||||
makeOP<opset1::Multiply>({cat_Concat, unsqueeze_Unsqueeze_447}, {{"auto_broadcast", "numpy"}});
|
||||
auto add_Add = makeOP<opset1::Add>({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}});
|
||||
|
||||
return std::make_shared<ov::Model>(ov::NodeVector{add_Add}, ov::ParameterVector{input, pos_id_end, pos_ids});
|
||||
}
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
class RoPECPUTest : public SubgraphBaseTest {
|
||||
public:
|
||||
ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) {
|
||||
auto tensor = ov::Tensor(ov::element::i32, shape);
|
||||
auto* ptr = static_cast<int32_t*>(tensor.data());
|
||||
for (size_t i = 0; i < tensor.get_size(); i++) {
|
||||
ptr[i] = start;
|
||||
start += step;
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||
const auto& funcInputs = function->inputs();
|
||||
|
||||
const int position_id_start = 15;
|
||||
auto& input_shape = targetInputStaticShapes[0];
|
||||
auto seq_length = input_shape[1];
|
||||
|
||||
ov::Tensor t_input =
|
||||
utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768);
|
||||
ov::Tensor t_position_id_end = create_i32_tensor(ov::Shape({}), position_id_start + seq_length);
|
||||
ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), position_id_start);
|
||||
|
||||
inputs.clear();
|
||||
inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input});
|
||||
inputs.insert({funcInputs[1].get_node_shared_ptr(), t_position_id_end});
|
||||
inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids});
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
|
||||
const int batch = 2;
|
||||
const int seq_length = 7;
|
||||
const size_t max_position_embeddings = 2048;
|
||||
const size_t ndims = 128;
|
||||
const size_t num_head = 32;
|
||||
|
||||
InputShape inpShape = {{batch, seq_length, num_head, ndims}, {{batch, seq_length, num_head, ndims}}};
|
||||
init_input_shapes({inpShape});
|
||||
function = buildROPE_Llama2(batch, seq_length, max_position_embeddings, num_head, ndims);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RoPECPUTest, smoke_CompareWithRefs) {
|
||||
run();
|
||||
}
|
||||
|
||||
} // namespace CPULayerTestsDefinitions
|
@ -1,452 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <openvino/core/model.hpp>
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include <openvino/opsets/opset3.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <ov_ops/type_relaxed.hpp>
|
||||
#include <transformations/cpu_opset/common/op/rope.hpp>
|
||||
#include <transformations/cpu_opset/common/pass/rope_fusion.hpp>
|
||||
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "utils/gen_pattern.hpp"
|
||||
#include "utils/print_model.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace ov::gen_pattern;
|
||||
|
||||
static ov::OutputVector makeCosSinCache(size_t max_position_embeddings, size_t rotary_ndims) {
|
||||
std::vector<float> lut_sin(max_position_embeddings * rotary_ndims, 0.0f);
|
||||
std::vector<float> lut_cos(max_position_embeddings * rotary_ndims, 0.0f);
|
||||
|
||||
// rotate_half style cos/sin table:
|
||||
// y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2
|
||||
// y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1
|
||||
//
|
||||
for (size_t i = 0, k = 0; i < rotary_ndims; i += 2, k++) {
|
||||
auto xita_i = 1.0 / std::pow(10000.0, static_cast<double>(i) / rotary_ndims);
|
||||
float* psin = lut_sin.data();
|
||||
float* pcos = lut_cos.data();
|
||||
for (size_t m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) {
|
||||
auto vsin = std::sin(xita_i * m);
|
||||
auto vcos = std::cos(xita_i * m);
|
||||
pcos[k] = pcos[k + rotary_ndims / 2] = vcos;
|
||||
psin[k] = psin[k + rotary_ndims / 2] = vsin;
|
||||
}
|
||||
}
|
||||
auto Cos = makeConst(ov::element::f32, ov::Shape({1, 1, max_position_embeddings, rotary_ndims}), lut_cos);
|
||||
auto Sin = makeConst(ov::element::f32, ov::Shape({1, 1, max_position_embeddings, rotary_ndims}), lut_sin);
|
||||
|
||||
return {Cos, Sin};
|
||||
}
|
||||
|
||||
static std::shared_ptr<ov::Model> buildROPE_Llama2(const size_t batch,
|
||||
const size_t seq_length,
|
||||
const size_t max_position_embeddings,
|
||||
const size_t ndims,
|
||||
bool sin_cos_preprocessing) {
|
||||
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{batch, seq_length, 32, ndims});
|
||||
auto param_cos = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, seq_length, ndims});
|
||||
auto param_sin = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, seq_length, ndims});
|
||||
|
||||
auto seq_len = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
|
||||
auto gather_id = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1, seq_length});
|
||||
|
||||
auto gather_from_sin_cos = [&](const ov::Output<ov::Node>& const_tab) {
|
||||
auto ScatterUpdate_152236 = makeOP<ov::opset3::ScatterUpdate>({{0, 0, 0}, {2}, seq_len, {0}});
|
||||
auto slice_Slice = makeOP<ov::opset1::StridedSlice>({const_tab, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 0}},
|
||||
{"end_mask", {1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto squeeze_Squeeze_435 =
|
||||
makeOP<ov::opset1::Reshape>({slice_Slice, {-1, static_cast<int>(ndims)}}, {{"special_zero", false}});
|
||||
auto index_441_Gather = makeOP<ov::opset8::Gather>({squeeze_Squeeze_435, gather_id, {0}}, {{"batch_dims", 0}});
|
||||
return makeOP<ov::opset1::Reshape>({index_441_Gather, {1, 1, -1, static_cast<int>(ndims)}},
|
||||
{{"special_zero", false}});
|
||||
};
|
||||
|
||||
ov::OutputVector cos_sin(2);
|
||||
ov::ParameterVector parameters;
|
||||
if (sin_cos_preprocessing) {
|
||||
auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims);
|
||||
cos_sin[0] = gather_from_sin_cos(cos_sin_cache[0]);
|
||||
cos_sin[1] = gather_from_sin_cos(cos_sin_cache[1]);
|
||||
parameters = ov::ParameterVector{input, seq_len, gather_id};
|
||||
} else {
|
||||
cos_sin[0] = param_cos;
|
||||
cos_sin[1] = param_sin;
|
||||
parameters = ov::ParameterVector{input, param_cos, param_sin};
|
||||
}
|
||||
|
||||
auto transpose_Transpose = makeOP<ov::opset1::Transpose>({input, {0, 2, 1, 3}});
|
||||
auto mul_Multiply = makeOP<ov::opset1::Multiply>({transpose_Transpose, cos_sin[0]}, {{"auto_broadcast", "numpy"}});
|
||||
auto slice_Slice_459 =
|
||||
makeOP<ov::opset1::StridedSlice>({transpose_Transpose, {0, 0, 0, 64}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto Constant_182988 = makeConst(ov::element::f32,
|
||||
ov::Shape({
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
}),
|
||||
{-1.000000f});
|
||||
auto neg_Multiply = makeOP<ov::opset1::Multiply>({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}});
|
||||
auto slice_Slice =
|
||||
makeOP<ov::opset1::StridedSlice>({transpose_Transpose, {0, 0, 0, 0}, {0, 0, 0, 64}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, slice_Slice}, {{"axis", -1}});
|
||||
auto mul_Multiply_463 = makeOP<ov::opset1::Multiply>({cat_Concat, cos_sin[1]}, {{"auto_broadcast", "numpy"}});
|
||||
auto add_Add = makeOP<ov::opset1::Add>({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}});
|
||||
|
||||
return std::make_shared<ov::Model>(ov::NodeVector{add_Add}, parameters);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertToROPE_LLama2_no_gather) {
|
||||
disable_rt_info_check();
|
||||
const int batch = 2;
|
||||
const int seq_length = 16;
|
||||
const size_t max_position_embeddings = 2048;
|
||||
const size_t ndims = 128;
|
||||
const size_t num_head = 32;
|
||||
|
||||
model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, false);
|
||||
manager.register_pass<RoPEFusion>();
|
||||
|
||||
{
|
||||
auto hidden_states =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{batch, seq_length, num_head, ndims});
|
||||
auto param_cos = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, seq_length, ndims});
|
||||
auto param_sin = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, seq_length, ndims});
|
||||
auto add_Add = makeOP<RoPENode>({hidden_states, param_cos, param_sin},
|
||||
{{"config.slice_start", 0},
|
||||
{"config.slice_stop", 0},
|
||||
{"config.input_trans0213", true},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.rotary_ndims", static_cast<int>(ndims)},
|
||||
{"config.gather_position_arg_id", 0}});
|
||||
|
||||
model_ref = std::make_shared<ov::Model>(ov::NodeVector{add_Add},
|
||||
ov::ParameterVector{hidden_states, param_cos, param_sin});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertToROPE_LLama2_with_gather) {
|
||||
disable_rt_info_check();
|
||||
const int batch = 2;
|
||||
const int seq_length = 16;
|
||||
const size_t max_position_embeddings = 2048;
|
||||
const size_t ndims = 128;
|
||||
const size_t num_head = 32;
|
||||
|
||||
model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, true);
|
||||
manager.register_pass<RoPEFusion>();
|
||||
|
||||
{
|
||||
auto hidden_states =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{batch, seq_length, num_head, ndims});
|
||||
auto seq_len = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
|
||||
auto gather_id = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1, seq_length});
|
||||
auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims);
|
||||
|
||||
auto add_Add = makeOP<RoPENode>({hidden_states, cos_sin_cache[0], cos_sin_cache[1], gather_id},
|
||||
{{"config.slice_start", 0},
|
||||
{"config.slice_stop", 0},
|
||||
{"config.input_trans0213", true},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.rotary_ndims", static_cast<int>(ndims)},
|
||||
{"config.gather_position_arg_id", 3}});
|
||||
|
||||
model_ref = std::make_shared<ov::Model>(ov::NodeVector{add_Add},
|
||||
ov::ParameterVector{hidden_states, seq_len, gather_id});
|
||||
}
|
||||
}
|
||||
|
||||
static std::shared_ptr<ov::Model> buildROPE_GPTNEOX(const int batch,
|
||||
const int seq_length,
|
||||
const int max_position_embeddings,
|
||||
const int ndims,
|
||||
const int num_heads,
|
||||
const int rotary_ndims,
|
||||
bool sin_cos_preprocessing) {
|
||||
auto batch_s = static_cast<size_t>(batch);
|
||||
auto seq_length_s = static_cast<size_t>(seq_length);
|
||||
auto ndims_s = static_cast<size_t>(ndims);
|
||||
auto rotary_ndims_s = static_cast<size_t>(rotary_ndims);
|
||||
auto num_heads_s = static_cast<size_t>(num_heads);
|
||||
|
||||
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32,
|
||||
ov::Shape{batch_s, seq_length_s, num_heads_s, ndims_s * 3});
|
||||
auto seq_len = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
|
||||
auto gather_idx =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s});
|
||||
auto batch_limit = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
|
||||
|
||||
ov::ParameterVector parameters;
|
||||
ov::OutputVector cos_sin(2);
|
||||
if (sin_cos_preprocessing) {
|
||||
auto cos_sin_lut = makeCosSinCache(max_position_embeddings, rotary_ndims);
|
||||
auto ro_slice_Slice = makeOP<ov::opset1::StridedSlice>({cos_sin_lut[0], {0}, batch_limit, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
cos_sin[0] = makeOP<ov::opset6::GatherElements>({ro_slice_Slice, gather_idx}, {{"axis", 2}});
|
||||
|
||||
auto ro_slice_Slice_385 = makeOP<ov::opset1::StridedSlice>({cos_sin_lut[1], {0}, batch_limit, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
cos_sin[1] = makeOP<ov::opset6::GatherElements>({ro_slice_Slice_385, gather_idx}, {{"axis", 2}});
|
||||
parameters = ov::ParameterVector{input, gather_idx, batch_limit};
|
||||
} else {
|
||||
auto param_cos =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s});
|
||||
auto param_sin =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s});
|
||||
parameters = ov::ParameterVector{input, param_cos, param_sin};
|
||||
cos_sin[0] = param_cos;
|
||||
cos_sin[1] = param_sin;
|
||||
}
|
||||
|
||||
auto slice_Slice = makeOP<ov::opset1::StridedSlice>({input, {0, 0, 0, 0}, {0, 0, 0, ndims}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto permute_Transpose = makeOP<ov::opset1::Transpose>({slice_Slice, {0, 2, 1, 3}});
|
||||
auto slice_Slice_351 =
|
||||
makeOP<ov::opset1::StridedSlice>({permute_Transpose, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto mul_Multiply = makeOP<ov::opset1::Multiply>({slice_Slice_351, cos_sin[0]}, {{"auto_broadcast", "numpy"}});
|
||||
auto slice_Slice_420 = makeOP<ov::opset1::StridedSlice>(
|
||||
{slice_Slice_351, {0, 0, 0, rotary_ndims / 2}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto Constant_396096 = makeConst(ov::element::f32,
|
||||
ov::Shape({
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
}),
|
||||
{-1.000000f});
|
||||
auto neg_Multiply = makeOP<ov::opset1::Multiply>({slice_Slice_420, Constant_396096}, {{"auto_broadcast", "numpy"}});
|
||||
auto slice_Slice_414 =
|
||||
makeOP<ov::opset1::StridedSlice>({slice_Slice_351, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims / 2}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, slice_Slice_414}, {{"axis", -1}});
|
||||
auto mul_Multiply_424 = makeOP<ov::opset1::Multiply>({cat_Concat, cos_sin[1]}, {{"auto_broadcast", "numpy"}});
|
||||
auto add_Add = makeOP<ov::opset1::Add>({mul_Multiply, mul_Multiply_424}, {{"auto_broadcast", "numpy"}});
|
||||
auto slice_Slice_357 =
|
||||
makeOP<ov::opset1::StridedSlice>({permute_Transpose, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto cat_Concat_458 = makeOP<ov::opset1::Concat>({add_Add, slice_Slice_357}, {{"axis", -1}});
|
||||
|
||||
return std::make_shared<ov::Model>(ov::NodeVector{cat_Concat_458}, parameters);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_no_gather) {
|
||||
disable_rt_info_check();
|
||||
const int batch = 2;
|
||||
const int seq_len = 16;
|
||||
const int ndims = 80;
|
||||
const int num_heads = 32;
|
||||
const int rotary_ndims = 20;
|
||||
const int max_position_embeddings = 2048;
|
||||
|
||||
model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, false);
|
||||
manager.register_pass<RoPEFusion>();
|
||||
{
|
||||
auto input =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims * 3});
|
||||
auto param_cos =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, seq_len, rotary_ndims});
|
||||
auto param_sin =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, seq_len, rotary_ndims});
|
||||
auto rope = makeOP<RoPENode>({input, param_cos, param_sin},
|
||||
{{"config.slice_start", 0},
|
||||
{"config.slice_stop", ndims},
|
||||
{"config.input_trans0213", true},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.rotary_ndims", rotary_ndims},
|
||||
{"config.gather_position_arg_id", 0}});
|
||||
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, param_cos, param_sin});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_with_gather) {
|
||||
disable_rt_info_check();
|
||||
const int batch = 2;
|
||||
const int seq_len = 16;
|
||||
const int ndims = 80;
|
||||
const int rotary_ndims = 20;
|
||||
const int num_heads = 32;
|
||||
const int max_position_embeddings = 2048;
|
||||
|
||||
model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, true);
|
||||
manager.register_pass<RoPEFusion>();
|
||||
{
|
||||
auto cos_sin = makeCosSinCache(max_position_embeddings, rotary_ndims);
|
||||
auto input =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims * 3});
|
||||
auto gather_idx =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1, 1, seq_len, rotary_ndims});
|
||||
auto batch_limit = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
|
||||
|
||||
auto rope = makeOP<RoPENode>({input, cos_sin[0], cos_sin[1], gather_idx},
|
||||
{{"config.slice_start", 0},
|
||||
{"config.slice_stop", ndims},
|
||||
{"config.input_trans0213", true},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.rotary_ndims", rotary_ndims},
|
||||
{"config.gather_position_arg_id", 3}});
|
||||
model_ref =
|
||||
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, gather_idx, batch_limit});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertToROPE_GPTJ) {
|
||||
disable_rt_info_check();
|
||||
const int batch = 2;
|
||||
const int seq_len = 7;
|
||||
const int num_heads = 16;
|
||||
const int ndims = 256;
|
||||
const int rotary_ndims = 64;
|
||||
{
|
||||
std::vector<int32_t> rpi_idx(rotary_ndims);
|
||||
for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) {
|
||||
rpi_idx[i] = index;
|
||||
rpi_idx[i + 1] = index;
|
||||
}
|
||||
auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx);
|
||||
|
||||
auto input =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims});
|
||||
auto gather_sin_cos =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, seq_len, rotary_ndims});
|
||||
|
||||
auto split = makeOP<ov::opset1::VariadicSplit>({gather_sin_cos, {-1}, {rotary_ndims / 2, -1}});
|
||||
auto sin_tab =
|
||||
makeOP<ov::opset1::Reshape>({split->output(0), {1, -1, 1, rotary_ndims / 2}}, {{"special_zero", false}});
|
||||
auto cos_tab =
|
||||
makeOP<ov::opset1::Reshape>({split->output(1), {1, -1, 1, rotary_ndims / 2}}, {{"special_zero", false}});
|
||||
|
||||
auto slice_Slice_576 =
|
||||
makeOP<ov::opset1::StridedSlice>({input, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto repeat_interleave_Cos =
|
||||
makeOP<ov::opset8::Gather>({cos_tab, repeat_interleave_index, {3}}, {{"batch_dims", 0}});
|
||||
auto mul_Multiply_757 =
|
||||
makeOP<ov::opset1::Multiply>({slice_Slice_576, repeat_interleave_Cos}, {{"auto_broadcast", "numpy"}});
|
||||
|
||||
auto slice_Slice_787 =
|
||||
makeOP<ov::opset1::StridedSlice>({slice_Slice_576, {0, 0, 0, 1}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto Constant_191672 = makeConst(ov::element::f32,
|
||||
ov::Shape({
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
}),
|
||||
{-1.000000f});
|
||||
auto neg_Multiply_790 =
|
||||
makeOP<ov::opset1::Multiply>({slice_Slice_787, Constant_191672}, {{"auto_broadcast", "numpy"}});
|
||||
auto Unsqueeze_61918 = makeOP<ov::opset1::Unsqueeze>({neg_Multiply_790, {-1}});
|
||||
auto slice_Slice_781 =
|
||||
makeOP<ov::opset1::StridedSlice>({slice_Slice_576, {0, 0, 0, 0}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto Unsqueeze_61919 = makeOP<ov::opset1::Unsqueeze>({slice_Slice_781, {-1}});
|
||||
auto stack_795 = makeOP<ov::opset1::Concat>({Unsqueeze_61918, Unsqueeze_61919}, {{"axis", -1}});
|
||||
auto ShapeOf_165368 = makeOP<ov::op::TypeRelaxed<ov::opset1::ShapeOf>>(
|
||||
{stack_795},
|
||||
{{"type_relax", true}, {"input_data_types", {}}, {"output_data_types", {ov::element::i32}}});
|
||||
auto flatten_Slice_811 = makeOP<ov::opset1::StridedSlice>({ShapeOf_165368, {0}, {3}, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto flatten_Concat_814 = makeOP<ov::opset1::Concat>({flatten_Slice_811, {-1}}, {{"axis", 0}});
|
||||
auto flatten_Reshape_815 =
|
||||
makeOP<ov::opset1::Reshape>({stack_795, flatten_Concat_814}, {{"special_zero", true}});
|
||||
auto repeat_interleave_Sin =
|
||||
makeOP<ov::opset8::Gather>({sin_tab, repeat_interleave_index, {3}}, {{"batch_dims", 0}});
|
||||
auto mul_Multiply_816 =
|
||||
makeOP<ov::opset1::Multiply>({flatten_Reshape_815, repeat_interleave_Sin}, {{"auto_broadcast", "numpy"}});
|
||||
auto add_Add_819 = makeOP<ov::opset1::Add>({mul_Multiply_757, mul_Multiply_816}, {{"auto_broadcast", "numpy"}});
|
||||
auto slice_Slice_582 =
|
||||
makeOP<ov::opset1::StridedSlice>({input, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto cat_Concat_826 = makeOP<ov::opset1::Concat>({add_Add_819, slice_Slice_582}, {{"axis", -1}});
|
||||
auto permute_Transpose_828 = makeOP<ov::opset1::Transpose>({cat_Concat_826, {0, 2, 1, 3}});
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{permute_Transpose_828},
|
||||
ov::ParameterVector{input, gather_sin_cos});
|
||||
}
|
||||
manager.register_pass<RoPEFusion>();
|
||||
{
|
||||
auto input =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims});
|
||||
auto cos_sin = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, seq_len, rotary_ndims});
|
||||
auto rope = makeOP<RoPENode>({input, cos_sin, cos_sin},
|
||||
{{"config.slice_start", 0},
|
||||
{"config.slice_stop", 0},
|
||||
{"config.input_trans0213", false},
|
||||
{"config.is_interleaved", true},
|
||||
{"config.rotary_ndims", rotary_ndims},
|
||||
{"config.gather_position_arg_id", 0}});
|
||||
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin});
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user