Convert operation optimization for FP16 -> INT8 and FP32 -> INT8 (#5275)
Improvements for ConvertPrecision transformation pass
This commit is contained in:
parent
c9672ee9ec
commit
d2c9bddff1
@ -67,7 +67,7 @@ std::shared_ptr<ngraph::Function> TransformNetwork(const std::shared_ptr<const n
|
|||||||
// Example: register CommonOptimizations transformation from transformations library
|
// Example: register CommonOptimizations transformation from transformations library
|
||||||
passManager.register_pass<ngraph::pass::CommonOptimizations>();
|
passManager.register_pass<ngraph::pass::CommonOptimizations>();
|
||||||
// Template plugin handles only FP32 networks
|
// Template plugin handles only FP32 networks
|
||||||
passManager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
passManager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ngraph::element::f16, ngraph::element::f32 }});
|
||||||
// Example: register plugin specific transformation
|
// Example: register plugin specific transformation
|
||||||
passManager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
|
passManager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
|
||||||
passManager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
|
passManager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
|
||||||
|
@ -175,7 +175,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
|||||||
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
|
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
|
||||||
manager.register_pass<ngraph::pass::ConvertGather0D>();
|
manager.register_pass<ngraph::pass::ConvertGather0D>();
|
||||||
|
|
||||||
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list {
|
static const precisions_array convert_precision_list {
|
||||||
{ngraph::element::i64, ngraph::element::i32},
|
{ngraph::element::i64, ngraph::element::i32},
|
||||||
{ngraph::element::u64, ngraph::element::i32},
|
{ngraph::element::u64, ngraph::element::i32},
|
||||||
{ngraph::element::u16, ngraph::element::i32},
|
{ngraph::element::u16, ngraph::element::i32},
|
||||||
@ -185,9 +185,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
|||||||
{ngraph::element::u4, ngraph::element::u8},
|
{ngraph::element::u4, ngraph::element::u8},
|
||||||
};
|
};
|
||||||
|
|
||||||
for (auto& precision : convert_precision_list) {
|
manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list);
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pass_config = manager.get_pass_config();
|
auto pass_config = manager.get_pass_config();
|
||||||
|
|
||||||
@ -366,7 +364,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
|||||||
// Conversion to FP32 might be needed for quantized models that face any fp16 related issues (e.g. overflow) for non-quantized layers
|
// Conversion to FP32 might be needed for quantized models that face any fp16 related issues (e.g. overflow) for non-quantized layers
|
||||||
// With this key users can work-around such issues
|
// With this key users can work-around such issues
|
||||||
if (!config.enable_fp16_for_quantized_models) {
|
if (!config.enable_fp16_for_quantized_models) {
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::f16, ngraph::element::f32 }});
|
||||||
}
|
}
|
||||||
auto lptPrerequisites = manager.register_pass<ngraph::pass::GraphRewrite>();
|
auto lptPrerequisites = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||||
const std::vector<ngraph::element::Type> supportedTypes = { ngraph::element::i8, ngraph::element::u8 };
|
const std::vector<ngraph::element::Type> supportedTypes = { ngraph::element::i8, ngraph::element::u8 };
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
#include <ie_plugin_config.hpp>
|
#include <ie_plugin_config.hpp>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <unordered_set>
|
||||||
#include <ie_system_conf.h>
|
#include <ie_system_conf.h>
|
||||||
#include <nodes/list.hpp>
|
#include <nodes/list.hpp>
|
||||||
#include <ie_ngraph_utils.hpp>
|
#include <ie_ngraph_utils.hpp>
|
||||||
@ -65,6 +66,7 @@
|
|||||||
#include <ngraph/opsets/opset6.hpp>
|
#include <ngraph/opsets/opset6.hpp>
|
||||||
#include <ngraph/op/util/op_types.hpp>
|
#include <ngraph/op/util/op_types.hpp>
|
||||||
#include <ngraph/pass/manager.hpp>
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <ngraph/graph_util.hpp>
|
||||||
|
|
||||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||||
|
|
||||||
@ -122,6 +124,28 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
|||||||
std::vector<ngraph::element::Type>{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
|
std::vector<ngraph::element::Type>{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto get_convert_precisions = []() {
|
||||||
|
precisions_array array = {
|
||||||
|
{ngraph::element::i64, ngraph::element::i32},
|
||||||
|
{ngraph::element::u64, ngraph::element::i32},
|
||||||
|
{ngraph::element::i16, ngraph::element::i32},
|
||||||
|
{ngraph::element::u16, ngraph::element::i32},
|
||||||
|
{ngraph::element::u32, ngraph::element::i32},
|
||||||
|
{ngraph::element::f64, ngraph::element::f32},
|
||||||
|
{ngraph::element::f16, ngraph::element::f32},
|
||||||
|
{ngraph::element::boolean, ngraph::element::u8},
|
||||||
|
{ngraph::element::i4, ngraph::element::i8},
|
||||||
|
{ngraph::element::u4, ngraph::element::u8}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!with_cpu_x86_avx512_core())
|
||||||
|
array.push_back({ngraph::element::bf16, ngraph::element::f32});
|
||||||
|
|
||||||
|
return array;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const auto precisions = get_convert_precisions();
|
||||||
|
|
||||||
// WA: ConvertPriorBox must be executed before the 1st ConstantFolding pass
|
// WA: ConvertPriorBox must be executed before the 1st ConstantFolding pass
|
||||||
manager.register_pass<ngraph::pass::CommonOptimizations>();
|
manager.register_pass<ngraph::pass::CommonOptimizations>();
|
||||||
manager.register_pass<ngraph::pass::ConvertRNNSequenceToTensorIterator>();
|
manager.register_pass<ngraph::pass::ConvertRNNSequenceToTensorIterator>();
|
||||||
@ -140,27 +164,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
|||||||
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();
|
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();
|
||||||
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
|
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
|
||||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list{
|
|
||||||
{ngraph::element::i64, ngraph::element::i32},
|
|
||||||
{ngraph::element::u64, ngraph::element::i32},
|
|
||||||
{ngraph::element::i16, ngraph::element::i32},
|
|
||||||
{ngraph::element::u16, ngraph::element::i32},
|
|
||||||
{ngraph::element::u32, ngraph::element::i32},
|
|
||||||
{ngraph::element::f64, ngraph::element::f32},
|
|
||||||
{ngraph::element::f16, ngraph::element::f32},
|
|
||||||
{ngraph::element::boolean, ngraph::element::u8},
|
|
||||||
{ngraph::element::i4, ngraph::element::i8},
|
|
||||||
{ngraph::element::u4, ngraph::element::u8},
|
|
||||||
};
|
|
||||||
|
|
||||||
// In case BF16 is not supported by the target CPU we explicitly convert it to FP32
|
|
||||||
if (!with_cpu_x86_avx512_core())
|
|
||||||
convert_precision_list.push_back({ngraph::element::bf16, ngraph::element::f32});
|
|
||||||
|
|
||||||
for (auto &precision : convert_precision_list) {
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pass_config = manager.get_pass_config();
|
auto pass_config = manager.get_pass_config();
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
|
|||||||
manager.register_pass<ReshapeFullyConnectedFusion>();
|
manager.register_pass<ReshapeFullyConnectedFusion>();
|
||||||
}
|
}
|
||||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});
|
||||||
manager.run_passes(nGraphFunc);
|
manager.run_passes(nGraphFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include <transformations_visibility.hpp>
|
#include <transformations_visibility.hpp>
|
||||||
|
|
||||||
@ -69,18 +70,24 @@ class NGRAPH_API ConvertPrecision;
|
|||||||
* LessEqual
|
* LessEqual
|
||||||
*/
|
*/
|
||||||
|
|
||||||
using type_to_fuse_map = std::map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
|
using type_to_fuse_map = std::unordered_map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
|
||||||
|
using precisions_array = std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>>;
|
||||||
|
|
||||||
class ngraph::pass::ConvertPrecision : public ngraph::pass::FunctionPass {
|
class ngraph::pass::ConvertPrecision : public ngraph::pass::FunctionPass {
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
ConvertPrecision(ngraph::element::Type_t from, ngraph::element::Type_t to, type_to_fuse_map additional_type_to_fuse_map = {})
|
ConvertPrecision(ngraph::element::Type_t from, ngraph::element::Type_t to, type_to_fuse_map additional_type_to_fuse_map = {})
|
||||||
: FunctionPass(),
|
: FunctionPass(),
|
||||||
m_from(from),
|
m_precisions(precisions_array {{ from, to }}),
|
||||||
m_to(to),
|
m_additional_type_to_fuse_map(additional_type_to_fuse_map) {}
|
||||||
|
|
||||||
|
ConvertPrecision(const precisions_array& precisions, const type_to_fuse_map & additional_type_to_fuse_map = {})
|
||||||
|
: FunctionPass(),
|
||||||
|
m_precisions(precisions),
|
||||||
m_additional_type_to_fuse_map(additional_type_to_fuse_map) {}
|
m_additional_type_to_fuse_map(additional_type_to_fuse_map) {}
|
||||||
|
|
||||||
bool run_on_function(std::shared_ptr<Function> f) override;
|
bool run_on_function(std::shared_ptr<Function> f) override;
|
||||||
private:
|
private:
|
||||||
element::Type m_from, m_to;
|
precisions_array m_precisions;
|
||||||
type_to_fuse_map m_additional_type_to_fuse_map;
|
type_to_fuse_map m_additional_type_to_fuse_map;
|
||||||
};
|
};
|
||||||
|
@ -15,10 +15,11 @@
|
|||||||
#include "vpu/ngraph/operations/out_shape_of_reshape.hpp"
|
#include "vpu/ngraph/operations/out_shape_of_reshape.hpp"
|
||||||
#include <stack>
|
#include <stack>
|
||||||
#include <deque>
|
#include <deque>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace vpu {
|
namespace vpu {
|
||||||
|
|
||||||
using typeToFuseMap = std::map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
|
using typeToFuseMap = std::unordered_map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
|
||||||
|
|
||||||
bool fuseTypeToStaticShapeNonMaxSuppression(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
bool fuseTypeToStaticShapeNonMaxSuppression(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||||
bool fuseTypeToStaticShapeNonZero(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
bool fuseTypeToStaticShapeNonZero(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||||
|
@ -192,14 +192,17 @@ ie::CNNNetwork FrontEnd::convertNetwork(ie::CNNNetwork& network) {
|
|||||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||||
// ConvertPrecision must be executed before ConvertOpSet1ToLegacy due to this pass works with operations from opsets only
|
// ConvertPrecision must be executed before ConvertOpSet1ToLegacy due to this pass works with operations from opsets only
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32, myriadTypeToFuseMap);
|
static const precisions_array precisions = {
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::u64, ngraph::element::i32, myriadTypeToFuseMap);
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::u32, ngraph::element::i32, myriadTypeToFuseMap);
|
{ ngraph::element::u64, ngraph::element::i32 },
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32, myriadTypeToFuseMap);
|
{ ngraph::element::u32, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::boolean, ngraph::element::i32 }
|
||||||
|
};
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions, myriadTypeToFuseMap);
|
||||||
|
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||||
// ConvertOpSet1ToLegacy can produce constants with I64 precision
|
// ConvertOpSet1ToLegacy can produce constants with I64 precision
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32, myriadTypeToFuseMap);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }}, myriadTypeToFuseMap);
|
||||||
manager.register_pass<vpu::MergeSubsequentDSROperations>();
|
manager.register_pass<vpu::MergeSubsequentDSROperations>();
|
||||||
|
|
||||||
auto pass_config = manager.get_pass_config();
|
auto pass_config = manager.get_pass_config();
|
||||||
|
@ -174,7 +174,7 @@ TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
|
|||||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||||
|
|
||||||
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list {
|
static const precisions_array convert_precision_list {
|
||||||
{ngraph::element::i64, ngraph::element::i32},
|
{ngraph::element::i64, ngraph::element::i32},
|
||||||
{ngraph::element::u64, ngraph::element::i32},
|
{ngraph::element::u64, ngraph::element::i32},
|
||||||
{ngraph::element::u16, ngraph::element::i32},
|
{ngraph::element::u16, ngraph::element::i32},
|
||||||
@ -183,12 +183,9 @@ TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
|
|||||||
{ngraph::element::boolean, ngraph::element::u8},
|
{ngraph::element::boolean, ngraph::element::u8},
|
||||||
};
|
};
|
||||||
|
|
||||||
for (auto & precision : convert_precision_list) {
|
manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list);
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});
|
||||||
|
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
@ -54,8 +54,13 @@ TEST(TransformationTests, ConvertPrecision_NMS3) {
|
|||||||
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,8 +82,13 @@ TEST(TransformationTests, ConvertPrecision_NMS4) {
|
|||||||
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,8 +105,13 @@ TEST(TransformationTests, ConvertPrecision_ShapeOf) {
|
|||||||
f = std::make_shared<Function>(NodeVector{shape_of}, ParameterVector{input});
|
f = std::make_shared<Function>(NodeVector{shape_of}, ParameterVector{input});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,8 +128,13 @@ TEST(TransformationTests, ConvertPrecision_Convert) {
|
|||||||
f = std::make_shared<Function>(NodeVector{convert}, ParameterVector{input});
|
f = std::make_shared<Function>(NodeVector{convert}, ParameterVector{input});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,7 +152,7 @@ TEST(TransformationTests, ConvertPrecision_ConvertElimination) {
|
|||||||
f = std::make_shared<Function>(NodeVector{convert}, ParameterVector{input});
|
f = std::make_shared<Function>(NodeVector{convert}, ParameterVector{input});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::f16, ngraph::element::f32 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
|
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
|
||||||
}
|
}
|
||||||
@ -158,8 +178,13 @@ TEST(TransformationTests, ConvertPrecision_TopK) {
|
|||||||
f = std::make_shared<Function>(OutputVector{topk->output(0), topk->output(1)}, ParameterVector{input});
|
f = std::make_shared<Function>(OutputVector{topk->output(0), topk->output(1)}, ParameterVector{input});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,8 +201,13 @@ TEST(TransformationTests, ConvertPrecision_NonZero) {
|
|||||||
f = std::make_shared<Function>(OutputVector{non_zero}, ParameterVector{input});
|
f = std::make_shared<Function>(OutputVector{non_zero}, ParameterVector{input});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,8 +225,13 @@ TEST(TransformationTests, ConvertPrecision_Bucketize) {
|
|||||||
f = std::make_shared<Function>(OutputVector{b}, ParameterVector{input});
|
f = std::make_shared<Function>(OutputVector{b}, ParameterVector{input});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -223,8 +258,13 @@ TEST(TransformationTests, ConvertPrecision_Roundings) {
|
|||||||
f = std::make_shared<Function>(OutputVector{ss}, ParameterVector{input});
|
f = std::make_shared<Function>(OutputVector{ss}, ParameterVector{input});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
|
|
||||||
auto casted_end = std::dynamic_pointer_cast<opset1::Constant>(ss->input_value(2).get_node_shared_ptr());
|
auto casted_end = std::dynamic_pointer_cast<opset1::Constant>(ss->input_value(2).get_node_shared_ptr());
|
||||||
@ -279,8 +319,13 @@ TEST(TransformationTests, ConvertPrecision_TIBody) {
|
|||||||
ngraph::ParameterVector{X, Y});
|
ngraph::ParameterVector{X, Y});
|
||||||
|
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::i64, ngraph::element::i32 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
|
|
||||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
|
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
|
||||||
@ -300,8 +345,13 @@ TEST(TransformationTests, ConvertPrecision_Equal) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -320,8 +370,13 @@ TEST(TransformationTests, ConvertPrecision_NotEqual) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,8 +395,13 @@ TEST(TransformationTests, ConvertPrecision_Greater) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -360,8 +420,13 @@ TEST(TransformationTests, ConvertPrecision_GreaterEqual) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -380,8 +445,13 @@ TEST(TransformationTests, ConvertPrecision_Less) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -400,8 +470,13 @@ TEST(TransformationTests, ConvertPrecision_LessEqual) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
static const precisions_array precisions = {
|
||||||
|
{ ngraph::element::boolean, ngraph::element::u8 },
|
||||||
|
{ ngraph::element::f16, ngraph::element::f32 }
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -420,7 +495,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalAnd) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,7 +513,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalOr) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -456,7 +531,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalXor) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,7 +548,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalNot) {
|
|||||||
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1});
|
f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -491,7 +566,7 @@ TEST(TransformationTests, ConvertPrecision_Select) {
|
|||||||
f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
|
f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::u8 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -509,8 +584,8 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxedWithSelect) {
|
|||||||
f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
|
f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::i32 }});
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i32, ngraph::element::i64);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i32, ngraph::element::i64 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -529,8 +604,8 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxed) {
|
|||||||
f = std::make_shared<Function>(OutputVector{type_relaxed}, ParameterVector{input1});
|
f = std::make_shared<Function>(OutputVector{type_relaxed}, ParameterVector{input1});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::boolean, ngraph::element::i32 }});
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i32, ngraph::element::i64);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i32, ngraph::element::i64 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
|
|
||||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
|
ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
|
||||||
@ -555,7 +630,7 @@ TEST(TransformationTests, ConvertPrecision_Variables) {
|
|||||||
f = std::make_shared<Function>(NodeVector{mul}, ParameterVector{inp});
|
f = std::make_shared<Function>(NodeVector{mul}, ParameterVector{inp});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::f16, ngraph::element::f32 }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -573,7 +648,7 @@ void constant_convert_test(element::Type type_from, element::Type type_to, const
|
|||||||
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(type_from, type_to);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ type_from, type_to }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
auto ops = f->get_ordered_ops();
|
auto ops = f->get_ordered_ops();
|
||||||
@ -603,7 +678,7 @@ void constant_convert_test(element::Type_t type_from, element::Type_t type_to, F
|
|||||||
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
||||||
|
|
||||||
pass::Manager manager;
|
pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::ConvertPrecision>(type_from, type_to);
|
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ type_from, type_to }});
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
auto ops = f->get_ordered_ops();
|
auto ops = f->get_ordered_ops();
|
||||||
|
@ -28,6 +28,10 @@ namespace ngraph
|
|||||||
void convert<uint8_t, float16>(const uint8_t* arg, float16* out, size_t count);
|
void convert<uint8_t, float16>(const uint8_t* arg, float16* out, size_t count);
|
||||||
template <>
|
template <>
|
||||||
void convert<float16, float>(const float16* arg, float* out, size_t count);
|
void convert<float16, float>(const float16* arg, float* out, size_t count);
|
||||||
|
template <>
|
||||||
|
void convert<float, int8_t>(const float* arg, int8_t* out, size_t count);
|
||||||
|
template <>
|
||||||
|
void convert<float16, int8_t>(const float16* arg, int8_t* out, size_t count);
|
||||||
|
|
||||||
// overload to handle ngraph::boolean (it is stored as char)
|
// overload to handle ngraph::boolean (it is stored as char)
|
||||||
template <typename TI, typename TO>
|
template <typename TI, typename TO>
|
||||||
@ -39,7 +43,6 @@ namespace ngraph
|
|||||||
out[i] = static_cast<char>(static_cast<bool>(arg[i]));
|
out[i] = static_cast<char>(static_cast<bool>(arg[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace reference
|
} // namespace reference
|
||||||
|
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
|
@ -16,6 +16,11 @@ namespace ngraph
|
|||||||
template <typename src_t, typename dst_t>
|
template <typename src_t, typename dst_t>
|
||||||
void jit_convert_vec(jit::Generator&, const Xbyak::RegExp&, const Xbyak::RegExp&);
|
void jit_convert_vec(jit::Generator&, const Xbyak::RegExp&, const Xbyak::RegExp&);
|
||||||
|
|
||||||
|
template <typename src_t, typename dst_t>
|
||||||
|
void jit_convert_vec_prepare(jit::Generator&)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void jit_convert_vec<uint8_t, float16>(jit::Generator& gen,
|
void jit_convert_vec<uint8_t, float16>(jit::Generator& gen,
|
||||||
const Xbyak::RegExp& src,
|
const Xbyak::RegExp& src,
|
||||||
@ -47,6 +52,61 @@ namespace ngraph
|
|||||||
gen.vmovups(gen.yword[dst], f32vec);
|
gen.vmovups(gen.yword[dst], f32vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void jit_convert_vec_prepare<float, int8_t>(jit::Generator& gen)
|
||||||
|
{
|
||||||
|
auto order = gen.ymm1;
|
||||||
|
auto addr = gen.r15;
|
||||||
|
|
||||||
|
static const int8_t offsets[32] = {0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1,
|
||||||
|
-1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 4,
|
||||||
|
8, 12, -1, -1, -1, -1, -1, -1, -1, -1};
|
||||||
|
|
||||||
|
gen.mov(addr, (size_t)offsets); // get offsets[] address
|
||||||
|
gen.vmovdqu(order, gen.yword[addr]); // save offsets[] to ymm register
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void jit_convert_vec<float, int8_t>(jit::Generator& gen,
|
||||||
|
const Xbyak::RegExp& src,
|
||||||
|
const Xbyak::RegExp& dst)
|
||||||
|
{
|
||||||
|
auto order = gen.ymm1;
|
||||||
|
auto p32vec = gen.ymm2;
|
||||||
|
auto p32vec_lo = gen.xmm2;
|
||||||
|
auto p32vec_hi = gen.xmm3;
|
||||||
|
|
||||||
|
gen.vcvtps2dq(p32vec, gen.yword[src]); // convert 8 floats to 8 ints
|
||||||
|
gen.vpshufb(p32vec, p32vec, order); // Shuffle the bytes according to the order
|
||||||
|
gen.vextracti128(p32vec_hi, p32vec, 1); // extract upper part of p32vec
|
||||||
|
gen.vpor(p32vec_lo, p32vec_lo, p32vec_hi); // p32vec_lo = p32vec_lo | p32vec_hi
|
||||||
|
gen.movq(gen.qword[dst], p32vec_lo); // save the result
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void jit_convert_vec_prepare<float16, int8_t>(jit::Generator& gen)
|
||||||
|
{
|
||||||
|
jit_convert_vec_prepare<float, int8_t>(gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void jit_convert_vec<float16, int8_t>(jit::Generator& gen,
|
||||||
|
const Xbyak::RegExp& src,
|
||||||
|
const Xbyak::RegExp& dst)
|
||||||
|
{
|
||||||
|
auto order = gen.ymm1;
|
||||||
|
auto p32vec = gen.ymm2;
|
||||||
|
auto p32vec_lo = gen.xmm2;
|
||||||
|
auto p32vec_hi = gen.xmm3;
|
||||||
|
|
||||||
|
gen.vcvtph2ps(p32vec, gen.xword[src]); // convert 8 fp16's to 8 floats
|
||||||
|
gen.vcvtps2dq(p32vec, p32vec); // convert 8 floats to 8 ints
|
||||||
|
gen.vpshufb(p32vec, p32vec, order); // Shuffle the bytes according to the order
|
||||||
|
gen.vextracti128(p32vec_hi, p32vec, 1); // extract upper part of p32vec
|
||||||
|
gen.vpor(p32vec_lo, p32vec_lo, p32vec_hi); // p32vec_lo = p32vec_lo | p32vec_hi
|
||||||
|
gen.movq(gen.qword[dst], p32vec_lo); // save the result
|
||||||
|
}
|
||||||
|
|
||||||
class jit_convert_array : public jit::Generator
|
class jit_convert_array : public jit::Generator
|
||||||
{
|
{
|
||||||
typedef struct context
|
typedef struct context
|
||||||
@ -61,6 +121,7 @@ namespace ngraph
|
|||||||
void (*convert_vec)(jit::Generator&,
|
void (*convert_vec)(jit::Generator&,
|
||||||
const Xbyak::RegExp&,
|
const Xbyak::RegExp&,
|
||||||
const Xbyak::RegExp&);
|
const Xbyak::RegExp&);
|
||||||
|
void (*prepare)(jit::Generator&);
|
||||||
} context_t;
|
} context_t;
|
||||||
|
|
||||||
jit_convert_array(const context_t& ctx)
|
jit_convert_array(const context_t& ctx)
|
||||||
@ -77,6 +138,8 @@ namespace ngraph
|
|||||||
|
|
||||||
preamble();
|
preamble();
|
||||||
|
|
||||||
|
ctx.prepare(*this);
|
||||||
|
|
||||||
mov(reg_src, ptr[param + offsetof(args_t, src)]);
|
mov(reg_src, ptr[param + offsetof(args_t, src)]);
|
||||||
mov(reg_dst, ptr[param + offsetof(args_t, out)]);
|
mov(reg_dst, ptr[param + offsetof(args_t, out)]);
|
||||||
mov(reg_sz, ptr[param + offsetof(args_t, count)]);
|
mov(reg_sz, ptr[param + offsetof(args_t, count)]);
|
||||||
@ -137,7 +200,8 @@ namespace ngraph
|
|||||||
static const jit_convert_array::context_t context{
|
static const jit_convert_array::context_t context{
|
||||||
{sizeof(src_t), &jit::Generator::copy<src_t>},
|
{sizeof(src_t), &jit::Generator::copy<src_t>},
|
||||||
{sizeof(dst_t), &jit::Generator::copy<dst_t>},
|
{sizeof(dst_t), &jit::Generator::copy<dst_t>},
|
||||||
jit_convert_vec<src_t, dst_t>};
|
jit_convert_vec<src_t, dst_t>,
|
||||||
|
jit_convert_vec_prepare<src_t, dst_t>};
|
||||||
|
|
||||||
static jit_convert_array generator(context);
|
static jit_convert_array generator(context);
|
||||||
|
|
||||||
@ -146,44 +210,49 @@ namespace ngraph
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TI, typename TO>
|
||||||
|
void convert_impl(const TI* arg, TO* out, size_t count)
|
||||||
|
{
|
||||||
|
auto converter = jit_convert_array::get<TI, TO>();
|
||||||
|
|
||||||
|
if (converter)
|
||||||
|
{
|
||||||
|
jit_convert_array::args_t args = {arg, out, count};
|
||||||
|
converter(&args);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < count; ++i)
|
||||||
|
{
|
||||||
|
out[i] = static_cast<TO>(arg[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void convert<uint8_t, float16>(const uint8_t* arg, float16* out, size_t count)
|
void convert<uint8_t, float16>(const uint8_t* arg, float16* out, size_t count)
|
||||||
{
|
{
|
||||||
auto converter = jit_convert_array::get<uint8_t, float16>();
|
convert_impl(arg, out, count);
|
||||||
|
|
||||||
if (converter)
|
|
||||||
{
|
|
||||||
jit_convert_array::args_t args = {arg, out, count};
|
|
||||||
converter(&args);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < count; ++i)
|
|
||||||
{
|
|
||||||
out[i] = static_cast<float16>(arg[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void convert<float16, float>(const float16* arg, float* out, size_t count)
|
void convert<float16, float>(const float16* arg, float* out, size_t count)
|
||||||
{
|
{
|
||||||
auto converter = jit_convert_array::get<float16, float>();
|
convert_impl(arg, out, count);
|
||||||
|
}
|
||||||
|
|
||||||
if (converter)
|
template <>
|
||||||
{
|
void convert<float, int8_t>(const float* arg, int8_t* out, size_t count)
|
||||||
jit_convert_array::args_t args = {arg, out, count};
|
{
|
||||||
converter(&args);
|
convert_impl(arg, out, count);
|
||||||
}
|
}
|
||||||
else
|
|
||||||
{
|
template <>
|
||||||
for (size_t i = 0; i < count; ++i)
|
void convert<float16, int8_t>(const float16* arg, int8_t* out, size_t count)
|
||||||
{
|
{
|
||||||
out[i] = static_cast<float>(arg[i]);
|
convert_impl(arg, out, count);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} // namespace reference
|
} // namespace reference
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
|
@ -158,6 +158,26 @@ namespace ngraph
|
|||||||
pop(rsi);
|
pop(rsi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void Generator::copy<int8_t>(const Xbyak::Reg64& dst,
|
||||||
|
const Xbyak::Reg64& src,
|
||||||
|
const Xbyak::Reg64& size)
|
||||||
|
{
|
||||||
|
push(rsi);
|
||||||
|
push(r15);
|
||||||
|
|
||||||
|
xor_(rsi, rsi);
|
||||||
|
|
||||||
|
foreach (rsi, 1, size, [&, this](const Xbyak::Reg64& idx) {
|
||||||
|
mov(r15b, byte[src + idx * sizeof(int8_t)]);
|
||||||
|
mov(byte[dst + idx * sizeof(int8_t)], r15b);
|
||||||
|
})
|
||||||
|
;
|
||||||
|
|
||||||
|
pop(r15);
|
||||||
|
pop(rsi);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void Generator::copy<uint16_t>(const Xbyak::Reg64& dst,
|
void Generator::copy<uint16_t>(const Xbyak::Reg64& dst,
|
||||||
const Xbyak::Reg64& src,
|
const Xbyak::Reg64& src,
|
||||||
|
@ -15,7 +15,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0);
|
|||||||
bool ngraph::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ngraph::Function> f)
|
bool ngraph::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ngraph::Function> f)
|
||||||
{
|
{
|
||||||
ngraph::pass::Manager m(get_pass_config());
|
ngraph::pass::Manager m(get_pass_config());
|
||||||
m.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f32, ngraph::element::f16);
|
m.register_pass<ngraph::pass::ConvertPrecision>(
|
||||||
|
precisions_array{{ngraph::element::f32, ngraph::element::f16}});
|
||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -125,6 +125,194 @@ bool fuse_type_to_reduce_logical(const std::shared_ptr<ngraph::Node>& node,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
void validate_nodes_and_infer_types(const std::vector<std::shared_ptr<Node>>& ops)
|
||||||
|
{
|
||||||
|
for (auto& node : ops)
|
||||||
|
{
|
||||||
|
node->revalidate_and_infer_types();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool convert_precision(pass::PassBase& pass,
|
||||||
|
const std::shared_ptr<ngraph::Function>& f,
|
||||||
|
const type_to_fuse_map& type_to_fuse,
|
||||||
|
const type_to_fuse_map& type_to_extend,
|
||||||
|
element::Type from,
|
||||||
|
element::Type to)
|
||||||
|
{
|
||||||
|
// As Constant operations can be shared between multiple nGraph Functions so before
|
||||||
|
// changing precision we need to understand which Constant consumers belongs
|
||||||
|
// to the current nGraph Function
|
||||||
|
std::unordered_map<const ngraph::Node*, std::vector<Input<Node>>> const_to_internal_output;
|
||||||
|
|
||||||
|
auto register_constants =
|
||||||
|
[&const_to_internal_output](const std::vector<std::shared_ptr<Node>>& ops) {
|
||||||
|
for (auto& node : ops)
|
||||||
|
{
|
||||||
|
for (auto& input : node->inputs())
|
||||||
|
{
|
||||||
|
if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(
|
||||||
|
input.get_source_output().get_node_shared_ptr()))
|
||||||
|
{
|
||||||
|
const_to_internal_output[const_node.get()].emplace_back(input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto convert_node_output_precision = [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||||
|
for (auto output : node->outputs())
|
||||||
|
{
|
||||||
|
if (output.get_element_type() == from)
|
||||||
|
{
|
||||||
|
// Handle case with Constants as they can have consumers from other nGraph
|
||||||
|
// Function object
|
||||||
|
auto it = const_to_internal_output.find(node.get());
|
||||||
|
if (it != const_to_internal_output.end())
|
||||||
|
{
|
||||||
|
return fuse_type_to_constant(node, to, it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that node type exists in map and we can fuse type into node
|
||||||
|
auto t2f_it = type_to_fuse.find(node->get_type_info());
|
||||||
|
if (t2f_it != type_to_fuse.end() &&
|
||||||
|
t2f_it->second(node, to, output.get_index()))
|
||||||
|
{
|
||||||
|
// We need to break if original node was replaced
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto convert_node_input_precision = [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||||
|
for (auto input : node->inputs())
|
||||||
|
{
|
||||||
|
if (input.get_element_type() == from)
|
||||||
|
{
|
||||||
|
// For some operations we need to extend their input types to support new type
|
||||||
|
auto it = type_to_extend.find(node->get_type_info());
|
||||||
|
if (it != type_to_extend.end() && it->second(node, to, input.get_index()))
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::function<bool(const std::shared_ptr<Function>&, bool)> convert_function_precision =
|
||||||
|
[&](const std::shared_ptr<Function>& f, bool is_subgraph) {
|
||||||
|
bool is_changed = false;
|
||||||
|
|
||||||
|
auto ops = f->get_ordered_ops();
|
||||||
|
|
||||||
|
// Iterate over all nodes in topological order and then iterate over node outputs.
|
||||||
|
// If output type mismatch given type we try to fuse type into this operation
|
||||||
|
// otherwise we insert Convert operation.
|
||||||
|
for (auto& node : ops)
|
||||||
|
{
|
||||||
|
pass.transformation_callback(node);
|
||||||
|
// Recursively apply transformation for sub-graph based operations
|
||||||
|
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
||||||
|
{
|
||||||
|
if (auto sub_graph = sub_graph_node->get_function())
|
||||||
|
{
|
||||||
|
is_changed |= convert_function_precision(sub_graph, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is_changed |= convert_node_input_precision(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_changed)
|
||||||
|
ops = f->get_ordered_ops();
|
||||||
|
|
||||||
|
// Register internal constants only after fixing input type that could lead to nodes
|
||||||
|
// replacement
|
||||||
|
register_constants(ops);
|
||||||
|
|
||||||
|
bool is_output_precision_changed = false;
|
||||||
|
|
||||||
|
for (auto& node : ops)
|
||||||
|
{
|
||||||
|
is_output_precision_changed |= convert_node_output_precision(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_output_precision_changed)
|
||||||
|
{
|
||||||
|
ops = f->get_ordered_ops();
|
||||||
|
is_changed |= is_output_precision_changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_subgraph)
|
||||||
|
{
|
||||||
|
if (is_changed)
|
||||||
|
validate_nodes_and_infer_types(ops);
|
||||||
|
|
||||||
|
// TODO: we need to split NopElimination pass to separate MatcherPasses and call
|
||||||
|
// Convert elimination here
|
||||||
|
for (auto& node : ops)
|
||||||
|
{
|
||||||
|
if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node))
|
||||||
|
{
|
||||||
|
// WA for topK, dont remove fake convert
|
||||||
|
if (convert->input(0).get_element_type() ==
|
||||||
|
convert->get_convert_element_type() &&
|
||||||
|
convert->input_value(0).get_node_shared_ptr()->get_output_size() ==
|
||||||
|
1)
|
||||||
|
{
|
||||||
|
replace_output_update_name(convert->output(0),
|
||||||
|
convert->input_value(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return is_changed;
|
||||||
|
};
|
||||||
|
|
||||||
|
return convert_function_precision(f, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct EnumClassHash
|
||||||
|
{
|
||||||
|
template <class T>
|
||||||
|
std::size_t operator()(T t) const
|
||||||
|
{
|
||||||
|
return static_cast<size_t>(t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using precisions_set_t = std::unordered_set<ngraph::element::Type_t, EnumClassHash>;
|
||||||
|
|
||||||
|
precisions_set_t find_all_used_precisions(const std::shared_ptr<ngraph::Function>& fn)
|
||||||
|
{
|
||||||
|
precisions_set_t used_precisions;
|
||||||
|
|
||||||
|
ngraph::traverse_nodes(fn, [&](std::shared_ptr<ngraph::Node> node) {
|
||||||
|
for (auto output : node->outputs())
|
||||||
|
{
|
||||||
|
used_precisions.emplace(output.get_element_type());
|
||||||
|
}
|
||||||
|
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node))
|
||||||
|
{
|
||||||
|
if (auto sub_graph = sub_graph_node->get_function())
|
||||||
|
{
|
||||||
|
auto sub_graph_precisions = find_all_used_precisions(sub_graph);
|
||||||
|
used_precisions.insert(sub_graph_precisions.begin(),
|
||||||
|
sub_graph_precisions.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return used_precisions;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPrecision, "ConvertPrecision", 0);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPrecision, "ConvertPrecision", 0);
|
||||||
|
|
||||||
bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f)
|
bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f)
|
||||||
@ -161,117 +349,25 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
|
|||||||
{opset4::Select::type_info, extend_select_type},
|
{opset4::Select::type_info, extend_select_type},
|
||||||
};
|
};
|
||||||
|
|
||||||
// As Constant operations can be shared between multiple nGraph Functions so before
|
bool is_changed = false;
|
||||||
// changing precision we need to understand which Constant consumers belongs
|
|
||||||
// to the current nGraph Function
|
|
||||||
std::map<const std::shared_ptr<ngraph::Node>, std::vector<Input<Node>>>
|
|
||||||
const_to_internal_output;
|
|
||||||
|
|
||||||
std::function<void(const std::shared_ptr<Function>&)> register_constants =
|
auto const used_precisions = find_all_used_precisions(f);
|
||||||
[&const_to_internal_output](const std::shared_ptr<Function>& f) {
|
|
||||||
for (auto& node : f->get_ordered_ops())
|
|
||||||
{
|
|
||||||
for (auto& input : node->inputs())
|
|
||||||
{
|
|
||||||
if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(
|
|
||||||
input.get_source_output().get_node_shared_ptr()))
|
|
||||||
{
|
|
||||||
const_to_internal_output[const_node].emplace_back(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto convert_node_output_precision = [this, &const_to_internal_output, &type_to_fuse](
|
for (auto const& p : m_precisions)
|
||||||
const std::shared_ptr<ngraph::Node>& node) {
|
|
||||||
for (auto output : node->outputs())
|
|
||||||
{
|
|
||||||
if (output.get_element_type() == m_from)
|
|
||||||
{
|
|
||||||
// Handle case with Constants as they can have consumers from other nGraph Function
|
|
||||||
// object
|
|
||||||
if (ngraph::op::is_constant(node) && const_to_internal_output.count(node))
|
|
||||||
{
|
|
||||||
fuse_type_to_constant(node, m_to, const_to_internal_output.at(node));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that node type exists in map and we can fuse type into node
|
|
||||||
if (type_to_fuse.count(node->get_type_info()) &&
|
|
||||||
type_to_fuse.at(node->get_type_info())(node, m_to, output.get_index()))
|
|
||||||
{
|
|
||||||
// We need to break if original node was replaced
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto convert_node_input_precision = [this](const std::shared_ptr<ngraph::Node>& node) {
|
|
||||||
for (auto input : node->inputs())
|
|
||||||
{
|
|
||||||
if (input.get_element_type() == m_from)
|
|
||||||
{
|
|
||||||
// For some operations we need to extend their input types to support new type
|
|
||||||
if (type_to_extend.count(node->get_type_info()) &&
|
|
||||||
type_to_extend.at(node->get_type_info())(node, m_to, input.get_index()))
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::function<void(const std::shared_ptr<Function>&)> convert_function_precision =
|
|
||||||
[this,
|
|
||||||
®ister_constants,
|
|
||||||
&convert_node_output_precision,
|
|
||||||
&convert_node_input_precision,
|
|
||||||
&convert_function_precision](const std::shared_ptr<Function>& f) {
|
|
||||||
// Iterate over all nodes in topological order and then iterate over node outputs.
|
|
||||||
// If output type mismatch given type we try to fuse type into this operation
|
|
||||||
// otherwise we insert Convert operation.
|
|
||||||
for (auto& node : f->get_ordered_ops())
|
|
||||||
{
|
|
||||||
transformation_callback(node);
|
|
||||||
// Recursively apply transformation for sub-graph based operations
|
|
||||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
|
||||||
{
|
|
||||||
if (auto sub_graph = sub_graph_node->get_function())
|
|
||||||
{
|
|
||||||
convert_function_precision(sub_graph);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
convert_node_input_precision(node);
|
|
||||||
}
|
|
||||||
// Register internal constants only after fixing input type that could lead to nodes
|
|
||||||
// replacement
|
|
||||||
register_constants(f);
|
|
||||||
|
|
||||||
for (auto& node : f->get_ordered_ops())
|
|
||||||
{
|
|
||||||
convert_node_output_precision(node);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
convert_function_precision(f);
|
|
||||||
f->validate_nodes_and_infer_types();
|
|
||||||
|
|
||||||
// TODO: we need to split NopElimination pass to separate MatcherPasses and call Convert
|
|
||||||
// elimination here
|
|
||||||
for (auto& node : f->get_ordered_ops())
|
|
||||||
{
|
{
|
||||||
if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node))
|
if (used_precisions.count(p.first))
|
||||||
{
|
is_changed =
|
||||||
// WA for topK, dont remove fake convert
|
is_changed |
|
||||||
if (convert->input(0).get_element_type() == convert->get_convert_element_type() &&
|
convert_precision(*this, f, type_to_fuse, type_to_extend, p.first, p.second);
|
||||||
convert->input_value(0).get_node_shared_ptr()->get_output_size() == 1)
|
|
||||||
{
|
|
||||||
replace_output_update_name(convert->output(0), convert->input_value(0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return true;
|
|
||||||
|
(void)is_changed; // ignored
|
||||||
|
|
||||||
|
// Returning value is false because pass::Manager always apply Validation pass
|
||||||
|
// if function was changed. This helps to avoid excess Validations after applying
|
||||||
|
// this pass. In future when we will return more meaningful status code it will be
|
||||||
|
// replaced with real status reported by manager.run_passes() method call.
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool fuse_type_to_shapeof(const std::shared_ptr<ngraph::Node>& node, element::Type to, size_t idx)
|
bool fuse_type_to_shapeof(const std::shared_ptr<ngraph::Node>& node, element::Type to, size_t idx)
|
||||||
@ -532,15 +628,6 @@ namespace
|
|||||||
return new_constant;
|
return new_constant;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct EnumClassHash
|
|
||||||
{
|
|
||||||
template <class T>
|
|
||||||
std::size_t operator()(T t) const
|
|
||||||
{
|
|
||||||
return static_cast<size_t>(t);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Method converts low precision integer types
|
* @brief Method converts low precision integer types
|
||||||
* The method uses the next logic for conversion:
|
* The method uses the next logic for conversion:
|
||||||
@ -631,8 +718,8 @@ namespace
|
|||||||
element::Type to)
|
element::Type to)
|
||||||
{
|
{
|
||||||
// Supported integer precisions
|
// Supported integer precisions
|
||||||
static const std::unordered_set<element::Type_t, EnumClassHash>
|
static const precisions_set_t supported_integer_precisions = {
|
||||||
supported_integer_precisions = {element::i4, element::u4, element::u1};
|
element::i4, element::u4, element::u1};
|
||||||
// Get source element type and source data
|
// Get source element type and source data
|
||||||
auto src_type = constant->get_element_type();
|
auto src_type = constant->get_element_type();
|
||||||
const auto* src_data = reinterpret_cast<const uint8_t*>(constant->get_data_ptr());
|
const auto* src_data = reinterpret_cast<const uint8_t*>(constant->get_data_ptr());
|
||||||
|
@ -439,3 +439,22 @@ NGRAPH_TEST(${BACKEND_NAME}, convert_u8_to_u64)
|
|||||||
|
|
||||||
ConvertTest(input, input_shape, input_type, expected_output, expected_output_type);
|
ConvertTest(input, input_shape, input_type, expected_output, expected_output_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, convert_float32_int8)
|
||||||
|
{
|
||||||
|
std::vector<float> f32vec = {-100.5, -20.5, -15, -10.5, -0.5, 0, 0.5, 10.5, 15, 20.5, 100.5};
|
||||||
|
std::vector<int8_t> result(f32vec.size());
|
||||||
|
std::vector<int8_t> i8vec(std::begin(f32vec), std::end(f32vec));
|
||||||
|
runtime::reference::convert(f32vec.data(), result.data(), f32vec.size());
|
||||||
|
EXPECT_EQ(result, i8vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, convert_fp16_int8)
|
||||||
|
{
|
||||||
|
std::vector<float> f32vec = {-100.5, -20.5, -15, -10.5, -0.5, 0, 0.5, 10.5, 15, 20.5, 100.5};
|
||||||
|
std::vector<float16> f16vec(std::begin(f32vec), std::end(f32vec));
|
||||||
|
std::vector<int8_t> i8vec(std::begin(f16vec), std::end(f16vec));
|
||||||
|
std::vector<int8_t> result(i8vec.size());
|
||||||
|
runtime::reference::convert(f16vec.data(), result.data(), f16vec.size());
|
||||||
|
EXPECT_EQ(result, i8vec);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user