[CPU] Support chatglm RoPE (#21295)
This commit is contained in:
parent
15ff4e6596
commit
fcbb80d372
@ -54,6 +54,13 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
|
||||
gather.reset(inputs[config.gather_position_arg_id]);
|
||||
}
|
||||
|
||||
if (t_cos.m_rank == 2) {
|
||||
t_cos = t_cos.reshape({1, 1, t_cos.size(0), t_cos.size(1)});
|
||||
}
|
||||
if (t_sin.m_rank == 2) {
|
||||
t_sin = t_sin.reshape({1, 1, t_sin.size(0), t_sin.size(1)});
|
||||
}
|
||||
|
||||
auto batch_size = t_src.size(0);
|
||||
auto head_cnt = t_src.size(1);
|
||||
auto seq_len = t_src.size(2);
|
||||
@ -124,6 +131,48 @@ struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {
|
||||
void execute(dnnl::stream strm,
|
||||
const RoPENode::Config& config,
|
||||
const std::vector<MemoryPtr>& inputs,
|
||||
const std::vector<MemoryPtr>& outputs) override {
|
||||
ov::intel_cpu::PlainTensor<T> t_src(inputs[0]);
|
||||
ov::intel_cpu::PlainTensor<float> t_cos_sin(inputs[1]);
|
||||
ov::intel_cpu::PlainTensor<T> t_dst(outputs[0]);
|
||||
|
||||
// [seq_len, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)]
|
||||
if (config.slice_stop - config.slice_start > 0) {
|
||||
t_src = t_src.slice(2, config.slice_start, config.slice_stop);
|
||||
}
|
||||
auto seq_len = t_src.size(0);
|
||||
auto batch_size = t_src.size(1);
|
||||
|
||||
auto head_cnt = config.head_cnt;
|
||||
auto head_size = config.head_size;
|
||||
|
||||
auto rotary_dims = config.rotary_ndims;
|
||||
|
||||
parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) {
|
||||
auto* src = &t_src.at({p, b, h * head_size});
|
||||
// [length, batch_size, ndims//2, 2]
|
||||
auto* cos_sin = &t_cos_sin.at({p, b, 0, 0}, true);
|
||||
auto* dst = &t_dst.at({p, b, h, 0});
|
||||
|
||||
size_t i = 0;
|
||||
for (; i < rotary_dims; i += 2) {
|
||||
auto cosv = cos_sin[i];
|
||||
auto sinv = cos_sin[i + 1];
|
||||
dst[i] = cosv * src[i] - sinv * src[i + 1];
|
||||
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
|
||||
}
|
||||
for (; i < head_size; i++) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
void RoPE::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
@ -132,7 +181,14 @@ void RoPE::initSupportedPrimitiveDescriptors() {
|
||||
auto rtPrecision = srcPrecision;
|
||||
auto CosSinPrecision = ov::element::f32;
|
||||
|
||||
if (m_config.is_interleaved) {
|
||||
if (m_config.is_chatglm) {
|
||||
if (rtPrecision == ov::element::bf16) {
|
||||
m_executor = std::make_shared<RoPEExecutorChatGLM<ov::bfloat16>>();
|
||||
} else {
|
||||
m_executor = std::make_shared<RoPEExecutorChatGLM<float>>();
|
||||
rtPrecision = ov::element::f32;
|
||||
}
|
||||
} else if (m_config.is_interleaved) {
|
||||
OPENVINO_ASSERT(m_config.input_trans0213 == false);
|
||||
OPENVINO_ASSERT(m_config.slice_start == 0);
|
||||
OPENVINO_ASSERT(m_config.slice_stop == 0);
|
||||
|
@ -45,6 +45,8 @@ private:
|
||||
struct RoPEExecutorRotateHalf;
|
||||
template <typename T>
|
||||
struct RoPEExecutorInterleaved;
|
||||
template <typename T>
|
||||
struct RoPEExecutorChatGLM;
|
||||
RoPENode::Config m_config;
|
||||
std::shared_ptr<Executor> m_executor;
|
||||
};
|
||||
|
@ -22,6 +22,18 @@ void ov::intel_cpu::RoPENode::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(RoPENode_validate_and_infer_types);
|
||||
auto input_pshape = get_input_partial_shape(0);
|
||||
auto input_slice_size = m_config.slice_stop - m_config.slice_start;
|
||||
|
||||
if (m_config.is_chatglm) {
|
||||
// chatGLM specific RoPE
|
||||
// input [length, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)]
|
||||
// output [length, batch_size, head_cnt, hidden_states_k]
|
||||
set_output_type(
|
||||
0,
|
||||
get_input_element_type(0),
|
||||
{input_pshape[0], input_pshape[1], ov::Dimension(m_config.head_cnt), ov::Dimension(m_config.head_size)});
|
||||
return;
|
||||
}
|
||||
|
||||
if (input_slice_size > 0) {
|
||||
input_pshape[3] = input_slice_size;
|
||||
}
|
||||
@ -44,6 +56,9 @@ bool ov::intel_cpu::RoPENode::visit_attributes(ngraph::AttributeVisitor& visitor
|
||||
visitor.on_attribute("input_trans0213", m_config.input_trans0213);
|
||||
visitor.on_attribute("is_interleaved", m_config.is_interleaved);
|
||||
visitor.on_attribute("rotary_ndims", m_config.rotary_ndims);
|
||||
visitor.on_attribute("is_chatglm", m_config.is_chatglm);
|
||||
visitor.on_attribute("head_cnt", m_config.head_cnt);
|
||||
visitor.on_attribute("head_size", m_config.head_size);
|
||||
visitor.on_attribute("gather_position_arg_id", m_config.gather_position_arg_id);
|
||||
visitor.finish_structure();
|
||||
return true;
|
||||
|
@ -70,6 +70,9 @@ public:
|
||||
bool input_trans0213 = false; // transpose input dim 1&2
|
||||
bool is_interleaved = false; // interleaved mode, implies trans0213 happens after RoPE
|
||||
size_t rotary_ndims = 0; // dimensions to be embedded (d in the description)
|
||||
bool is_chatglm = false; // chatglm is special which overrides other setting
|
||||
size_t head_cnt = 0;
|
||||
size_t head_size = 0;
|
||||
int gather_position_arg_id =
|
||||
0; // arg id of position tensor, ==3 when gather from sin/cos inputs according to position is required
|
||||
};
|
||||
|
@ -132,7 +132,17 @@ ov::intel_cpu::RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() {
|
||||
{"ellipsis_mask", {}}});
|
||||
auto squeeze = makePattern<opset1::Reshape>({slice_Slice, {-1, head_dims}});
|
||||
auto index_Gather = makePattern<opset8::Gather>({squeeze, gather_positions_2d, 0}, {{"batch_dims", 0}});
|
||||
auto unsqueeze = makePattern<opset1::Reshape>({index_Gather, {1, 1, -1, head_dims}});
|
||||
|
||||
// another simplified pattern for gathering at position_ids
|
||||
auto slice_Slice2 = makePattern<opset1::StridedSlice>({const_tab, {0}, seq_len, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto index_Gather2 = makePattern<opset8::Gather>({slice_Slice2, gather_positions_2d, 0}, {{"batch_dims", 0}});
|
||||
|
||||
auto unsqueeze = makePattern<opset1::Reshape>({index_Gather | index_Gather2, {1, 1, -1, head_dims}});
|
||||
return unsqueeze;
|
||||
};
|
||||
|
||||
@ -439,6 +449,131 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(result, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::intel_cpu::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {
|
||||
MATCHER_SCOPE(RoPEFusionChatGLM);
|
||||
|
||||
auto qkv_linear = makePattern("f32[?,?,?]"); // f32[seq_length, batch_size, 4608]
|
||||
auto seq_length = makePattern("i32[1]");
|
||||
auto cos_sin_cache = makePattern("f32[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2]
|
||||
|
||||
auto ndims = Symbol("ndims");
|
||||
auto head_cnt = Symbol("head_cnt");
|
||||
auto head_size = Symbol("head_size");
|
||||
auto total_size_q = Symbol("total_size_q");
|
||||
auto total_size_k = Symbol("total_size_k");
|
||||
auto total_size_v = Symbol("total_size_v");
|
||||
|
||||
auto qkv_proj = makePattern<opset1::VariadicSplit>({qkv_linear, -1, {total_size_q, total_size_k, total_size_v}});
|
||||
qkv_proj->set_output_size(3);
|
||||
|
||||
// get key [L, B, Hkv, S]
|
||||
auto cur_key = makePattern<opset1::Reshape>({qkv_proj->output(split_output_id), {0, 0, head_cnt, head_size}},
|
||||
{{"special_zero", true}});
|
||||
|
||||
auto slice_Slice_437 = makePattern<opset1::StridedSlice>({cur_key, {0, 0, 0, 0}, {0, 0, 0, ndims}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
|
||||
// rotate half
|
||||
auto ListConstruct_452_Concat =
|
||||
makePattern<opset1::Concat>({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}});
|
||||
auto ListConstruct_379_Concat =
|
||||
makePattern<opset1::Concat>({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}});
|
||||
|
||||
auto reshape_Reshape_453 =
|
||||
makePattern<opset1::Reshape>({slice_Slice_437, ListConstruct_452_Concat}, {{"special_zero", false}});
|
||||
auto x_even = makePattern<opset8::Gather>({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}});
|
||||
auto slice_Slice_449 = makePattern<opset1::StridedSlice>({cos_sin_cache, {0}, seq_length, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto view_Reshape_460 =
|
||||
makePattern<opset1::Reshape>({slice_Slice_449, ListConstruct_379_Concat}, {{"special_zero", false}});
|
||||
auto cos_tab = makePattern<opset8::Gather>({view_Reshape_460, 0, -1}, {{"batch_dims", 0}});
|
||||
auto x_even_cos = makePattern<opset1::Multiply>({x_even, cos_tab}, {{"auto_broadcast", "numpy"}});
|
||||
auto x_odd = makePattern<opset8::Gather>({reshape_Reshape_453, 1, -1}, {{"batch_dims", 0}});
|
||||
auto sin_tab = makePattern<opset8::Gather>({view_Reshape_460, 1, -1}, {{"batch_dims", 0}});
|
||||
auto x_odd_sin = makePattern<opset1::Multiply>({x_odd, sin_tab}, {{"auto_broadcast", "numpy"}});
|
||||
auto neg_x_odd_sin = makePattern<opset1::Multiply>({x_odd_sin, -1.000000f}, {{"auto_broadcast", "numpy"}});
|
||||
auto sub_Subtract_469 = makePattern<opset1::Add>({x_even_cos, neg_x_odd_sin}, {{"auto_broadcast", "numpy"}});
|
||||
|
||||
auto y_even = makePattern<opset1::Unsqueeze>({sub_Subtract_469, -1});
|
||||
auto x_odd_cos = makePattern<opset1::Multiply>({x_odd, cos_tab}, {{"auto_broadcast", "numpy"}});
|
||||
auto x_even_sin = makePattern<opset1::Multiply>({x_even, sin_tab}, {{"auto_broadcast", "numpy"}});
|
||||
auto add_Add_476 = makePattern<opset1::Add>({x_odd_cos, x_even_sin}, {{"auto_broadcast", "numpy"}});
|
||||
auto y_odd = makePattern<opset1::Unsqueeze>({add_Add_476, -1});
|
||||
|
||||
auto stack_481 = makePattern<opset1::Concat>({y_even, y_odd}, {{"axis", -1}});
|
||||
|
||||
auto ShapeOf_135133 = makePattern<opset1::ShapeOf>({stack_481});
|
||||
auto flatten_Slice_497 = makePattern<opset1::StridedSlice>({ShapeOf_135133, {0}, {3}, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto flatten_Concat_500 = makePattern<opset1::Concat>({flatten_Slice_497, {-1}}, {{"axis", 0}});
|
||||
auto const_target_shape = makeConst({0, 0, head_cnt, ndims});
|
||||
// [length, batch, head_cnt, half_rotary_dims, 2]
|
||||
auto flatten_Reshape_501 =
|
||||
makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape}, {{"special_zero", true}});
|
||||
auto slice_Slice_443 =
|
||||
makePattern<opset1::StridedSlice>({cur_key, {0, 0, 0, ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto cat_Concat_505 = makePattern<opset1::Concat>({flatten_Reshape_501, slice_Slice_443}, {{"axis", -1}});
|
||||
|
||||
auto result = cat_Concat_505;
|
||||
|
||||
matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto root = m.get_match_root();
|
||||
PatternValidator validator(m);
|
||||
if (!validator) {
|
||||
return false;
|
||||
}
|
||||
|
||||
RoPENode::Config config;
|
||||
OutputVector new_args;
|
||||
config.rotary_ndims = validator["ndims"];
|
||||
config.is_chatglm = true;
|
||||
config.head_cnt = validator["head_cnt"];
|
||||
config.head_size = validator["head_size"];
|
||||
|
||||
if (split_output_id == 0) {
|
||||
// query : split_output_id == 0
|
||||
config.slice_start = 0;
|
||||
config.slice_stop = validator["total_size_q"];
|
||||
} else {
|
||||
// key : split_output_id == 1
|
||||
config.slice_start = validator["total_size_q"];
|
||||
config.slice_stop = config.slice_start + validator["total_size_k"];
|
||||
}
|
||||
|
||||
new_args.push_back(pattern_map.at(qkv_linear));
|
||||
new_args.push_back(pattern_map.at(cos_sin_cache));
|
||||
new_args.push_back(pattern_map.at(cos_sin_cache));
|
||||
|
||||
auto old_node = root;
|
||||
|
||||
auto new_node = std::make_shared<RoPENode>(new_args, config);
|
||||
new_node->set_friendly_name(old_node->get_friendly_name());
|
||||
ov::replace_node(old_node, new_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(result, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -20,7 +20,11 @@ public:
|
||||
OPENVINO_RTTI("RoPEFusionGPTJ", "0");
|
||||
RoPEFusionGPTJ();
|
||||
};
|
||||
|
||||
class RoPEFusionChatGLM : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("RoPEFusionChatGLM", "0");
|
||||
RoPEFusionChatGLM(int split_output_id);
|
||||
};
|
||||
class RoPEFusionIOSlicing : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("RoPEFusionIOSlicing", "0");
|
||||
@ -56,6 +60,9 @@ public:
|
||||
add_matcher<RoPEFusionCosSinPreprocess>();
|
||||
add_matcher<RoPEFusionIOSlicing>();
|
||||
add_matcher<RoPEFusionPreprocess>();
|
||||
|
||||
add_matcher<RoPEFusionChatGLM>(0);
|
||||
add_matcher<RoPEFusionChatGLM>(1);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -357,7 +357,7 @@ struct AttrAny {
|
||||
return any.as<std::vector<T>>();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
std::vector<T> as_T_vector() {
|
||||
if (any.empty())
|
||||
return {};
|
||||
@ -574,6 +574,10 @@ public:
|
||||
bool match_value(ov::pass::pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override {
|
||||
// strictly requires pattern & graph value to come from output port with same index,
|
||||
// this is absolute necessary when pattern contains split node connections.
|
||||
if (pattern_value.get_index() != graph_value.get_index())
|
||||
return false;
|
||||
if (m_predicate(graph_value)) {
|
||||
auto& pattern_map = matcher->get_pattern_value_map();
|
||||
pattern_map[shared_from_this()] = graph_value;
|
||||
@ -858,7 +862,7 @@ public:
|
||||
} else if (auto a = ov::as_type<ov::AttributeAdapter<ov::CoordinateDiff>>(&adapter)) {
|
||||
is_matched = m_attr_map[name].equal_to<int64_t, int>(a->get());
|
||||
} else {
|
||||
OPENVINO_THROW("AttrSetter met unsupported AttributeAdapter");
|
||||
OPENVINO_THROW("AttrMatcher met unsupported AttributeAdapter ", name);
|
||||
}
|
||||
add_match_result(name, is_matched);
|
||||
}
|
||||
@ -919,6 +923,13 @@ std::shared_ptr<Node> makeConst(const ov::element::Type& type,
|
||||
return std::make_shared<ov::op::v0::Constant>(type, shape, std::vector<T>(values));
|
||||
}
|
||||
|
||||
inline std::shared_ptr<Node> makeConst(const std::vector<Symbol>& v) {
|
||||
auto node = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
|
||||
auto& rt_info = node->get_rt_info();
|
||||
rt_info["symbolic_const_value"] = v;
|
||||
return node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<Node> makeConst(const ov::element::Type& type, const ov::Shape& shape, const std::vector<T>& values) {
|
||||
return std::make_shared<ov::op::v0::Constant>(type, shape, values);
|
||||
@ -1198,7 +1209,7 @@ public:
|
||||
auto byte_size =
|
||||
shape_size(vconst_node->get_output_shape(0)) * vconst_node->get_output_element_type(0).size();
|
||||
if (std::memcmp(pconst_node->get_data_ptr(), vconst_node->get_data_ptr(), byte_size) != 0) {
|
||||
_VERBOSE_LOG("Constant value mismatch.");
|
||||
_VERBOSE_LOG("Constant value mismatch on ", pconst_node, " vs ", vconst_node);
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
|
@ -319,6 +319,10 @@ void dump_cpp_style(std::ostream& os, const std::shared_ptr<ov::Model>& model) {
|
||||
const auto& type_info = op->get_type_info();
|
||||
auto version_info = std::string(type_info.get_version());
|
||||
auto type = version_info + "::" + type_info.name;
|
||||
auto& rt_info = op->get_rt_info();
|
||||
if (rt_info.count("opset") && rt_info["opset"] == "type_relaxed_opset") {
|
||||
type = std::string("ov::op::TypeRelaxed<") + type + ">";
|
||||
}
|
||||
auto name = opname[op.get()];
|
||||
os << prefix << " ";
|
||||
|
||||
|
@ -131,7 +131,7 @@ static std::shared_ptr<ov::Model> buildROPE_Llama2(const int batch,
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
class RoPECPUTest : public SubgraphBaseTest {
|
||||
class RoPECPUTestLlama2 : public SubgraphBaseTest {
|
||||
public:
|
||||
ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) {
|
||||
auto tensor = ov::Tensor(ov::element::i32, shape);
|
||||
@ -177,8 +177,155 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RoPECPUTest, smoke_CompareWithRefs) {
|
||||
TEST_F(RoPECPUTestLlama2, smoke_CompareWithRefs) {
|
||||
run();
|
||||
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
|
||||
}
|
||||
|
||||
class RoPECPUTestChatGLM : public SubgraphBaseTest {
|
||||
public:
|
||||
ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) {
|
||||
auto tensor = ov::Tensor(ov::element::i32, shape);
|
||||
auto* ptr = static_cast<int32_t*>(tensor.data());
|
||||
for (size_t i = 0; i < tensor.get_size(); i++) {
|
||||
ptr[i] = start;
|
||||
start += step;
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||
const auto& funcInputs = function->inputs();
|
||||
|
||||
auto& input_shape = targetInputStaticShapes[0];
|
||||
auto seq_length = input_shape[0];
|
||||
// auto batch = input_shape[1];
|
||||
|
||||
ov::Tensor t_input =
|
||||
utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768);
|
||||
ov::Tensor t_cos_sin_cache =
|
||||
utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {32768, 32, 2}, 2, -1.0f, 32768);
|
||||
ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), 15);
|
||||
|
||||
inputs.clear();
|
||||
inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input});
|
||||
inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache});
|
||||
inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids});
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims) {
|
||||
auto input =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{-1, batch, 4096 + 256 + 256});
|
||||
auto cos_sin_cache = std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{32768, 32, 2});
|
||||
auto position_ids = std::make_shared<ov::opset1::Parameter>(ov::element::i32, PartialShape{-1, -1});
|
||||
|
||||
auto __module_transformer_index_67_Gather =
|
||||
makeOP<opset8::Gather>({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}});
|
||||
auto __module_transformer_transpose_Transpose =
|
||||
makeOP<opset1::Transpose>({__module_transformer_index_67_Gather, {1, 0, 2, 3}});
|
||||
auto size_ShapeOf_110 =
|
||||
makeOP<opset3::ShapeOf>({__module_transformer_transpose_Transpose}, {{"output_type", "i32"}});
|
||||
auto __getitem___Gather = makeOP<opset8::Gather>({size_ShapeOf_110, -2, 0}, {{"batch_dims", 0}});
|
||||
auto mul_Multiply = makeOP<opset1::Multiply>({__getitem___Gather, 2}, {{"auto_broadcast", "numpy"}});
|
||||
auto slice_Unsqueeze_112 = makeOP<opset1::Unsqueeze>({mul_Multiply, 0});
|
||||
|
||||
auto floordiv_Divide =
|
||||
makeOP<opset1::Divide>({mul_Multiply, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}});
|
||||
auto floordiv_Floor = makeOP<opset1::Floor>({floordiv_Divide});
|
||||
auto ListConstruct_126_Reshape_2 = makeOP<opset1::Reshape>({floordiv_Floor, {-1}}, {{"special_zero", false}});
|
||||
|
||||
auto ListUnpack_321 = makeOP<opset1::VariadicSplit>({input, -1, {4096, 256, 256}});
|
||||
auto view_Reshape =
|
||||
makeOP<opset1::Reshape>({ListUnpack_321->output(0), {0, 0, 32, 128}}, {{"special_zero", true}});
|
||||
|
||||
auto ScatterUpdate_229053 = makeOP<opset3::ScatterUpdate>({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}});
|
||||
auto slice_Slice_357 =
|
||||
makeOP<opset1::StridedSlice>({view_Reshape, {0, 0, 0, 0}, ScatterUpdate_229053, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto size_ShapeOf_346 = makeOP<opset3::ShapeOf>({view_Reshape}, {{"output_type", "i32"}});
|
||||
auto size_Gather_348 = makeOP<opset8::Gather>({size_ShapeOf_346, 0, 0}, {{"batch_dims", 0}});
|
||||
auto ListConstruct_372_Reshape = makeOP<opset1::Reshape>({size_Gather_348, {-1}}, {{"special_zero", false}});
|
||||
auto size_Gather_351 = makeOP<opset8::Gather>({size_ShapeOf_346, {2}, 0}, {{"batch_dims", 0}});
|
||||
auto ListConstruct_372_Concat =
|
||||
makeOP<opset1::Concat>({ListConstruct_372_Reshape, {-1}, size_Gather_351, ListConstruct_126_Reshape_2, {2}},
|
||||
{{"axis", 0}});
|
||||
auto reshape_Reshape_373 =
|
||||
makeOP<opset1::Reshape>({slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}});
|
||||
auto select_Gather_381 = makeOP<opset8::Gather>({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}});
|
||||
auto slice_Unsqueeze_367 = makeOP<opset1::Unsqueeze>({size_Gather_348, 0});
|
||||
auto slice_Slice_369 =
|
||||
makeOP<opset1::StridedSlice>({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto size_ShapeOf_374 = makeOP<opset3::ShapeOf>({reshape_Reshape_373}, {{"output_type", "i32"}});
|
||||
auto size_Gather_376 = makeOP<opset8::Gather>({size_ShapeOf_374, {3}, 0}, {{"batch_dims", 0}});
|
||||
auto ListConstruct_379_Concat =
|
||||
makeOP<opset1::Concat>({ListConstruct_372_Reshape, {-1}, {1}, size_Gather_376, {2}}, {{"axis", 0}});
|
||||
auto view_Reshape_380 =
|
||||
makeOP<opset1::Reshape>({slice_Slice_369, ListConstruct_379_Concat}, {{"special_zero", false}});
|
||||
auto select_Gather_382 = makeOP<opset8::Gather>({view_Reshape_380, 0, -1}, {{"batch_dims", 0}});
|
||||
auto mul_Multiply_383 =
|
||||
makeOP<opset1::Multiply>({select_Gather_381, select_Gather_382}, {{"auto_broadcast", "numpy"}});
|
||||
auto select_Gather_384 = makeOP<opset8::Gather>({reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}});
|
||||
auto select_Gather_385 = makeOP<opset8::Gather>({view_Reshape_380, 1, -1}, {{"batch_dims", 0}});
|
||||
auto mul_Multiply_386 =
|
||||
makeOP<opset1::Multiply>({select_Gather_384, select_Gather_385}, {{"auto_broadcast", "numpy"}});
|
||||
auto sub_Subtract_389 =
|
||||
makeOP<opset1::Subtract>({mul_Multiply_383, mul_Multiply_386}, {{"auto_broadcast", "numpy"}});
|
||||
auto Unsqueeze_62716 = makeOP<opset1::Unsqueeze>({sub_Subtract_389, -1});
|
||||
auto mul_Multiply_391 =
|
||||
makeOP<opset1::Multiply>({select_Gather_384, select_Gather_382}, {{"auto_broadcast", "numpy"}});
|
||||
auto mul_Multiply_393 =
|
||||
makeOP<opset1::Multiply>({select_Gather_381, select_Gather_385}, {{"auto_broadcast", "numpy"}});
|
||||
auto add_Add_396 = makeOP<opset1::Add>({mul_Multiply_391, mul_Multiply_393}, {{"auto_broadcast", "numpy"}});
|
||||
auto Unsqueeze_62717 = makeOP<opset1::Unsqueeze>({add_Add_396, -1});
|
||||
auto stack_401 = makeOP<opset1::Concat>({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}});
|
||||
auto flatten_ShapeOf_402 = makeOP<opset3::ShapeOf>({stack_401}, {{"output_type", "i32"}});
|
||||
auto flatten_Slice_417 = makeOP<opset1::StridedSlice>({flatten_ShapeOf_402, {0}, {3}, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto flatten_Concat_420 = makeOP<opset1::Concat>({flatten_Slice_417, {-1}}, {{"axis", 0}});
|
||||
auto flatten_Reshape_421 = makeOP<opset1::Reshape>({stack_401, flatten_Concat_420}, {{"special_zero", true}});
|
||||
auto ScatterUpdate_229067 = makeOP<opset3::ScatterUpdate>({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}});
|
||||
auto slice_Slice_363 =
|
||||
makeOP<opset1::StridedSlice>({view_Reshape, ScatterUpdate_229067, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto cat_Concat_425 = makeOP<opset1::Concat>({flatten_Reshape_421, slice_Slice_363}, {{"axis", -1}});
|
||||
return std::make_shared<ov::Model>(ov::NodeVector{cat_Concat_425},
|
||||
ov::ParameterVector{input, cos_sin_cache, position_ids});
|
||||
}
|
||||
void SetUp() override {
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
|
||||
const int batch = 2;
|
||||
const int seq_length = 7;
|
||||
const int num_head = 32;
|
||||
const int rotary_dims = 64;
|
||||
|
||||
InputShape inpShape = {{-1, batch, 4096 + 256 + 256}, {{seq_length, batch, 4096 + 256 + 256}}};
|
||||
init_input_shapes({inpShape});
|
||||
function = buildROPE_ChatGLM(batch, num_head, rotary_dims);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RoPECPUTestChatGLM, smoke_CompareWithRefs) {
|
||||
run();
|
||||
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
|
||||
}
|
||||
|
||||
} // namespace CPULayerTestsDefinitions
|
||||
|
@ -139,6 +139,9 @@ TEST_F(TransformationTestsF, ConvertToROPE_LLama2_no_gather) {
|
||||
{"config.slice_stop", 0},
|
||||
{"config.input_trans0213", true},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.is_chatglm", false},
|
||||
{"config.head_cnt", 0},
|
||||
{"config.head_size", 0},
|
||||
{"config.rotary_ndims", static_cast<int>(ndims)},
|
||||
{"config.gather_position_arg_id", 0}});
|
||||
|
||||
@ -170,6 +173,9 @@ TEST_F(TransformationTestsF, ConvertToROPE_LLama2_with_gather) {
|
||||
{"config.slice_stop", 0},
|
||||
{"config.input_trans0213", true},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.is_chatglm", false},
|
||||
{"config.head_cnt", 0},
|
||||
{"config.head_size", 0},
|
||||
{"config.rotary_ndims", static_cast<int>(ndims)},
|
||||
{"config.gather_position_arg_id", 3}});
|
||||
|
||||
@ -304,6 +310,9 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_no_gather) {
|
||||
{"config.slice_stop", ndims},
|
||||
{"config.input_trans0213", true},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.is_chatglm", false},
|
||||
{"config.head_cnt", 0},
|
||||
{"config.head_size", 0},
|
||||
{"config.rotary_ndims", rotary_ndims},
|
||||
{"config.gather_position_arg_id", 0}});
|
||||
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, param_cos, param_sin});
|
||||
@ -334,6 +343,9 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_with_gather) {
|
||||
{"config.slice_stop", ndims},
|
||||
{"config.input_trans0213", true},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.is_chatglm", false},
|
||||
{"config.head_cnt", 0},
|
||||
{"config.head_size", 0},
|
||||
{"config.rotary_ndims", rotary_ndims},
|
||||
{"config.gather_position_arg_id", 3}});
|
||||
model_ref =
|
||||
@ -445,8 +457,119 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ) {
|
||||
{"config.slice_stop", 0},
|
||||
{"config.input_trans0213", false},
|
||||
{"config.is_interleaved", true},
|
||||
{"config.is_chatglm", false},
|
||||
{"config.head_cnt", 0},
|
||||
{"config.head_size", 0},
|
||||
{"config.rotary_ndims", rotary_ndims},
|
||||
{"config.gather_position_arg_id", 0}});
|
||||
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertToROPE_chatGML) {
|
||||
disable_rt_info_check();
|
||||
const int batch = 2;
|
||||
const int seq_len = 7;
|
||||
const int num_heads = 32;
|
||||
const int ndims = 128;
|
||||
const int rotary_ndims = 64;
|
||||
const int max_pos_length = 2048;
|
||||
{
|
||||
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{seq_len, batch, 4608});
|
||||
auto seq_length = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
|
||||
auto cos_sin_cache =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32,
|
||||
ov::Shape{max_pos_length, batch, rotary_ndims / 2, 2});
|
||||
|
||||
auto ListUnpack_321 = makeOP<ov::opset1::VariadicSplit>({input, -1, {4096, 256, 256}});
|
||||
auto view_Reshape = makeOP<ov::opset1::Reshape>({ListUnpack_321->output(0), {0, 0, num_heads, ndims}},
|
||||
{{"special_zero", true}});
|
||||
auto aten_slice_Slice_357 =
|
||||
makeOP<ov::opset1::StridedSlice>({view_Reshape, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto ListConstruct_372_Concat =
|
||||
makeOP<ov::opset1::Concat>({seq_length, {-1}, {num_heads}, {rotary_ndims / 2}, {2}}, {{"axis", 0}});
|
||||
auto aten_reshape_Reshape_373 =
|
||||
makeOP<ov::opset1::Reshape>({aten_slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}});
|
||||
auto aten_select_Gather_381 =
|
||||
makeOP<ov::opset8::Gather>({aten_reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}});
|
||||
auto aten_slice_Slice_369 = makeOP<ov::opset1::StridedSlice>({cos_sin_cache, {0}, seq_length, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto ListConstruct_379_Concat =
|
||||
makeOP<ov::opset1::Concat>({seq_length, {-1}, {1}, {rotary_ndims / 2}, {2}}, {{"axis", 0}});
|
||||
auto aten_view_Reshape_380 =
|
||||
makeOP<ov::opset1::Reshape>({aten_slice_Slice_369, ListConstruct_379_Concat}, {{"special_zero", false}});
|
||||
auto aten_select_Gather_382 = makeOP<ov::opset8::Gather>({aten_view_Reshape_380, 0, -1}, {{"batch_dims", 0}});
|
||||
auto aten_mul_Multiply_383 = makeOP<ov::opset1::Multiply>({aten_select_Gather_381, aten_select_Gather_382},
|
||||
{{"auto_broadcast", "numpy"}});
|
||||
auto aten_select_Gather_384 =
|
||||
makeOP<ov::opset8::Gather>({aten_reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}});
|
||||
auto aten_select_Gather_385 = makeOP<ov::opset8::Gather>({aten_view_Reshape_380, 1, -1}, {{"batch_dims", 0}});
|
||||
auto aten_mul_Multiply_386 = makeOP<ov::opset1::Multiply>({aten_select_Gather_384, aten_select_Gather_385},
|
||||
{{"auto_broadcast", "numpy"}});
|
||||
auto Multiply_101315 =
|
||||
makeOP<ov::opset1::Multiply>({aten_mul_Multiply_386, -1.000000f}, {{"auto_broadcast", "numpy"}});
|
||||
auto aten_sub_Subtract_389 =
|
||||
makeOP<ov::opset1::Add>({aten_mul_Multiply_383, Multiply_101315}, {{"auto_broadcast", "numpy"}});
|
||||
auto Unsqueeze_62716 = makeOP<ov::opset1::Unsqueeze>({aten_sub_Subtract_389, -1});
|
||||
auto aten_mul_Multiply_391 = makeOP<ov::opset1::Multiply>({aten_select_Gather_384, aten_select_Gather_382},
|
||||
{{"auto_broadcast", "numpy"}});
|
||||
auto aten_mul_Multiply_393 = makeOP<ov::opset1::Multiply>({aten_select_Gather_381, aten_select_Gather_385},
|
||||
{{"auto_broadcast", "numpy"}});
|
||||
auto aten_add_Add_396 =
|
||||
makeOP<ov::opset1::Add>({aten_mul_Multiply_391, aten_mul_Multiply_393}, {{"auto_broadcast", "numpy"}});
|
||||
auto Unsqueeze_62717 = makeOP<ov::opset1::Unsqueeze>({aten_add_Add_396, -1});
|
||||
auto aten_stack_401 = makeOP<ov::opset1::Concat>({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}});
|
||||
auto ShapeOf_134820 = makeOP<ov::op::TypeRelaxed<ov::opset1::ShapeOf>>(
|
||||
{aten_stack_401},
|
||||
{{"type_relax", true}, {"input_data_types", {}}, {"output_data_types", {ov::element::i32}}});
|
||||
auto aten_flatten_Slice_417 = makeOP<ov::opset1::StridedSlice>({ShapeOf_134820, {0}, {3}, {1}},
|
||||
{{"begin_mask", {0}},
|
||||
{"end_mask", {0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto aten_flatten_Concat_420 = makeOP<ov::opset1::Concat>({aten_flatten_Slice_417, {-1}}, {{"axis", 0}});
|
||||
auto aten_flatten_Reshape_421 =
|
||||
makeOP<ov::opset1::Reshape>({aten_stack_401, aten_flatten_Concat_420}, {{"special_zero", true}});
|
||||
auto aten_slice_Slice_363 =
|
||||
makeOP<ov::opset1::StridedSlice>({view_Reshape, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
|
||||
{{"begin_mask", {1, 1, 1, 0}},
|
||||
{"end_mask", {1, 1, 1, 0}},
|
||||
{"new_axis_mask", {}},
|
||||
{"shrink_axis_mask", {}},
|
||||
{"ellipsis_mask", {}}});
|
||||
auto aten_cat_Concat_425 =
|
||||
makeOP<ov::opset1::Concat>({aten_flatten_Reshape_421, aten_slice_Slice_363}, {{"axis", -1}});
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat_425},
|
||||
ov::ParameterVector{input, seq_length, cos_sin_cache});
|
||||
}
|
||||
manager.register_pass<RoPEFusion>();
|
||||
{
|
||||
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{seq_len, batch, 4608});
|
||||
auto seq_length = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::Shape{1});
|
||||
auto cos_sin_cache =
|
||||
std::make_shared<ov::opset1::Parameter>(ov::element::f32,
|
||||
ov::Shape{max_pos_length, batch, rotary_ndims / 2, 2});
|
||||
auto rope = makeOP<RoPENode>({input, cos_sin_cache, cos_sin_cache},
|
||||
{{"config.slice_start", 0},
|
||||
{"config.slice_stop", 4096},
|
||||
{"config.input_trans0213", false},
|
||||
{"config.is_interleaved", false},
|
||||
{"config.rotary_ndims", rotary_ndims},
|
||||
{"config.is_chatglm", true},
|
||||
{"config.head_cnt", num_heads},
|
||||
{"config.head_size", ndims},
|
||||
{"config.gather_position_arg_id", 0}});
|
||||
model_ref =
|
||||
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, seq_length, cos_sin_cache});
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user