Convert operation optimization for FP16 -> INT8 and FP32 -> INT8 (#5275)
Improvements for ConvertPrecision transformation pass
This commit is contained in:
parent
c9672ee9ec
commit
d2c9bddff1
@ -67,7 +67,7 @@ std::shared_ptr<ngraph::Function> TransformNetwork(const std::shared_ptr<const n
|
||||
// Example: register CommonOptimizations transformation from transformations library
|
||||
passManager.register_pass<ngraph::pass::CommonOptimizations>();
|
||||
// Template plugin handles only FP32 networks
|
||||
passManager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
passManager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ngraph::element::f16, ngraph::element::f32 }});
|
||||
// Example: register plugin specific transformation
|
||||
passManager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
|
||||
passManager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
|
||||
|
@ -175,7 +175,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather0D>();
|
||||
|
||||
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list {
|
||||
static const precisions_array convert_precision_list {
|
||||
{ngraph::element::i64, ngraph::element::i32},
|
||||
{ngraph::element::u64, ngraph::element::i32},
|
||||
{ngraph::element::u16, ngraph::element::i32},
|
||||
@ -185,9 +185,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
{ngraph::element::u4, ngraph::element::u8},
|
||||
};
|
||||
|
||||
for (auto& precision : convert_precision_list) {
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
|
||||
}
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list);
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
|
||||
@ -366,7 +364,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
// Conversion to FP32 might be needed for quantized models that face any fp16 related issues (e.g. overflow) for non-quantized layers
|
||||
// With this key users can work-around such issues
|
||||
if (!config.enable_fp16_for_quantized_models) {
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::f16, ngraph::element::f32 }});
|
||||
}
|
||||
auto lptPrerequisites = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
const std::vector<ngraph::element::Type> supportedTypes = { ngraph::element::i8, ngraph::element::u8 };
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <ie_plugin_config.hpp>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
#include <ie_system_conf.h>
|
||||
#include <nodes/list.hpp>
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
@ -65,6 +66,7 @@
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/graph_util.hpp>
|
||||
|
||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||
|
||||
@ -122,6 +124,28 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
||||
std::vector<ngraph::element::Type>{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
|
||||
}
|
||||
|
||||
auto get_convert_precisions = []() {
|
||||
precisions_array array = {
|
||||
{ngraph::element::i64, ngraph::element::i32},
|
||||
{ngraph::element::u64, ngraph::element::i32},
|
||||
{ngraph::element::i16, ngraph::element::i32},
|
||||
{ngraph::element::u16, ngraph::element::i32},
|
||||
{ngraph::element::u32, ngraph::element::i32},
|
||||
{ngraph::element::f64, ngraph::element::f32},
|
||||
{ngraph::element::f16, ngraph::element::f32},
|
||||
{ngraph::element::boolean, ngraph::element::u8},
|
||||
{ngraph::element::i4, ngraph::element::i8},
|
||||
{ngraph::element::u4, ngraph::element::u8}
|
||||
};
|
||||
|
||||
if (!with_cpu_x86_avx512_core())
|
||||
array.push_back({ngraph::element::bf16, ngraph::element::f32});
|
||||
|
||||
return array;
|
||||
};
|
||||
|
||||
static const auto precisions = get_convert_precisions();
|
||||
|
||||
// WA: ConvertPriorBox must be executed before the 1st ConstantFolding pass
|
||||
manager.register_pass<ngraph::pass::CommonOptimizations>();
|
||||
manager.register_pass<ngraph::pass::ConvertRNNSequenceToTensorIterator>();
|
||||
@ -140,27 +164,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
||||
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();
|
||||
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list{
|
||||
{ngraph::element::i64, ngraph::element::i32},
|
||||
{ngraph::element::u64, ngraph::element::i32},
|
||||
{ngraph::element::i16, ngraph::element::i32},
|
||||
{ngraph::element::u16, ngraph::element::i32},
|
||||
{ngraph::element::u32, ngraph::element::i32},
|
||||
{ngraph::element::f64, ngraph::element::f32},
|
||||
{ngraph::element::f16, ngraph::element::f32},
|
||||
{ngraph::element::boolean, ngraph::element::u8},
|
||||
{ngraph::element::i4, ngraph::element::i8},
|
||||
{ngraph::element::u4, ngraph::element::u8},
|
||||
};
|
||||
|
||||
// In case BF16 is not supported by the target CPU we explicitly convert it to FP32
|
||||
if (!with_cpu_x86_avx512_core())
|
||||
convert_precision_list.push_back({ngraph::element::bf16, ngraph::element::f32});
|
||||
|
||||
for (auto &precision : convert_precision_list) {
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
|
||||
}
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
|
||||
|
@ -42,7 +42,7 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
|
||||
manager.register_pass<ReshapeFullyConnectedFusion>();
|
||||
}
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});
|
||||
manager.run_passes(nGraphFunc);
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
@ -69,18 +70,24 @@ class NGRAPH_API ConvertPrecision;
|
||||
* LessEqual
|
||||
*/
|
||||
|
||||
using type_to_fuse_map = std::map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
|
||||
using type_to_fuse_map = std::unordered_map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
|
||||
using precisions_array = std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>>;
|
||||
|
||||
class ngraph::pass::ConvertPrecision : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertPrecision(ngraph::element::Type_t from, ngraph::element::Type_t to, type_to_fuse_map additional_type_to_fuse_map = {})
|
||||
: FunctionPass(),
|
||||
m_from(from),
|
||||
m_to(to),
|
||||
m_precisions(precisions_array {{ from, to }}),
|
||||
m_additional_type_to_fuse_map(additional_type_to_fuse_map) {}
|
||||
|
||||
ConvertPrecision(const precisions_array& precisions, const type_to_fuse_map & additional_type_to_fuse_map = {})
|
||||
: FunctionPass(),
|
||||
m_precisions(precisions),
|
||||
m_additional_type_to_fuse_map(additional_type_to_fuse_map) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<Function> f) override;
|
||||
private:
|
||||
element::Type m_from, m_to;
|
||||
precisions_array m_precisions;
|
||||
type_to_fuse_map m_additional_type_to_fuse_map;
|
||||
};
|
||||
|
@ -15,10 +15,11 @@
|
||||
#include "vpu/ngraph/operations/out_shape_of_reshape.hpp"
|
||||
#include <stack>
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace vpu {
|
||||
|
||||
using typeToFuseMap = std::map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
|
||||
using typeToFuseMap = std::unordered_map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
|
||||
|
||||
bool fuseTypeToStaticShapeNonMaxSuppression(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuseTypeToStaticShapeNonZero(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
|
@ -192,14 +192,17 @@ ie::CNNNetwork FrontEnd::convertNetwork(ie::CNNNetwork& network) {
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||
// ConvertPrecision must be executed before ConvertOpSet1ToLegacy due to this pass works with operations from opsets only
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32, myriadTypeToFuseMap);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::u64, ngraph::element::i32, myriadTypeToFuseMap);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::u32, ngraph::element::i32, myriadTypeToFuseMap);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32, myriadTypeToFuseMap);
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::u64, ngraph::element::i32 },
|
||||
{ ngraph::element::u32, ngraph::element::i32 },
|
||||
{ ngraph::element::boolean, ngraph::element::i32 }
|
||||
};
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions, myriadTypeToFuseMap);
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
// ConvertOpSet1ToLegacy can produce constants with I64 precision
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32, myriadTypeToFuseMap);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }}, myriadTypeToFuseMap);
|
||||
manager.register_pass<vpu::MergeSubsequentDSROperations>();
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
|
@ -174,7 +174,7 @@ TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||
|
||||
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list {
|
||||
static const precisions_array convert_precision_list {
|
||||
{ngraph::element::i64, ngraph::element::i32},
|
||||
{ngraph::element::u64, ngraph::element::i32},
|
||||
{ngraph::element::u16, ngraph::element::i32},
|
||||
@ -183,12 +183,9 @@ TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
|
||||
{ngraph::element::boolean, ngraph::element::u8},
|
||||
};
|
||||
|
||||
for (auto & precision : convert_precision_list) {
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
|
||||
}
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list);
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
|
@ -54,8 +54,13 @@ TEST(TransformationTests, ConvertPrecision_NMS3) {
|
||||
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -77,8 +82,13 @@ TEST(TransformationTests, ConvertPrecision_NMS4) {
|
||||
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -95,8 +105,13 @@ TEST(TransformationTests, ConvertPrecision_ShapeOf) {
|
||||
f = std::make_shared<Function>(NodeVector{shape_of}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -113,8 +128,13 @@ TEST(TransformationTests, ConvertPrecision_Convert) {
|
||||
f = std::make_shared<Function>(NodeVector{convert}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -132,7 +152,7 @@ TEST(TransformationTests, ConvertPrecision_ConvertElimination) {
|
||||
f = std::make_shared<Function>(NodeVector{convert}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::f16, ngraph::element::f32 }});
|
||||
manager.run_passes(f);
|
||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
|
||||
}
|
||||
@ -158,8 +178,13 @@ TEST(TransformationTests, ConvertPrecision_TopK) {
|
||||
f = std::make_shared<Function>(OutputVector{topk->output(0), topk->output(1)}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -176,8 +201,13 @@ TEST(TransformationTests, ConvertPrecision_NonZero) {
|
||||
f = std::make_shared<Function>(OutputVector{non_zero}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -195,8 +225,13 @@ TEST(TransformationTests, ConvertPrecision_Bucketize) {
|
||||
f = std::make_shared<Function>(OutputVector{b}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -223,8 +258,13 @@ TEST(TransformationTests, ConvertPrecision_Roundings) {
|
||||
f = std::make_shared<Function>(OutputVector{ss}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
|
||||
auto casted_end = std::dynamic_pointer_cast<opset1::Constant>(ss->input_value(2).get_node_shared_ptr());
|
||||
@ -279,8 +319,13 @@ TEST(TransformationTests, ConvertPrecision_TIBody) {
|
||||
ngraph::ParameterVector{X, Y});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
|
||||
@ -300,8 +345,13 @@ TEST(TransformationTests, ConvertPrecision_Equal) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -320,8 +370,13 @@ TEST(TransformationTests, ConvertPrecision_NotEqual) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -340,8 +395,13 @@ TEST(TransformationTests, ConvertPrecision_Greater) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -360,8 +420,13 @@ TEST(TransformationTests, ConvertPrecision_GreaterEqual) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -380,8 +445,13 @@ TEST(TransformationTests, ConvertPrecision_Less) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -400,8 +470,13 @@ TEST(TransformationTests, ConvertPrecision_LessEqual) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||
{ ngraph::element::f16, ngraph::element::f32 }
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -420,7 +495,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalAnd) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -438,7 +513,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalOr) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -456,7 +531,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalXor) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -473,7 +548,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalNot) {
|
||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -491,7 +566,7 @@ TEST(TransformationTests, ConvertPrecision_Select) {
|
||||
f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -509,8 +584,8 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxedWithSelect) {
|
||||
f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i32, ngraph::element::i64);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::i32 }});
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i32, ngraph::element::i64 }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -529,8 +604,8 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxed) {
|
||||
f = std::make_shared<Function>(OutputVector{type_relaxed}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i32, ngraph::element::i64);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::i32 }});
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i32, ngraph::element::i64 }});
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
|
||||
@ -555,7 +630,7 @@ TEST(TransformationTests, ConvertPrecision_Variables) {
|
||||
f = std::make_shared<Function>(NodeVector{mul}, ParameterVector{inp});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::f16, ngraph::element::f32 }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -573,7 +648,7 @@ void constant_convert_test(element::Type type_from, element::Type type_to, const
|
||||
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(type_from, type_to);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ type_from, type_to }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
auto ops = f->get_ordered_ops();
|
||||
@ -603,7 +678,7 @@ void constant_convert_test(element::Type_t type_from, element::Type_t type_to, F
|
||||
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(type_from, type_to);
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ type_from, type_to }});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
auto ops = f->get_ordered_ops();
|
||||
|
@ -28,6 +28,10 @@ namespace ngraph
|
||||
void convert<uint8_t, float16>(const uint8_t* arg, float16* out, size_t count);
|
||||
template <>
|
||||
void convert<float16, float>(const float16* arg, float* out, size_t count);
|
||||
template <>
|
||||
void convert<float, int8_t>(const float* arg, int8_t* out, size_t count);
|
||||
template <>
|
||||
void convert<float16, int8_t>(const float16* arg, int8_t* out, size_t count);
|
||||
|
||||
// overload to handle ngraph::boolean (it is stored as char)
|
||||
template <typename TI, typename TO>
|
||||
@ -39,7 +43,6 @@ namespace ngraph
|
||||
out[i] = static_cast<char>(static_cast<bool>(arg[i]));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference
|
||||
|
||||
} // namespace runtime
|
||||
|
@ -16,6 +16,11 @@ namespace ngraph
|
||||
template <typename src_t, typename dst_t>
|
||||
void jit_convert_vec(jit::Generator&, const Xbyak::RegExp&, const Xbyak::RegExp&);
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
void jit_convert_vec_prepare(jit::Generator&)
|
||||
{
|
||||
}
|
||||
|
||||
template <>
|
||||
void jit_convert_vec<uint8_t, float16>(jit::Generator& gen,
|
||||
const Xbyak::RegExp& src,
|
||||
@ -47,6 +52,61 @@ namespace ngraph
|
||||
gen.vmovups(gen.yword[dst], f32vec);
|
||||
}
|
||||
|
||||
template <>
|
||||
void jit_convert_vec_prepare<float, int8_t>(jit::Generator& gen)
|
||||
{
|
||||
auto order = gen.ymm1;
|
||||
auto addr = gen.r15;
|
||||
|
||||
static const int8_t offsets[32] = {0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 4,
|
||||
8, 12, -1, -1, -1, -1, -1, -1, -1, -1};
|
||||
|
||||
gen.mov(addr, (size_t)offsets); // get offsets[] address
|
||||
gen.vmovdqu(order, gen.yword[addr]); // save offsets[] to ymm register
|
||||
}
|
||||
|
||||
template <>
|
||||
void jit_convert_vec<float, int8_t>(jit::Generator& gen,
|
||||
const Xbyak::RegExp& src,
|
||||
const Xbyak::RegExp& dst)
|
||||
{
|
||||
auto order = gen.ymm1;
|
||||
auto p32vec = gen.ymm2;
|
||||
auto p32vec_lo = gen.xmm2;
|
||||
auto p32vec_hi = gen.xmm3;
|
||||
|
||||
gen.vcvtps2dq(p32vec, gen.yword[src]); // convert 8 floats to 8 ints
|
||||
gen.vpshufb(p32vec, p32vec, order); // Shuffle the bytes according to the order
|
||||
gen.vextracti128(p32vec_hi, p32vec, 1); // extract upper part of p32vec
|
||||
gen.vpor(p32vec_lo, p32vec_lo, p32vec_hi); // p32vec_lo = p32vec_lo | p32vec_hi
|
||||
gen.movq(gen.qword[dst], p32vec_lo); // save the result
|
||||
}
|
||||
|
||||
template <>
|
||||
void jit_convert_vec_prepare<float16, int8_t>(jit::Generator& gen)
|
||||
{
|
||||
jit_convert_vec_prepare<float, int8_t>(gen);
|
||||
}
|
||||
|
||||
template <>
|
||||
void jit_convert_vec<float16, int8_t>(jit::Generator& gen,
|
||||
const Xbyak::RegExp& src,
|
||||
const Xbyak::RegExp& dst)
|
||||
{
|
||||
auto order = gen.ymm1;
|
||||
auto p32vec = gen.ymm2;
|
||||
auto p32vec_lo = gen.xmm2;
|
||||
auto p32vec_hi = gen.xmm3;
|
||||
|
||||
gen.vcvtph2ps(p32vec, gen.xword[src]); // convert 8 fp16's to 8 floats
|
||||
gen.vcvtps2dq(p32vec, p32vec); // convert 8 floats to 8 ints
|
||||
gen.vpshufb(p32vec, p32vec, order); // Shuffle the bytes according to the order
|
||||
gen.vextracti128(p32vec_hi, p32vec, 1); // extract upper part of p32vec
|
||||
gen.vpor(p32vec_lo, p32vec_lo, p32vec_hi); // p32vec_lo = p32vec_lo | p32vec_hi
|
||||
gen.movq(gen.qword[dst], p32vec_lo); // save the result
|
||||
}
|
||||
|
||||
class jit_convert_array : public jit::Generator
|
||||
{
|
||||
typedef struct context
|
||||
@ -61,6 +121,7 @@ namespace ngraph
|
||||
void (*convert_vec)(jit::Generator&,
|
||||
const Xbyak::RegExp&,
|
||||
const Xbyak::RegExp&);
|
||||
void (*prepare)(jit::Generator&);
|
||||
} context_t;
|
||||
|
||||
jit_convert_array(const context_t& ctx)
|
||||
@ -77,6 +138,8 @@ namespace ngraph
|
||||
|
||||
preamble();
|
||||
|
||||
ctx.prepare(*this);
|
||||
|
||||
mov(reg_src, ptr[param + offsetof(args_t, src)]);
|
||||
mov(reg_dst, ptr[param + offsetof(args_t, out)]);
|
||||
mov(reg_sz, ptr[param + offsetof(args_t, count)]);
|
||||
@ -137,7 +200,8 @@ namespace ngraph
|
||||
static const jit_convert_array::context_t context{
|
||||
{sizeof(src_t), &jit::Generator::copy<src_t>},
|
||||
{sizeof(dst_t), &jit::Generator::copy<dst_t>},
|
||||
jit_convert_vec<src_t, dst_t>};
|
||||
jit_convert_vec<src_t, dst_t>,
|
||||
jit_convert_vec_prepare<src_t, dst_t>};
|
||||
|
||||
static jit_convert_array generator(context);
|
||||
|
||||
@ -146,44 +210,49 @@ namespace ngraph
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TI, typename TO>
|
||||
void convert_impl(const TI* arg, TO* out, size_t count)
|
||||
{
|
||||
auto converter = jit_convert_array::get<TI, TO>();
|
||||
|
||||
if (converter)
|
||||
{
|
||||
jit_convert_array::args_t args = {arg, out, count};
|
||||
converter(&args);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
out[i] = static_cast<TO>(arg[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
void convert<uint8_t, float16>(const uint8_t* arg, float16* out, size_t count)
|
||||
{
|
||||
auto converter = jit_convert_array::get<uint8_t, float16>();
|
||||
|
||||
if (converter)
|
||||
{
|
||||
jit_convert_array::args_t args = {arg, out, count};
|
||||
converter(&args);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
out[i] = static_cast<float16>(arg[i]);
|
||||
}
|
||||
}
|
||||
convert_impl(arg, out, count);
|
||||
}
|
||||
|
||||
template <>
|
||||
void convert<float16, float>(const float16* arg, float* out, size_t count)
|
||||
{
|
||||
auto converter = jit_convert_array::get<float16, float>();
|
||||
convert_impl(arg, out, count);
|
||||
}
|
||||
|
||||
if (converter)
|
||||
template <>
|
||||
void convert<float, int8_t>(const float* arg, int8_t* out, size_t count)
|
||||
{
|
||||
jit_convert_array::args_t args = {arg, out, count};
|
||||
converter(&args);
|
||||
convert_impl(arg, out, count);
|
||||
}
|
||||
else
|
||||
|
||||
template <>
|
||||
void convert<float16, int8_t>(const float16* arg, int8_t* out, size_t count)
|
||||
{
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
out[i] = static_cast<float>(arg[i]);
|
||||
}
|
||||
}
|
||||
convert_impl(arg, out, count);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
|
@ -158,6 +158,26 @@ namespace ngraph
|
||||
pop(rsi);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Generator::copy<int8_t>(const Xbyak::Reg64& dst,
|
||||
const Xbyak::Reg64& src,
|
||||
const Xbyak::Reg64& size)
|
||||
{
|
||||
push(rsi);
|
||||
push(r15);
|
||||
|
||||
xor_(rsi, rsi);
|
||||
|
||||
foreach (rsi, 1, size, [&, this](const Xbyak::Reg64& idx) {
|
||||
mov(r15b, byte[src + idx * sizeof(int8_t)]);
|
||||
mov(byte[dst + idx * sizeof(int8_t)], r15b);
|
||||
})
|
||||
;
|
||||
|
||||
pop(r15);
|
||||
pop(rsi);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Generator::copy<uint16_t>(const Xbyak::Reg64& dst,
|
||||
const Xbyak::Reg64& src,
|
||||
|
@ -15,7 +15,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0);
|
||||
bool ngraph::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ngraph::Function> f)
|
||||
{
|
||||
ngraph::pass::Manager m(get_pass_config());
|
||||
m.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f32, ngraph::element::f16);
|
||||
m.register_pass<ngraph::pass::ConvertPrecision>(
|
||||
precisions_array{{ngraph::element::f32, ngraph::element::f16}});
|
||||
m.run_passes(f);
|
||||
return false;
|
||||
}
|
||||
|
@ -125,6 +125,194 @@ bool fuse_type_to_reduce_logical(const std::shared_ptr<ngraph::Node>& node,
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
void validate_nodes_and_infer_types(const std::vector<std::shared_ptr<Node>>& ops)
|
||||
{
|
||||
for (auto& node : ops)
|
||||
{
|
||||
node->revalidate_and_infer_types();
|
||||
}
|
||||
}
|
||||
|
||||
bool convert_precision(pass::PassBase& pass,
|
||||
const std::shared_ptr<ngraph::Function>& f,
|
||||
const type_to_fuse_map& type_to_fuse,
|
||||
const type_to_fuse_map& type_to_extend,
|
||||
element::Type from,
|
||||
element::Type to)
|
||||
{
|
||||
// As Constant operations can be shared between multiple nGraph Functions so before
|
||||
// changing precision we need to understand which Constant consumers belongs
|
||||
// to the current nGraph Function
|
||||
std::unordered_map<const ngraph::Node*, std::vector<Input<Node>>> const_to_internal_output;
|
||||
|
||||
auto register_constants =
|
||||
[&const_to_internal_output](const std::vector<std::shared_ptr<Node>>& ops) {
|
||||
for (auto& node : ops)
|
||||
{
|
||||
for (auto& input : node->inputs())
|
||||
{
|
||||
if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(
|
||||
input.get_source_output().get_node_shared_ptr()))
|
||||
{
|
||||
const_to_internal_output[const_node.get()].emplace_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto convert_node_output_precision = [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||
for (auto output : node->outputs())
|
||||
{
|
||||
if (output.get_element_type() == from)
|
||||
{
|
||||
// Handle case with Constants as they can have consumers from other nGraph
|
||||
// Function object
|
||||
auto it = const_to_internal_output.find(node.get());
|
||||
if (it != const_to_internal_output.end())
|
||||
{
|
||||
return fuse_type_to_constant(node, to, it->second);
|
||||
}
|
||||
|
||||
// Check that node type exists in map and we can fuse type into node
|
||||
auto t2f_it = type_to_fuse.find(node->get_type_info());
|
||||
if (t2f_it != type_to_fuse.end() &&
|
||||
t2f_it->second(node, to, output.get_index()))
|
||||
{
|
||||
// We need to break if original node was replaced
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto convert_node_input_precision = [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||
for (auto input : node->inputs())
|
||||
{
|
||||
if (input.get_element_type() == from)
|
||||
{
|
||||
// For some operations we need to extend their input types to support new type
|
||||
auto it = type_to_extend.find(node->get_type_info());
|
||||
if (it != type_to_extend.end() && it->second(node, to, input.get_index()))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
std::function<bool(const std::shared_ptr<Function>&, bool)> convert_function_precision =
|
||||
[&](const std::shared_ptr<Function>& f, bool is_subgraph) {
|
||||
bool is_changed = false;
|
||||
|
||||
auto ops = f->get_ordered_ops();
|
||||
|
||||
// Iterate over all nodes in topological order and then iterate over node outputs.
|
||||
// If output type mismatch given type we try to fuse type into this operation
|
||||
// otherwise we insert Convert operation.
|
||||
for (auto& node : ops)
|
||||
{
|
||||
pass.transformation_callback(node);
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
||||
{
|
||||
if (auto sub_graph = sub_graph_node->get_function())
|
||||
{
|
||||
is_changed |= convert_function_precision(sub_graph, true);
|
||||
}
|
||||
}
|
||||
is_changed |= convert_node_input_precision(node);
|
||||
}
|
||||
|
||||
if (is_changed)
|
||||
ops = f->get_ordered_ops();
|
||||
|
||||
// Register internal constants only after fixing input type that could lead to nodes
|
||||
// replacement
|
||||
register_constants(ops);
|
||||
|
||||
bool is_output_precision_changed = false;
|
||||
|
||||
for (auto& node : ops)
|
||||
{
|
||||
is_output_precision_changed |= convert_node_output_precision(node);
|
||||
}
|
||||
|
||||
if (is_output_precision_changed)
|
||||
{
|
||||
ops = f->get_ordered_ops();
|
||||
is_changed |= is_output_precision_changed;
|
||||
}
|
||||
|
||||
if (!is_subgraph)
|
||||
{
|
||||
if (is_changed)
|
||||
validate_nodes_and_infer_types(ops);
|
||||
|
||||
// TODO: we need to split NopElimination pass to separate MatcherPasses and call
|
||||
// Convert elimination here
|
||||
for (auto& node : ops)
|
||||
{
|
||||
if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node))
|
||||
{
|
||||
// WA for topK, dont remove fake convert
|
||||
if (convert->input(0).get_element_type() ==
|
||||
convert->get_convert_element_type() &&
|
||||
convert->input_value(0).get_node_shared_ptr()->get_output_size() ==
|
||||
1)
|
||||
{
|
||||
replace_output_update_name(convert->output(0),
|
||||
convert->input_value(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return is_changed;
|
||||
};
|
||||
|
||||
return convert_function_precision(f, false);
|
||||
}
|
||||
|
||||
struct EnumClassHash
|
||||
{
|
||||
template <class T>
|
||||
std::size_t operator()(T t) const
|
||||
{
|
||||
return static_cast<size_t>(t);
|
||||
}
|
||||
};
|
||||
|
||||
using precisions_set_t = std::unordered_set<ngraph::element::Type_t, EnumClassHash>;
|
||||
|
||||
precisions_set_t find_all_used_precisions(const std::shared_ptr<ngraph::Function>& fn)
|
||||
{
|
||||
precisions_set_t used_precisions;
|
||||
|
||||
ngraph::traverse_nodes(fn, [&](std::shared_ptr<ngraph::Node> node) {
|
||||
for (auto output : node->outputs())
|
||||
{
|
||||
used_precisions.emplace(output.get_element_type());
|
||||
}
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node))
|
||||
{
|
||||
if (auto sub_graph = sub_graph_node->get_function())
|
||||
{
|
||||
auto sub_graph_precisions = find_all_used_precisions(sub_graph);
|
||||
used_precisions.insert(sub_graph_precisions.begin(),
|
||||
sub_graph_precisions.end());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return used_precisions;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPrecision, "ConvertPrecision", 0);
|
||||
|
||||
bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f)
|
||||
@ -161,117 +349,25 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
|
||||
{opset4::Select::type_info, extend_select_type},
|
||||
};
|
||||
|
||||
// As Constant operations can be shared between multiple nGraph Functions so before
|
||||
// changing precision we need to understand which Constant consumers belongs
|
||||
// to the current nGraph Function
|
||||
std::map<const std::shared_ptr<ngraph::Node>, std::vector<Input<Node>>>
|
||||
const_to_internal_output;
|
||||
bool is_changed = false;
|
||||
|
||||
std::function<void(const std::shared_ptr<Function>&)> register_constants =
|
||||
[&const_to_internal_output](const std::shared_ptr<Function>& f) {
|
||||
for (auto& node : f->get_ordered_ops())
|
||||
{
|
||||
for (auto& input : node->inputs())
|
||||
{
|
||||
if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(
|
||||
input.get_source_output().get_node_shared_ptr()))
|
||||
{
|
||||
const_to_internal_output[const_node].emplace_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
auto const used_precisions = find_all_used_precisions(f);
|
||||
|
||||
auto convert_node_output_precision = [this, &const_to_internal_output, &type_to_fuse](
|
||||
const std::shared_ptr<ngraph::Node>& node) {
|
||||
for (auto output : node->outputs())
|
||||
for (auto const& p : m_precisions)
|
||||
{
|
||||
if (output.get_element_type() == m_from)
|
||||
{
|
||||
// Handle case with Constants as they can have consumers from other nGraph Function
|
||||
// object
|
||||
if (ngraph::op::is_constant(node) && const_to_internal_output.count(node))
|
||||
{
|
||||
fuse_type_to_constant(node, m_to, const_to_internal_output.at(node));
|
||||
break;
|
||||
if (used_precisions.count(p.first))
|
||||
is_changed =
|
||||
is_changed |
|
||||
convert_precision(*this, f, type_to_fuse, type_to_extend, p.first, p.second);
|
||||
}
|
||||
|
||||
// Check that node type exists in map and we can fuse type into node
|
||||
if (type_to_fuse.count(node->get_type_info()) &&
|
||||
type_to_fuse.at(node->get_type_info())(node, m_to, output.get_index()))
|
||||
{
|
||||
// We need to break if original node was replaced
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
(void)is_changed; // ignored
|
||||
|
||||
auto convert_node_input_precision = [this](const std::shared_ptr<ngraph::Node>& node) {
|
||||
for (auto input : node->inputs())
|
||||
{
|
||||
if (input.get_element_type() == m_from)
|
||||
{
|
||||
// For some operations we need to extend their input types to support new type
|
||||
if (type_to_extend.count(node->get_type_info()) &&
|
||||
type_to_extend.at(node->get_type_info())(node, m_to, input.get_index()))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::function<void(const std::shared_ptr<Function>&)> convert_function_precision =
|
||||
[this,
|
||||
®ister_constants,
|
||||
&convert_node_output_precision,
|
||||
&convert_node_input_precision,
|
||||
&convert_function_precision](const std::shared_ptr<Function>& f) {
|
||||
// Iterate over all nodes in topological order and then iterate over node outputs.
|
||||
// If output type mismatch given type we try to fuse type into this operation
|
||||
// otherwise we insert Convert operation.
|
||||
for (auto& node : f->get_ordered_ops())
|
||||
{
|
||||
transformation_callback(node);
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
||||
{
|
||||
if (auto sub_graph = sub_graph_node->get_function())
|
||||
{
|
||||
convert_function_precision(sub_graph);
|
||||
}
|
||||
}
|
||||
convert_node_input_precision(node);
|
||||
}
|
||||
// Register internal constants only after fixing input type that could lead to nodes
|
||||
// replacement
|
||||
register_constants(f);
|
||||
|
||||
for (auto& node : f->get_ordered_ops())
|
||||
{
|
||||
convert_node_output_precision(node);
|
||||
}
|
||||
};
|
||||
|
||||
convert_function_precision(f);
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
// TODO: we need to split NopElimination pass to separate MatcherPasses and call Convert
|
||||
// elimination here
|
||||
for (auto& node : f->get_ordered_ops())
|
||||
{
|
||||
if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node))
|
||||
{
|
||||
// WA for topK, dont remove fake convert
|
||||
if (convert->input(0).get_element_type() == convert->get_convert_element_type() &&
|
||||
convert->input_value(0).get_node_shared_ptr()->get_output_size() == 1)
|
||||
{
|
||||
replace_output_update_name(convert->output(0), convert->input_value(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
// Returning value is false because pass::Manager always apply Validation pass
|
||||
// if function was changed. This helps to avoid excess Validations after applying
|
||||
// this pass. In future when we will return more meaningful status code it will be
|
||||
// replaced with real status reported by manager.run_passes() method call.
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_shapeof(const std::shared_ptr<ngraph::Node>& node, element::Type to, size_t idx)
|
||||
@ -532,15 +628,6 @@ namespace
|
||||
return new_constant;
|
||||
}
|
||||
|
||||
struct EnumClassHash
|
||||
{
|
||||
template <class T>
|
||||
std::size_t operator()(T t) const
|
||||
{
|
||||
return static_cast<size_t>(t);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Method converts low precision integer types
|
||||
* The method uses the next logic for conversion:
|
||||
@ -631,8 +718,8 @@ namespace
|
||||
element::Type to)
|
||||
{
|
||||
// Supported integer precisions
|
||||
static const std::unordered_set<element::Type_t, EnumClassHash>
|
||||
supported_integer_precisions = {element::i4, element::u4, element::u1};
|
||||
static const precisions_set_t supported_integer_precisions = {
|
||||
element::i4, element::u4, element::u1};
|
||||
// Get source element type and source data
|
||||
auto src_type = constant->get_element_type();
|
||||
const auto* src_data = reinterpret_cast<const uint8_t*>(constant->get_data_ptr());
|
||||
|
@ -439,3 +439,22 @@ NGRAPH_TEST(${BACKEND_NAME}, convert_u8_to_u64)
|
||||
|
||||
ConvertTest(input, input_shape, input_type, expected_output, expected_output_type);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, convert_float32_int8)
|
||||
{
|
||||
std::vector<float> f32vec = {-100.5, -20.5, -15, -10.5, -0.5, 0, 0.5, 10.5, 15, 20.5, 100.5};
|
||||
std::vector<int8_t> result(f32vec.size());
|
||||
std::vector<int8_t> i8vec(std::begin(f32vec), std::end(f32vec));
|
||||
runtime::reference::convert(f32vec.data(), result.data(), f32vec.size());
|
||||
EXPECT_EQ(result, i8vec);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, convert_fp16_int8)
|
||||
{
|
||||
std::vector<float> f32vec = {-100.5, -20.5, -15, -10.5, -0.5, 0, 0.5, 10.5, 15, 20.5, 100.5};
|
||||
std::vector<float16> f16vec(std::begin(f32vec), std::end(f32vec));
|
||||
std::vector<int8_t> i8vec(std::begin(f16vec), std::end(f16vec));
|
||||
std::vector<int8_t> result(i8vec.size());
|
||||
runtime::reference::convert(f16vec.data(), result.data(), f16vec.size());
|
||||
EXPECT_EQ(result, i8vec);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user