[CPU] Support chatglm RoPE (#21295)

This commit is contained in:
Tingqian Li 2023-12-04 20:18:01 +08:00 committed by GitHub
parent 15ff4e6596
commit fcbb80d372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 511 additions and 8 deletions

View File

@ -54,6 +54,13 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
gather.reset(inputs[config.gather_position_arg_id]);
}
if (t_cos.m_rank == 2) {
t_cos = t_cos.reshape({1, 1, t_cos.size(0), t_cos.size(1)});
}
if (t_sin.m_rank == 2) {
t_sin = t_sin.reshape({1, 1, t_sin.size(0), t_sin.size(1)});
}
auto batch_size = t_src.size(0);
auto head_cnt = t_src.size(1);
auto seq_len = t_src.size(2);
@ -124,6 +131,48 @@ struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor {
}
};
template <typename T>
struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {
void execute(dnnl::stream strm,
const RoPENode::Config& config,
const std::vector<MemoryPtr>& inputs,
const std::vector<MemoryPtr>& outputs) override {
ov::intel_cpu::PlainTensor<T> t_src(inputs[0]);
ov::intel_cpu::PlainTensor<float> t_cos_sin(inputs[1]);
ov::intel_cpu::PlainTensor<T> t_dst(outputs[0]);
// [seq_len, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)]
if (config.slice_stop - config.slice_start > 0) {
t_src = t_src.slice(2, config.slice_start, config.slice_stop);
}
auto seq_len = t_src.size(0);
auto batch_size = t_src.size(1);
auto head_cnt = config.head_cnt;
auto head_size = config.head_size;
auto rotary_dims = config.rotary_ndims;
parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) {
auto* src = &t_src.at({p, b, h * head_size});
// [length, batch_size, ndims//2, 2]
auto* cos_sin = &t_cos_sin.at({p, b, 0, 0}, true);
auto* dst = &t_dst.at({p, b, h, 0});
size_t i = 0;
for (; i < rotary_dims; i += 2) {
auto cosv = cos_sin[i];
auto sinv = cos_sin[i + 1];
dst[i] = cosv * src[i] - sinv * src[i + 1];
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
}
for (; i < head_size; i++) {
dst[i] = src[i];
}
});
}
};
void RoPE::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
@ -132,7 +181,14 @@ void RoPE::initSupportedPrimitiveDescriptors() {
auto rtPrecision = srcPrecision;
auto CosSinPrecision = ov::element::f32;
if (m_config.is_interleaved) {
if (m_config.is_chatglm) {
if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<RoPEExecutorChatGLM<ov::bfloat16>>();
} else {
m_executor = std::make_shared<RoPEExecutorChatGLM<float>>();
rtPrecision = ov::element::f32;
}
} else if (m_config.is_interleaved) {
OPENVINO_ASSERT(m_config.input_trans0213 == false);
OPENVINO_ASSERT(m_config.slice_start == 0);
OPENVINO_ASSERT(m_config.slice_stop == 0);

View File

@ -45,6 +45,8 @@ private:
struct RoPEExecutorRotateHalf;
template <typename T>
struct RoPEExecutorInterleaved;
template <typename T>
struct RoPEExecutorChatGLM;
RoPENode::Config m_config;
std::shared_ptr<Executor> m_executor;
};

View File

@ -22,6 +22,18 @@ void ov::intel_cpu::RoPENode::validate_and_infer_types() {
INTERNAL_OP_SCOPE(RoPENode_validate_and_infer_types);
auto input_pshape = get_input_partial_shape(0);
auto input_slice_size = m_config.slice_stop - m_config.slice_start;
if (m_config.is_chatglm) {
// chatGLM specific RoPE
// input [length, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)]
// output [length, batch_size, head_cnt, hidden_states_k]
set_output_type(
0,
get_input_element_type(0),
{input_pshape[0], input_pshape[1], ov::Dimension(m_config.head_cnt), ov::Dimension(m_config.head_size)});
return;
}
if (input_slice_size > 0) {
input_pshape[3] = input_slice_size;
}
@ -44,6 +56,9 @@ bool ov::intel_cpu::RoPENode::visit_attributes(ngraph::AttributeVisitor& visitor
visitor.on_attribute("input_trans0213", m_config.input_trans0213);
visitor.on_attribute("is_interleaved", m_config.is_interleaved);
visitor.on_attribute("rotary_ndims", m_config.rotary_ndims);
visitor.on_attribute("is_chatglm", m_config.is_chatglm);
visitor.on_attribute("head_cnt", m_config.head_cnt);
visitor.on_attribute("head_size", m_config.head_size);
visitor.on_attribute("gather_position_arg_id", m_config.gather_position_arg_id);
visitor.finish_structure();
return true;

View File

@ -70,6 +70,9 @@ public:
bool input_trans0213 = false; // transpose input dim 1&2
bool is_interleaved = false; // interleaved mode, implies trans0213 happens after RoPE
size_t rotary_ndims = 0; // dimensions to be embedded (d in the description)
bool is_chatglm = false; // chatglm is special which overrides other setting
size_t head_cnt = 0;
size_t head_size = 0;
int gather_position_arg_id =
0; // arg id of position tensor, ==3 when gather from sin/cos inputs according to position is required
};

View File

@ -132,7 +132,17 @@ ov::intel_cpu::RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() {
{"ellipsis_mask", {}}});
auto squeeze = makePattern<opset1::Reshape>({slice_Slice, {-1, head_dims}});
auto index_Gather = makePattern<opset8::Gather>({squeeze, gather_positions_2d, 0}, {{"batch_dims", 0}});
auto unsqueeze = makePattern<opset1::Reshape>({index_Gather, {1, 1, -1, head_dims}});
// another simplified pattern for gathering at position_ids
auto slice_Slice2 = makePattern<opset1::StridedSlice>({const_tab, {0}, seq_len, {1}},
{{"begin_mask", {0}},
{"end_mask", {0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto index_Gather2 = makePattern<opset8::Gather>({slice_Slice2, gather_positions_2d, 0}, {{"batch_dims", 0}});
auto unsqueeze = makePattern<opset1::Reshape>({index_Gather | index_Gather2, {1, 1, -1, head_dims}});
return unsqueeze;
};
@ -439,6 +449,131 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(result, matcher_name);
this->register_matcher(m, callback);
}
ov::intel_cpu::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {
MATCHER_SCOPE(RoPEFusionChatGLM);
auto qkv_linear = makePattern("f32[?,?,?]"); // f32[seq_length, batch_size, 4608]
auto seq_length = makePattern("i32[1]");
auto cos_sin_cache = makePattern("f32[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2]
auto ndims = Symbol("ndims");
auto head_cnt = Symbol("head_cnt");
auto head_size = Symbol("head_size");
auto total_size_q = Symbol("total_size_q");
auto total_size_k = Symbol("total_size_k");
auto total_size_v = Symbol("total_size_v");
auto qkv_proj = makePattern<opset1::VariadicSplit>({qkv_linear, -1, {total_size_q, total_size_k, total_size_v}});
qkv_proj->set_output_size(3);
// get key [L, B, Hkv, S]
auto cur_key = makePattern<opset1::Reshape>({qkv_proj->output(split_output_id), {0, 0, head_cnt, head_size}},
{{"special_zero", true}});
auto slice_Slice_437 = makePattern<opset1::StridedSlice>({cur_key, {0, 0, 0, 0}, {0, 0, 0, ndims}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
// rotate half
auto ListConstruct_452_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}});
auto ListConstruct_379_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}});
auto reshape_Reshape_453 =
makePattern<opset1::Reshape>({slice_Slice_437, ListConstruct_452_Concat}, {{"special_zero", false}});
auto x_even = makePattern<opset8::Gather>({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}});
auto slice_Slice_449 = makePattern<opset1::StridedSlice>({cos_sin_cache, {0}, seq_length, {1}},
{{"begin_mask", {0}},
{"end_mask", {0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto view_Reshape_460 =
makePattern<opset1::Reshape>({slice_Slice_449, ListConstruct_379_Concat}, {{"special_zero", false}});
auto cos_tab = makePattern<opset8::Gather>({view_Reshape_460, 0, -1}, {{"batch_dims", 0}});
auto x_even_cos = makePattern<opset1::Multiply>({x_even, cos_tab}, {{"auto_broadcast", "numpy"}});
auto x_odd = makePattern<opset8::Gather>({reshape_Reshape_453, 1, -1}, {{"batch_dims", 0}});
auto sin_tab = makePattern<opset8::Gather>({view_Reshape_460, 1, -1}, {{"batch_dims", 0}});
auto x_odd_sin = makePattern<opset1::Multiply>({x_odd, sin_tab}, {{"auto_broadcast", "numpy"}});
auto neg_x_odd_sin = makePattern<opset1::Multiply>({x_odd_sin, -1.000000f}, {{"auto_broadcast", "numpy"}});
auto sub_Subtract_469 = makePattern<opset1::Add>({x_even_cos, neg_x_odd_sin}, {{"auto_broadcast", "numpy"}});
auto y_even = makePattern<opset1::Unsqueeze>({sub_Subtract_469, -1});
auto x_odd_cos = makePattern<opset1::Multiply>({x_odd, cos_tab}, {{"auto_broadcast", "numpy"}});
auto x_even_sin = makePattern<opset1::Multiply>({x_even, sin_tab}, {{"auto_broadcast", "numpy"}});
auto add_Add_476 = makePattern<opset1::Add>({x_odd_cos, x_even_sin}, {{"auto_broadcast", "numpy"}});
auto y_odd = makePattern<opset1::Unsqueeze>({add_Add_476, -1});
auto stack_481 = makePattern<opset1::Concat>({y_even, y_odd}, {{"axis", -1}});
auto ShapeOf_135133 = makePattern<opset1::ShapeOf>({stack_481});
auto flatten_Slice_497 = makePattern<opset1::StridedSlice>({ShapeOf_135133, {0}, {3}, {1}},
{{"begin_mask", {0}},
{"end_mask", {0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto flatten_Concat_500 = makePattern<opset1::Concat>({flatten_Slice_497, {-1}}, {{"axis", 0}});
auto const_target_shape = makeConst({0, 0, head_cnt, ndims});
// [length, batch, head_cnt, half_rotary_dims, 2]
auto flatten_Reshape_501 =
makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape}, {{"special_zero", true}});
auto slice_Slice_443 =
makePattern<opset1::StridedSlice>({cur_key, {0, 0, 0, ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto cat_Concat_505 = makePattern<opset1::Concat>({flatten_Reshape_501, slice_Slice_443}, {{"axis", -1}});
auto result = cat_Concat_505;
matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto root = m.get_match_root();
PatternValidator validator(m);
if (!validator) {
return false;
}
RoPENode::Config config;
OutputVector new_args;
config.rotary_ndims = validator["ndims"];
config.is_chatglm = true;
config.head_cnt = validator["head_cnt"];
config.head_size = validator["head_size"];
if (split_output_id == 0) {
// query : split_output_id == 0
config.slice_start = 0;
config.slice_stop = validator["total_size_q"];
} else {
// key : split_output_id == 1
config.slice_start = validator["total_size_q"];
config.slice_stop = config.slice_start + validator["total_size_k"];
}
new_args.push_back(pattern_map.at(qkv_linear));
new_args.push_back(pattern_map.at(cos_sin_cache));
new_args.push_back(pattern_map.at(cos_sin_cache));
auto old_node = root;
auto new_node = std::make_shared<RoPENode>(new_args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::replace_node(old_node, new_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(result, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -20,7 +20,11 @@ public:
OPENVINO_RTTI("RoPEFusionGPTJ", "0");
RoPEFusionGPTJ();
};
class RoPEFusionChatGLM : public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("RoPEFusionChatGLM", "0");
RoPEFusionChatGLM(int split_output_id);
};
class RoPEFusionIOSlicing : public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("RoPEFusionIOSlicing", "0");
@ -56,6 +60,9 @@ public:
add_matcher<RoPEFusionCosSinPreprocess>();
add_matcher<RoPEFusionIOSlicing>();
add_matcher<RoPEFusionPreprocess>();
add_matcher<RoPEFusionChatGLM>(0);
add_matcher<RoPEFusionChatGLM>(1);
}
};

View File

@ -357,7 +357,7 @@ struct AttrAny {
return any.as<std::vector<T>>();
}
template<typename T>
template <typename T>
std::vector<T> as_T_vector() {
if (any.empty())
return {};
@ -574,6 +574,10 @@ public:
bool match_value(ov::pass::pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override {
// strictly requires pattern & graph value to come from output port with same index,
// this is absolute necessary when pattern contains split node connections.
if (pattern_value.get_index() != graph_value.get_index())
return false;
if (m_predicate(graph_value)) {
auto& pattern_map = matcher->get_pattern_value_map();
pattern_map[shared_from_this()] = graph_value;
@ -858,7 +862,7 @@ public:
} else if (auto a = ov::as_type<ov::AttributeAdapter<ov::CoordinateDiff>>(&adapter)) {
is_matched = m_attr_map[name].equal_to<int64_t, int>(a->get());
} else {
OPENVINO_THROW("AttrSetter met unsupported AttributeAdapter");
OPENVINO_THROW("AttrMatcher met unsupported AttributeAdapter ", name);
}
add_match_result(name, is_matched);
}
@ -919,6 +923,13 @@ std::shared_ptr<Node> makeConst(const ov::element::Type& type,
return std::make_shared<ov::op::v0::Constant>(type, shape, std::vector<T>(values));
}
inline std::shared_ptr<Node> makeConst(const std::vector<Symbol>& v) {
auto node = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto& rt_info = node->get_rt_info();
rt_info["symbolic_const_value"] = v;
return node;
}
template <typename T>
std::shared_ptr<Node> makeConst(const ov::element::Type& type, const ov::Shape& shape, const std::vector<T>& values) {
return std::make_shared<ov::op::v0::Constant>(type, shape, values);
@ -1198,7 +1209,7 @@ public:
auto byte_size =
shape_size(vconst_node->get_output_shape(0)) * vconst_node->get_output_element_type(0).size();
if (std::memcmp(pconst_node->get_data_ptr(), vconst_node->get_data_ptr(), byte_size) != 0) {
_VERBOSE_LOG("Constant value mismatch.");
_VERBOSE_LOG("Constant value mismatch on ", pconst_node, " vs ", vconst_node);
return false;
}
continue;

View File

@ -319,6 +319,10 @@ void dump_cpp_style(std::ostream& os, const std::shared_ptr<ov::Model>& model) {
const auto& type_info = op->get_type_info();
auto version_info = std::string(type_info.get_version());
auto type = version_info + "::" + type_info.name;
auto& rt_info = op->get_rt_info();
if (rt_info.count("opset") && rt_info["opset"] == "type_relaxed_opset") {
type = std::string("ov::op::TypeRelaxed<") + type + ">";
}
auto name = opname[op.get()];
os << prefix << " ";

View File

@ -131,7 +131,7 @@ static std::shared_ptr<ov::Model> buildROPE_Llama2(const int batch,
namespace CPULayerTestsDefinitions {
class RoPECPUTest : public SubgraphBaseTest {
class RoPECPUTestLlama2 : public SubgraphBaseTest {
public:
ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) {
auto tensor = ov::Tensor(ov::element::i32, shape);
@ -177,8 +177,155 @@ protected:
}
};
TEST_F(RoPECPUTest, smoke_CompareWithRefs) {
TEST_F(RoPECPUTestLlama2, smoke_CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
}
class RoPECPUTestChatGLM : public SubgraphBaseTest {
public:
ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) {
auto tensor = ov::Tensor(ov::element::i32, shape);
auto* ptr = static_cast<int32_t*>(tensor.data());
for (size_t i = 0; i < tensor.get_size(); i++) {
ptr[i] = start;
start += step;
}
return tensor;
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
const auto& funcInputs = function->inputs();
auto& input_shape = targetInputStaticShapes[0];
auto seq_length = input_shape[0];
// auto batch = input_shape[1];
ov::Tensor t_input =
utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768);
ov::Tensor t_cos_sin_cache =
utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {32768, 32, 2}, 2, -1.0f, 32768);
ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), 15);
inputs.clear();
inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input});
inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache});
inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids});
}
protected:
std::shared_ptr<ov::Model> buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims) {
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{-1, batch, 4096 + 256 + 256});
auto cos_sin_cache = std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{32768, 32, 2});
auto position_ids = std::make_shared<ov::opset1::Parameter>(ov::element::i32, PartialShape{-1, -1});
auto __module_transformer_index_67_Gather =
makeOP<opset8::Gather>({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}});
auto __module_transformer_transpose_Transpose =
makeOP<opset1::Transpose>({__module_transformer_index_67_Gather, {1, 0, 2, 3}});
auto size_ShapeOf_110 =
makeOP<opset3::ShapeOf>({__module_transformer_transpose_Transpose}, {{"output_type", "i32"}});
auto __getitem___Gather = makeOP<opset8::Gather>({size_ShapeOf_110, -2, 0}, {{"batch_dims", 0}});
auto mul_Multiply = makeOP<opset1::Multiply>({__getitem___Gather, 2}, {{"auto_broadcast", "numpy"}});
auto slice_Unsqueeze_112 = makeOP<opset1::Unsqueeze>({mul_Multiply, 0});
auto floordiv_Divide =
makeOP<opset1::Divide>({mul_Multiply, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}});
auto floordiv_Floor = makeOP<opset1::Floor>({floordiv_Divide});
auto ListConstruct_126_Reshape_2 = makeOP<opset1::Reshape>({floordiv_Floor, {-1}}, {{"special_zero", false}});
auto ListUnpack_321 = makeOP<opset1::VariadicSplit>({input, -1, {4096, 256, 256}});
auto view_Reshape =
makeOP<opset1::Reshape>({ListUnpack_321->output(0), {0, 0, 32, 128}}, {{"special_zero", true}});
auto ScatterUpdate_229053 = makeOP<opset3::ScatterUpdate>({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}});
auto slice_Slice_357 =
makeOP<opset1::StridedSlice>({view_Reshape, {0, 0, 0, 0}, ScatterUpdate_229053, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto size_ShapeOf_346 = makeOP<opset3::ShapeOf>({view_Reshape}, {{"output_type", "i32"}});
auto size_Gather_348 = makeOP<opset8::Gather>({size_ShapeOf_346, 0, 0}, {{"batch_dims", 0}});
auto ListConstruct_372_Reshape = makeOP<opset1::Reshape>({size_Gather_348, {-1}}, {{"special_zero", false}});
auto size_Gather_351 = makeOP<opset8::Gather>({size_ShapeOf_346, {2}, 0}, {{"batch_dims", 0}});
auto ListConstruct_372_Concat =
makeOP<opset1::Concat>({ListConstruct_372_Reshape, {-1}, size_Gather_351, ListConstruct_126_Reshape_2, {2}},
{{"axis", 0}});
auto reshape_Reshape_373 =
makeOP<opset1::Reshape>({slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}});
auto select_Gather_381 = makeOP<opset8::Gather>({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}});
auto slice_Unsqueeze_367 = makeOP<opset1::Unsqueeze>({size_Gather_348, 0});
auto slice_Slice_369 =
makeOP<opset1::StridedSlice>({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}},
{{"begin_mask", {0}},
{"end_mask", {0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto size_ShapeOf_374 = makeOP<opset3::ShapeOf>({reshape_Reshape_373}, {{"output_type", "i32"}});
auto size_Gather_376 = makeOP<opset8::Gather>({size_ShapeOf_374, {3}, 0}, {{"batch_dims", 0}});
auto ListConstruct_379_Concat =
makeOP<opset1::Concat>({ListConstruct_372_Reshape, {-1}, {1}, size_Gather_376, {2}}, {{"axis", 0}});
auto view_Reshape_380 =
makeOP<opset1::Reshape>({slice_Slice_369, ListConstruct_379_Concat}, {{"special_zero", false}});
auto select_Gather_382 = makeOP<opset8::Gather>({view_Reshape_380, 0, -1}, {{"batch_dims", 0}});
auto mul_Multiply_383 =
makeOP<opset1::Multiply>({select_Gather_381, select_Gather_382}, {{"auto_broadcast", "numpy"}});
auto select_Gather_384 = makeOP<opset8::Gather>({reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}});
auto select_Gather_385 = makeOP<opset8::Gather>({view_Reshape_380, 1, -1}, {{"batch_dims", 0}});
auto mul_Multiply_386 =
makeOP<opset1::Multiply>({select_Gather_384, select_Gather_385}, {{"auto_broadcast", "numpy"}});
auto sub_Subtract_389 =
makeOP<opset1::Subtract>({mul_Multiply_383, mul_Multiply_386}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_62716 = makeOP<opset1::Unsqueeze>({sub_Subtract_389, -1});
auto mul_Multiply_391 =
makeOP<opset1::Multiply>({select_Gather_384, select_Gather_382}, {{"auto_broadcast", "numpy"}});
auto mul_Multiply_393 =
makeOP<opset1::Multiply>({select_Gather_381, select_Gather_385}, {{"auto_broadcast", "numpy"}});
auto add_Add_396 = makeOP<opset1::Add>({mul_Multiply_391, mul_Multiply_393}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_62717 = makeOP<opset1::Unsqueeze>({add_Add_396, -1});
auto stack_401 = makeOP<opset1::Concat>({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}});
auto flatten_ShapeOf_402 = makeOP<opset3::ShapeOf>({stack_401}, {{"output_type", "i32"}});
auto flatten_Slice_417 = makeOP<opset1::StridedSlice>({flatten_ShapeOf_402, {0}, {3}, {1}},
{{"begin_mask", {0}},
{"end_mask", {0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto flatten_Concat_420 = makeOP<opset1::Concat>({flatten_Slice_417, {-1}}, {{"axis", 0}});
auto flatten_Reshape_421 = makeOP<opset1::Reshape>({stack_401, flatten_Concat_420}, {{"special_zero", true}});
auto ScatterUpdate_229067 = makeOP<opset3::ScatterUpdate>({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}});
auto slice_Slice_363 =
makeOP<opset1::StridedSlice>({view_Reshape, ScatterUpdate_229067, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto cat_Concat_425 = makeOP<opset1::Concat>({flatten_Reshape_421, slice_Slice_363}, {{"axis", -1}});
return std::make_shared<ov::Model>(ov::NodeVector{cat_Concat_425},
ov::ParameterVector{input, cos_sin_cache, position_ids});
}
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
const int batch = 2;
const int seq_length = 7;
const int num_head = 32;
const int rotary_dims = 64;
InputShape inpShape = {{-1, batch, 4096 + 256 + 256}, {{seq_length, batch, 4096 + 256 + 256}}};
init_input_shapes({inpShape});
function = buildROPE_ChatGLM(batch, num_head, rotary_dims);
}
};
TEST_F(RoPECPUTestChatGLM, smoke_CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
}
} // namespace CPULayerTestsDefinitions

View File

@ -139,6 +139,9 @@ TEST_F(TransformationTestsF, ConvertToROPE_LLama2_no_gather) {
{"config.slice_stop", 0},
{"config.input_trans0213", true},
{"config.is_interleaved", false},
{"config.is_chatglm", false},
{"config.head_cnt", 0},
{"config.head_size", 0},
{"config.rotary_ndims", static_cast<int>(ndims)},
{"config.gather_position_arg_id", 0}});
@ -170,6 +173,9 @@ TEST_F(TransformationTestsF, ConvertToROPE_LLama2_with_gather) {
{"config.slice_stop", 0},
{"config.input_trans0213", true},
{"config.is_interleaved", false},
{"config.is_chatglm", false},
{"config.head_cnt", 0},
{"config.head_size", 0},
{"config.rotary_ndims", static_cast<int>(ndims)},
{"config.gather_position_arg_id", 3}});
@ -304,6 +310,9 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_no_gather) {
{"config.slice_stop", ndims},
{"config.input_trans0213", true},
{"config.is_interleaved", false},
{"config.is_chatglm", false},
{"config.head_cnt", 0},
{"config.head_size", 0},
{"config.rotary_ndims", rotary_ndims},
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, param_cos, param_sin});
@ -334,6 +343,9 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_with_gather) {
{"config.slice_stop", ndims},
{"config.input_trans0213", true},
{"config.is_interleaved", false},
{"config.is_chatglm", false},
{"config.head_cnt", 0},
{"config.head_size", 0},
{"config.rotary_ndims", rotary_ndims},
{"config.gather_position_arg_id", 3}});
model_ref =
@ -445,8 +457,119 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ) {
{"config.slice_stop", 0},
{"config.input_trans0213", false},
{"config.is_interleaved", true},
{"config.is_chatglm", false},
{"config.head_cnt", 0},
{"config.head_size", 0},
{"config.rotary_ndims", rotary_ndims},
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin});
}
}
TEST_F(TransformationTestsF, ConvertToROPE_chatGML) {
disable_rt_info_check();
const int batch = 2;
const int seq_len = 7;
const int num_heads = 32;
const int ndims = 128;
const int rotary_ndims = 64;
const int max_pos_length = 2048;
{
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{seq_len, batch, 4608});
auto seq_length = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
auto cos_sin_cache =
std::make_shared<ov::opset1::Parameter>(ov::element::f32,
ov::Shape{max_pos_length, batch, rotary_ndims / 2, 2});
auto ListUnpack_321 = makeOP<ov::opset1::VariadicSplit>({input, -1, {4096, 256, 256}});
auto view_Reshape = makeOP<ov::opset1::Reshape>({ListUnpack_321->output(0), {0, 0, num_heads, ndims}},
{{"special_zero", true}});
auto aten_slice_Slice_357 =
makeOP<ov::opset1::StridedSlice>({view_Reshape, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto ListConstruct_372_Concat =
makeOP<ov::opset1::Concat>({seq_length, {-1}, {num_heads}, {rotary_ndims / 2}, {2}}, {{"axis", 0}});
auto aten_reshape_Reshape_373 =
makeOP<ov::opset1::Reshape>({aten_slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}});
auto aten_select_Gather_381 =
makeOP<ov::opset8::Gather>({aten_reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}});
auto aten_slice_Slice_369 = makeOP<ov::opset1::StridedSlice>({cos_sin_cache, {0}, seq_length, {1}},
{{"begin_mask", {0}},
{"end_mask", {0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto ListConstruct_379_Concat =
makeOP<ov::opset1::Concat>({seq_length, {-1}, {1}, {rotary_ndims / 2}, {2}}, {{"axis", 0}});
auto aten_view_Reshape_380 =
makeOP<ov::opset1::Reshape>({aten_slice_Slice_369, ListConstruct_379_Concat}, {{"special_zero", false}});
auto aten_select_Gather_382 = makeOP<ov::opset8::Gather>({aten_view_Reshape_380, 0, -1}, {{"batch_dims", 0}});
auto aten_mul_Multiply_383 = makeOP<ov::opset1::Multiply>({aten_select_Gather_381, aten_select_Gather_382},
{{"auto_broadcast", "numpy"}});
auto aten_select_Gather_384 =
makeOP<ov::opset8::Gather>({aten_reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}});
auto aten_select_Gather_385 = makeOP<ov::opset8::Gather>({aten_view_Reshape_380, 1, -1}, {{"batch_dims", 0}});
auto aten_mul_Multiply_386 = makeOP<ov::opset1::Multiply>({aten_select_Gather_384, aten_select_Gather_385},
{{"auto_broadcast", "numpy"}});
auto Multiply_101315 =
makeOP<ov::opset1::Multiply>({aten_mul_Multiply_386, -1.000000f}, {{"auto_broadcast", "numpy"}});
auto aten_sub_Subtract_389 =
makeOP<ov::opset1::Add>({aten_mul_Multiply_383, Multiply_101315}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_62716 = makeOP<ov::opset1::Unsqueeze>({aten_sub_Subtract_389, -1});
auto aten_mul_Multiply_391 = makeOP<ov::opset1::Multiply>({aten_select_Gather_384, aten_select_Gather_382},
{{"auto_broadcast", "numpy"}});
auto aten_mul_Multiply_393 = makeOP<ov::opset1::Multiply>({aten_select_Gather_381, aten_select_Gather_385},
{{"auto_broadcast", "numpy"}});
auto aten_add_Add_396 =
makeOP<ov::opset1::Add>({aten_mul_Multiply_391, aten_mul_Multiply_393}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_62717 = makeOP<ov::opset1::Unsqueeze>({aten_add_Add_396, -1});
auto aten_stack_401 = makeOP<ov::opset1::Concat>({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}});
auto ShapeOf_134820 = makeOP<ov::op::TypeRelaxed<ov::opset1::ShapeOf>>(
{aten_stack_401},
{{"type_relax", true}, {"input_data_types", {}}, {"output_data_types", {ov::element::i32}}});
auto aten_flatten_Slice_417 = makeOP<ov::opset1::StridedSlice>({ShapeOf_134820, {0}, {3}, {1}},
{{"begin_mask", {0}},
{"end_mask", {0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto aten_flatten_Concat_420 = makeOP<ov::opset1::Concat>({aten_flatten_Slice_417, {-1}}, {{"axis", 0}});
auto aten_flatten_Reshape_421 =
makeOP<ov::opset1::Reshape>({aten_stack_401, aten_flatten_Concat_420}, {{"special_zero", true}});
auto aten_slice_Slice_363 =
makeOP<ov::opset1::StridedSlice>({view_Reshape, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto aten_cat_Concat_425 =
makeOP<ov::opset1::Concat>({aten_flatten_Reshape_421, aten_slice_Slice_363}, {{"axis", -1}});
model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat_425},
ov::ParameterVector{input, seq_length, cos_sin_cache});
}
manager.register_pass<RoPEFusion>();
{
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{seq_len, batch, 4608});
auto seq_length = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
auto cos_sin_cache =
std::make_shared<ov::opset1::Parameter>(ov::element::f32,
ov::Shape{max_pos_length, batch, rotary_ndims / 2, 2});
auto rope = makeOP<RoPENode>({input, cos_sin_cache, cos_sin_cache},
{{"config.slice_start", 0},
{"config.slice_stop", 4096},
{"config.input_trans0213", false},
{"config.is_interleaved", false},
{"config.rotary_ndims", rotary_ndims},
{"config.is_chatglm", true},
{"config.head_cnt", num_heads},
{"config.head_size", ndims},
{"config.gather_position_arg_id", 0}});
model_ref =
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, seq_length, cos_sin_cache});
}
}