From 1960536e8ec68e89a10d16a45c0fd6d582f4884f Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Fri, 3 Nov 2023 13:47:51 +0400 Subject: [PATCH] Fix GPTQ model conversion after two breaking changes (#20823) * Fix GPTQ model conversion after two breaking changes * Code style fix * Remove redundant check --- .../pytorch/src/transforms/u4_block_repack.cpp | 3 +-- src/frontends/pytorch/src/utils_quantize.cpp | 13 ++++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/frontends/pytorch/src/transforms/u4_block_repack.cpp b/src/frontends/pytorch/src/transforms/u4_block_repack.cpp index 9dcd4569ea8..ed0e5b6bbf5 100644 --- a/src/frontends/pytorch/src/transforms/u4_block_repack.cpp +++ b/src/frontends/pytorch/src/transforms/u4_block_repack.cpp @@ -85,8 +85,7 @@ U4BlockRepack::U4BlockRepack() { } } - copy_runtime_info({std::move(constant), std::move(reshape1), std::move(transpose), std::move(reshape2)}, - new_const); + copy_runtime_info({std::move(constant), std::move(reshape1), std::move(transpose), reshape2}, new_const); replace_node(reshape2, new_const); return true; diff --git a/src/frontends/pytorch/src/utils_quantize.cpp b/src/frontends/pytorch/src/utils_quantize.cpp index 1346fd76971..70253b7f757 100644 --- a/src/frontends/pytorch/src/utils_quantize.cpp +++ b/src/frontends/pytorch/src/utils_quantize.cpp @@ -5,6 +5,7 @@ #include "utils_quantize.hpp" #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/bitwise_and.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" @@ -175,9 +176,15 @@ std::shared_ptr u4_compression_stack(const OutputVector& list_elems, int64 if (list_elems.size() != 2) return nullptr; - auto bitwise_and = cast_fw_node(list_elems[0].get_node_shared_ptr(), "aten::bitwise_and"); - if (!bitwise_and) - return nullptr; + + auto bitwise_and_candidate = list_elems[0].get_node_shared_ptr(); + std::shared_ptr bitwise_and = cast_fw_node(bitwise_and_candidate, "aten::bitwise_and"); + if (!bitwise_and) { + bitwise_and = std::dynamic_pointer_cast(bitwise_and_candidate); + if (!bitwise_and) + return nullptr; + } + auto bitwise_shift = cast_fw_node(list_elems[1].get_node_shared_ptr(), "aten::bitwise_right_shift"); if (!bitwise_shift) return nullptr;