Introduce NF4 data type (#19900)

* NF4 reference initial commit

* Compilable version.

* Executable NF4.

* Fixed nf4 unpacking.

* 1) Fixed warnings with nf4.
2) Removed unused functions.

* Added one test for nf4.

* Fixed code-style errors.

* Fixed code-style errors.

* Fixed NamingConventionCheck errors.

* Fixed test with nf4.

* Fixed windows compilation.

* Fixed casting warning.

* Fixed incorrect changes.

* Changed order of elements in nf4 pack/unpack.

* 1) Made Convert only on direction nf4->other type.
2) Applied reviewers suggestions.

* Fixed code style.

* Fised code style.

* 1) Added array header.
2) Added Bitsandbytes to third-party-programs.txt.

* 1) Removed unused code.
2) Fixed style typos.
3) Revert submodule version.

* Added test for nf4 compression.

* NF4 test refactoring.

* Added cpp tests for NF4.

* Removed model compilation from NF4 tests.

* Reverted submodule version.
This commit is contained in:
andreyanufr 2023-09-28 18:56:57 +02:00 committed by GitHub
parent 3de1332838
commit b73b2502b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 431 additions and 17 deletions

View File

@ -18,7 +18,7 @@ VariableReference: '^\w+$'
EnumName: '^[A-Z][\w]+$'
# excepts element_type
EnumConstantName: '^([A-Z\d_]+|undefined|dynamic|boolean|bf16|f16|f32|f64|i4|i8|i16|i32|i64|u1|u4|u8|u16|u32|u64|asymmetric|align_corners|round_prefer_floor|round_prefer_ceil|floor|ceil|simple|nearest|linear|linear_onnx|cubic|area|scales|sizes|half_pixel|tf_half_pixel_for_nn|pytorch_half_pixel|asymetric)$'
EnumConstantName: '^([A-Z\d_]+|undefined|dynamic|boolean|bf16|f16|f32|f64|i4|i8|i16|i32|i64|u1|u4|u8|u16|u32|u64|nf4|asymmetric|align_corners|round_prefer_floor|round_prefer_ceil|floor|ceil|simple|nearest|linear|linear_onnx|cubic|area|scales|sizes|half_pixel|tf_half_pixel_for_nn|pytorch_half_pixel|asymetric)$'
# TODO: align
UsingDeclaration: '^.*$'
TypedefName: '^.*$'

View File

