[Transformations] FuseU4WeightsAndZeroPoint transformation (#20503)
This commit is contained in:
parent
d490ab68d1
commit
afda7ad70f
@ -0,0 +1,26 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertU4WeightsZeroPointToScalar;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Converts U4 weights zero point to scalar if all values are equal
|
||||
*/
|
||||
class ov::pass::ConvertU4WeightsZeroPointToScalar : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertU4WeightsZeroPointToScalar", "0");
|
||||
ConvertU4WeightsZeroPointToScalar();
|
||||
};
|
@ -0,0 +1,80 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/reference/autobroadcast_binop.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar() {
|
||||
MATCHER_SCOPE(ConvertU4WeightsZeroPointToScalar);
|
||||
auto weights_m = pattern::wrap_type<ov::op::v0::Constant>(pattern::type_matches(ov::element::u4));
|
||||
auto convert_m = pattern::wrap_type<ov::op::v0::Convert>({weights_m}, pattern::consumers_count(1));
|
||||
|
||||
auto float_zp_predicate = [](ov::Output<ov::Node> output) -> bool {
|
||||
return pattern::type_matches_any({ov::element::f32, ov::element::f16})(output) &&
|
||||
pattern::consumers_count(1)(output);
|
||||
};
|
||||
auto float_zero_point_m = pattern::wrap_type<ov::op::v0::Constant>(float_zp_predicate);
|
||||
|
||||
auto u4_zp_predicate = [](ov::Output<ov::Node> output) -> bool {
|
||||
return pattern::type_matches(ov::element::u4)(output) && pattern::consumers_count(1)(output);
|
||||
};
|
||||
auto u4_zero_point_m = pattern::wrap_type<ov::op::v0::Constant>(u4_zp_predicate);
|
||||
auto zero_point_convert_m = pattern::wrap_type<ov::op::v0::Convert>({u4_zero_point_m}, float_zp_predicate);
|
||||
|
||||
auto zero_point_m = std::make_shared<pattern::op::Or>(OutputVector{float_zero_point_m, zero_point_convert_m});
|
||||
auto subtract_m = pattern::wrap_type<ov::op::v1::Subtract>({convert_m, zero_point_m});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
auto& pattern_map = m.get_pattern_value_map();
|
||||
auto weights = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(weights_m).get_node_shared_ptr());
|
||||
std::shared_ptr<ov::op::v0::Constant> zero_point;
|
||||
if (pattern_map.count(float_zero_point_m)) {
|
||||
const auto& float_zp = pattern_map.at(float_zero_point_m);
|
||||
zero_point = ov::as_type_ptr<ov::op::v0::Constant>(float_zp.get_node_shared_ptr());
|
||||
} else {
|
||||
const auto& u4_zp = pattern_map.at(u4_zero_point_m);
|
||||
zero_point = ov::as_type_ptr<ov::op::v0::Constant>(u4_zp.get_node_shared_ptr());
|
||||
}
|
||||
if (!weights || !zero_point)
|
||||
return false;
|
||||
// Due to the matcher specific and Subtract branches similarity,
|
||||
// weights and zero_point might be mixed up with each other
|
||||
if (ov::shape_size(weights->get_shape()) < ov::shape_size(zero_point->get_shape()))
|
||||
std::swap(zero_point, weights);
|
||||
|
||||
auto zero_point_shape = zero_point->get_shape();
|
||||
if (ov::shape_size(zero_point_shape) == 1)
|
||||
return false;
|
||||
|
||||
const auto& weights_shape = weights->get_shape();
|
||||
const size_t weights_rank = weights_shape.size();
|
||||
const size_t zero_point_rank = zero_point_shape.size();
|
||||
// Zero point constant can be converted into scalar only if this does not affect Subtract output shape
|
||||
if (weights_rank < zero_point_rank)
|
||||
return false;
|
||||
|
||||
zero_point_shape.insert(zero_point_shape.begin(), weights_rank - zero_point_rank, 1);
|
||||
for (size_t i = 0; i < weights_rank; ++i) {
|
||||
if (zero_point_shape[i] > weights_shape[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
float zp_value;
|
||||
if (!ov::op::util::get_single_value(zero_point, zp_value))
|
||||
return false;
|
||||
const auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value});
|
||||
return ov::replace_node_update_name(zero_point, new_zp);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(subtract_m, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -21,6 +21,7 @@
|
||||
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
|
||||
#include "transformations/common_optimizations/convert_nms_gather_path_to_unsigned.hpp"
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp"
|
||||
#include "transformations/common_optimizations/convolution_to_group_convolution_fusion.hpp"
|
||||
#include "transformations/common_optimizations/depth_to_space_fusion.hpp"
|
||||
#include "transformations/common_optimizations/dilated_convolution_converter.hpp"
|
||||
@ -212,6 +213,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
|
||||
ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes)
|
||||
ADD_MATCHER(common_fusions, NonZeroHorizontalFusion)
|
||||
ADD_MATCHER(common_fusions, AdaptivePoolToReduce)
|
||||
ADD_MATCHER(common_fusions, ConvertU4WeightsZeroPointToScalar)
|
||||
common_fusions->set_name("ov::pass::CommonFusions");
|
||||
|
||||
REGISTER_PASS(manager, BinarizeWeights)
|
||||
|
@ -31,6 +31,8 @@ bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float
|
||||
return util::normalize_single_value(const_node->get_vector<bfloat16>(), value, check_value_range);
|
||||
case element::Type_t::f64:
|
||||
return util::normalize_single_value(const_node->get_vector<double>(), value, check_value_range);
|
||||
case element::Type_t::i4:
|
||||
return util::normalize_single_value(const_node->cast_vector<int8_t>(), value, check_value_range);
|
||||
case element::Type_t::i8:
|
||||
return util::normalize_single_value(const_node->get_vector<int8_t>(), value, check_value_range);
|
||||
case element::Type_t::i16:
|
||||
@ -39,6 +41,8 @@ bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float
|
||||
return util::normalize_single_value(const_node->get_vector<int32_t>(), value, check_value_range);
|
||||
case element::Type_t::i64:
|
||||
return util::normalize_single_value(const_node->get_vector<int64_t>(), value, check_value_range);
|
||||
case element::Type_t::u4:
|
||||
return util::normalize_single_value(const_node->cast_vector<int8_t>(), value, check_value_range);
|
||||
case element::Type_t::u8:
|
||||
return util::normalize_single_value(const_node->get_vector<uint8_t>(), value, check_value_range);
|
||||
case element::Type_t::u16:
|
||||
|
@ -0,0 +1,208 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "openvino/core/model.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov;
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalar) {
|
||||
auto weights_precision = ov::element::u4;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{32, 1, 64};
|
||||
{
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8.1f});
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
||||
{
|
||||
ov::Shape scalar_shape{};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(decompression_precision, scalar_shape, {8.1f});
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model_ref = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalar) {
|
||||
auto weights_precision = ov::element::u4;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{32, 1, 64};
|
||||
{
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
|
||||
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
||||
{
|
||||
ov::Shape scalar_shape{};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(weights_precision, scalar_shape, {8});
|
||||
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model_ref = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalarWeightsWithBiggerRank) {
|
||||
auto weights_precision = ov::element::u4;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{64};
|
||||
{
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8});
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
||||
{
|
||||
ov::Shape scalar_shape{};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(decompression_precision, scalar_shape, {8});
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model_ref = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointNotScalarLikeZP) {
|
||||
auto weights_precision = ov::element::u8;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{32, 1, 64};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
std::vector<std::int8_t> zero_point_values(ov::shape_size(decompression_shape), 8);
|
||||
zero_point_values.back() = 6;
|
||||
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, zero_point_values);
|
||||
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointNotU4Weights) {
|
||||
auto weights_precision = ov::element::u8;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{32, 1, 64};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
|
||||
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalarAdditionalZPConsumer) {
|
||||
auto weights_precision = ov::element::u4;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{32, 1, 64};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8});
|
||||
auto zero_point_consumer = std::make_shared<ov::op::v3::ShapeOf>(zero_point);
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply, zero_point_consumer}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarAdditionalZPConsumer) {
|
||||
auto weights_precision = ov::element::u4;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{32, 1, 64};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
|
||||
auto zero_point_consumer = std::make_shared<ov::op::v3::ShapeOf>(zero_point);
|
||||
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply, zero_point_consumer}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarAdditionalZPConvertConsumer) {
|
||||
auto weights_precision = ov::element::u4;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{32, 1, 64};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
|
||||
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
|
||||
auto zero_point_convert_consumer = std::make_shared<ov::op::v3::ShapeOf>(zero_point_convert);
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply, zero_point_convert_consumer}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarZPWithBiggerRank) {
|
||||
auto weights_precision = ov::element::u4;
|
||||
auto decompression_precision = ov::element::f32;
|
||||
ov::Shape weights_shape{32, 128, 64};
|
||||
ov::Shape decompression_shape{1, 32, 1, 64};
|
||||
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
|
||||
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
|
||||
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
|
||||
auto zero_point_convert_consumer = std::make_shared<ov::op::v3::ShapeOf>(zero_point_convert);
|
||||
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
|
||||
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
|
||||
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
|
||||
model = std::make_shared<Model>(NodeVector{multiply, zero_point_convert_consumer}, ParameterVector{});
|
||||
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
|
||||
}
|
@ -319,7 +319,8 @@ bool replace_output_update_name(Output<Node> output, const Output<Node>& replace
|
||||
|
||||
bool replace_node_update_name(const std::shared_ptr<Node>& target, const std::shared_ptr<Node>& replacement) {
|
||||
for (auto& output : target->output(0).get_target_inputs()) {
|
||||
if (ov::as_type<op::v0::Parameter>(replacement->input_value(0).get_node()) &&
|
||||
if (replacement->get_input_size() > 0 &&
|
||||
ov::as_type<op::v0::Parameter>(replacement->input_value(0).get_node()) &&
|
||||
ov::as_type<op::v0::Result>(output.get_node())) {
|
||||
return false;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user