diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index a4e8b140415..aa75c202c11 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -214,7 +214,6 @@ static const TypeToNameMap& get_type_to_name_tbl() { { "Unique", Type::Unique}, { "Ngram", Type::Ngram}, { "ScaledDotProductAttention", Type::ScaledDotProductAttention}, - { "RoPE", Type::RoPE}, }; return type_to_name_tbl; } @@ -329,7 +328,6 @@ std::string NameFromType(const Type type) { CASE(Unique); CASE(Ngram); CASE(ScaledDotProductAttention); - CASE(RoPE); CASE(Unknown); } #undef CASE diff --git a/src/plugins/intel_cpu/src/cpu_types.h b/src/plugins/intel_cpu/src/cpu_types.h index cf214542b1b..f7f40d2c1fc 100644 --- a/src/plugins/intel_cpu/src/cpu_types.h +++ b/src/plugins/intel_cpu/src/cpu_types.h @@ -114,7 +114,6 @@ enum class Type { Unique, Ngram, ScaledDotProductAttention, - RoPE, }; enum class Algorithm { diff --git a/src/plugins/intel_cpu/src/nodes/rope.cpp b/src/plugins/intel_cpu/src/nodes/rope.cpp deleted file mode 100644 index 5ec1aaa2183..00000000000 --- a/src/plugins/intel_cpu/src/nodes/rope.cpp +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "rope.h" - -#include -#include -#include -#include -#include -#include - -#include "common/bfloat16.hpp" -#include "common/cpu_memcpy.h" -#include "utils/plain_tensor.hpp" - -using namespace InferenceEngine; - -namespace ov { -namespace intel_cpu { -namespace node { - -RoPE::RoPE(const std::shared_ptr& op, const GraphContext::CPtr context) - : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { - std::string errorMessage; - if (!isSupportedOperation(op, errorMessage)) { - OPENVINO_THROW("CPU: " + errorMessage); - } - - const auto node = std::dynamic_pointer_cast(op); - m_config = node->get_config(); -} - -template -struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor { - void execute(dnnl::stream strm, - const RoPENode::Config& config, - const std::vector& inputs, - const std::vector& outputs) override { - ov::intel_cpu::PlainTensor t_src(inputs[0]); - ov::intel_cpu::PlainTensor t_cos(inputs[1]); - ov::intel_cpu::PlainTensor t_sin(inputs[2]); - ov::intel_cpu::PlainTensor t_dst(outputs[0]); - ov::intel_cpu::PlainTensor gather; - - if (config.slice_stop - config.slice_start > 0) { - t_src = t_src.slice(3, config.slice_start, config.slice_stop); - } - if (config.input_trans0213) { - t_src = t_src.permute({0, 2, 1, 3}); - } - if (config.gather_position_arg_id > 0) { - gather.reset(inputs[config.gather_position_arg_id]); - } - - auto batch_size = t_src.size(0); - auto head_cnt = t_src.size(1); - auto seq_len = t_src.size(2); - auto feature_size = t_src.size(3); - - auto rotary_dims = config.rotary_ndims; - auto half_rotary_dims = rotary_dims / 2; - - parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) { - auto cos_pos = p; - if (gather) { - if (gather.m_rank == 4) - cos_pos = gather.at({b, h, p, 0}, true); - else - cos_pos = gather.at({b, p}, true); - } - auto* src = &t_src.at({b, h, p, 0}); - auto* cos = &t_cos.at({b, h, cos_pos, 0}, true); - auto* sin = &t_sin.at({b, h, cos_pos, 0}, true); - auto* dst = &t_dst.at({b, h, p, 0}); - - size_t i = 0; - for (; i < half_rotary_dims; i++) { - dst[i] = cos[i] * src[i] + sin[i] * (-src[i + half_rotary_dims]); - } - for (; i < rotary_dims; i++) { - dst[i] = cos[i] * src[i] + sin[i] * (src[i - half_rotary_dims]); - } - for (; i < feature_size; i++) { - dst[i] = src[i]; - } - }); - } -}; - -template -struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor { - void execute(dnnl::stream strm, - const RoPENode::Config& config, - const std::vector& inputs, - const std::vector& outputs) override { - ov::intel_cpu::PlainTensor t_src(inputs[0]); - ov::intel_cpu::PlainTensor t_sin_cos(inputs[1]); - ov::intel_cpu::PlainTensor t_dst(outputs[0]); - - auto batch_size = t_src.size(0); - auto seq_len = t_src.size(1); - auto head_cnt = t_src.size(2); - auto head_dims = t_src.size(3); - - auto rotary_dims = config.rotary_ndims; - auto half_rotary_dims = rotary_dims / 2; - parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) { - auto* x = &t_src.at({b, p, h, 0}); - float* sin = &t_sin_cos.at({b, p, 0}, true); - float* cos = &t_sin_cos.at({b, p, half_rotary_dims}, true); - auto* dst = &t_dst.at({b, h, p, 0}); - - size_t i = 0; - for (size_t j = 0; i < rotary_dims; i += 2, j++) { - dst[i] = cos[j] * x[i] - sin[j] * x[i + 1]; - dst[i + 1] = cos[j] * x[i + 1] + sin[j] * x[i]; - } - for (; i < head_dims; i++) { - dst[i] = x[i]; - } - }); - } -}; - -void RoPE::initSupportedPrimitiveDescriptors() { - if (!supportedPrimitiveDescriptors.empty()) - return; - auto srcPrecision = getOriginalInputPrecisionAtPort(0); - - auto rtPrecision = srcPrecision; - auto CosSinPrecision = ov::element::f32; - - if (m_config.is_interleaved) { - OPENVINO_ASSERT(m_config.input_trans0213 == false); - OPENVINO_ASSERT(m_config.slice_start == 0); - OPENVINO_ASSERT(m_config.slice_stop == 0); - OPENVINO_ASSERT(m_config.gather_position_arg_id == 0); - if (rtPrecision == ov::element::bf16) { - m_executor = std::make_shared>(); - } else { - m_executor = std::make_shared>(); - rtPrecision = ov::element::f32; - } - } else { - if (rtPrecision == ov::element::bf16) { - m_executor = std::make_shared>(); - } else { - m_executor = std::make_shared>(); - rtPrecision = ov::element::f32; - } - } - - // initialize input ports - std::vector inPortConfigs; - inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(0), false, -1); - inPortConfigs.emplace_back(LayoutType::ncsp, CosSinPrecision, getInputShapeAtPort(1), false, -1); - inPortConfigs.emplace_back(LayoutType::ncsp, CosSinPrecision, getInputShapeAtPort(2), false, -1); - if (m_config.gather_position_arg_id > 0) { - inPortConfigs.emplace_back(LayoutType::ncsp, - ov::element::i32, - getInputShapeAtPort(m_config.gather_position_arg_id), - false, - -1); - } - - // initialize output port - std::vector outPortConfigs; - outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1); - - addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any); -} - -void RoPE::execute(dnnl::stream strm) { - std::vector inputs(getParentEdges().size()), outputs(getChildEdges().size()); - for (size_t i = 0; i < inputs.size(); i++) { - inputs[i] = getParentEdgeAt(i)->getMemoryPtr(); - } - for (size_t i = 0; i < outputs.size(); i++) { - outputs[i] = getChildEdgeAt(i)->getMemoryPtr(); - } - m_executor->execute(strm, m_config, inputs, outputs); -} - -bool RoPE::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { - try { - const auto node = std::dynamic_pointer_cast(op); - if (!node) { - errorMessage = "Only RoPENode operation is supported"; - return false; - } - } catch (...) { - return false; - } - return true; -} - -} // namespace node -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/rope.h b/src/plugins/intel_cpu/src/nodes/rope.h deleted file mode 100644 index c1b2bbda3b3..00000000000 --- a/src/plugins/intel_cpu/src/nodes/rope.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once -#include -#include - -#include -#include -#include - -#include "transformations/cpu_opset/common/op/rope.hpp" - -namespace ov { -namespace intel_cpu { -namespace node { - -class RoPE : public Node { -public: - RoPE(const std::shared_ptr& op, const GraphContext::CPtr context); - - void getSupportedDescriptors() override {} - bool created() const override { - return getType() == Type::RoPE; - } - bool needPrepareParams() const override { - return false; - }; - void executeDynamicImpl(dnnl::stream strm) override { - execute(strm); - } - void initSupportedPrimitiveDescriptors() override; - void execute(dnnl::stream strm) override; - static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; - -private: - struct Executor { - virtual void execute(dnnl::stream strm, - const RoPENode::Config& config, - const std::vector& inputs, - const std::vector& outputs) = 0; - }; - template - struct RoPEExecutorRotateHalf; - template - struct RoPEExecutorInterleaved; - RoPENode::Config m_config; - std::shared_ptr m_executor; -}; - -} // namespace node -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes_factory.cpp b/src/plugins/intel_cpu/src/nodes_factory.cpp index bead297d033..963218314fa 100644 --- a/src/plugins/intel_cpu/src/nodes_factory.cpp +++ b/src/plugins/intel_cpu/src/nodes_factory.cpp @@ -94,7 +94,6 @@ #include "nodes/unique.hpp" #include "nodes/ngram.h" #include "nodes/scaled_attn.h" -#include "nodes/rope.h" namespace ov { namespace intel_cpu { @@ -182,7 +181,6 @@ Node::NodesFactory::NodesFactory() INTEL_CPU_NODE(Eye, Type::Eye); INTEL_CPU_NODE(Unique, Type::Unique); INTEL_CPU_NODE(Ngram, Type::Ngram); - INTEL_CPU_NODE(RoPE, Type::RoPE); INTEL_CPU_NODE(Interpolate, Type::Interpolate); INTEL_CPU_NODE(RandomUniform, Type::RandomUniform); INTEL_CPU_NODE(Reduce, Type::Reduce); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.cpp deleted file mode 100644 index 8b4461b479e..00000000000 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.cpp +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// -#include "rope.hpp" - -#include - -#include "transformations/itt.hpp" - -ov::intel_cpu::RoPENode::RoPENode(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) { - constructor_validate_and_infer_types(); -} - -std::shared_ptr ov::intel_cpu::RoPENode::clone_with_new_inputs( - const ngraph::OutputVector& new_args) const { - INTERNAL_OP_SCOPE(RoPENode_with_new_inputs); - check_new_args_count(this, new_args); - return std::make_shared(new_args, m_config); -} - -void ov::intel_cpu::RoPENode::validate_and_infer_types() { - INTERNAL_OP_SCOPE(RoPENode_validate_and_infer_types); - auto input_pshape = get_input_partial_shape(0); - auto input_slice_size = m_config.slice_stop - m_config.slice_start; - if (input_slice_size > 0) { - input_pshape[3] = input_slice_size; - } - if (m_config.input_trans0213) { - // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens before RoPE - std::swap(input_pshape[2], input_pshape[1]); - } else if (m_config.is_interleaved) { - // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens after RoPE - std::swap(input_pshape[2], input_pshape[1]); - } - - set_output_type(0, get_input_element_type(0), input_pshape); -} - -bool ov::intel_cpu::RoPENode::visit_attributes(ngraph::AttributeVisitor& visitor) { - INTERNAL_OP_SCOPE(RoPENode_visit_attributes); - visitor.start_structure("config"); - visitor.on_attribute("slice_start", m_config.slice_start); - visitor.on_attribute("slice_stop", m_config.slice_stop); - visitor.on_attribute("input_trans0213", m_config.input_trans0213); - visitor.on_attribute("is_interleaved", m_config.is_interleaved); - visitor.on_attribute("rotary_ndims", m_config.rotary_ndims); - visitor.on_attribute("gather_position_arg_id", m_config.gather_position_arg_id); - visitor.finish_structure(); - return true; -} diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.hpp deleted file mode 100644 index cc6df7ec2b1..00000000000 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.hpp +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include - -namespace ov { -namespace intel_cpu { - -/** - * The operation performs rotary positional embedding operation described in: - * ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING by Jianlin Su - * - * the core computation is application of 2x2 rotation matrix on basis - * of pair of input states x[i0] & x[i1] to get the rotary embedded pair of output - * states y[i0] and y[i1]: - * - * suppose dimension of hidden states (of each attention head) is N and d of which - * are to be embedded (d <= N), non-embedded parts are copied into output. - * - * for i in 0...(d/2) - * if (is_interleaved) { - * // interleaving style of indexing - * i0 = i*2 - * i1 = i*2 + 1 - * } else { - * // rotate-half style of indexing - * i0 = i - * i1 = i + (d/2) - * } - * y[i0] = x[i0]*cos(m * xita[i]) - x[i1]*sin(m * xita[i]) - * y[i1] = x[i1]*cos(m * xita[i]) + x[i0]*sin(m * xita[i]) - * Note: m is token position of current input - * - * based on configuration, additional preprocessing steps maybe performed as well: - * - slicing last dimension of input tensor - * (when q/k/v are merged and only q or k part is to be extracted & embedded) - * - transpose input tensor - * (when q/k comes from fullyconnect has layout [batch, seq_len, head_cnt, head_dim] - * but output of RoPE is required to be of layout [batch, head_cnt, seq_length, head_dims]) - * - gather sin/cos from input tensor 2&3 using position index tensor passed through input 4 - * - * Inputs: - * 1. Input hidden states tensor of type T1 - shape: - * [batch, seq_length, head_cnt, head_dims] when input_trans0213 == false OR - * [batch, head_cnt, seq_length, head_dims] when input_trans0213 == true - * 2. pre-calculated cos(m*xita[n]) tensor of type T2 - shape [1, 1, max_position_embeddings, d]. - * 3. pre-calculated sin(m*xita[n]) tensor of type T2 - shape [1, 1, max_position_embeddings, d]. - * input 3 is combined with 2 when is_interleaved is true. - * 4. postion index tensor of type T3 - shape [batch, 1, seq_length, 1 or d] OR [batch, seq_length] optional - * Outputs: - * 1. New embedding tensor of type T1 and of shape [batch, head_cnt, seq_length, head_dims] - * Types: - * T1 - FP32 or BF16 - * T2 - FP32 - * T3 - I32 - */ -class RoPENode : public ngraph::op::Op { -public: - OPENVINO_OP("RoPE", "cpu_plugin_opset"); - - RoPENode() = default; - - struct Config { - size_t slice_start = 0; // slice inner-most dimensions of input - size_t slice_stop = 0; - bool input_trans0213 = false; // transpose input dim 1&2 - bool is_interleaved = false; // interleaved mode, implies trans0213 happens after RoPE - size_t rotary_ndims = 0; // dimensions to be embedded (d in the description) - int gather_position_arg_id = - 0; // arg id of position tensor, ==3 when gather from sin/cos inputs according to position is required - }; - - RoPENode(const OutputVector& args, const Config& cfg); - - bool visit_attributes(ngraph::AttributeVisitor& visitor) override; - - void validate_and_infer_types() override; - - std::shared_ptr clone_with_new_inputs(const ngraph::OutputVector& new_args) const override; - - const Config& get_config() const { - return m_config; - } - - Config& get_config() { - return m_config; - } - -private: - Config m_config; -}; - -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp deleted file mode 100644 index cacd0a6f837..00000000000 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp +++ /dev/null @@ -1,435 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "rope_fusion.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "itt.hpp" -#include "ov_ops/type_relaxed.hpp" -#include "transformations/cpu_opset/common/op/rope.hpp" -#include "utils/gen_pattern.hpp" - -using namespace ov::gen_pattern; - -ov::intel_cpu::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() { - MATCHER_SCOPE(RoPEFusionGPTNEOX); - - // rope pattern matching triggers a little design flaw: - // y1 = mul(x, cos) - // y2 = mul(x, sin) - // y = add(y1, y2) - // there is a chance that in 'y1' branch, pattern x is mapped to actual value of cos (mul is commutable) - // this leads to the matching failure of 'y2' branch, because cos didn't appear in that - // branch. - // so here we use a WA, only match the path of rotate_hal(x)*sin and check the x*cos path - // in the callback - auto x = makePattern(ov::Rank(4)); - auto x_or_cos1 = makePattern(ov::Rank(4)); - auto x_or_cos2 = makePattern(ov::Rank(4)); - auto t_sin = makePattern(ov::Rank(4)); - - x->set_friendly_name("x"); - - auto half_ndims = Symbol("half_ndims"); - auto int32_max = std::numeric_limits::max(); - - // rotate half : [-x2, x1] - auto x2 = GenSlice(x, half_ndims, int32_max, 1, 3); - auto x2neg = makePattern({x2, -1.0f}, {{"auto_broadcast", "numpy"}}); - auto x1 = GenSlice(x, 0, half_ndims, 1, 3); - auto x_rotate_half = makePattern({x2neg, x1}, {{"axis", -1}}); - - auto mul_cos = makePattern({x_or_cos1, x_or_cos2}, {{"auto_broadcast", "numpy"}}); - auto mul_sin = makePattern({x_rotate_half, t_sin}, {{"auto_broadcast", "numpy"}}); - - // [x1, x2]*cos + [-x2, x1]*sin - auto result = makePattern({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); - - matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { - PatternValidator validator(m); - if (!validator) { - return false; - } - - const auto& pattern_map = m.get_pattern_value_map(); - auto root = m.get_match_root(); - - // check mul(x, cos) exists - Output v_cos; - if (pattern_map.at(x_or_cos1) == pattern_map.at(x)) { - v_cos = pattern_map.at(x_or_cos2); - } else if (pattern_map.at(x_or_cos2) == pattern_map.at(x)) { - v_cos = pattern_map.at(x_or_cos1); - } else { - // not a RoPE - return false; - } - - RoPENode::Config config; - OutputVector new_args; - config.rotary_ndims = 2 * validator["half_ndims"]; - - new_args.push_back(pattern_map.at(x)); - new_args.push_back(v_cos); - new_args.push_back(pattern_map.at(t_sin)); - - auto old_node = root; - auto new_node = std::make_shared(new_args, config); - new_node->set_friendly_name(old_node->get_friendly_name()); - ov::replace_node(old_node, new_node); - - // this new node may match following additional matchers - register_new_node(new_node); - - return true; - }; - - auto m = std::make_shared(result, matcher_name); - this->register_matcher(m, callback); -} - -ov::intel_cpu::RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() { - MATCHER_SCOPE(RoPEFusionCosSinPreprocess); - - auto cos_const = makePattern({}); // "f32[1,1,2048,24]" - auto sin_const = makePattern({}); // "f32[1,1,2048,24]" - - auto node_batch_size = makePattern("i32[1]"); - auto tile_batch = makePattern("i32[1]"); - auto gather_positions = makePattern("i32[?,?,?,?]"); - - auto prepare_cos_sin_gptneox = [&](std::shared_ptr const_tab) { - auto slice1 = makePattern({const_tab, {0}, node_batch_size, {1}}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - return makePattern({slice1, gather_positions}, {{"axis", 2}}); - }; - - auto seq_len = makePattern("i32[1]"); - auto gather_positions_2d = makePattern("i32[?,?]"); - - auto head_dims = Symbol("head_dims"); - auto prepare_cos_sin_llama = [&](std::shared_ptr const_tab) { - auto ScatterUpdate = makePattern({{0, 0, 0}, 2, seq_len, 0}); - auto slice_Slice = makePattern({const_tab, {0, 0, 0}, ScatterUpdate, {1, 1, 1}}, - {{"begin_mask", {1, 1, 0}}, - {"end_mask", {1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto squeeze = makePattern({slice_Slice, {-1, head_dims}}); - auto index_Gather = makePattern({squeeze, gather_positions_2d, 0}, {{"batch_dims", 0}}); - auto unsqueeze = makePattern({index_Gather, {1, 1, -1, head_dims}}); - return unsqueeze; - }; - - auto cos_tab = prepare_cos_sin_gptneox(cos_const) | prepare_cos_sin_llama(cos_const); - auto sin_tab = prepare_cos_sin_gptneox(sin_const) | prepare_cos_sin_llama(sin_const); - - auto x = makePattern(ov::Rank(4)); - auto rope = makePattern({x, cos_tab, sin_tab}); - - matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { - PatternValidator validator(m); - if (!validator) { - return false; - } - const auto& pattern_map = m.get_pattern_value_map(); - auto root = m.get_match_root(); - auto rope_node = as_type_ptr(pattern_map.at(rope).get_node_shared_ptr()); - if (!rope_node) - return false; - - if (pattern_map.count(cos_const)) { - rope_node->set_argument(1, pattern_map.at(cos_const)); - } - if (pattern_map.count(sin_const)) { - rope_node->set_argument(2, pattern_map.at(sin_const)); - } - - auto& config = rope_node->get_config(); - if (pattern_map.count(gather_positions)) { - auto arg_id = rope_node->get_input_size(); - rope_node->set_argument(arg_id, pattern_map.at(gather_positions)); - config.gather_position_arg_id = arg_id; - } else if (pattern_map.count(gather_positions_2d)) { - auto arg_id = rope_node->get_input_size(); - rope_node->set_argument(arg_id, pattern_map.at(gather_positions_2d)); - config.gather_position_arg_id = arg_id; - } - rope_node->validate_and_infer_types(); - register_new_node(rope_node); - return true; - }; - auto m = std::make_shared(rope, matcher_name); - this->register_matcher(m, callback); -} - -// only a fraction of head_size is rotary-embedded -ov::intel_cpu::RoPEFusionIOSlicing::RoPEFusionIOSlicing() { - MATCHER_SCOPE(RoPEFusionIOSlicing); - auto int32_max = std::numeric_limits::max(); - auto data = makePattern(ov::Rank(4)); - - auto ndims = Symbol("ndims"); - auto x = GenSlice(data, 0, ndims, 1, 3); - auto y = GenSlice(data, ndims, int32_max, 1, 3); - auto x_emb = makePattern({x, {}, {}}) | makePattern({x, {}, {}, {}}); - auto result = makePattern({x_emb, y}, {{"axis", -1}}); - - matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { - const auto& pattern_map = m.get_pattern_value_map(); - auto root = m.get_match_root(); - - auto rope_node = as_type_ptr(root->input_value(0).get_node_shared_ptr()); - if (!rope_node) - return false; - - PatternValidator validator(m); - if (!validator) { - return false; - } - auto ndims = validator["ndims"]; - - auto& config = rope_node->get_config(); - if (config.rotary_ndims != ndims) - return false; - - // remove slice & concat - rope_node->set_argument(0, pattern_map.at(data)); - rope_node->set_friendly_name(root->get_friendly_name()); - ov::replace_node(root, rope_node); - - rope_node->validate_and_infer_types(); - register_new_node(rope_node); - return true; - }; - auto m = std::make_shared(result, matcher_name); - this->register_matcher(m, callback); -} - -ov::intel_cpu::RoPEFusionPreprocess::RoPEFusionPreprocess() { - MATCHER_SCOPE(RoPEFusionPreprocess); - - // gptneox-preprocess of input data - auto input_to_slice = makePattern(ov::Rank(4)); - auto input_to_trans = makePattern(ov::Rank(4)); // no need to slice from 3S - - // in some model qkv prejection is combined and - // needs to be sliced before RoPE - auto slice_start = Symbol("slice_start"); - auto slice_stop = Symbol("slice_stop"); - auto input_slice = GenSlice(input_to_slice, slice_start, slice_stop, 1, 3); - - // some model will transpose from [B,L,H,S] to [B,H,L,S] before RoPE - auto x = makePattern({input_slice | input_to_trans, {0, 2, 1, 3}}); - auto result = makePattern({x, {}, {}}) | makePattern({x, {}, {}, {}}); - - matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { - PatternValidator validator(m); - if (!validator) { - return false; - } - - const auto& pattern_map = m.get_pattern_value_map(); - auto root = m.get_match_root(); - auto rope_node = as_type_ptr(root); - if (!rope_node) - return false; - - auto& config = rope_node->get_config(); - - if (pattern_map.count(input_to_slice)) { - config.slice_start = validator["slice_start"]; - config.slice_stop = validator["slice_stop"]; - config.input_trans0213 = true; - rope_node->set_argument(0, pattern_map.at(input_to_slice)); - } else if (pattern_map.count(input_to_trans)) { - config.input_trans0213 = true; - rope_node->set_argument(0, pattern_map.at(input_to_trans)); - } else { - return false; - } - rope_node->validate_and_infer_types(); - register_new_node(rope_node); - return true; - }; - auto m = std::make_shared(result, matcher_name); - this->register_matcher(m, callback); -} - -// remove stridedslice from 0 to int32_max with stride 1 -ov::intel_cpu::EliminateStridedSlice::EliminateStridedSlice() { - MATCHER_SCOPE(EliminateStridedSlice); - auto data = ov::pass::pattern::any_input(ngraph::pattern::has_static_rank()); - auto begin = ov::pass::pattern::wrap_type(ngraph::pattern::type_matches(ov::element::i32)); - auto end = ov::pass::pattern::wrap_type(ngraph::pattern::type_matches(ov::element::i32)); - auto stride = ov::pass::pattern::wrap_type(ngraph::pattern::type_matches(ov::element::i32)); - - auto strided_slice = - ov::pass::pattern::wrap_type({data, begin, end, stride}, [](const Output& value) { - auto s1 = as_type_ptr(value.get_node_shared_ptr()); - if (!s1->get_new_axis_mask().empty() || !s1->get_shrink_axis_mask().empty() || - !s1->get_ellipsis_mask().empty()) { - return false; - } - - auto inputs = s1->input_values(); - - auto begin = as_type_ptr(inputs[1].get_node_shared_ptr()); - auto end = as_type_ptr(inputs[2].get_node_shared_ptr()); - // stride is all 1 - auto stride = as_type_ptr(inputs[3].get_node_shared_ptr()); - - if (!begin) - return false; - if (!end) - return false; - if (!stride) - return false; - - auto v_stride = stride->cast_vector(); - for (auto& v : v_stride) { - if (v != 1) - return false; - } - - auto v_begin = begin->cast_vector(); - auto v_end = end->cast_vector(); - - auto& begin_mask = s1->get_begin_mask(); - auto& end_mask = s1->get_end_mask(); - auto mask_size = begin_mask.size(); - if (begin_mask.size() != end_mask.size()) { - return false; - } - - auto int32_max = std::numeric_limits::max(); - for (size_t i = 0; i < mask_size; i++) { - if (begin_mask[i] != end_mask[i]) - return false; - // all valid [begin, end] are [0, int32_max] - if (begin_mask[i] == 0 && end_mask[i] == 0) { - if (v_begin[i] != 0 || v_end[i] != int32_max) - return false; - } - } - return true; - }); - - matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { - auto root = m.get_match_root(); - return replace_output_update_name(root->output(0), root->input_value(0)); - }; - - auto m = std::make_shared(strided_slice, matcher_name); - this->register_matcher(m, callback); -} - -ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() { - MATCHER_SCOPE(RoPEFusionGPTJ); - - auto int32_max = std::numeric_limits::max(); - auto ndims = Symbol("ndims"); - - auto view_Reshape = makePattern(ov::Rank(4)); - - // view_Reshape : B,L,H,S - auto slice_Slice_965 = GenSlice(view_Reshape, 0, ndims, 1, 3); - - auto gather_sin_cos = makePattern("f32"); - - auto varsplit = makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); - varsplit->set_output_size(2); - auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}); - auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}); - // repeate cos/sin table - auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { - const auto& vec = node.get_vector(); - int32_t v = 0; - for (size_t i = 0; i < vec.size(); i += 2, v++) { - if (vec[i] != v || vec[i + 1] != v) - return false; - } - return true; - }); - auto repeat_interleave_sin = makePattern({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}}); - auto repeat_interleave_cos = makePattern({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}}); - - auto t_cos = makePattern(ov::Rank(4)); - auto t_sin = makePattern(ov::Rank(4)); - - // x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2]) - auto slice_Slice_1174 = GenSlice(slice_Slice_965, 1, int32_max, 2, 3); - - auto neg_Multiply_1177 = makePattern({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}}); - auto Unsqueeze_65524 = makePattern({neg_Multiply_1177, -1}); - - auto slice_Slice_1168 = GenSlice(slice_Slice_965, 0, int32_max, 2, 3); - auto Unsqueeze_65525 = makePattern({slice_Slice_1168, -1}); - auto stack_1182 = makePattern({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); - - auto ShapeOf_169068 = makePattern({stack_1182}); - auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0); - auto flatten_Concat_1197 = makePattern({flatten_Slice_1194, {-1}}, {{"axis", 0}}); - auto flatten_Reshape_1198 = makePattern({stack_1182, flatten_Concat_1197}); - - // x*cos [B,L,H,ndims] - auto mul_cos = - makePattern({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}}); - auto mul_sin = - makePattern({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}}); - - // *cos + *sin - auto rotary_emb = makePattern({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); - - auto slice_Slice_971 = GenSlice(view_Reshape, ndims, int32_max, 1, 3); - auto cat_Concat_1211 = makePattern({rotary_emb, slice_Slice_971}, {{"axis", -1}}); - auto permute_Transpose_1213 = makePattern({cat_Concat_1211, {0, 2, 1, 3}}); - - auto result = permute_Transpose_1213; - - matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { - const auto& pattern_map = m.get_pattern_value_map(); - auto root = m.get_match_root(); - PatternValidator validator(m); - if (!validator) { - return false; - } - - RoPENode::Config config; - OutputVector new_args; - config.rotary_ndims = validator["ndims"]; - - config.is_interleaved = true; - - // input is [B,L,H,S] - new_args.push_back(pattern_map.at(view_Reshape)); - // sin_cos table (gathered with positions) [1, L, 64] - new_args.push_back(pattern_map.at(gather_sin_cos)); - new_args.push_back(pattern_map.at(gather_sin_cos)); - - auto old_node = root; - - auto new_node = std::make_shared(new_args, config); - new_node->set_friendly_name(old_node->get_friendly_name()); - ov::replace_node(old_node, new_node); - return true; - }; - - auto m = std::make_shared(result, matcher_name); - this->register_matcher(m, callback); -} \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.hpp deleted file mode 100644 index 58bab527504..00000000000 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.hpp +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace ov { -namespace intel_cpu { - -class RoPEFusionGPTNEOX : public ngraph::pass::MatcherPass { -public: - OPENVINO_RTTI("RoPEFusionGPTNEOX", "0"); - RoPEFusionGPTNEOX(); -}; - -class RoPEFusionGPTJ : public ngraph::pass::MatcherPass { -public: - OPENVINO_RTTI("RoPEFusionGPTJ", "0"); - RoPEFusionGPTJ(); -}; - -class RoPEFusionIOSlicing : public ngraph::pass::MatcherPass { -public: - OPENVINO_RTTI("RoPEFusionIOSlicing", "0"); - RoPEFusionIOSlicing(); -}; - -class RoPEFusionPreprocess : public ngraph::pass::MatcherPass { -public: - OPENVINO_RTTI("RoPEFusionPreprocess", "0"); - RoPEFusionPreprocess(); -}; - -class RoPEFusionCosSinPreprocess : public ngraph::pass::MatcherPass { -public: - OPENVINO_RTTI("RoPEFusionCosSinPreprocess", "0"); - RoPEFusionCosSinPreprocess(); -}; - -class EliminateStridedSlice : public ngraph::pass::MatcherPass { -public: - OPENVINO_RTTI("EliminateStridedSlice", "0"); - EliminateStridedSlice(); -}; - -class RoPEFusion : public ngraph::pass::GraphRewrite { -public: - OPENVINO_RTTI("RoPEFusion", "0"); - RoPEFusion() { - add_matcher(); - add_matcher(); - // optional heads & tails are fused in separate matcher pass, - // after RoPENode has been created. - add_matcher(); - add_matcher(); - add_matcher(); - } -}; - -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index a034acb2572..2cd69ffdee2 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -83,7 +83,6 @@ #include "transformations/smart_reshape/matmul_sr.hpp" #include "transformations/init_node_info.hpp" #include "utils/ngraph_transformation.hpp" -#include "utils/print_model.hpp" // LPT transformations #include "low_precision/add.hpp" @@ -111,7 +110,6 @@ #include "transformations/cpu_opset/common/pass/insert_convert_after_extension.hpp" #include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp" #include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp" -#include "transformations/cpu_opset/common/pass/rope_fusion.hpp" // Snippets #include "snippets/pass/tokenization.hpp" @@ -656,10 +654,6 @@ void Transformations::PostLpt() { // Execute before snippets. Otherwise FQ will be converted to Subgraph CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn); - - CPU_REGISTER_PASS_X64(postLPTPassManager, EliminateStridedSlice); - CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion); - postLPTPassManager.run_passes(model); } diff --git a/src/plugins/intel_cpu/src/utils/gen_pattern.hpp b/src/plugins/intel_cpu/src/utils/gen_pattern.hpp deleted file mode 100644 index c562190494e..00000000000 --- a/src/plugins/intel_cpu/src/utils/gen_pattern.hpp +++ /dev/null @@ -1,1304 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "openvino/opsets/opset1.hpp" -#include "openvino/opsets/opset2.hpp" -#include "openvino/opsets/opset3.hpp" -#include "openvino/opsets/opset4.hpp" -#include "openvino/opsets/opset5.hpp" -#include "openvino/opsets/opset6.hpp" -#include "openvino/opsets/opset7.hpp" -#include "openvino/opsets/opset8.hpp" -#include "openvino/pass/pattern/matcher.hpp" -#include "openvino/pass/pattern/op/label.hpp" -#include "openvino/pass/pattern/op/or.hpp" -#include "openvino/pass/pattern/op/wrap_type.hpp" - -namespace ov { -namespace gen_pattern { - -static bool force_matcher_verbose = false; - -#ifdef CPU_DEBUG_CAPS - -template -static inline void _verbose_log(Args&&... args) { - std::stringstream ss; - int dummy[] = {(ss << std::forward(args) << " ", 0)...}; - (void)(dummy); - ss << std::endl; - std::cout << ss.str(); -} - -static int matcher_verbose_enabled() { - static const int enabled = std::getenv("GENP_VERBOSE") ? (atoi(std::getenv("GENP_VERBOSE"))) : 0; - return enabled; -} - -# define _VERBOSE_LOG(...) \ - if (matcher_verbose_enabled() || force_matcher_verbose) \ - _verbose_log(__VA_ARGS__) -#else -# define _VERBOSE_LOG(...) -#endif - -namespace detail { -inline std::vector split_string(const std::string& s, const std::string& delimiter) { - std::vector ret; - size_t pos = 0, pos_next; - std::string token; - while ((pos_next = s.find(delimiter, pos)) != std::string::npos) { - token = s.substr(pos, pos_next - pos); - ret.push_back(token); - pos = pos_next + 1; - } - // return whole string if no delimiter if found - token = s.substr(pos, pos_next); - ret.push_back(token); - return ret; -} - -template -std::string vec2str(const std::vector& vec, int cnt_limit = 9) { - std::stringstream ss; - ss << "{"; - const char* sep = ""; - for (auto& v : vec) { - cnt_limit--; - if (cnt_limit == 0) { - ss << sep << "..."; - break; - } - ss << sep << v; - sep = ","; - } - ss << "}"; - return ss.str(); -} -} // namespace detail - -struct values_info { - values_info(const char* pattern_list = nullptr) { - if (pattern_list == nullptr || pattern_list[0] == 0) { - all_type_pshape.clear(); - return; - } - auto pattern_vector = detail::split_string(pattern_list, " "); - for (auto& pattern : pattern_vector) { - if (pattern[0] == '[') { - all_type_pshape.emplace_back(ov::element::dynamic, ov::PartialShape(pattern)); - } else { - auto sep = pattern.find("["); - if (sep != std::string::npos) { - // ele_type[p_shape] - all_type_pshape.emplace_back(ov::element::Type(pattern.substr(0, sep)), - ov::PartialShape(pattern.substr(sep))); - } else { - // ele_type - all_type_pshape.emplace_back(ov::element::Type(pattern), ov::PartialShape::dynamic()); - } - } - } - } - - size_t size() { - return all_type_pshape.size(); - } - const std::pair& operator[](int index) { - return all_type_pshape[index]; - } - - //------------------------------------------------------------- - bool predicate(const ov::Output& value) const { - if (all_type_pshape.empty()) - return true; - auto index = value.get_index(); - auto& item = all_type_pshape[index]; - if (!item.first.compatible(value.get_element_type()) || !item.second.compatible(value.get_partial_shape())) { - _VERBOSE_LOG("* mismatched vtype between value & pattern : ", - value.get_element_type(), - value.get_partial_shape(), - "vs", - item.first, - item.second); - return false; - } - return true; - } - - std::string to_string() { - std::stringstream ss; - const char* sep = ""; - for (auto& t : all_type_pshape) { - ss << sep << t.first << t.second; - sep = ";"; - } - return ss.str(); - } - - std::vector> all_type_pshape; -}; - -// Symbol : a constant that unknown at the pattern's building time -// but collected and validated after pattern was matched -// with some sub-graph values. -class Symbol { -private: - struct Entity { - const char* name = "?"; - char op; - double literal_const_value; - std::shared_ptr lhs; - std::shared_ptr rhs; - // _,+,-,*,/ - // l : literal const - // n : named symbol - double eval(const std::map& value_map) const { - switch (op) { - case 'l': - return literal_const_value; - case 'n': - return value_map.at(this); - case '+': - return lhs->eval(value_map) + rhs->eval(value_map); - case '-': - return lhs->eval(value_map) - rhs->eval(value_map); - case '*': - return lhs->eval(value_map) * rhs->eval(value_map); - case '/': - return lhs->eval(value_map) / rhs->eval(value_map); - case '_': - return -lhs->eval(value_map); - case 'r': - return std::sqrt(lhs->eval(value_map)); - default: - assert(false); - return std::numeric_limits::quiet_NaN(); - } - } - }; - std::shared_ptr entity; - -public: - Symbol() { - entity = std::make_shared(); - entity->op = 'n'; - } - Symbol(const char* name) { - entity = std::make_shared(); - entity->op = 'n'; - entity->name = name; - } - Symbol(const int value) { - entity = std::make_shared(); - entity->op = 'l'; - entity->literal_const_value = value; - } - Symbol(char op, const Symbol& lhs, const Symbol& rhs) { - entity = std::make_shared(); - entity->op = op; - entity->lhs = lhs.entity; - entity->rhs = rhs.entity; - } - double eval(const std::map& value_map) const { - return entity->eval(value_map); - } - bool is_independent_var() const { - return entity->op == 'n'; - } - int is_literal_const() const { - return entity->op == 'l'; - } - char get_op() const { - return entity->op; - } - void* get_id() const { - return entity.get(); - } - const char* get_name() const { - return entity->name; - } - bool operator<(const Symbol& rhs) const { - return get_id() < rhs.get_id(); - } -}; - -inline Symbol operator-(const Symbol& lhs) { - return Symbol('_', lhs, lhs); -} -inline Symbol operator+(const Symbol& lhs, const Symbol& rhs) { - return Symbol('+', lhs, rhs); -} -inline Symbol operator-(const Symbol& lhs, const Symbol& rhs) { - return Symbol('-', lhs, rhs); -} -inline Symbol operator*(const Symbol& lhs, const Symbol& rhs) { - return Symbol('*', lhs, rhs); -} -inline Symbol operator/(const Symbol& lhs, const Symbol& rhs) { - return Symbol('/', lhs, rhs); -} -inline Symbol sqrt(Symbol lhs) { - return Symbol('r', lhs, lhs); -} - -namespace detail { - -// AttrAny is simple wrapper of Any to provide some constructor -// to take advantage of C++ implicit conversion to allow: -// - attribute expressed using initializer_list. -// - symbol to be used as attributes -struct AttrAny { - ov::Any any; - - // empty attribute, means empty vector, and error for scalar - AttrAny() {} - - AttrAny(const Symbol& v) : any(v) {} - AttrAny(const ov::element::Type& v) : any(v) {} - AttrAny(const ov::PartialShape& v) : any(v) {} - AttrAny(const ov::Dimension& v) : any(v) {} - AttrAny(bool v) : any(v) {} - AttrAny(int v) : any(v) {} - AttrAny(float v) : any(v) {} - AttrAny(double v) : any(v) {} - AttrAny(long v) : any(static_cast(v)) {} - AttrAny(long long v) : any(static_cast(v)) {} - AttrAny(const char* v) : any(v) {} - AttrAny(const std::string& v) : any(v) {} - - // template ::value>::type = true> - // AttrAny(const T& v) : any(v) {} - - // template ::value>::type = true> - // AttrAny(const std::vector& v) : any(v) {} - - AttrAny(const std::vector& v) : any(v) {} - - // template ::value>::type = true> - // AttrAny(std::initializer_list values) : any(std::vector(values)) {} - AttrAny(std::initializer_list values) : any(std::vector(values)) {} - AttrAny(std::initializer_list values) : any(std::vector(values.begin(), values.end())) {} - AttrAny(std::initializer_list values) : any(std::vector(values)) {} - AttrAny(std::initializer_list values) : any(std::vector(values)) {} - AttrAny(std::initializer_list values) : any(std::vector(values.begin(), values.end())) {} - - AttrAny(std::initializer_list values) : any(std::vector(values)) {} - AttrAny(std::initializer_list values) : any(std::vector(values)) {} - - std::string as_string() { - if (any.is()) - return any.as(); - return any.as(); - } - bool as_bool() { - if (any.is()) - return any.as(); - return any.as(); - } - double as_double() { - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - return any.as(); - } - int64_t as_int64_t() { - if (any.is()) - return any.as(); - return any.as(); - } - - template - std::vector as_vector() { - if (any.empty()) - return {}; - if (!std::is_same::value) { - if (any.is>()) { - auto ivec = any.as>(); - return std::vector(ivec.begin(), ivec.end()); - } - if (any.is>()) { - auto vec = any.as>(); - return std::vector(vec.begin(), vec.end()); - } - } - if (!std::is_same::value) { - if (any.is>()) { - auto ivec = any.as>(); - return std::vector(ivec.begin(), ivec.end()); - } - if (any.is>()) { - auto vec = any.as>(); - return std::vector(vec.begin(), vec.end()); - } - } - if (any.is>()) { - auto ivec = any.as>(); - return std::vector(ivec.begin(), ivec.end()); - } - return any.as>(); - } - - template - std::vector as_T_vector() { - if (any.empty()) - return {}; - if (any.is()) { - auto to_vec = [](std::initializer_list v) { - return std::vector(v); - }; - return to_vec({any.as()}); - } - if (any.is>()) { - auto ivec = any.as>(); - return std::vector(ivec.begin(), ivec.end()); - } - return any.as>(); - } - - std::vector as_str_vector() { - if (any.empty()) - return {}; - if (any.is>()) { - auto vec = any.as>(); - return std::vector(vec.begin(), vec.end()); - } - return any.as>(); - } - - template - T cast_to() { - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - if (any.is()) - return any.as(); - return any.as(); - } - - template - bool equal_to(const std::vector& rhs) { - if (any.empty() && rhs.empty()) - return true; - auto& vec = any.as>(); - return std::equal(vec.begin(), vec.end(), rhs.begin()); - } - - template - bool equal_to(const std::vector& rhs) { - if (any.empty() && rhs.empty()) - return true; - - if (any.is>()) { - auto& vec = any.as>(); - return vec.size() == rhs.size() && std::equal(vec.begin(), vec.end(), rhs.begin()); - } - return equal_to(rhs); - } - - template - typename std::enable_if::value, bool>::type equal_to(const T& rhs) { - return rhs == any.as(); - } - - template - typename std::enable_if::value, bool>::type equal_to(const T& rhs) { - if (any.is()) { - auto& value = any.as(); - return rhs == static_cast(value); - } - return equal_to(rhs); - } -}; - -using AttrMap = std::map; - -class AttrSetter : public ov::AttributeVisitor { -public: - AttrMap& m_attr_map; - std::vector m_missing_attrs; - - AttrSetter(AttrMap& attrs) : m_attr_map(attrs) {} - - const std::vector& get_missing_attrs() { - return m_missing_attrs; - } - - bool should_skip(const std::string& name) { - if (m_attr_map.count(name) == 0) { - // attributes not specified is recorded as missing - m_missing_attrs.push_back(name); - return true; - } - - if (m_attr_map[name].any.is()) { - m_missing_attrs.push_back(name); - return true; - } - - if (m_attr_map[name].any.empty()) { - // input is set to empty, meaning default value is used. - return true; - } - return false; - } - - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name)) - return; - value.set(m_attr_map[name].as_string()); - } - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name)) - return; - value.set(m_attr_map[name].as_bool()); - } - void on_adapter(const std::string& name, ov::ValueAccessor& adapter) override { - if (should_skip(name)) - return; - auto& any = m_attr_map[name].any; - if (auto a = ov::as_type>(&adapter)) { - static_cast(*a) = any.as(); - } else if (auto a = ov::as_type>(&adapter)) { - a->set(any.as()); - } else if (auto a = ov::as_type>(&adapter)) { - a->set(any.as()); - } else if (auto a = ov::as_type>(&adapter)) { - a->set(m_attr_map[name].as_vector()); - } else if (auto a = ov::as_type>(&adapter)) { - a->set(m_attr_map[name].as_vector()); - } else if (auto a = ov::as_type>>(&adapter)) { -#if defined(__APPLE__) || defined(__EMSCRIPTEN__) - static_cast&>(*a) = m_attr_map[name].as_vector(); -#else - a->set(m_attr_map[name].as_vector()); -#endif - } else if (auto a = ov::as_type>(&adapter)) { - a->set(m_attr_map[name].as_vector()); - //} else if (auto a = ov::as_type>(&adapter)) { - // a->set(m_attr_map[name].as_string()); - } else if (auto a = ov::as_type>(&adapter)) { - a->set(m_attr_map[name].as_string()); - } else if (auto a = ov::as_type>(&adapter)) { - a->set(m_attr_map[name].as_vector()); - } else if (auto a = ov::as_type>(&adapter)) { - a->set(m_attr_map[name].as_T_vector()); - } else { - OPENVINO_THROW("unsupported AttributeAdapter for attribute : ", name); - } - } - - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name)) - return; - value.set(m_attr_map[name].as_double()); - } - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name)) - return; - value.set(m_attr_map[name].as_int64_t()); - } - void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { - if (should_skip(name)) - return; - value.set(m_attr_map[name].as_vector()); - } - - void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { - if (should_skip(name)) - return; - value.set(m_attr_map[name].as_vector()); - } - - void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { - if (should_skip(name)) - return; - value.set(m_attr_map[name].as_vector()); - } - - void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { - if (should_skip(name)) - return; - value.set(m_attr_map[name].as_str_vector()); - } -}; - -class GenericPattern : public ov::pass::pattern::op::Pattern { -public: - OPENVINO_RTTI("GenericPattern"); - - explicit GenericPattern(const OutputVector& args = {}, const detail::AttrMap& attrs = {}) - : ov::pass::pattern::op::Pattern(args) { - set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); - m_attrs = attrs; - } - - // this allows code inside pred to access pattern node itself - void set_predicate(ov::pass::pattern::op::ValuePredicate pred) { - m_predicate = pred; - } - - bool match_value(ov::pass::pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override { - if (m_predicate(graph_value)) { - auto& pattern_map = matcher->get_pattern_value_map(); - pattern_map[shared_from_this()] = graph_value; - matcher->add_node(graph_value); - return (get_input_size() == 0 - ? true - : matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr())); - } - return false; - } - - detail::AttrMap& get_attrs() { - return m_attrs; - } - -private: - detail::AttrMap m_attrs; -}; - -// A glue/syntax-sugar type which allows more types to be used as input to makePattern() -struct PatternNode { - std::shared_ptr node; - int output_port = -1; - - operator ov::Output() const { - return get_output(); - } - - ov::Output get_output() const { - if (output_port >= 0) - return node->output(output_port); - return node->get_default_output(); - } - - PatternNode(const Output& out) : node(out.get_node_shared_ptr()), output_port(out.get_index()) {} - - PatternNode() { - node = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank()); - } - PatternNode(ov::Rank rank) { - node = ov::pass::pattern::any_input([rank](const Output& value) { - if (!rank.compatible(value.get_partial_shape().rank())) { - _VERBOSE_LOG("*mismatched PatternNode rank ", value, " expecting ", rank); - return false; - } - return true; - }); - } - - PatternNode(values_info vt) { - node = ov::pass::pattern::any_input([vt](const Output& value) { - if (!vt.predicate(value)) { - _VERBOSE_LOG("*mismatched PatternNode ", value); - return false; - } - _VERBOSE_LOG(" matched PatternNode ", value); - return true; - }); - } - PatternNode(const std::shared_ptr& node) : node(node) {} - PatternNode(const std::shared_ptr& node) : node(node) {} - PatternNode(const std::shared_ptr& pattern) - : node(std::dynamic_pointer_cast(pattern)) {} - - // 1D-vector & scalar of symbol - PatternNode(std::initializer_list v) { - // initializer_list of Symbol ls special, need to be recorded - // and eval/check in the callback after whole match is complete, - // where all observed actual constant values are known, first - // we will go over all symbols and collect actual value for individual - // symbol(named symbol), and then we go over all derived symbols and - // evaluate their predicated values and compare against what observed, - // and check if they all match. - // node = ConstVector(std::vector(v), nullptr); - node = ov::pass::pattern::wrap_type(); - - auto& rt_info = node->get_rt_info(); - rt_info["symbolic_const_value"] = std::vector(v); - } - PatternNode(const std::vector& v) { - node = ov::pass::pattern::wrap_type(); - auto& rt_info = node->get_rt_info(); - rt_info["symbolic_const_value"] = v; - } - - PatternNode(Symbol v) { - node = ov::pass::pattern::wrap_type(); - auto& rt_info = node->get_rt_info(); - rt_info["symbolic_const_value"] = std::vector({v}); - } - - // scalar constant (treated as wildcard for single-element-constant with any rank) - PatternNode(int v) : node(std::make_shared(element::from(), Shape({}), v)) {} - PatternNode(float v) : node(std::make_shared(element::from(), Shape({}), v)) {} - - PatternNode(std::initializer_list v, values_info vi = nullptr) { - node = ConstVector(std::vector(v), vi); - } - PatternNode(std::initializer_list v, values_info vi = nullptr) { - node = ConstVector(std::vector(v), vi); - } - PatternNode(std::initializer_list v, values_info vi = nullptr) { - node = ConstVector(std::vector(v.begin(), v.end()), vi); - } - PatternNode(std::initializer_list v, values_info vi = nullptr) { - node = ConstVector(std::vector(v.begin(), v.end()), vi); - } - - // 1d const tensor or scalar - template ::value, bool>::type = true> - static std::shared_ptr ConstVector(const std::vector& vec, values_info vi = nullptr) { - if (vi.size() > 0) - return std::make_shared(vi[0].first, vi[0].second.to_shape(), vec); - // initializer_list w/o value_info means to create normal 1D vector - return std::make_shared(element::from(), Shape({vec.size()}), vec); - } -}; - -using SymbolObservationVector = std::vector>; - -template -void add_symbol_observed(SymbolObservationVector& sov, const Symbol& sym, const T& value) { - auto v = static_cast(value); - OPENVINO_ASSERT(static_cast(v) == value); // ensure there is no precison lost in double - sov.push_back(std::make_pair(sym, v)); -} -/* -template -static bool vector_equal_to_any(const std::vector& v0, detail::AttrAny& any) { - auto v1 = any.cast_to_vector(); - if (v0.size() != v1.size()) - return false; - return std::equal(v0.begin(), v0.end(), v1.begin()); -} - -template -static bool scalar_equal_to_any(const T& v0, detail::AttrAny& any) { - if (any.is()) { - return v0 == any.as(); - } else if (any.is()) { - return v0 == any.as(); - } - return v0 == any.as(); -} -*/ -// for arithmetic data type, Attr matcher will success as long as the actuall attributes -// is equal to the casted attributes from pattern w/o requiring exact type match. -class AttrMatcher : public ov::AttributeVisitor { -public: - AttrMap& m_attr_map; - std::vector m_missing_attrs; - SymbolObservationVector* m_psov; - bool m_all_matched; - - AttrMatcher(AttrMap& attrs, SymbolObservationVector* psov = nullptr) - : m_attr_map(attrs), - m_psov(psov), - m_all_matched(true) {} - - bool matched() { - return m_all_matched; - } - - const std::vector& get_missing_attrs() { - return m_missing_attrs; - } - - bool should_skip(const std::string& name, bool allow_symbol = false) { - if (m_attr_map.count(name) == 0) { - m_missing_attrs.push_back(name); - return true; - } - - if (!allow_symbol) { - OPENVINO_ASSERT(!m_attr_map[name].any.is(), "Symbol is not allowed."); - } - return false; - } - - void add_match_result(const std::string& name, bool is_matched) { - if (!is_matched) { - _VERBOSE_LOG(" attribute '", name, "' mismatch."); - } - m_all_matched = m_all_matched && is_matched; - } - - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name)) - return; - add_match_result(name, value.get() == m_attr_map[name].as_string()); - } - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name)) - return; - add_match_result(name, m_attr_map[name].equal_to(value.get())); - } - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name)) - return; - add_match_result(name, m_attr_map[name].equal_to(value.get())); - } - void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { - if (should_skip(name)) - return; - add_match_result(name, m_attr_map[name].equal_to(value.get())); - } - - void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { - if (should_skip(name)) - return; - add_match_result(name, m_attr_map[name].equal_to(value.get())); - } - - void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { - if (should_skip(name)) - return; - add_match_result(name, m_attr_map[name].equal_to(value.get())); - } - - void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { - if (should_skip(name)) - return; - add_match_result(name, m_attr_map[name].equal_to(value.get())); - } - - // only integer is allowed to be of symbol type - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name, true)) - return; - auto& any = m_attr_map[name].any; - if (any.is()) { - if (m_psov) { - // collect symbol reference and do comparison later - add_symbol_observed(*m_psov, any.as(), value.get()); - } - return; - } - add_match_result(name, m_attr_map[name].cast_to() == value.get()); - } - void on_adapter(const std::string& name, ov::ValueAccessor& value) override { - if (should_skip(name, true)) - return; - auto& any = m_attr_map[name].any; - if (any.is()) { - if (m_psov) { - // collect symbol reference and do comparison later - add_symbol_observed(*m_psov, any.as(), value.get()); - } - return; - } - add_match_result(name, m_attr_map[name].cast_to() == value.get()); - } - - void on_adapter(const std::string& name, ov::ValueAccessor& adapter) override { - if (should_skip(name)) - return; - OPENVINO_ASSERT(m_attr_map.count(name) > 0); - auto& any = m_attr_map[name].any; - bool is_matched = true; - if (auto a = ov::as_type>(&adapter)) { - is_matched = (static_cast(*a) == any.as()); - } else if (auto a = ov::as_type>(&adapter)) { - is_matched = (a->get() == any.as()); - } else if (auto a = ov::as_type>(&adapter)) { - is_matched = (a->get() == any.as()); - } else if (auto a = ov::as_type>(&adapter)) { - is_matched = m_attr_map[name].equal_to(a->get()); - } else if (auto a = ov::as_type>(&adapter)) { - is_matched = m_attr_map[name].equal_to(a->get()); - } else if (auto a = ov::as_type>>(&adapter)) { -#if defined(__APPLE__) || defined(__EMSCRIPTEN__) - is_matched = m_attr_map[name].equal_to(static_cast&>(*a)); -#else - is_matched = m_attr_map[name].equal_to(a->get()); -#endif - } else if (auto a = ov::as_type>(&adapter)) { - is_matched = m_attr_map[name].equal_to(a->get()); - } else if (auto a = ov::as_type>(&adapter)) { - is_matched = (a->get() == any.as()); - } else if (auto a = ov::as_type>(&adapter)) { - is_matched = (a->get() == any.as()); - } else if (auto a = ov::as_type>(&adapter)) { - is_matched = m_attr_map[name].equal_to(a->get()); - } else { - OPENVINO_THROW("AttrSetter met unsupported AttributeAdapter"); - } - add_match_result(name, is_matched); - } -}; -} // namespace detail - -//================================================================================================== - -inline std::shared_ptr GenInput(values_info vt = nullptr) { - return ov::pass::pattern::any_input([vt](const Output& value) { - if (!vt.predicate(value)) { - _VERBOSE_LOG("*mismatched GenInput ", value); - return false; - } - _VERBOSE_LOG(" matched GenInput ", value); - return true; - }); -} - -inline std::shared_ptr makePattern() { - detail::PatternNode g; - return g.node; -} - -inline std::shared_ptr makePattern(ov::Rank rank) { - detail::PatternNode g(rank); - return g.node; -} - -inline std::shared_ptr makePattern(values_info vt) { - detail::PatternNode g(vt); - return g.node; -} - -// unknown const -inline std::shared_ptr makeConst(const ov::element::Type& type, - const ov::PartialShape& pshape, - std::function pred) { - return ov::pass::pattern::wrap_type([type, pshape, pred](const Output& value) { - auto cnode = ov::as_type_ptr(value.get_node_shared_ptr()); - if (!cnode) - return false; - - if (!type.compatible(value.get_element_type()) || !pshape.compatible(value.get_partial_shape())) { - return false; - } - if (pred && !pred(*cnode)) { - return false; - } - return true; - }); -} - -template -std::shared_ptr makeConst(const ov::element::Type& type, - const ov::Shape& shape, - std::initializer_list values) { - return std::make_shared(type, shape, std::vector(values)); -} - -template -std::shared_ptr makeConst(const ov::element::Type& type, const ov::Shape& shape, const std::vector& values) { - return std::make_shared(type, shape, values); -} - -template -std::shared_ptr makePattern(const std::vector& inputs, - detail::AttrMap attrmap = {}, - values_info vt = nullptr, - const char* friendly_name = nullptr) { - auto* p_type_info = &(T::get_type_info_static()); - OutputVector args; - for (auto& in : inputs) - args.push_back(in.get_output()); - - // pattern nodes are better for pattern matching because - // - it can be generic/incomplete, so normal OP node is not working properly - // - it has predicate to correctly decide which branch to take (in Or pattern) - auto pattern_node = std::make_shared(args, attrmap); - - if (friendly_name) { - pattern_node->set_friendly_name(friendly_name); - } else { - std::stringstream ss; - ss << p_type_info->get_version() << "::" << p_type_info->name; - ss << "("; - const char* sep = ""; - for (auto& i : args) { - ss << sep << i.get_node()->get_name(); - sep = ","; - } - ss << ")"; - pattern_node->set_friendly_name(ss.str()); - } - - auto* pnode = pattern_node.get(); - pnode->set_predicate([p_type_info, vt, pnode, friendly_name, attrmap](const Output& value) { - (void)friendly_name; - auto value_node = value.get_node_shared_ptr(); - if (!value_node->get_type_info().is_castable(*p_type_info)) { - _VERBOSE_LOG("*mismatched makePattern OP type: ", pnode->get_friendly_name(), "vs", value); - return false; - } - - if (!vt.predicate(value)) { - _VERBOSE_LOG("*mismatched makePattern value info: ", pnode->get_friendly_name(), "vs", value); - return false; - } - - auto& attr_map = pnode->get_attrs(); - if (!attr_map.empty()) { - detail::AttrMatcher visitor(attr_map); - value_node->visit_attributes(visitor); - if (!visitor.matched()) { - _VERBOSE_LOG("*mismatched attributes : ", - pnode->get_friendly_name(), - " vs ", - value_node->get_friendly_name()); - return false; - } - } - - _VERBOSE_LOG(" matched makePattern ", pnode->get_friendly_name(), " == ", value); - return true; - }); - - return pattern_node; -} - -template -std::shared_ptr makeOP(const std::vector& inputs, - detail::AttrMap attrmap = {}, - const char* friendly_name = nullptr) { - std::shared_ptr node = std::make_shared(); - - OutputVector args; - for (auto& in : inputs) - args.push_back(in.get_output()); - node->set_arguments(args); - - detail::AttrSetter visitor(attrmap); - node->visit_attributes(visitor); - - auto missing_attrs = visitor.get_missing_attrs(); - - // when some attribute is missing or is symbol, the returned - // node is suitable for pattern matching only. - OPENVINO_ASSERT(missing_attrs.size() == 0, - "missing ", - missing_attrs.size(), - " attributes : ", - missing_attrs[0], - "..."); - - if (friendly_name) - node->set_friendly_name(friendly_name); - node->constructor_validate_and_infer_types(); - return node; -} - -template -std::shared_ptr GenConst_tril(values_info vt) { - return ov::pass::pattern::wrap_type([vt](const Output& value) { - auto s1 = as_type_ptr(value.get_node_shared_ptr()); - if (!s1) { - _VERBOSE_LOG("*mismatched GenConst_tril op type: opset1::Constant vs", value); - return false; - } - - if (!vt.predicate(value)) { - _VERBOSE_LOG("*mismatched GenConst_tril values_info:", value); - return false; - } - - // ignore higher dimensions, require lowerst 2D to be lower triangular - auto shape = s1->get_output_shape(0); - auto rank = shape.size(); - if (rank < 2) { - _VERBOSE_LOG("*mismatched GenConst_tril rank < 2 (rank=", rank, ")"); - return false; - } - if (shape[rank - 1] != shape[rank - 2]) { - _VERBOSE_LOG("*mismatched GenConst_tril shape[-1] != shape[-2] : ", - shape[rank - 1], - " != ", - shape[rank - 2]); - return false; - } - // NxN const matrix - auto N = shape[rank - 1]; - std::vector output_vector = s1->cast_vector(); - // check if it's unit lower triangular matrix - for (size_t i = 0; i < N; i++) { - for (size_t j = 0; j < N; j++) { - if (static_cast(output_vector[i * N + j]) != static_cast(j <= i)) - return false; - } - } - return true; - }); -} - -inline std::shared_ptr operator|(const Output& lhs, const Output& rhs) { - return std::make_shared(OutputVector{lhs, rhs}); -} - -inline std::shared_ptr operator|(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return std::make_shared( - OutputVector{lhs->get_default_output(), rhs->get_default_output()}); -} - -inline std::shared_ptr GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) { - auto opt1 = makePattern({data, {start}, {stop}, {step}, {static_cast(axis)}}); - - std::vector vbegin(axis + 1, Symbol(0)); - std::vector vend(axis + 1, Symbol(0)); - std::vector vstride(axis + 1, Symbol(1)); - - vbegin[axis] = start; - vend[axis] = stop; - vstride[axis] = step; - - detail::PatternNode begin(vbegin); - detail::PatternNode end(vend); - detail::PatternNode stride(vstride); - - std::vector begin_mask(axis + 1, 1); - std::vector end_mask(axis + 1, 1); - std::vector new_axis_mask; - std::vector shrink_axis_mask; - std::vector ellipsis_mask; - - begin_mask[axis] = 0; - end_mask[axis] = 0; - - auto opt2 = makePattern({data, begin, end, stride}, - {{"begin_mask", begin_mask}, - {"end_mask", end_mask}, - {"new_axis_mask", new_axis_mask}, - {"shrink_axis_mask", shrink_axis_mask}, - {"ellipsis_mask", ellipsis_mask}}); - return opt1 | opt2; -} - -//================================================================================================== -class PatternValidator { -public: - PatternValidator(ov::pass::pattern::Matcher& m, bool force_verbose = false) { - auto saved_force_matcher_verbose = force_matcher_verbose; - force_matcher_verbose = force_verbose; - m_is_valid = validate(m); - force_matcher_verbose = saved_force_matcher_verbose; - } - - double& operator[](const char* symbol_name) { - return m_symbol_values[symbol_name]; - } - - operator bool() { - if (!m_is_valid) { - _VERBOSE_LOG("PatternValidator failed."); - } - return m_is_valid; - } - - bool validate(ov::pass::pattern::Matcher& m) { - detail::SymbolObservationVector sov; - - auto& pvmap = m.get_pattern_value_map(); - for (auto& pv : pvmap) { - auto pnode = pv.first; - auto value_node = pv.second.get_node_shared_ptr(); - auto& rt_info = pnode->get_rt_info(); - - if (auto pattern_node = std::dynamic_pointer_cast(pnode)) { - // pattern_node has no attribute and it has been matched in its predicate - if (rt_info.count("symbolic_const_value")) { - // symbolic constant node, a symbol reference is observed - auto& symbols = rt_info["symbolic_const_value"].as>(); - auto constop = std::dynamic_pointer_cast(value_node); - if (!constop) { - _VERBOSE_LOG("symbolic_const_value unexpected OP: ", value_node->get_friendly_name()); - return false; - } - auto ele_cnt = shape_size(constop->get_shape()); - auto ele_type = constop->get_element_type(); - - if (ele_cnt != symbols.size()) { - _VERBOSE_LOG("symbolic_const_value expect ", - symbols.size(), - " but got ", - ele_cnt, - " from ", - value_node->get_friendly_name()); - return false; - } - - if (ele_type == ov::element::i32 || ele_type == ov::element::f32 || ele_type == ov::element::i64) { - auto observed = constop->cast_vector(); - for (size_t i = 0; i < symbols.size(); i++) - detail::add_symbol_observed(sov, symbols[i], observed[i]); - } else { - _VERBOSE_LOG("Unexpect element type ", ele_type, " from ", value_node->get_friendly_name()); - return false; - } - } - continue; - } - if (auto pconst_node = std::dynamic_pointer_cast(pnode)) { - // const_node needs to match type/shape/value - auto vconst_node = std::dynamic_pointer_cast(value_node); - if (!vconst_node) { - _VERBOSE_LOG("expecting Constant op, but got ", value_node); - return false; - } - if (pconst_node->get_output_element_type(0) != vconst_node->get_output_element_type(0)) { - _VERBOSE_LOG("expecting Constant of type ", - pconst_node->get_output_element_type(0), - " but got ", - vconst_node); - return false; - } - // for constant node matched in pattern, a scalar constant is considered to - // be compatible with any shape with 1 element, like {}, {1,1}, {1,1,...} - const auto& expected_shape = pconst_node->get_output_shape(0); - if (expected_shape.size() == 0) { - if (shape_size(vconst_node->get_output_shape(0)) != 1) { - _VERBOSE_LOG("expecting a single element const, but got ", vconst_node); - return false; - } - } else { - if (expected_shape != vconst_node->get_output_shape(0)) { - _VERBOSE_LOG("expecting Constant of shape ", expected_shape, " but got ", vconst_node); - return false; - } - } - auto byte_size = - shape_size(vconst_node->get_output_shape(0)) * vconst_node->get_output_element_type(0).size(); - if (std::memcmp(pconst_node->get_data_ptr(), vconst_node->get_data_ptr(), byte_size) != 0) { - _VERBOSE_LOG("Constant value mismatch."); - return false; - } - continue; - } - - // compare attributes between them - // assume that there is no Symbol in the attributes, we need to fetch each attributes - // from - if (rt_info.count("__attrs__") == 0) { - _VERBOSE_LOG(" attr compare failed: __attrs__ not found for ", pnode->get_friendly_name()); - return false; - } - - // attr not specified is treated as not-care and ignored - // attr with symbol - - detail::AttrMap& attr_map = rt_info["__attrs__"].as(); - detail::AttrMatcher visitor(attr_map, &sov); - value_node->visit_attributes(visitor); - if (!visitor.matched()) { - _VERBOSE_LOG(" attr compare failed: ", - pnode->get_friendly_name(), - " vs ", - value_node->get_friendly_name()); - return false; - } - } - - // check symbol consistency & return independent symbols - // assign independent symbols & check literals - std::map symbol_value_map; - for (auto& ref : sov) { - auto& sym = ref.first; - auto& value = ref.second; - - if (sym.is_independent_var()) { - auto id = sym.get_id(); - if (symbol_value_map.count(id)) { - if (symbol_value_map[id] != value) { - _VERBOSE_LOG(" in-consistency between multiple references of same symbol : ", - symbol_value_map[id], - " != ", - value); - return false; - } - } else { - symbol_value_map[id] = value; - m_symbol_values[sym.get_name()] = value; - _VERBOSE_LOG("Independent Symbol: ", sym.get_name(), " = ", value); - } - } - - if (sym.is_literal_const()) { - auto literal = sym.eval(symbol_value_map); - if (literal != value) { - _VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value); - return false; - } - // no need to put literal into value map to eval them. - } - } - - // derive/eval dependent symbol's value and check against observed - for (auto& ref : sov) { - auto& sym = ref.first; - if (!sym.is_literal_const() && !sym.is_independent_var()) { - auto derived = sym.eval(symbol_value_map); - auto value = ref.second; - bool is_match; - - if (std::trunc(value) == value) { - // observed integer - is_match = (derived == value); - } else { - auto abs_diff = std::abs(derived - value); - auto avg = 0.5f * std::abs(derived + value); - if (avg != 0) { - is_match = abs_diff < avg * 1e-7; // relative error less than threshold - } else { - is_match = (derived == value); - } - } - if (!is_match) { - _VERBOSE_LOG(" mismatch between derived & value : ", - std::setprecision(std::numeric_limits::max_digits10), - derived, - " != ", - std::setprecision(std::numeric_limits::max_digits10), - value); - return false; - } - } - } - return true; - } - -private: - std::map m_symbol_values; - bool m_is_valid; -}; - -} // namespace gen_pattern -} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/utils/print_model.hpp b/src/plugins/intel_cpu/src/utils/print_model.hpp deleted file mode 100644 index 6b4eb01180a..00000000000 --- a/src/plugins/intel_cpu/src/utils/print_model.hpp +++ /dev/null @@ -1,415 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "openvino/core/attribute_visitor.hpp" -#include "openvino/core/model.hpp" -#include "openvino/core/node.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/pass/pass.hpp" - -namespace ov { -namespace pass { - -namespace detail { - -// to_code convert value into literal/constexpr/initializer_list/factory_calls in C++ source code -inline std::string to_code(bool value) { - return value ? "true" : "false"; -} -inline std::string to_code(const std::string& value) { - return std::string("\"") + value + "\""; -} -inline std::string to_code(const element::Type& value) { - return std::string("element::") + value.to_string(); -} -inline std::string to_code(const ov::Shape& value) { - std::stringstream ss; - ss << "ov::Shape({"; - for (auto& d : value) - ss << d << ","; - ss << "})"; - return ss.str(); -} -inline std::string to_code(int value) { - if (INT_MAX == value) { - return "INT_MAX"; - } - if (INT_MIN == value) { - return "INT_MIN"; - } - return std::to_string(value); -} -inline std::string to_code(int64_t value) { - if (LLONG_MAX == value) { - return "LLONG_MAX"; - } - if (LLONG_MIN == value) { - return "LLONG_MIN"; - } - const char* suffix = "LL"; - if (value == static_cast(static_cast(value))) { - // save suffix since most values can be expressed as int - // this produces more readable code - suffix = ""; - } - return std::to_string(value) + suffix; -} -inline std::string to_code(uint64_t value) { - if (ULLONG_MAX == value) { - return "ULLONG_MAX"; - } - const char* suffix = "uLL"; - if (value == static_cast(static_cast(value))) { - // save suffix since most values can be expressed as int - // this produces more readable code - suffix = ""; - } - return std::to_string(value) + suffix; -} -inline std::string to_code(int8_t value) { - return std::to_string(static_cast(value)); -} -inline std::string to_code(uint8_t value) { - return std::to_string(static_cast(value)); -} - -template -std::string to_code_float(T value) { - if (std::isnan(value)) { - return "NAN"; - } else if (std::isinf(value)) { - return (value > 0 ? "INFINITY" : "-INFINITY"); - } else if (value == FLT_MIN) { - return "FLT_MIN"; - } else if (value == -FLT_MIN) { - return "-FLT_MIN"; - } else if (value == FLT_MAX) { - return "FLT_MAX"; - } else if (value == -FLT_MAX) { - return "-FLT_MAX"; - } - auto strv = std::to_string(value); - if (strv.find(".") == std::string::npos && strv.find("e") == std::string::npos) - strv += ".0"; - if (std::is_same::value) - strv += "f"; - return strv; -} - -inline std::string to_code(float value) { - return to_code_float(value); -} -inline std::string to_code(double value) { - return to_code_float(value); -} -template -std::string to_code(const std::vector& values, bool no_braces = false, int maxsize = 80) { - std::stringstream ss; - if (!no_braces) - ss << "{"; - const char* sep = ""; - for (auto& v : values) { - if (ss.tellp() > maxsize) { - ss << "... (" << values.size() << " in total)"; - break; - } - ss << sep << to_code(v); - sep = ","; - } - if (!no_braces) - ss << "}"; - return ss.str(); -} - -template -std::string to_code(std::shared_ptr constop) { - bool no_braces = (constop->get_shape().size() == 0); - auto ele_type = constop->get_element_type(); - if (ele_type == element::Type_t::f32) { - return to_code(constop->get_vector(), no_braces); - } else if (ele_type == element::Type_t::i8) { - return to_code(constop->get_vector(), no_braces); - } else if (ele_type == element::Type_t::u8) { - return to_code(constop->get_vector(), no_braces); - } else if (ele_type == element::Type_t::i32) { - return to_code(constop->get_vector(), no_braces); - } else if (ele_type == element::Type_t::i64) { - return to_code(constop->get_vector(), no_braces); - } - - // general case - std::stringstream ss; - if (!no_braces) - ss << "{"; - auto ele_size = shape_size(constop->get_shape()); - if (ele_size < 9) { - const char* sep = ""; - for (auto v : constop->get_value_strings()) { - ss << sep << v; - sep = ", "; - } - } else { - ss << "..."; - } - if (!no_braces) - ss << "}"; - return ss.str(); -} - -class OstreamAttributeVisitor : public ngraph::AttributeVisitor { - std::ostream& os; - const char* sep = ""; - -public: - OstreamAttributeVisitor(std::ostream& os) : os(os) {} - - void append_attribute(const std::string& name, const std::string& value) { - os << sep << "{\"" << name << "\", " << value << "}"; - sep = ", "; - } - - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - if (auto a = ov::as_type>>(&adapter)) { - const auto& strset = a->get(); - std::vector values(strset.begin(), strset.end()); - append_attribute(name, to_code(values)); - } else if (auto a = ov::as_type>>(&adapter)) { - append_attribute(name, to_code(a->get())); - } else if (auto a = ov::as_type>(&adapter)) { - const auto& value = a->get(); - append_attribute(name, value.to_string()); - } else if (auto a = ov::as_type>>(&adapter)) { - const auto& vinfo = a->get()->get_info(); - std::stringstream ss; - ss << vinfo.variable_id << vinfo.data_shape << vinfo.data_type; - append_attribute(name, ss.str()); - } else { - append_attribute(name, "?"); - } - } - - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { - append_attribute(name, to_code(adapter.get())); - } - void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { - append_attribute(name, "Model"); - } -}; - -template -void dump_cpp_style(std::ostream& os, const std::shared_ptr& model) { - const ov::Model& f = *model; - std::string prefix = ""; - std::string tag = ""; - std::string sep = ""; - os << prefix; - for (auto op : f.get_results()) { - os << sep << op->get_name(); - sep = ","; - } - os << " " << f.get_friendly_name() << "(\n" << prefix; - for (auto op : f.get_parameters()) { - os << " " << tag << op->get_friendly_name() << ",\n" << prefix; - } - os << ") {\n"; - - // collect all scalar & short 1D vectors for literal-style display - std::map, std::string> literal_consts; - for (auto op : f.get_ordered_ops()) { - if (auto constop = std::dynamic_pointer_cast(op)) { - // only i32/f32 type const literal can be parsed by C++ compiler - if (constop->get_output_element_type(0) != ov::element::i32 && - constop->get_output_element_type(0) != ov::element::i64 && - constop->get_output_element_type(0) != ov::element::f32) - continue; - auto shape = constop->get_shape(); - if (shape.size() > 1) - continue; - if (shape_size(constop->get_shape()) > 64) - continue; - literal_consts[op] = to_code(constop); - } - } - - auto get_output_values_info = [](std::shared_ptr& op) { - std::stringstream ss; - const char* sep = ""; - for (size_t i = 0; i < op->get_output_size(); i++) { - ss << sep << op->get_output_element_type(i) << op->get_output_partial_shape(i); - sep = " "; - } - return ss.str(); - }; - - // change name convension - std::map opname; - std::map opname_count; - for (auto op : f.get_ordered_ops()) { - auto name = op->get_friendly_name(); - std::replace(name.begin(), name.end(), '\\', '_'); - std::replace(name.begin(), name.end(), '/', '_'); - std::replace(name.begin(), name.end(), '.', '_'); - std::replace(name.begin(), name.end(), '[', '_'); - std::replace(name.begin(), name.end(), ']', '_'); - std::replace(name.begin(), name.end(), '-', 'n'); - if (name[0] >= '0' && name[0] <= '9') { - const auto& type_info = op->get_type_info(); - name.insert(0, type_info.name); - } - int idx = 0; - if (opname_count.count(name)) { - idx = opname_count[name] + 1; - } - opname_count[name] = idx; - - if (idx) - name += std::to_string(idx); - - opname[op.get()] = name; - } - - for (auto op : f.get_ordered_ops()) { - if (literal_consts.count(op)) - continue; - - const auto& type_info = op->get_type_info(); - auto version_info = std::string(type_info.get_version()); - auto type = version_info + "::" + type_info.name; - auto name = opname[op.get()]; - os << prefix << " "; - - if (auto constop = std::dynamic_pointer_cast(op)) { - os << "auto " << name << " = makeConst(" << to_code(op->get_output_element_type(0)) << ", " - << to_code(op->get_output_shape(0)) << ", " << to_code(constop) << ");" << std::endl; - } else { - os << "auto " << name << " = makeOP<" << type << ">({"; - // input args - sep = ""; - for (size_t i = 0; i < op->get_input_size(); i++) { - auto vout = op->get_input_source_output(i); - auto iop = vout.get_node_shared_ptr(); - if (iop->get_output_size() > 1) { - auto out_port = vout.get_index(); - os << sep << tag << opname[iop.get()] << "->output(" << out_port << ")"; - } else { - if (literal_consts.count(iop)) - os << sep << tag << literal_consts[iop]; - else - os << sep << tag << opname[iop.get()]; - } - sep = ", "; - } - os << "}"; - - // attributes as AnyMap - std::stringstream ss2; - OstreamAttributeVisitor osvis(ss2); - op->visit_attributes(osvis); - auto str_attr = ss2.str(); - if (str_attr.size()) - os << ", {" << str_attr << "}"; - os << "); // tensor_array<" << get_output_values_info(op) << "> " << op->get_friendly_name(); - - os << "("; - sep = ""; - for (size_t i = 0; i < op->get_input_size(); i++) { - auto vout = op->get_input_source_output(i); - auto iop = vout.get_node_shared_ptr(); - os << sep << tag << iop->get_friendly_name(); - if (iop->get_output_size() > 1) { - auto out_port = vout.get_index(); - os << "[" << out_port << "]"; - } - sep = ", "; - } - os << ")" << std::endl; - } - - // recursively output subgraphs - if (auto msubgraph = std::dynamic_pointer_cast(op)) { - auto cnt = msubgraph->get_internal_subgraphs_size(); - for (size_t i = 0; i < cnt; i++) { - os << " MultiSubGraphOp " << tag << msubgraph->get_friendly_name() << "[" << i << "]" << std::endl; - dump_cpp_style(os, msubgraph->get_function(i)); - } - } - } - os << prefix << "}\n"; -} - -} // namespace detail - -class OPENVINO_API PrintModel : public ov::pass::ModelPass { -public: - OPENVINO_RTTI("ov::pass::PrintModel"); - - PrintModel(std::string file_name) { - static int dump_index = 0; - m_file_name = std::string("modelprint_") + std::to_string(dump_index) + "_" + file_name; - dump_index++; - } - ~PrintModel() {} - - bool run_on_model(const std::shared_ptr& model) override { - if (m_file_name.empty()) - return false; - - std::ofstream ofs(m_file_name); - if (!ofs) { - // OPENVINO_WARN << "Error opening file " << m_file_name << " for output" << std::endl; - return false; - } - detail::dump_cpp_style(ofs, model); - ofs.close(); - return true; - } - -protected: - std::string m_file_name; -}; -} // namespace pass -} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/rotary_pos_emb.cpp deleted file mode 100644 index edfce4dcc95..00000000000 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/rotary_pos_emb.cpp +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "common_test_utils/common_utils.hpp" -#include "functional_test_utils/skip_tests_config.hpp" -#include "ie_precision.hpp" -#include "shared_test_classes/base/ov_subgraph.hpp" -#include "test_utils/cpu_test_utils.hpp" -#include "test_utils/fusing_test_utils.hpp" -#include "utils/gen_pattern.hpp" - -using namespace CPUTestUtils; -using namespace ov::gen_pattern; -using namespace ov::test; -using namespace ov; - -static ov::OutputVector makeCosSinCache(int max_position_embeddings, int rotary_ndims) { - std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); - std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); - - // rotate_half style cos/sin table: - // y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2 - // y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1 - // - for (int i = 0, k = 0; i < rotary_ndims; i += 2, k++) { - auto xita_i = 1.0 / std::pow(10000.0, static_cast(i) / rotary_ndims); - float* psin = lut_sin.data(); - float* pcos = lut_cos.data(); - for (int m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) { - auto vsin = std::sin(xita_i * m); - auto vcos = std::cos(xita_i * m); - pcos[k] = pcos[k + rotary_ndims / 2] = vcos; - psin[k] = psin[k + rotary_ndims / 2] = vsin; - } - } - auto shape = ov::Shape({1, 1, static_cast(max_position_embeddings), static_cast(rotary_ndims)}); - auto Cos = makeConst(ov::element::f32, shape, lut_cos); - auto Sin = makeConst(ov::element::f32, shape, lut_sin); - return {Cos, Sin}; -} - -static std::shared_ptr buildROPE_Llama2(const int batch, - const int seq_length, - const int max_position_embeddings, - const int num_head, - const int ndims) { - auto input = std::make_shared(ov::element::f32, PartialShape{batch, -1, num_head, ndims}); - auto pos_id_end = std::make_shared(ov::element::i32, ov::Shape{}); - auto pos_ids = std::make_shared(ov::element::i32, PartialShape{1, -1}); - - auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); - auto Constant582 = cos_sin_cache[0]; - auto Constant585 = cos_sin_cache[1]; - - // concat KV length - auto transpose_Transpose = makeOP({input, {0, 2, 1, 3}}); - auto slice_Unsqueeze_426 = makeOP({pos_id_end, 0}); - auto ScatterUpdate_152236 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); - auto slice_Slice = makeOP({Constant582, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}}, - {{"begin_mask", {1, 1, 0}}, - {"end_mask", {1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto squeeze_Squeeze = makeOP({slice_Slice, 1}); - auto squeeze_Squeeze_435 = makeOP({squeeze_Squeeze, 0}); - auto index_441_Gather = makeOP({squeeze_Squeeze_435, pos_ids, 0}, {{"batch_dims", 0}}); - auto unsqueeze_Unsqueeze = makeOP({index_441_Gather, 1}); - auto mul_Multiply = - makeOP({transpose_Transpose, unsqueeze_Unsqueeze}, {{"auto_broadcast", "numpy"}}); - auto size_ShapeOf_448 = makeOP({transpose_Transpose}, {{"output_type", "i32"}}); - auto size_Gather_450 = makeOP({size_ShapeOf_448, 3, 0}, {{"batch_dims", 0}}); - auto floor_divide_Divide = - makeOP({size_Gather_450, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto floor_divide_Floor = makeOP({floor_divide_Divide}); - auto slice_Unsqueeze_452 = makeOP({floor_divide_Floor, 0}); - auto ScatterUpdate_152312 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); - auto slice_Slice_459 = makeOP( - {transpose_Transpose, ScatterUpdate_152312, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Constant_182988 = makeConst(element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - }), - {-1.000000f}); - auto neg_Multiply = makeOP({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}}); - auto ScatterUpdate_152368 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); - auto slice_Slice2 = - makeOP({transpose_Transpose, {0, 0, 0, 0}, ScatterUpdate_152368, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto cat_Concat = makeOP({neg_Multiply, slice_Slice2}, {{"axis", -1}}); - auto ScatterUpdate_152421 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); - auto slice_Slice_433 = makeOP({Constant585, {0, 0, 0}, ScatterUpdate_152421, {1, 1, 1}}, - {{"begin_mask", {1, 1, 0}}, - {"end_mask", {1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto squeeze_Squeeze_436 = makeOP({slice_Slice_433, 1}); - auto squeeze_Squeeze_437 = makeOP({squeeze_Squeeze_436, 0}); - auto index_446_Gather = makeOP({squeeze_Squeeze_437, pos_ids, 0}, {{"batch_dims", 0}}); - auto unsqueeze_Unsqueeze_447 = makeOP({index_446_Gather, 1}); - auto mul_Multiply_463 = - makeOP({cat_Concat, unsqueeze_Unsqueeze_447}, {{"auto_broadcast", "numpy"}}); - auto add_Add = makeOP({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}}); - - return std::make_shared(ov::NodeVector{add_Add}, ov::ParameterVector{input, pos_id_end, pos_ids}); -} - -namespace CPULayerTestsDefinitions { - -class RoPECPUTest : public SubgraphBaseTest { -public: - ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) { - auto tensor = ov::Tensor(ov::element::i32, shape); - auto* ptr = static_cast(tensor.data()); - for (size_t i = 0; i < tensor.get_size(); i++) { - ptr[i] = start; - start += step; - } - return tensor; - } - - void generate_inputs(const std::vector& targetInputStaticShapes) override { - const auto& funcInputs = function->inputs(); - - const int position_id_start = 15; - auto& input_shape = targetInputStaticShapes[0]; - auto seq_length = input_shape[1]; - - ov::Tensor t_input = - utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); - ov::Tensor t_position_id_end = create_i32_tensor(ov::Shape({}), position_id_start + seq_length); - ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), position_id_start); - - inputs.clear(); - inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); - inputs.insert({funcInputs[1].get_node_shared_ptr(), t_position_id_end}); - inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids}); - } - -protected: - void SetUp() override { - targetDevice = ov::test::utils::DEVICE_CPU; - - const int batch = 2; - const int seq_length = 7; - const size_t max_position_embeddings = 2048; - const size_t ndims = 128; - const size_t num_head = 32; - - InputShape inpShape = {{batch, seq_length, num_head, ndims}, {{batch, seq_length, num_head, ndims}}}; - init_input_shapes({inpShape}); - function = buildROPE_Llama2(batch, seq_length, max_position_embeddings, num_head, ndims); - } -}; - -TEST_F(RoPECPUTest, smoke_CompareWithRefs) { - run(); -} - -} // namespace CPULayerTestsDefinitions diff --git a/src/plugins/intel_cpu/tests/unit/ngraph_transformations/convert_to_rope.cpp b/src/plugins/intel_cpu/tests/unit/ngraph_transformations/convert_to_rope.cpp deleted file mode 100644 index 91fae33253a..00000000000 --- a/src/plugins/intel_cpu/tests/unit/ngraph_transformations/convert_to_rope.cpp +++ /dev/null @@ -1,452 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common_test_utils/ov_test_utils.hpp" -#include "utils/gen_pattern.hpp" -#include "utils/print_model.hpp" - -using namespace testing; -using namespace ov::intel_cpu; -using namespace ov::gen_pattern; - -static ov::OutputVector makeCosSinCache(size_t max_position_embeddings, size_t rotary_ndims) { - std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); - std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); - - // rotate_half style cos/sin table: - // y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2 - // y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1 - // - for (size_t i = 0, k = 0; i < rotary_ndims; i += 2, k++) { - auto xita_i = 1.0 / std::pow(10000.0, static_cast(i) / rotary_ndims); - float* psin = lut_sin.data(); - float* pcos = lut_cos.data(); - for (size_t m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) { - auto vsin = std::sin(xita_i * m); - auto vcos = std::cos(xita_i * m); - pcos[k] = pcos[k + rotary_ndims / 2] = vcos; - psin[k] = psin[k + rotary_ndims / 2] = vsin; - } - } - auto Cos = makeConst(ov::element::f32, ov::Shape({1, 1, max_position_embeddings, rotary_ndims}), lut_cos); - auto Sin = makeConst(ov::element::f32, ov::Shape({1, 1, max_position_embeddings, rotary_ndims}), lut_sin); - - return {Cos, Sin}; -} - -static std::shared_ptr buildROPE_Llama2(const size_t batch, - const size_t seq_length, - const size_t max_position_embeddings, - const size_t ndims, - bool sin_cos_preprocessing) { - auto input = std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, 32, ndims}); - auto param_cos = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); - auto param_sin = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); - - auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); - auto gather_id = std::make_shared(ov::element::i32, ov::Shape{1, seq_length}); - - auto gather_from_sin_cos = [&](const ov::Output& const_tab) { - auto ScatterUpdate_152236 = makeOP({{0, 0, 0}, {2}, seq_len, {0}}); - auto slice_Slice = makeOP({const_tab, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}}, - {{"begin_mask", {1, 1, 0}}, - {"end_mask", {1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto squeeze_Squeeze_435 = - makeOP({slice_Slice, {-1, static_cast(ndims)}}, {{"special_zero", false}}); - auto index_441_Gather = makeOP({squeeze_Squeeze_435, gather_id, {0}}, {{"batch_dims", 0}}); - return makeOP({index_441_Gather, {1, 1, -1, static_cast(ndims)}}, - {{"special_zero", false}}); - }; - - ov::OutputVector cos_sin(2); - ov::ParameterVector parameters; - if (sin_cos_preprocessing) { - auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); - cos_sin[0] = gather_from_sin_cos(cos_sin_cache[0]); - cos_sin[1] = gather_from_sin_cos(cos_sin_cache[1]); - parameters = ov::ParameterVector{input, seq_len, gather_id}; - } else { - cos_sin[0] = param_cos; - cos_sin[1] = param_sin; - parameters = ov::ParameterVector{input, param_cos, param_sin}; - } - - auto transpose_Transpose = makeOP({input, {0, 2, 1, 3}}); - auto mul_Multiply = makeOP({transpose_Transpose, cos_sin[0]}, {{"auto_broadcast", "numpy"}}); - auto slice_Slice_459 = - makeOP({transpose_Transpose, {0, 0, 0, 64}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Constant_182988 = makeConst(ov::element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - }), - {-1.000000f}); - auto neg_Multiply = makeOP({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}}); - auto slice_Slice = - makeOP({transpose_Transpose, {0, 0, 0, 0}, {0, 0, 0, 64}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto cat_Concat = makeOP({neg_Multiply, slice_Slice}, {{"axis", -1}}); - auto mul_Multiply_463 = makeOP({cat_Concat, cos_sin[1]}, {{"auto_broadcast", "numpy"}}); - auto add_Add = makeOP({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}}); - - return std::make_shared(ov::NodeVector{add_Add}, parameters); -} - -TEST_F(TransformationTestsF, ConvertToROPE_LLama2_no_gather) { - disable_rt_info_check(); - const int batch = 2; - const int seq_length = 16; - const size_t max_position_embeddings = 2048; - const size_t ndims = 128; - const size_t num_head = 32; - - model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, false); - manager.register_pass(); - - { - auto hidden_states = - std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, num_head, ndims}); - auto param_cos = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); - auto param_sin = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); - auto add_Add = makeOP({hidden_states, param_cos, param_sin}, - {{"config.slice_start", 0}, - {"config.slice_stop", 0}, - {"config.input_trans0213", true}, - {"config.is_interleaved", false}, - {"config.rotary_ndims", static_cast(ndims)}, - {"config.gather_position_arg_id", 0}}); - - model_ref = std::make_shared(ov::NodeVector{add_Add}, - ov::ParameterVector{hidden_states, param_cos, param_sin}); - } -} - -TEST_F(TransformationTestsF, ConvertToROPE_LLama2_with_gather) { - disable_rt_info_check(); - const int batch = 2; - const int seq_length = 16; - const size_t max_position_embeddings = 2048; - const size_t ndims = 128; - const size_t num_head = 32; - - model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, true); - manager.register_pass(); - - { - auto hidden_states = - std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, num_head, ndims}); - auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); - auto gather_id = std::make_shared(ov::element::i32, ov::Shape{1, seq_length}); - auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); - - auto add_Add = makeOP({hidden_states, cos_sin_cache[0], cos_sin_cache[1], gather_id}, - {{"config.slice_start", 0}, - {"config.slice_stop", 0}, - {"config.input_trans0213", true}, - {"config.is_interleaved", false}, - {"config.rotary_ndims", static_cast(ndims)}, - {"config.gather_position_arg_id", 3}}); - - model_ref = std::make_shared(ov::NodeVector{add_Add}, - ov::ParameterVector{hidden_states, seq_len, gather_id}); - } -} - -static std::shared_ptr buildROPE_GPTNEOX(const int batch, - const int seq_length, - const int max_position_embeddings, - const int ndims, - const int num_heads, - const int rotary_ndims, - bool sin_cos_preprocessing) { - auto batch_s = static_cast(batch); - auto seq_length_s = static_cast(seq_length); - auto ndims_s = static_cast(ndims); - auto rotary_ndims_s = static_cast(rotary_ndims); - auto num_heads_s = static_cast(num_heads); - - auto input = std::make_shared(ov::element::f32, - ov::Shape{batch_s, seq_length_s, num_heads_s, ndims_s * 3}); - auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); - auto gather_idx = - std::make_shared(ov::element::i32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); - auto batch_limit = std::make_shared(ov::element::i32, ov::Shape{1}); - - ov::ParameterVector parameters; - ov::OutputVector cos_sin(2); - if (sin_cos_preprocessing) { - auto cos_sin_lut = makeCosSinCache(max_position_embeddings, rotary_ndims); - auto ro_slice_Slice = makeOP({cos_sin_lut[0], {0}, batch_limit, {1}}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - cos_sin[0] = makeOP({ro_slice_Slice, gather_idx}, {{"axis", 2}}); - - auto ro_slice_Slice_385 = makeOP({cos_sin_lut[1], {0}, batch_limit, {1}}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - cos_sin[1] = makeOP({ro_slice_Slice_385, gather_idx}, {{"axis", 2}}); - parameters = ov::ParameterVector{input, gather_idx, batch_limit}; - } else { - auto param_cos = - std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); - auto param_sin = - std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); - parameters = ov::ParameterVector{input, param_cos, param_sin}; - cos_sin[0] = param_cos; - cos_sin[1] = param_sin; - } - - auto slice_Slice = makeOP({input, {0, 0, 0, 0}, {0, 0, 0, ndims}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto permute_Transpose = makeOP({slice_Slice, {0, 2, 1, 3}}); - auto slice_Slice_351 = - makeOP({permute_Transpose, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto mul_Multiply = makeOP({slice_Slice_351, cos_sin[0]}, {{"auto_broadcast", "numpy"}}); - auto slice_Slice_420 = makeOP( - {slice_Slice_351, {0, 0, 0, rotary_ndims / 2}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Constant_396096 = makeConst(ov::element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - }), - {-1.000000f}); - auto neg_Multiply = makeOP({slice_Slice_420, Constant_396096}, {{"auto_broadcast", "numpy"}}); - auto slice_Slice_414 = - makeOP({slice_Slice_351, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims / 2}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto cat_Concat = makeOP({neg_Multiply, slice_Slice_414}, {{"axis", -1}}); - auto mul_Multiply_424 = makeOP({cat_Concat, cos_sin[1]}, {{"auto_broadcast", "numpy"}}); - auto add_Add = makeOP({mul_Multiply, mul_Multiply_424}, {{"auto_broadcast", "numpy"}}); - auto slice_Slice_357 = - makeOP({permute_Transpose, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto cat_Concat_458 = makeOP({add_Add, slice_Slice_357}, {{"axis", -1}}); - - return std::make_shared(ov::NodeVector{cat_Concat_458}, parameters); -} - -TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_no_gather) { - disable_rt_info_check(); - const int batch = 2; - const int seq_len = 16; - const int ndims = 80; - const int num_heads = 32; - const int rotary_ndims = 20; - const int max_position_embeddings = 2048; - - model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, false); - manager.register_pass(); - { - auto input = - std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims * 3}); - auto param_cos = - std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_len, rotary_ndims}); - auto param_sin = - std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_len, rotary_ndims}); - auto rope = makeOP({input, param_cos, param_sin}, - {{"config.slice_start", 0}, - {"config.slice_stop", ndims}, - {"config.input_trans0213", true}, - {"config.is_interleaved", false}, - {"config.rotary_ndims", rotary_ndims}, - {"config.gather_position_arg_id", 0}}); - model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, param_cos, param_sin}); - } -} - -TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_with_gather) { - disable_rt_info_check(); - const int batch = 2; - const int seq_len = 16; - const int ndims = 80; - const int rotary_ndims = 20; - const int num_heads = 32; - const int max_position_embeddings = 2048; - - model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, true); - manager.register_pass(); - { - auto cos_sin = makeCosSinCache(max_position_embeddings, rotary_ndims); - auto input = - std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims * 3}); - auto gather_idx = - std::make_shared(ov::element::i32, ov::Shape{1, 1, seq_len, rotary_ndims}); - auto batch_limit = std::make_shared(ov::element::i32, ov::Shape{1}); - - auto rope = makeOP({input, cos_sin[0], cos_sin[1], gather_idx}, - {{"config.slice_start", 0}, - {"config.slice_stop", ndims}, - {"config.input_trans0213", true}, - {"config.is_interleaved", false}, - {"config.rotary_ndims", rotary_ndims}, - {"config.gather_position_arg_id", 3}}); - model_ref = - std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_idx, batch_limit}); - } -} - -TEST_F(TransformationTestsF, ConvertToROPE_GPTJ) { - disable_rt_info_check(); - const int batch = 2; - const int seq_len = 7; - const int num_heads = 16; - const int ndims = 256; - const int rotary_ndims = 64; - { - std::vector rpi_idx(rotary_ndims); - for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) { - rpi_idx[i] = index; - rpi_idx[i + 1] = index; - } - auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx); - - auto input = - std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims}); - auto gather_sin_cos = - std::make_shared(ov::element::f32, ov::Shape{1, seq_len, rotary_ndims}); - - auto split = makeOP({gather_sin_cos, {-1}, {rotary_ndims / 2, -1}}); - auto sin_tab = - makeOP({split->output(0), {1, -1, 1, rotary_ndims / 2}}, {{"special_zero", false}}); - auto cos_tab = - makeOP({split->output(1), {1, -1, 1, rotary_ndims / 2}}, {{"special_zero", false}}); - - auto slice_Slice_576 = - makeOP({input, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto repeat_interleave_Cos = - makeOP({cos_tab, repeat_interleave_index, {3}}, {{"batch_dims", 0}}); - auto mul_Multiply_757 = - makeOP({slice_Slice_576, repeat_interleave_Cos}, {{"auto_broadcast", "numpy"}}); - - auto slice_Slice_787 = - makeOP({slice_Slice_576, {0, 0, 0, 1}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Constant_191672 = makeConst(ov::element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - }), - {-1.000000f}); - auto neg_Multiply_790 = - makeOP({slice_Slice_787, Constant_191672}, {{"auto_broadcast", "numpy"}}); - auto Unsqueeze_61918 = makeOP({neg_Multiply_790, {-1}}); - auto slice_Slice_781 = - makeOP({slice_Slice_576, {0, 0, 0, 0}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Unsqueeze_61919 = makeOP({slice_Slice_781, {-1}}); - auto stack_795 = makeOP({Unsqueeze_61918, Unsqueeze_61919}, {{"axis", -1}}); - auto ShapeOf_165368 = makeOP>( - {stack_795}, - {{"type_relax", true}, {"input_data_types", {}}, {"output_data_types", {ov::element::i32}}}); - auto flatten_Slice_811 = makeOP({ShapeOf_165368, {0}, {3}, {1}}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto flatten_Concat_814 = makeOP({flatten_Slice_811, {-1}}, {{"axis", 0}}); - auto flatten_Reshape_815 = - makeOP({stack_795, flatten_Concat_814}, {{"special_zero", true}}); - auto repeat_interleave_Sin = - makeOP({sin_tab, repeat_interleave_index, {3}}, {{"batch_dims", 0}}); - auto mul_Multiply_816 = - makeOP({flatten_Reshape_815, repeat_interleave_Sin}, {{"auto_broadcast", "numpy"}}); - auto add_Add_819 = makeOP({mul_Multiply_757, mul_Multiply_816}, {{"auto_broadcast", "numpy"}}); - auto slice_Slice_582 = - makeOP({input, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto cat_Concat_826 = makeOP({add_Add_819, slice_Slice_582}, {{"axis", -1}}); - auto permute_Transpose_828 = makeOP({cat_Concat_826, {0, 2, 1, 3}}); - model = std::make_shared(ov::NodeVector{permute_Transpose_828}, - ov::ParameterVector{input, gather_sin_cos}); - } - manager.register_pass(); - { - auto input = - std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims}); - auto cos_sin = std::make_shared(ov::element::f32, ov::Shape{1, seq_len, rotary_ndims}); - auto rope = makeOP({input, cos_sin, cos_sin}, - {{"config.slice_start", 0}, - {"config.slice_stop", 0}, - {"config.input_trans0213", false}, - {"config.is_interleaved", true}, - {"config.rotary_ndims", rotary_ndims}, - {"config.gather_position_arg_id", 0}}); - model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin}); - } -} \ No newline at end of file