@ -1640,3 +1640,29 @@ INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
-------------------------------------------------------------
30. Bitsandbytes (https://github.com/TimDettmers/bitsandbytes)
MIT License
Copyright (c) Facebook, Inc. and its affiliates.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -186,6 +186,7 @@ typedef enum {
U16, //!< u16 element type
U32, //!< u32 element type
U64, //!< u64 element type
NF4, //!< nf4 element type
} ov_element_type_e;
/**

View File

@ -23,7 +23,8 @@ const std::map<ov_element_type_e, ov::element::Type> element_type_map = {
{ov_element_type_e::U8, ov::element::u8},
{ov_element_type_e::U16, ov::element::u16},
{ov_element_type_e::U32, ov::element::u32},
{ov_element_type_e::U64, ov::element::u64}};
{ov_element_type_e::U64, ov::element::u64},
{ov_element_type_e::NF4, ov::element::nf4}};
inline ov_element_type_e find_ov_element_type_e(ov::element::Type type) {
for (auto iter = element_type_map.begin(); iter != element_type_map.end(); iter++) {

View File

@ -23,7 +23,7 @@ def pack_data(array: np.ndarray, type: Type) -> np.ndarray:
:param type: Type to interpret the array values. Type must be u1, u4 or i4.
:type type: openvino.runtime.Type
"""
assert type in [Type.u1, Type.u4, Type.i4], "Packing algorithm for the" "data types stored in 1, 2 or 4 bits"
assert type in [Type.u1, Type.u4, Type.i4, Type.nf4], "Packing algorithm for the" "data types stored in 1, 2 or 4 bits"
minimum_regular_dtype = np.int8 if type == Type.i4 else np.uint8
casted_to_regular_type = array.astype(dtype=minimum_regular_dtype, casting="unsafe")
@ -62,7 +62,7 @@ def unpack_data(array: np.ndarray, type: Type, shape: Union[list, Shape]) -> np.
:param shape: the new shape for the unpacked array.
:type shape: Union[list, openvino.runtime.Shape]
"""
assert type in [Type.u1, Type.u4, Type.i4], "Unpacking algorithm for the" "data types stored in 1, 2 or 4 bits"
assert type in [Type.u1, Type.u4, Type.i4, Type.nf4], "Unpacking algorithm for the" "data types stored in 1, 2 or 4 bits"
unpacked = np.unpackbits(array.view(np.uint8))
shape = list(shape)
if type.bitwidth == 1:

View File

@ -30,6 +30,7 @@ const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype() {
{ov::element::boolean, py::dtype("bool")},
{ov::element::u1, py::dtype("uint8")},
{ov::element::u4, py::dtype("uint8")},
{ov::element::nf4, py::dtype("uint8")},
{ov::element::i4, py::dtype("int8")},
};
return ov_type_to_dtype_mapping;

View File

@ -48,6 +48,7 @@ void regclass_graph_Type(py::module m) {
type.attr("u32") = ov::element::u32;
type.attr("u64") = ov::element::u64;
type.attr("bf16") = ov::element::bf16;
type.attr("nf4") = ov::element::nf4;
type.def("__hash__", &ov::element::Type::hash);
type.def("__repr__", [](const ov::element::Type& self) {

View File

@ -377,6 +377,7 @@ def test_init_with_packed_buffer(dtype, ov_type):
(0, 2, ov.Type.u1, np.uint8),
(0, 16, ov.Type.u4, np.uint8),
(-8, 7, ov.Type.i4, np.int8),
(0, 16, ov.Type.nf4, np.uint8),
])
def test_packing(shape, low, high, ov_type, dtype):
ov_tensor = Tensor(ov_type, shape)

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from openvino.runtime import opset9 as opset
import openvino as ov
import pytest
@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [
(ov.Type.f32, np.float32),
(ov.Type.f64, np.float64),
(ov.Type.f16, np.float16),
])
def test_float_to_nf4_convert(ov_type, numpy_dtype):
data = np.linspace(-1.5, 1.5, num=41, dtype=numpy_dtype)
compressed_const = opset.constant(data, dtype=ov.Type.nf4, name="nf4_constant")
convert = opset.convert(compressed_const, data.dtype)
parameter = opset.parameter(ov.PartialShape([-1]), ov_type)
add_op = opset.add(parameter, convert)
model = ov.Model([add_op], [parameter])
compiled = ov.compile_model(model)
tensor = np.zeros(data.shape, dtype=numpy_dtype)
result = compiled(tensor)[0]
uniq = []
for res_val in result:
if res_val not in uniq:
uniq.append(res_val)
uniq = np.array(uniq)
assert len(uniq) == 16
target = [-1.0, -0.6961928009986877, -0.5250730514526367,
-0.39491748809814453, -0.28444138169288635,
-0.18477343022823334, -0.09105003625154495,
0.0, 0.07958029955625534, 0.16093020141124725,
0.24611230194568634, 0.33791524171829224,
0.44070982933044434, 0.5626170039176941,
0.7229568362236023, 1.0]
target = np.array(target)
diff = np.max(np.abs(target - uniq))
assert diff < 0.001

View File

@ -88,6 +88,7 @@ public:
switch (precision) {
case element::i4:
case element::u4:
case element::nf4:
return (levels == low_precision::levels::int4) || (levels == low_precision::levels::int4_narrow_range);
case element::i8:
case element::u8:

View File

@ -99,6 +99,9 @@ std::shared_ptr<Node> make_constant(const element::Type& type, const Shape& shap
case element::Type_t::u4:
unsupported_data_type = "u4";
break;
case element::Type_t::nf4:
unsupported_data_type = "nf4";
break;
case element::Type_t::undefined:
unsupported_data_type = "undefined";
break;

View File

@ -40,6 +40,7 @@ using ov::element::i32;
using ov::element::i4;
using ov::element::i64;
using ov::element::i8;
using ov::element::nf4;
using ov::element::u1;
using ov::element::u16;
using ov::element::u32;

View File

@ -20,6 +20,7 @@
#include "openvino/core/rtti.hpp"
#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/type/float16.hpp"
#include "openvino/core/type/nf4.hpp"
/**
* @defgroup ov_element_cpp_api Element types
@ -50,7 +51,8 @@ enum class Type_t {
u8, //!< u8 element type
u16, //!< u16 element type
u32, //!< u32 element type
u64 //!< u64 element type
u64, //!< u64 element type
nf4 //!< nf4 element type
};
/// \brief Base class to define element type
@ -177,6 +179,9 @@ constexpr Type u32(Type_t::u32);
/// \brief u64 element type
/// \ingroup ov_element_cpp_api
constexpr Type u64(Type_t::u64);
/// \brief nf4 element type
/// \ingroup ov_element_cpp_api
constexpr Type nf4(Type_t::nf4);
template <typename T>
Type from() {

View File

@ -92,4 +92,9 @@ template <>
struct element_type_traits<element::Type_t::u64> {
using value_type = uint64_t;
};
template <>
struct element_type_traits<element::Type_t::nf4> {
using value_type = int8_t;
};
} // namespace ov

View File

@ -0,0 +1,47 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <array>
#include <cmath>
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "openvino/core/core_visibility.hpp"
namespace ov {
class OPENVINO_API ConvertNF4 {
public:
constexpr ConvertNF4() = default;
template <typename T, typename std::enable_if<!std::is_integral<T>::value, bool>::type = true>
static void unpack(T* dst, const uint8_t* src, std::size_t idx) {
uint8_t nf4_idx = get_u4(src, idx);
float val = dequantize(nf4_idx);
dst[idx] = static_cast<T>(val);
}
template <typename T, typename std::enable_if<std::is_integral<T>::value, bool>::type = true>
static void unpack(T* dst, const uint8_t* src, std::size_t idx) {
uint8_t nf4_idx = get_u4(src, idx);
dst[idx] = static_cast<T>(nf4_idx);
}
static float dequantize(uint8_t val);
static uint8_t quantize(float x);
private:
static inline uint8_t get_u4(const uint8_t* buf, size_t idx) {
const size_t byte_idx = idx / 2;
const uint8_t bit_shift = 4 * (idx % 2);
return (buf[byte_idx] >> bit_shift) & 0xF;
}
};
}; // namespace ov

View File

@ -143,6 +143,9 @@ public:
case Type_t::u64:
fill_data<Type_t::u64>(value);
break;
case Type_t::nf4:
fill_data<Type_t::nf4>(value);
break;
case Type_t::undefined:
case Type_t::dynamic:
OPENVINO_THROW("unsupported type");
@ -408,7 +411,7 @@ private:
template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
Type != element::Type_t::i4,
Type != element::Type_t::i4 && Type != element::Type_t::nf4,
bool>::type = true>
StorageDataType get_element_value(size_t index) const {
return get_data_ptr<Type>()[index];
@ -428,6 +431,13 @@ private:
return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
}
template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::nf4, bool>::type = true>
StorageDataType get_element_value(size_t index) const {
return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 4 : 0)) & 0x0F;
}
template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::i4, bool>::type = true>
@ -554,7 +564,7 @@ private:
typename T,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
Type != element::Type_t::i4,
Type != element::Type_t::i4 && Type != element::Type_t::nf4,
bool>::type = true>
void fill_data(const T& value) {
#ifdef __clang__
@ -607,7 +617,9 @@ private:
template <element::Type_t Type,
typename T,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4, bool>::type = true>
typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4 ||
Type == element::Type_t::nf4,
bool>::type = true>
void fill_data(const T& value) {
uint8_t v = value_in_range<Type>(value);
v &= 0x0F;
@ -640,8 +652,8 @@ private:
template <element::Type_t Type,
typename T,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
Type != element::Type_t::i4,
typename std::enable_if<Type != element::Type_t::nf4 && Type != element::Type_t::u1 &&
Type != element::Type_t::u4 && Type != element::Type_t::i4,
bool>::type = true>
void write_buffer(const std::vector<T>& source) {
auto p = get_data_ptr_nc<Type>();
@ -670,6 +682,50 @@ private:
}
}
template <element::Type_t Type,
typename T,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::nf4 && std::is_integral<T>::value, bool>::type = true>
void write_buffer(const std::vector<T>& source) {
auto p = get_data_ptr_nc<Type>();
size_t i = 0;
for (; i < source.size() / 2; i++) {
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
const auto v2 = value_in_range<Type>(source[i * 2 + 1]) & 0x0F;
const auto v = (v2 << 4) | v1;
p[i] = static_cast<StorageDataType>(v);
}
if (source.size() % 2) {
const auto v = value_in_range<Type>(source[i * 2]) & 0x0F;
p[i] = static_cast<StorageDataType>(v);
}
}
template <element::Type_t Type,
typename T,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::nf4 &&
(std::is_floating_point<T>::value || std::is_same<T, bfloat16>::value ||
std::is_same<T, float16>::value),
bool>::type = true>
void write_buffer(const std::vector<T>& source) {
auto p = get_data_ptr_nc<Type>();
size_t i = 0;
for (; i < source.size() / 2; i++) {
const auto idx1 = ConvertNF4::quantize(static_cast<float>(source[i * 2]));
const auto idx2 = ConvertNF4::quantize(static_cast<float>(source[i * 2 + 1]));
const auto v1 = value_in_range<Type>(idx1) & 0x0F;
const auto v2 = value_in_range<Type>(idx2) & 0x0F;
const auto v = (v2 << 4) | v1;
p[i] = static_cast<StorageDataType>(v);
}
if (source.size() % 2) {
const auto idx1 = ConvertNF4::quantize(static_cast<float>(source[i * 2]));
const auto v = value_in_range<Type>(idx1) & 0x0F;
p[i] = static_cast<StorageDataType>(v);
}
}
template <element::Type_t Type,
typename T,
typename StorageDataType = fundamental_type_for<Type>,
@ -755,6 +811,9 @@ private:
case Type_t::u64:
write_buffer<Type_t::u64>(source);
break;
case Type_t::nf4:
write_buffer<Type_t::nf4>(source);
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
OPENVINO_THROW("unsupported type");
@ -765,7 +824,9 @@ private:
}
template <ov::element::Type_t Type,
typename ValueT,
typename std::enable_if<Type == ov::element::Type_t::u4, bool>::type = true>
typename std::enable_if<Type == ov::element::Type_t::u4 || Type == ov::element::Type_t::u4 ||
Type == ov::element::Type_t::nf4,
bool>::type = true>
static ov::fundamental_type_for<Type> value_in_range(const ValueT& value) {
const auto result = ov::fundamental_type_for<Type>(value);
OPENVINO_ASSERT(0 <= result && result <= 15, "assigned value out of range u4 values");

View File

@ -70,6 +70,7 @@ TO get_value(const uint8_t* buf, size_t idx, element::Type from_type) {
if (from_type == element::i4) {
return detail::get_i4(buf, idx);
}
auto v = reinterpret_cast<const TI*>(buf);
return static_cast<TO>(v[idx]);
}
@ -85,6 +86,8 @@ void lp_convert(const TI* arg, TO* out, size_t count, element::Type_t src_type,
detail::set_u4(output, i, detail::get_value<uint8_t, TI>(input, i, src_type));
} else if (dst_type == element::i4) {
detail::set_i4(output, i, detail::get_value<int8_t, TI>(input, i, src_type));
} else if (src_type == element::nf4) {
ConvertNF4::unpack(out, input, i);
} else {
out[i] = detail::get_value<TO, TI>(input, i, src_type);
}

View File

@ -53,7 +53,7 @@ TResult get_raw_data_as(const element::Type_t et, const void* const ptr, const s
auto out_it = std::inserter(out, out.end());
using namespace ov::element;
IfTypeOf<bf16, f16, f32, f64, i4, i8, i16, i32, i64, u4, u8, u16, u32, u64>::apply<TensorTransform>(
IfTypeOf<bf16, f16, f32, f64, i4, i8, i16, i32, i64, u4, u8, u16, u32, u64, nf4>::apply<TensorTransform>(
et,
ptr,
size,

View File

@ -130,6 +130,9 @@ ov::op::v0::Constant::Constant(const element::Type& type,
case Type_t::u64:
fill_data<Type_t::u64>(ngraph::parse_string<uint64_t>(values[0]));
break;
case Type_t::nf4:
fill_data<Type_t::nf4>(ngraph::parse_string<uint64_t>(values[0]));
break;
case Type_t::undefined:
OPENVINO_THROW("deserialize unsupported type undefined");
case Type_t::dynamic:
@ -186,6 +189,9 @@ ov::op::v0::Constant::Constant(const element::Type& type,
case Type_t::u64:
write_buffer<Type_t::u64>(ngraph::parse_string<uint64_t>(values));
break;
case Type_t::nf4:
write_buffer<Type_t::nf4>(ngraph::parse_string<uint8_t>(values));
break;
case Type_t::undefined:
OPENVINO_THROW("deserialize unsupported type undefined");
case Type_t::dynamic:
@ -296,6 +302,9 @@ string ov::op::v0::Constant::convert_value_to_string(size_t index) const {
case Type_t::u64:
rc = to_string(get_element_value<Type_t::u64>(index));
break;
case Type_t::nf4:
rc = to_string(get_element_value<Type_t::nf4>(index));
break;
case Type_t::undefined:
case Type_t::dynamic:
OPENVINO_THROW("unsupported type");
@ -367,6 +376,7 @@ vector<string> ov::op::v0::Constant::get_value_strings() const {
break;
case element::Type_t::u1:
case element::Type_t::u4:
case element::Type_t::nf4:
for (auto value : cast_vector<uint8_t>()) {
rc.push_back(to_string(value));
}
@ -523,6 +533,7 @@ bool ov::op::v0::Constant::are_all_data_elements_bitwise_identical() const {
case element::Type_t::i4:
case element::Type_t::u1:
case element::Type_t::u4:
case element::Type_t::nf4:
case element::Type_t::undefined:
case element::Type_t::dynamic:
break;

View File

@ -52,7 +52,8 @@ bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out) {
}
if (((INPUT_ET == element::u1) || (OUTPUT_ET == element::u1)) ||
((INPUT_ET == element::u4) || (OUTPUT_ET == element::u4)) ||
((INPUT_ET == element::i4) || (OUTPUT_ET == element::i4))) {
((INPUT_ET == element::i4) || (OUTPUT_ET == element::i4)) ||
((INPUT_ET == element::nf4) || (OUTPUT_ET == element::nf4))) {
ov::reference::detail::lp_convert(arg->get_data_ptr<INPUT_ET>(),
out->get_data_ptr<OUTPUT_ET>(),
element_count,
@ -91,6 +92,7 @@ bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out) {
TYPE_OUT_CASE(f32, arg, out);
TYPE_OUT_CASE(f64, arg, out);
TYPE_OUT_CASE(boolean, arg, out);
TYPE_OUT_CASE(nf4, arg, out);
default:
rc = false;
break;
@ -117,6 +119,7 @@ bool evaluate_convert(const HostTensorPtr& arg, const HostTensorPtr& out) {
NGRAPH_TYPE_CASE(evaluate_convert, f32, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, f64, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, boolean, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, nf4, arg, out);
default:
rc = false;
break;

View File

@ -741,6 +741,8 @@ std::string get_precision_name(const ov::element::Type& elem_type) {
return "BIN";
case ::ov::element::Type_t::boolean:
return "BOOL";
case ::ov::element::Type_t::nf4:
return "NF4";
default:
OPENVINO_THROW("Unsupported precision: ", elem_type);
}

View File

@ -374,6 +374,7 @@ static std::string get_value(const std::shared_ptr<ov::op::v0::Constant>& consta
case ov::element::Type_t::dynamic:
case ov::element::Type_t::u1:
case ov::element::Type_t::u4:
case ov::element::Type_t::nf4:
case ov::element::Type_t::i4:
ss << constant->get_output_element_type(0).get_type_name() << " value";
break;

View File

@ -69,6 +69,8 @@ inline TypeInfo get_type_info(ov::element::Type_t type) {
return {32, false, false, false, "uint32_t", "u32"};
case ov::element::Type_t::u64:
return {64, false, false, false, "uint64_t", "u64"};
case ov::element::Type_t::nf4:
return {4, false, false, true, "nfloat4", "nf4"};
default:
OPENVINO_THROW("ov::element::Type_t not supported: ", type);
}
@ -111,6 +113,8 @@ ov::element::Type type_from_string(const std::string& type) {
return ::ov::element::Type(::ov::element::Type_t::undefined);
} else if (type == "dynamic") {
return ::ov::element::Type(::ov::element::Type_t::dynamic);
} else if (type == "nf4" || type == "NF4") {
return ::ov::element::Type(::ov::element::Type_t::nf4);
} else {
OPENVINO_THROW("Incorrect type: ", type);
}
@ -163,6 +167,7 @@ ov::element::Type::Type(size_t bitwidth,
{ov::element::Type_t::u16, {16, false, false, false, "uint16_t", "u16"}},
{ov::element::Type_t::u32, {32, false, false, false, "uint32_t", "u32"}},
{ov::element::Type_t::u64, {64, false, false, false, "uint64_t", "u64"}},
{ov::element::Type_t::u4, {4, false, false, false, "uint4_t", "nf4"}},
};
for (const auto& t : elements_map) {
const TypeInfo& info = t.second;
@ -319,6 +324,7 @@ std::istream& ov::element::operator>>(std::istream& in, ov::element::Type& obj)
{"FP64", ov::element::f64},
{"FP16", ov::element::f16},
{"BIN", ov::element::u1},
{"NF4", ov::element::nf4},
};
std::string str;
in >> str;
@ -400,6 +406,7 @@ inline size_t compiler_byte_size(ov::element::Type_t et) {
ET_CASE(u16);
ET_CASE(u32);
ET_CASE(u64);
ET_CASE(nf4);
#undef ET_CASE
case ov::element::Type_t::undefined:
return 0;
@ -431,7 +438,8 @@ OPENVINO_API EnumNames<element::Type_t>& EnumNames<element::Type_t>::get() {
{"u8", element::Type_t::u8},
{"u16", element::Type_t::u16},
{"u32", element::Type_t::u32},
{"u64", element::Type_t::u64}});
{"u64", element::Type_t::u64},
{"nf4", element::Type_t::nf4}});
return enum_names;
}

81
src/core/src/type/nf4.cpp Normal file
View File

@ -0,0 +1,81 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
// Contains logic derived from bitsandbytes
// https://github.com/TimDettmers/bitsandbytes/blob/c82f51c0f784d8a43ebcb9cdefbf94e3f3b9c6c3/csrc/kernels.cu#L223
// implementation.
// Copyright notice from original source file is as follows.
//*******************************************************************************
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
//==============================================================================
#include "openvino/core/type/nf4.hpp"
using namespace ov;
float ConvertNF4::dequantize(uint8_t val) {
static const std::array<float, 16> lookup = {-1.0f,
-0.6961928009986877f,
-0.5250730514526367f,
-0.39491748809814453f,
-0.28444138169288635f,
-0.18477343022823334f,
-0.09105003625154495f,
0.0f,
0.07958029955625534f,
0.16093020141124725f,
0.24611230194568634f,
0.33791524171829224f,
0.44070982933044434f,
0.5626170039176941f,
0.7229568362236023f,
1.0f};
return lookup[val];
}
uint8_t ConvertNF4::quantize(float x) {
if (x > 0.03979014977812767f)
if (x > 0.3893125355243683f) // 1
if (x > 0.6427869200706482f) // 11
if (x > 0.8614784181118011f) // 111
return 0b1111;
else
return 0b1110;
else if (x > 0.5016634166240692f) // 110
return 0b1101;
else
return 0b1100;
else if (x > 0.2035212516784668f) // 10
if (x > 0.2920137718319893f) // 101
return 0b1011;
else
return 0b1010;
else if (x > 0.1202552504837513f) // 100
return 0b1001;
else
return 0b1000;
else if (x > -0.33967943489551544f) // 0
if (x > -0.13791173323988914f) // 01
if (x > -0.045525018125772476f) // 011
return 0b0111;
else
return 0b0110;
else if (x > -0.23460740596055984f) // 010
return 0b0101;
else
return 0b0100;
else if (x > -0.6106329262256622f) // 00
if (x > -0.4599952697753906f) // 001
return 0b0011;
else
return 0b0010;
else if (x > -0.8480964004993439f) // 000
return 0b0001;
else
return 0b0000;
}

View File

@ -64,6 +64,8 @@ TEST(element_type, from_string) {
EXPECT_EQ(element::Type("U32"), element::u32);
EXPECT_EQ(element::Type("u64"), element::u64);
EXPECT_EQ(element::Type("U64"), element::u64);
EXPECT_EQ(element::Type("nf4"), element::nf4);
EXPECT_EQ(element::Type("NF4"), element::nf4);
EXPECT_EQ(element::Type("undefined"), element::undefined);
EXPECT_EQ(element::Type("UNSPECIFIED"), element::undefined);

96
src/core/tests/nf4.cpp Normal file
View File

@ -0,0 +1,96 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/core/type/nf4.hpp"
#include "common_test_utils/test_tools.hpp"
#include "gtest/gtest.h"
#include "openvino/op/constant.hpp"
using namespace std;
TEST(nf4, convert_nf4_to_string) {
vector<uint8_t> values{186, 17};
auto constant = make_shared<ov::op::v0::Constant>(ov::element::nf4, ov::Shape{3}, &values[0]);
vector<string> ref{"10", "11", "1"};
for (size_t i = 0; i < 3; ++i) {
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
}
}
TEST(nf4, tensor_or_constant_size) {
vector<uint8_t> values{171, 16};
auto constant = make_shared<ov::op::v0::Constant>(ov::element::nf4, ov::Shape{3}, &values[0]);
EXPECT_EQ(2, constant->get_byte_size());
ov::Tensor runtime_tensor(ov::element::nf4, ov::Shape{3});
EXPECT_EQ(constant->get_byte_size(), runtime_tensor.get_byte_size());
}
template <typename T>
void test_nf4_convert() {
vector<float> const_data_f{-1.5f, -1.425f, -1.35f, -1.275f, -1.2f, -1.125f, -1.05f, -0.975f, -0.9f,
-0.825f, -0.75f, -0.675f, -0.6f, -0.525f, -0.45f, -0.375f, -0.3f, -0.225f,
-0.15f, -0.075f, 0.0f, 0.075f, 0.15f, 0.225f, 0.3f, 0.375f, 0.45f,
0.525f, 0.6f, 0.675f, 0.75f, 0.825f, 0.9f, 0.975f, 1.05f, 1.125f,
1.2f, 1.275f, 1.35f, 1.425f, 1.5};
vector<float> target_f{-1.0f,
-0.6961928009986877f,
-0.5250730514526367f,
-0.39491748809814453f,
-0.28444138169288635f,
-0.18477343022823334f,
-0.09105003625154495f,
0.0f,
0.07958029955625534f,
0.16093020141124725f,
0.24611230194568634f,
0.33791524171829224f,
0.44070982933044434f,
0.5626170039176941f,
0.7229568362236023f,
1.0f};
vector<T> const_data;
const_data.reserve(const_data_f.size());
for (auto& val : const_data_f) {
const_data.push_back(static_cast<T>(val));
}
vector<T> target;
target.reserve(target_f.size());
for (auto& val : target_f) {
target.push_back(static_cast<T>(val));
}
auto constant = ov::op::v0::Constant::create(ov::element::nf4, ov::Shape{const_data.size()}, const_data);
const uint8_t* p = static_cast<const uint8_t*>(constant->get_data_ptr());
EXPECT_NE(p, nullptr);
std::vector<uint8_t> packed_data(p, p + const_data.size() / 2 + const_data.size() % 2);
std::vector<T> decompressed_data(const_data.size(), 0);
for (size_t i = 0; i < const_data.size(); i++) {
ov::ConvertNF4::unpack(&decompressed_data[0], &packed_data[0], i);
}
auto it = std::unique(decompressed_data.begin(), decompressed_data.end());
decompressed_data.resize(std::distance(decompressed_data.begin(), it));
EXPECT_EQ(16, decompressed_data.size());
float max_diff = 0.0;
for (size_t i = 0; i < 16; i++) {
float diff = fabs(static_cast<float>(decompressed_data[i] - target[i]));
max_diff = std::max(max_diff, diff);
}
EXPECT_LE(max_diff, 0.001);
}
TEST(nf4, convert_float) {
test_nf4_convert<float>();
test_nf4_convert<ov::float16>();
test_nf4_convert<ov::bfloat16>();
}

View File

@ -16,8 +16,10 @@ inline void evaluate(const std::shared_ptr<ov::op::v1::ConvertLike>& op,
outputs[0].set_shape(inputs[0].get_shape());
size_t element_count = ov::shape_size(outputs[0].get_shape());
if (((ti == ov::element::u1) || (to == ov::element::u1)) || ((ti == ov::element::u4) || (to == ov::element::u4)) ||
((ti == ov::element::i4) || (to == ov::element::i4))) {
if (((ti == ngraph::element::u1) || (to == ngraph::element::u1)) ||
((ti == ngraph::element::u4) || (to == ngraph::element::u4)) ||
((ti == ngraph::element::i4) || (to == ngraph::element::i4)) ||
((ti == ngraph::element::nf4) || (to == ngraph::element::nf4))) {
ov::reference::detail::lp_convert(inputs[0].data<T_I>(), outputs[0].data<T_O>(), element_count, ti, to);
} else {
ov::reference::convert(inputs[0].data<T_I>(), outputs[0].data<T_O>(), element_count);

View File

@ -44,6 +44,7 @@ ov::Tensor create_and_fill_tensor(const ov::element::Type element_type,
case ov::element::Type_t::u1:
case ov::element::Type_t::i4:
case ov::element::Type_t::u4:
case ov::element::Type_t::nf4:
fill_data_random(static_cast<uint8_t*>(tensor.data()),
tensor.get_byte_size(),
range,

View File

@ -99,7 +99,7 @@ inline bool operator==(const __itt_id& left, const __itt_id& right) {
namespace sea {
uint64_t g_nRingBuffer = 1000000000ll * atoi(get_environ_value("INTEL_SEA_RING").c_str()); // in nanoseconds
uint64_t g_nRingBuffer = 1000000000ll * atoi(get_environ_value("INTEL_SEA_RING").c_str()); // in nanoseconds
uint64_t g_nAutoCut = 1024ull * 1024 * atoi(get_environ_value("INTEL_SEA_AUTOCUT").c_str()); // in MB
uint64_t g_features = sea::GetFeatureSet();