Reduce the number of validate and infer types in ConvertPrecision (#15277)
* Reduce the number of validate and infer types in ConvertPrecision Currently, ConvertPrecision pass frequently runs validate and infer types. This is due to the fact that it iterates over every precision pair, then over the whole model followed by validate and infer types. The proposed solution is to iterate over the model: for each node iterate over precisions array, update the node if required followed by validate and infer types. Ticket: 81311 * use map * clang format * move enum hasher * fix gpu * revalidate * reinvalidate if node has changed * remove validate for input prec changes * fix gpu * review * find * fix pytorch case * revalidate --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
parent
7578c636b9
commit
8477bc8897
@ -64,10 +64,16 @@ class TRANSFORMATIONS_API ConvertPrecision;
|
||||
* LessEqual
|
||||
*/
|
||||
|
||||
struct EnumClassHash {
|
||||
template <class T>
|
||||
std::size_t operator()(T t) const {
|
||||
return static_cast<size_t>(t);
|
||||
}
|
||||
};
|
||||
|
||||
using precisions_map = std::unordered_map<ov::element::Type_t, ov::element::Type, EnumClassHash>;
|
||||
using type_to_fuse_map =
|
||||
std::unordered_map<ov::NodeTypeInfo,
|
||||
std::function<bool(const std::shared_ptr<ov::Node>&, ov::element::Type, size_t idx)>>;
|
||||
using precisions_array = std::vector<std::pair<ov::element::Type, ov::element::Type>>;
|
||||
std::unordered_map<ov::NodeTypeInfo, std::function<bool(const std::shared_ptr<ov::Node>&, const precisions_map&)>>;
|
||||
|
||||
class ov::pass::ConvertPrecision : public ov::pass::ModelPass {
|
||||
public:
|
||||
@ -76,11 +82,11 @@ public:
|
||||
ov::element::Type_t to,
|
||||
type_to_fuse_map additional_type_to_fuse_map = {},
|
||||
bool keep_precision_sensitive_in_fp32 = false)
|
||||
: m_precisions(precisions_array{{from, to}}),
|
||||
: m_precisions(precisions_map{{from, to}}),
|
||||
m_additional_type_to_fuse_map(additional_type_to_fuse_map),
|
||||
m_keep_precision_sensitive_in_fp32(keep_precision_sensitive_in_fp32) {}
|
||||
|
||||
ConvertPrecision(const precisions_array& precisions,
|
||||
ConvertPrecision(const precisions_map& precisions,
|
||||
const type_to_fuse_map& additional_type_to_fuse_map = {},
|
||||
bool keep_precision_sensitive_in_fp32 = false)
|
||||
: m_precisions(precisions),
|
||||
@ -90,7 +96,7 @@ public:
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
|
||||
private:
|
||||
precisions_array m_precisions;
|
||||
precisions_map m_precisions;
|
||||
type_to_fuse_map m_additional_type_to_fuse_map;
|
||||
bool m_keep_precision_sensitive_in_fp32;
|
||||
};
|
||||
|
@ -20,8 +20,8 @@ bool ov::pass::ConvertCompressedOnlyToLegacy::run_on_model(const std::shared_ptr
|
||||
if (ov::op::util::has_decompression_converts(f)) {
|
||||
Manager manager(get_pass_config());
|
||||
|
||||
const precisions_array convert_precision_list{{ov::element::f32, ov::element::f16}};
|
||||
manager.register_pass<ConvertPrecision>(convert_precision_list);
|
||||
const precisions_map convert_precision_map{{ov::element::f32, ov::element::f16}};
|
||||
manager.register_pass<ConvertPrecision>(convert_precision_map);
|
||||
using namespace ov::pass;
|
||||
REGISTER_PASS(manager, EnableDecompressionConvertConstantFolding)
|
||||
REGISTER_PASS(manager, ConstantFolding)
|
||||
|
@ -28,37 +28,40 @@
|
||||
using namespace ov;
|
||||
|
||||
bool fuse_type_to_constant(const std::shared_ptr<ngraph::Node>& node,
|
||||
ngraph::element::Type to,
|
||||
const precisions_map& precisions,
|
||||
const std::vector<ngraph::Input<ngraph::Node>>& consumers);
|
||||
bool fuse_type_to_shapeof(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_shapeof_v0(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_unique_v10(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_range_v4(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_nms5(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_nms9(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_matrix_nms(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_multiclass_nms(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_generate_proposals(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_maxpool(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_nonzero(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_bucketize(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_shapeof(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_shapeof_v0(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_unique_v10(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_range_v4(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_nms5(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_nms9(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_matrix_nms(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_multiclass_nms(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_generate_proposals(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_maxpool(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_nonzero(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_bucketize(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_ctc_greedy_decoder_seq_len(const std::shared_ptr<ngraph::Node>& node,
|
||||
ngraph::element::Type to,
|
||||
size_t idx);
|
||||
const precisions_map& precisions);
|
||||
|
||||
bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx);
|
||||
bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
|
||||
bool extend_select_type(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool extend_reverse_type(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
|
||||
bool extend_select_type(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool extend_reverse_type(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
|
||||
template <typename T>
|
||||
bool fuse_type_to_binary_comparision(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_binary_comparision(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_overridden_output_type(to);
|
||||
return true;
|
||||
@ -72,7 +75,11 @@ bool fuse_type_to_binary_comparision(const std::shared_ptr<ngraph::Node>& node,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool fuse_type_to_logical(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_logical(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_overridden_output_type(to);
|
||||
type_relaxed->set_origin_input_type(ov::element::boolean, 0);
|
||||
@ -90,7 +97,11 @@ bool fuse_type_to_logical(const std::shared_ptr<ngraph::Node>& node, ngraph::ele
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool fuse_type_to_reduce_logical(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_reduce_logical(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_overridden_output_type(to);
|
||||
type_relaxed->set_origin_input_type(ov::element::boolean, 0);
|
||||
@ -107,23 +118,75 @@ bool fuse_type_to_reduce_logical(const std::shared_ptr<ngraph::Node>& node, ngra
|
||||
|
||||
namespace {
|
||||
|
||||
void validate_nodes_and_infer_types(const std::vector<std::shared_ptr<Node>>& ops) {
|
||||
for (auto& node : ops) {
|
||||
node->revalidate_and_infer_types();
|
||||
}
|
||||
bool node_is_replaced(const std::shared_ptr<Node>& node) {
|
||||
const auto outputs = node->outputs();
|
||||
bool has_consumers = std::all_of(outputs.begin(), outputs.end(), [](const Output<Node>& output) {
|
||||
return output.get_target_inputs().size() == 0;
|
||||
});
|
||||
return has_consumers && !(is_type<op::v0::Result>(node) || is_type<op::Sink>(node));
|
||||
}
|
||||
|
||||
bool convert_precision(ov::pass::PassBase& pass,
|
||||
const std::shared_ptr<ngraph::Function>& f,
|
||||
const type_to_fuse_map& type_to_fuse,
|
||||
const type_to_fuse_map& type_to_extend,
|
||||
ov::element::Type from,
|
||||
ov::element::Type to,
|
||||
bool skip_precision_sensitive = false) {
|
||||
// As Constant operations can be shared between multiple nGraph Functions so before
|
||||
// changing precision we need to understand which Constant consumers belongs
|
||||
// to the current nGraph Function
|
||||
std::unordered_map<const ngraph::Node*, std::vector<Input<Node>>> const_to_internal_output;
|
||||
bool convert_node_output_precision(
|
||||
const std::shared_ptr<ngraph::Node>& node,
|
||||
const precisions_map& precisions,
|
||||
const type_to_fuse_map& type_to_fuse,
|
||||
const std::unordered_map<const ngraph::Node*, std::vector<Input<Node>>>& const_to_internal_output,
|
||||
bool function_changed) {
|
||||
bool node_changed = false;
|
||||
// Handle case with Constants as they can have consumers from other nGraph Function object
|
||||
const auto constant = ov::as_type_ptr<opset10::Constant>(node);
|
||||
const auto it = const_to_internal_output.find(node.get());
|
||||
if (constant && it != const_to_internal_output.end()) {
|
||||
return fuse_type_to_constant(node, precisions, it->second);
|
||||
}
|
||||
|
||||
// Check that node type exists in map and we can fuse type into node
|
||||
const auto t2f_it = type_to_fuse.find(node->get_type_info());
|
||||
if (t2f_it != type_to_fuse.end()) {
|
||||
node_changed = t2f_it->second(node, precisions);
|
||||
}
|
||||
if ((function_changed || node_changed) && !node_is_replaced(node)) {
|
||||
node->revalidate_and_infer_types();
|
||||
}
|
||||
return node_changed;
|
||||
}
|
||||
|
||||
bool convert_node_input_precision(const std::shared_ptr<ngraph::Node>& node,
|
||||
const precisions_map& precisions,
|
||||
const type_to_fuse_map& type_to_extend) {
|
||||
// For some operations we need to extend their input types to support new type
|
||||
auto it = type_to_extend.find(node->get_type_info());
|
||||
if (it != type_to_extend.end()) {
|
||||
return it->second(node, precisions);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool convert_function_precision(
|
||||
const std::shared_ptr<Model>& f,
|
||||
const type_to_fuse_map& type_to_fuse,
|
||||
const type_to_fuse_map& type_to_extend,
|
||||
const precisions_map& precisions,
|
||||
std::unordered_map<const ngraph::Node*, std::vector<Input<Node>>>& const_to_internal_output,
|
||||
bool has_fp16_compression,
|
||||
bool skip_precision_sensitive,
|
||||
bool is_changed,
|
||||
bool is_subgraph) {
|
||||
bool is_output_precision_changed = false;
|
||||
|
||||
auto ops = f->get_ordered_ops();
|
||||
|
||||
// Iterate over all nodes in topological order and then iterate over node outputs.
|
||||
// If output type mismatch given type we try to fuse type into this operation
|
||||
// otherwise we insert Convert operation.
|
||||
for (auto& node : ops) {
|
||||
if (skip_precision_sensitive && fp16_compression_is_disabled(node) && has_fp16_compression)
|
||||
continue;
|
||||
is_changed |= convert_node_input_precision(node, precisions, type_to_extend);
|
||||
}
|
||||
|
||||
if (is_changed)
|
||||
ops = f->get_ordered_ops();
|
||||
|
||||
auto register_constants = [&const_to_internal_output](const std::vector<std::shared_ptr<Node>>& ops) {
|
||||
for (auto& node : ops) {
|
||||
@ -136,117 +199,81 @@ bool convert_precision(ov::pass::PassBase& pass,
|
||||
}
|
||||
};
|
||||
|
||||
auto convert_node_output_precision = [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||
bool res = false;
|
||||
// Handle case with Constants as they can have consumers from other nGraph Function object
|
||||
const auto constant = ov::as_type_ptr<opset10::Constant>(node);
|
||||
const auto it = const_to_internal_output.find(node.get());
|
||||
if (constant && constant->get_output_element_type(0) == from && it != const_to_internal_output.end()) {
|
||||
return fuse_type_to_constant(node, to, it->second);
|
||||
}
|
||||
// Register internal constants only after fixing input type that could lead to nodes
|
||||
// replacement
|
||||
register_constants(ops);
|
||||
|
||||
for (const auto& output : node->outputs()) {
|
||||
if (output.get_element_type() == from) {
|
||||
// Check that node type exists in map and we can fuse type into node
|
||||
const auto t2f_it = type_to_fuse.find(node->get_type_info());
|
||||
if (t2f_it != type_to_fuse.end()) {
|
||||
res |= t2f_it->second(node, to, output.get_index());
|
||||
}
|
||||
for (auto& node : ops) {
|
||||
// skip precision sensitive nodes
|
||||
if (skip_precision_sensitive && fp16_compression_is_disabled(node) && has_fp16_compression)
|
||||
continue;
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
|
||||
size_t sub_graphs_num = sub_graph_node->get_internal_subgraphs_size();
|
||||
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||
is_changed |= convert_function_precision(sub_graph_node->get_function(static_cast<int>(sub_graph_ind)),
|
||||
type_to_fuse,
|
||||
type_to_extend,
|
||||
precisions,
|
||||
const_to_internal_output,
|
||||
has_fp16_compression,
|
||||
skip_precision_sensitive,
|
||||
is_changed || is_output_precision_changed,
|
||||
true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
};
|
||||
is_output_precision_changed |= convert_node_output_precision(node,
|
||||
precisions,
|
||||
type_to_fuse,
|
||||
const_to_internal_output,
|
||||
is_changed || is_output_precision_changed);
|
||||
}
|
||||
|
||||
auto convert_node_input_precision = [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||
for (auto input : node->inputs()) {
|
||||
if (input.get_element_type() == from) {
|
||||
// For some operations we need to extend their input types to support new type
|
||||
auto it = type_to_extend.find(node->get_type_info());
|
||||
if (it != type_to_extend.end() && it->second(node, to, input.get_index())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
if (is_output_precision_changed) {
|
||||
ops = f->get_ordered_ops();
|
||||
is_changed |= is_output_precision_changed;
|
||||
}
|
||||
|
||||
std::function<bool(const std::shared_ptr<Model>&, bool)> convert_function_precision =
|
||||
[&](const std::shared_ptr<Model>& f, bool is_subgraph) {
|
||||
bool is_changed = false;
|
||||
|
||||
auto ops = f->get_ordered_ops();
|
||||
|
||||
// Iterate over all nodes in topological order and then iterate over node outputs.
|
||||
// If output type mismatch given type we try to fuse type into this operation
|
||||
// otherwise we insert Convert operation.
|
||||
for (auto& node : ops) {
|
||||
if (skip_precision_sensitive && fp16_compression_is_disabled(node) && to == element::f16)
|
||||
if (!is_subgraph) {
|
||||
// TODO: we need to split NopElimination pass to separate MatcherPasses and call
|
||||
// Convert elimination here
|
||||
for (auto& node : ops) {
|
||||
if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node)) {
|
||||
if (pass::constant_folding_is_disabled(node))
|
||||
continue;
|
||||
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
|
||||
size_t sub_graphs_num = sub_graph_node->get_internal_subgraphs_size();
|
||||
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||
is_changed |=
|
||||
convert_function_precision(sub_graph_node->get_function(static_cast<int>(sub_graph_ind)),
|
||||
true);
|
||||
}
|
||||
}
|
||||
is_changed |= convert_node_input_precision(node);
|
||||
}
|
||||
|
||||
if (is_changed)
|
||||
ops = f->get_ordered_ops();
|
||||
|
||||
// Register internal constants only after fixing input type that could lead to nodes
|
||||
// replacement
|
||||
register_constants(ops);
|
||||
|
||||
bool is_output_precision_changed = false;
|
||||
|
||||
for (auto& node : ops) {
|
||||
// skip precision sensitive nodes
|
||||
if (skip_precision_sensitive && fp16_compression_is_disabled(node) && to == element::f16)
|
||||
continue;
|
||||
is_output_precision_changed |= convert_node_output_precision(node);
|
||||
}
|
||||
|
||||
if (is_output_precision_changed) {
|
||||
ops = f->get_ordered_ops();
|
||||
is_changed |= is_output_precision_changed;
|
||||
}
|
||||
|
||||
if (!is_subgraph) {
|
||||
if (is_changed)
|
||||
validate_nodes_and_infer_types(ops);
|
||||
|
||||
// TODO: we need to split NopElimination pass to separate MatcherPasses and call
|
||||
// Convert elimination here
|
||||
for (auto& node : ops) {
|
||||
if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node)) {
|
||||
if (pass::constant_folding_is_disabled(node))
|
||||
continue;
|
||||
// WA for topK, dont remove fake convert
|
||||
if (convert->input(0).get_element_type() == convert->get_convert_element_type() &&
|
||||
convert->input_value(0).get_node_shared_ptr()->get_output_size() == 1) {
|
||||
replace_output_update_name(convert->output(0), convert->input_value(0));
|
||||
}
|
||||
}
|
||||
// WA for topK, dont remove fake convert
|
||||
if (convert->input(0).get_element_type() == convert->get_convert_element_type() &&
|
||||
convert->input_value(0).get_node_shared_ptr()->get_output_size() == 1) {
|
||||
replace_output_update_name(convert->output(0), convert->input_value(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return is_changed;
|
||||
};
|
||||
|
||||
return convert_function_precision(f, false);
|
||||
return is_changed;
|
||||
}
|
||||
|
||||
struct EnumClassHash {
|
||||
template <class T>
|
||||
std::size_t operator()(T t) const {
|
||||
return static_cast<size_t>(t);
|
||||
}
|
||||
};
|
||||
bool convert_precision(ov::pass::PassBase& pass,
|
||||
const std::shared_ptr<ngraph::Function>& f,
|
||||
const type_to_fuse_map& type_to_fuse,
|
||||
const type_to_fuse_map& type_to_extend,
|
||||
const precisions_map& precisions,
|
||||
bool has_fp16_compression,
|
||||
bool skip_precision_sensitive = false) {
|
||||
// As Constant operations can be shared between multiple nGraph Functions so before
|
||||
// changing precision we need to understand which Constant consumers belongs
|
||||
// to the current nGraph Function
|
||||
std::unordered_map<const ngraph::Node*, std::vector<Input<Node>>> const_to_internal_output;
|
||||
return convert_function_precision(f,
|
||||
type_to_fuse,
|
||||
type_to_extend,
|
||||
precisions,
|
||||
const_to_internal_output,
|
||||
has_fp16_compression,
|
||||
skip_precision_sensitive,
|
||||
false,
|
||||
false);
|
||||
}
|
||||
|
||||
using precisions_set_t = std::unordered_set<ngraph::element::Type_t, EnumClassHash>;
|
||||
|
||||
@ -273,6 +300,27 @@ precisions_set_t find_all_used_precisions(const std::shared_ptr<ngraph::Function
|
||||
} // namespace
|
||||
|
||||
bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
||||
const auto used_precisions_set = find_all_used_precisions(f);
|
||||
precisions_map used_precisions;
|
||||
for (const auto& p : used_precisions_set) {
|
||||
auto it = m_precisions.find(p);
|
||||
if (it != m_precisions.end())
|
||||
used_precisions.insert(*it);
|
||||
}
|
||||
|
||||
if (used_precisions.empty())
|
||||
return false;
|
||||
|
||||
bool has_fp16_compression = m_precisions.count(element::f32) > 0 && m_precisions[element::f32] == element::f16;
|
||||
|
||||
if (m_keep_precision_sensitive_in_fp32 && has_fp16_compression) {
|
||||
pass::Manager manager(get_pass_config());
|
||||
// Mark subgraphs with disable_fp16_compression to keep them in FP32
|
||||
manager.register_pass<pass::MarkSugraphsToKeepInMixedPrecision>();
|
||||
manager.register_pass<pass::AlignMixedFP32FP16Types>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
type_to_fuse_map type_to_fuse{
|
||||
{opset4::Parameter::get_type_info_static(), fuse_type_to_parameter},
|
||||
{opset4::Convert::get_type_info_static(), fuse_type_to_convert},
|
||||
@ -308,17 +356,6 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ngraph::Func
|
||||
{opset10::Unique::get_type_info_static(), fuse_type_to_unique_v10},
|
||||
{opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8}};
|
||||
|
||||
std::pair<ov::element::Type, ov::element::Type> compress_f16_pair = {ov::element::f32, ov::element::f16};
|
||||
bool has_compress_f16 = std::count(m_precisions.begin(), m_precisions.end(), compress_f16_pair) > 0;
|
||||
|
||||
if (m_keep_precision_sensitive_in_fp32 && has_compress_f16) {
|
||||
pass::Manager manager(get_pass_config());
|
||||
// Mark subgraphs with disable_fp16_compression to keep them in FP32
|
||||
manager.register_pass<pass::MarkSugraphsToKeepInMixedPrecision>();
|
||||
manager.register_pass<pass::AlignMixedFP32FP16Types>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
for (const auto& it : m_additional_type_to_fuse_map) {
|
||||
type_to_fuse[it.first] = it.second;
|
||||
}
|
||||
@ -330,20 +367,13 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ngraph::Func
|
||||
{opset1::Reverse::get_type_info_static(), extend_reverse_type},
|
||||
};
|
||||
|
||||
bool is_changed = false;
|
||||
|
||||
auto const used_precisions = find_all_used_precisions(f);
|
||||
|
||||
for (auto const& p : m_precisions) {
|
||||
if (used_precisions.count(p.first))
|
||||
is_changed = is_changed | convert_precision(*this,
|
||||
f,
|
||||
type_to_fuse,
|
||||
type_to_extend,
|
||||
p.first,
|
||||
p.second,
|
||||
m_keep_precision_sensitive_in_fp32);
|
||||
}
|
||||
bool is_changed = convert_precision(*this,
|
||||
f,
|
||||
type_to_fuse,
|
||||
type_to_extend,
|
||||
used_precisions,
|
||||
has_fp16_compression,
|
||||
m_keep_precision_sensitive_in_fp32);
|
||||
|
||||
// to remove extra converts
|
||||
if (m_keep_precision_sensitive_in_fp32) {
|
||||
@ -361,7 +391,11 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ngraph::Func
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_shapeof(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_shapeof(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto shapeof = ov::as_type_ptr<opset4::ShapeOf>(node)) {
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
shapeof->set_output_type(to);
|
||||
@ -371,7 +405,11 @@ bool fuse_type_to_shapeof(const std::shared_ptr<ngraph::Node>& node, ov::element
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto random_uniform = ov::as_type_ptr<opset8::RandomUniform>(node)) {
|
||||
if (to.is_integral_number() || to.is_real()) {
|
||||
random_uniform->set_out_type(to);
|
||||
@ -381,23 +419,28 @@ bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ngraph::Node>& node, o
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_unique_v10(const std::shared_ptr<Node>& node, ov::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_unique_v10(const std::shared_ptr<Node>& node, const precisions_map& precisions) {
|
||||
bool res = false;
|
||||
if (auto unique = ov::as_type_ptr<opset10::Unique>(node)) {
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
if (idx == 1 || idx == 2) {
|
||||
unique->set_index_element_type(to);
|
||||
res = true;
|
||||
} else if (idx == 3) {
|
||||
unique->set_count_element_type(to);
|
||||
res = true;
|
||||
}
|
||||
auto it = precisions.find(node->get_output_element_type(1));
|
||||
if (it != precisions.end()) {
|
||||
unique->set_index_element_type(it->second);
|
||||
res = true;
|
||||
}
|
||||
it = precisions.find(node->get_output_element_type(3));
|
||||
if (it != precisions.end()) {
|
||||
unique->set_count_element_type(it->second);
|
||||
res = true;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
bool fuse_type_to_range_v4(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_range_v4(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto range = ov::as_type_ptr<opset4::Range>(node)) {
|
||||
if (to.is_integral_number() || to.is_real()) {
|
||||
range->set_output_type(to);
|
||||
@ -407,7 +450,11 @@ bool fuse_type_to_range_v4(const std::shared_ptr<ngraph::Node>& node, ov::elemen
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto param = ov::as_type_ptr<opset4::Parameter>(node)) {
|
||||
param->set_element_type(to);
|
||||
param->validate_and_infer_types();
|
||||
@ -416,7 +463,11 @@ bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node, ov::eleme
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto convert = ov::as_type_ptr<opset4::Convert>(node)) {
|
||||
convert->set_convert_element_type(to);
|
||||
return true;
|
||||
@ -424,7 +475,11 @@ bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, ov::element
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto nms = ov::as_type_ptr<opset3::NonMaxSuppression>(node)) {
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
nms->set_output_type(to);
|
||||
@ -436,7 +491,11 @@ bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, ngraph::elemen
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto nms = ov::as_type_ptr<opset4::NonMaxSuppression>(node)) {
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
nms->set_output_type(to);
|
||||
@ -448,75 +507,125 @@ bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, ngraph::elemen
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_nms5(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_nms5(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto nms = ov::as_type_ptr<opset5::NonMaxSuppression>(node);
|
||||
if (!nms) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((idx == 0 || idx == 2) && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
nms->set_output_type(to);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_overridden_output_type(to, idx);
|
||||
return true;
|
||||
bool res = false;
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it != precisions.end()) {
|
||||
const auto& to = it->second;
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
nms->set_output_type(to);
|
||||
res = true;
|
||||
if (precisions.count(node->get_output_element_type(1)) == 0) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node);
|
||||
ov::element::TypeVector output_types;
|
||||
for (const auto& output : nms->outputs()) {
|
||||
output_types.emplace_back(output.get_element_type());
|
||||
for (size_t i = 0; i < node->get_output_size(); i++) {
|
||||
it = precisions.find(node->get_output_element_type(i));
|
||||
if (it == precisions.end()) {
|
||||
output_types.push_back(node->get_output_element_type(i));
|
||||
continue;
|
||||
}
|
||||
const auto& to = it->second;
|
||||
if (type_relaxed) {
|
||||
type_relaxed->set_overridden_output_type(to, i);
|
||||
res = true;
|
||||
}
|
||||
output_types.push_back(to);
|
||||
}
|
||||
output_types[idx] = to;
|
||||
auto relaxed_op =
|
||||
std::make_shared<ov::op::TypeRelaxed<opset5::NonMaxSuppression>>(*nms, ov::element::TypeVector{}, output_types);
|
||||
replace_node(node, relaxed_op);
|
||||
return true;
|
||||
|
||||
if (!type_relaxed) {
|
||||
auto relaxed_op = std::make_shared<ov::op::TypeRelaxed<opset5::NonMaxSuppression>>(*nms,
|
||||
ov::element::TypeVector{},
|
||||
output_types);
|
||||
replace_node(node, relaxed_op);
|
||||
res = true;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool fuse_type_to_nms9(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_nms9(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto nms = ov::as_type_ptr<opset9::NonMaxSuppression>(node);
|
||||
if (!nms) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((idx == 0 || idx == 2) && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
nms->set_output_type(to);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_overridden_output_type(to, idx);
|
||||
return true;
|
||||
bool res = false;
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it != precisions.end()) {
|
||||
const auto& to = it->second;
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
nms->set_output_type(to);
|
||||
res = true;
|
||||
if (precisions.count(node->get_output_element_type(1)) == 0) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node);
|
||||
ov::element::TypeVector output_types;
|
||||
for (const auto& output : nms->outputs()) {
|
||||
output_types.emplace_back(output.get_element_type());
|
||||
for (size_t i = 0; i < node->get_output_size(); i++) {
|
||||
it = precisions.find(node->get_output_element_type(i));
|
||||
if (it == precisions.end()) {
|
||||
output_types.push_back(node->get_output_element_type(i));
|
||||
continue;
|
||||
}
|
||||
const auto& to = it->second;
|
||||
if (type_relaxed) {
|
||||
type_relaxed->set_overridden_output_type(to, i);
|
||||
res = true;
|
||||
}
|
||||
output_types.push_back(to);
|
||||
}
|
||||
output_types[idx] = to;
|
||||
auto relaxed_op =
|
||||
std::make_shared<ov::op::TypeRelaxed<opset9::NonMaxSuppression>>(*nms, ov::element::TypeVector{}, output_types);
|
||||
replace_node(node, relaxed_op);
|
||||
return true;
|
||||
|
||||
if (!type_relaxed) {
|
||||
auto relaxed_op = std::make_shared<ov::op::TypeRelaxed<opset9::NonMaxSuppression>>(*nms,
|
||||
ov::element::TypeVector{},
|
||||
output_types);
|
||||
replace_node(node, relaxed_op);
|
||||
res = true;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool fuse_type_to_matrix_nms(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool update_type(size_t idx,
|
||||
const std::shared_ptr<ngraph::Node>& node,
|
||||
const precisions_map& precisions,
|
||||
std::function<void(const element::Type&)> update_method) {
|
||||
auto it = precisions.find(node->get_output_element_type(idx));
|
||||
if (it != precisions.end()) {
|
||||
const auto& to = it->second;
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
update_method(to);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_matrix_nms(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto nms = ov::as_type_ptr<opset8::MatrixNms>(node);
|
||||
if (!nms) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((idx == 1 || idx == 2) && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
return update_type(1, node, precisions, [&](const element::Type& to) {
|
||||
nms->set_output_type(to);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool fuse_type_to_multiclass_nms(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_multiclass_nms(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
std::shared_ptr<ov::op::util::MulticlassNmsBase> nms;
|
||||
if (ov::is_type<ov::op::v8::MulticlassNms>(node)) {
|
||||
nms = ov::as_type_ptr<opset8::MulticlassNms>(node);
|
||||
@ -527,85 +636,81 @@ bool fuse_type_to_multiclass_nms(const std::shared_ptr<ngraph::Node>& node, ngra
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((idx == 1 || idx == 2) && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
return update_type(1, node, precisions, [&](const element::Type& to) {
|
||||
nms->set_output_type(to);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool fuse_type_to_generate_proposals(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_generate_proposals(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto generate_proposals = ov::as_type_ptr<opset9::GenerateProposals>(node);
|
||||
if (!generate_proposals) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((idx == 2) && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
return update_type(2, node, precisions, [&](const element::Type& to) {
|
||||
generate_proposals->set_roi_num_type(to);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
if (auto topk = ov::as_type_ptr<opset4::TopK>(node)) {
|
||||
if (idx == 1 && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
return update_type(1, node, precisions, [&](const element::Type& to) {
|
||||
topk->set_index_element_type(to);
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_maxpool(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_maxpool(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
if (auto maxpool = ov::as_type_ptr<opset8::MaxPool>(node)) {
|
||||
if (idx == 1 && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
return update_type(1, node, precisions, [&](const element::Type& to) {
|
||||
maxpool->set_index_element_type(to);
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_ctc_greedy_decoder_seq_len(const std::shared_ptr<ngraph::Node>& node,
|
||||
ngraph::element::Type to,
|
||||
size_t idx) {
|
||||
const precisions_map& precisions) {
|
||||
bool res = false;
|
||||
if (auto ctc_decoder = ov::as_type_ptr<opset6::CTCGreedyDecoderSeqLen>(node)) {
|
||||
if (idx == 0 && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
res = update_type(0, node, precisions, [&](const element::Type& to) {
|
||||
ctc_decoder->set_classes_index_type(to);
|
||||
return true;
|
||||
}
|
||||
if (idx == 1 && (to == ov::element::i32 || to == ov::element::i64)) {
|
||||
ctc_decoder->set_sequence_length_type(to);
|
||||
return true;
|
||||
}
|
||||
});
|
||||
res = update_type(1,
|
||||
node,
|
||||
precisions,
|
||||
[&](const element::Type& to) {
|
||||
ctc_decoder->set_sequence_length_type(to);
|
||||
}) ||
|
||||
res;
|
||||
}
|
||||
return false;
|
||||
return res;
|
||||
}
|
||||
|
||||
bool fuse_type_to_nonzero(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_nonzero(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
if (auto nonzero = ov::as_type_ptr<opset4::NonZero>(node)) {
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
return update_type(0, node, precisions, [&](const element::Type& to) {
|
||||
nonzero->set_output_type(to);
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_bucketize(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_bucketize(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
if (auto b = ov::as_type_ptr<opset4::Bucketize>(node)) {
|
||||
if (to == ov::element::i32 || to == ov::element::i64) {
|
||||
return update_type(0, node, precisions, [&](const element::Type& to) {
|
||||
b->set_output_type(to);
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_shapeof_v0(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool fuse_type_to_shapeof_v0(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_overridden_output_type(to);
|
||||
return true;
|
||||
@ -619,7 +724,7 @@ bool fuse_type_to_shapeof_v0(const std::shared_ptr<ngraph::Node>& node, ngraph::
|
||||
return false;
|
||||
}
|
||||
|
||||
bool extend_select_type(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool extend_select_type(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_origin_input_type(ov::element::boolean, 0);
|
||||
return true;
|
||||
@ -634,7 +739,7 @@ bool extend_select_type(const std::shared_ptr<ngraph::Node>& node, ngraph::eleme
|
||||
return false;
|
||||
}
|
||||
|
||||
bool extend_reverse_type(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
bool extend_reverse_type(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
if (const auto casted = std::dynamic_pointer_cast<opset1::Reverse>(node)) {
|
||||
if (casted->get_mode() == ov::op::v1::Reverse::Mode::MASK) {
|
||||
auto relaxed_op = std::make_shared<op::TypeRelaxed<opset1::Reverse>>(
|
||||
@ -903,10 +1008,14 @@ std::shared_ptr<Node> convert_low_precisions_int(std::shared_ptr<opset4::Constan
|
||||
} // namespace
|
||||
|
||||
bool fuse_type_to_constant(const std::shared_ptr<ngraph::Node>& node,
|
||||
ov::element::Type to,
|
||||
const precisions_map& precisions,
|
||||
const std::vector<Input<Node>>& consumers) {
|
||||
auto from = node->get_element_type();
|
||||
auto it = precisions.find(from);
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto constant = ov::as_type_ptr<opset4::Constant>(node)) {
|
||||
auto from = constant->get_element_type();
|
||||
std::shared_ptr<ngraph::Node> new_const;
|
||||
if (from == ov::element::u64 && to == ov::element::i32) {
|
||||
new_const = change_constant_precision<ov::element::Type_t::u64, ov::element::Type_t::i32>(constant);
|
||||
|
@ -64,7 +64,7 @@ TEST(TransformationTests, ConvertPrecision_NMS3) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -94,7 +94,7 @@ TEST(TransformationTests, ConvertPrecision_NMS4) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -127,7 +127,7 @@ TEST(TransformationTests, ConvertPrecision_NMS5) {
|
||||
}
|
||||
|
||||
pass::Manager manager;
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f32, element::f16}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f32, element::f16}};
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
@ -150,7 +150,7 @@ TEST(TransformationTests, ConvertPrecision_MatrixNms) {
|
||||
}
|
||||
|
||||
pass::Manager manager;
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
@ -173,7 +173,7 @@ TEST(TransformationTests, ConvertPrecision_MulticlassNms) {
|
||||
}
|
||||
|
||||
pass::Manager manager;
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
@ -190,7 +190,7 @@ TEST(TransformationTests, ConvertPrecision_ShapeOf) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -212,7 +212,7 @@ TEST(TransformationTests, ConvertPrecision_Range) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -233,7 +233,7 @@ TEST(TransformationTests, ConvertPrecision_ConstantRelu) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -253,7 +253,7 @@ TEST(TransformationTests, ConvertPrecision_Convert) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -273,7 +273,7 @@ TEST(TransformationTests, ConvertPrecision_ConvertElimination) {
|
||||
f = std::make_shared<Model>(NodeVector{convert}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f16, element::f32}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f16, element::f32}});
|
||||
manager.run_passes(f);
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
}
|
||||
@ -300,7 +300,7 @@ TEST(TransformationTests, ConvertPrecision_TopK) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -320,7 +320,7 @@ TEST(TransformationTests, ConvertPrecision_Unique10) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(model);
|
||||
@ -348,7 +348,7 @@ TEST(TransformationTests, ConvertPrecision_NonZero) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -369,7 +369,7 @@ TEST(TransformationTests, ConvertPrecision_Bucketize) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -399,7 +399,7 @@ TEST(TransformationTests, ConvertPrecision_Roundings) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -456,7 +456,7 @@ TEST(TransformationTests, ConvertPrecision_TIBody) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -479,7 +479,7 @@ TEST(TransformationTests, ConvertPrecision_Equal) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -501,7 +501,7 @@ TEST(TransformationTests, ConvertPrecision_NotEqual) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -523,7 +523,7 @@ TEST(TransformationTests, ConvertPrecision_Greater) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -545,7 +545,7 @@ TEST(TransformationTests, ConvertPrecision_GreaterEqual) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -567,7 +567,7 @@ TEST(TransformationTests, ConvertPrecision_Less) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -589,7 +589,7 @@ TEST(TransformationTests, ConvertPrecision_LessEqual) {
|
||||
|
||||
pass::Manager manager;
|
||||
|
||||
static const precisions_array precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
@ -610,7 +610,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalAnd) {
|
||||
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::boolean, element::u8}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -628,7 +628,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalOr) {
|
||||
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::boolean, element::u8}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -646,7 +646,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalXor) {
|
||||
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::boolean, element::u8}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -663,7 +663,7 @@ TEST(TransformationTests, ConvertPrecision_LogicalNot) {
|
||||
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::boolean, element::u8}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -681,7 +681,7 @@ TEST(TransformationTests, ConvertPrecision_Select) {
|
||||
f = std::make_shared<Model>(OutputVector{select}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::boolean, element::u8}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -699,8 +699,8 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxedWithSelect) {
|
||||
f = std::make_shared<Model>(OutputVector{select}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::boolean, element::i32}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::i32, element::i64}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::i32}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::i32, element::i64}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -721,8 +721,8 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxed) {
|
||||
f = std::make_shared<Model>(OutputVector{type_relaxed}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::boolean, element::i32}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::i32, element::i64}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::i32}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::i32, element::i64}});
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
@ -747,7 +747,7 @@ TEST(TransformationTests, ConvertPrecision_Variables) {
|
||||
f = std::make_shared<Model>(NodeVector{mul}, ParameterVector{inp});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f16, element::f32}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f16, element::f32}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
@ -778,7 +778,7 @@ TEST(TransformationTests, ConvertPrecision_skip_precision_sensitive) {
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -812,7 +812,7 @@ TEST(TransformationTests, ConvertPrecision_without_keep_precision_sensitive_in_f
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = false;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -834,7 +834,7 @@ TEST(TransformationTests, ConvertPrecision_check_marking_does_not_leak_in_trivia
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -873,7 +873,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_1) {
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -926,7 +926,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_2) {
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1005,7 +1005,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_3) {
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1083,7 +1083,7 @@ TEST(TransformationTests, ConvertCompressedToMixedPrecission_do_not_keep_in_fp32
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = false; // didn't keep in FP32 intentionally
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1108,7 +1108,7 @@ void constant_convert_test(element::Type type_from,
|
||||
f = std::make_shared<Model>(NodeVector{c}, ParameterVector{});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{type_from, type_to}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{type_from, type_to}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
auto ops = f->get_ordered_ops();
|
||||
@ -1138,7 +1138,7 @@ void constant_convert_test(element::Type_t type_from, element::Type_t type_to, F
|
||||
f = std::make_shared<Model>(NodeVector{c}, ParameterVector{});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{type_from, type_to}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{type_from, type_to}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
auto ops = f->get_ordered_ops();
|
||||
@ -1336,7 +1336,7 @@ TEST(TransformationTests, ConvertPrecision_keep_precission_sensitive_fp32_with_e
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1383,7 +1383,7 @@ TEST(TransformationTests, ConvertPrecision_keep_precission_sensitive_fp32_with_r
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1431,7 +1431,7 @@ TEST(TransformationTests, ConvertPrecision_reducesum_without_exp) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1470,7 +1470,7 @@ TEST(TransformationTests, ConvertPrecision_MarkNormalizationOps_1) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1509,7 +1509,7 @@ TEST(TransformationTests, ConvertPrecision_MarkNormalizationOps_2) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1587,7 +1587,7 @@ TEST(TransformationTests, ConvertPrecision_keep_precission_sensitive_fp32_t2t_su
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1667,7 +1667,7 @@ TEST(TransformationTests, ConvertPrecision_DivisionByZeroMinimalPattern) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1708,7 +1708,7 @@ TEST(TransformationTests, ConvertPrecision_PowWithNegativeExponent) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
@ -1756,7 +1756,7 @@ TEST(TransformationTests, ConvertPrecision_exp_through_unsqueeze) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_array{{element::f32, element::f16}},
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
|
@ -15,7 +15,7 @@ using namespace std;
|
||||
bool ov::pass::ConvertFP32ToFP16::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_MODEL_SCOPE(ConvertFP32ToFP16);
|
||||
ov::pass::Manager m(get_pass_config());
|
||||
m.register_pass<ov::pass::ConvertPrecision>(precisions_array{{ngraph::element::f32, ngraph::element::f16}});
|
||||
m.register_pass<ov::pass::ConvertPrecision>(precisions_map{{ngraph::element::f32, ngraph::element::f16}});
|
||||
m.run_passes(f);
|
||||
return false;
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ TEST_F(GetSupportedNodesTest, SupportedCompressedConstantNop) {
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ov::pass::ConvertPrecision>(precisions_array{{ngraph::element::f16, ngraph::element::f32}});
|
||||
m.register_pass<ov::pass::ConvertPrecision>(precisions_map{{ngraph::element::f16, ngraph::element::f32}});
|
||||
m.register_pass<ov::pass::NopElimination>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
|
@ -44,7 +44,7 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
|
||||
// after transformation "MoveEltwiseUpThroughDataMov" there can be Reshape sequences that should be eliminated or fused
|
||||
manager.register_pass<ov::pass::ReshapeSequenceFusion>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_map {{ ngraph::element::i64, ngraph::element::i32 }});
|
||||
manager.register_pass<ov::pass::Validate>();
|
||||
|
||||
manager.run_passes(nGraphFunc);
|
||||
|
@ -112,7 +112,12 @@ namespace intel_cpu {
|
||||
|
||||
using const_node_ptr = const std::shared_ptr<const ov::Node>;
|
||||
|
||||
bool Transformations::fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx) {
|
||||
bool Transformations::fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
const auto& from = node->get_output_element_type(0);
|
||||
auto it = precisions.find(from);
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
const auto& to = it->second;
|
||||
if (auto convert = ov::as_type_ptr<ov::opset10::Convert>(node)) {
|
||||
// For Convert node, converting precision from floating point to boolean will lead to mathematical
|
||||
// error, because here the output precision boolean is replaced by u8. E.g. floating point value 0.01
|
||||
@ -187,7 +192,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
}
|
||||
|
||||
auto get_convert_precisions = []() {
|
||||
precisions_array array = {
|
||||
precisions_map map = {
|
||||
{ov::element::i64, ov::element::i32},
|
||||
{ov::element::u64, ov::element::i32},
|
||||
{ov::element::i16, ov::element::i32},
|
||||
@ -201,9 +206,9 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
};
|
||||
|
||||
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
|
||||
array.push_back({ov::element::bf16, ov::element::f32});
|
||||
map.insert({ov::element::bf16, ov::element::f32});
|
||||
|
||||
return array;
|
||||
return map;
|
||||
};
|
||||
static const auto precisions = get_convert_precisions();
|
||||
type_to_fuse_map type_to_fuse = {{ov::opset10::Convert::get_type_info_static(), fuse_type_to_convert}};
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "utils/debug_capabilities.h"
|
||||
#include "low_precision/low_precision.hpp"
|
||||
#include "config.h"
|
||||
#include "transformations/convert_precision.hpp"
|
||||
|
||||
#include "itt.h"
|
||||
|
||||
@ -61,7 +62,7 @@ private:
|
||||
|
||||
void Snippets(void);
|
||||
|
||||
static bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx);
|
||||
static bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -180,7 +180,7 @@ TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
|
||||
manager.register_pass<ov::pass::ConvertOpSet3ToOpSet2>();
|
||||
manager.register_pass<ov::pass::ConvertOpSet2ToOpSet1>();
|
||||
|
||||
static const precisions_array convert_precision_list{
|
||||
static const precisions_map convert_precision_map{
|
||||
{ngraph::element::i64, ngraph::element::i32},
|
||||
{ngraph::element::u64, ngraph::element::i32},
|
||||
{ngraph::element::u16, ngraph::element::i32},
|
||||
@ -189,9 +189,9 @@ TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
|
||||
{ngraph::element::boolean, ngraph::element::u8},
|
||||
};
|
||||
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(convert_precision_list);
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(convert_precision_map);
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_array{{ngraph::element::i64, ngraph::element::i32}});
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_map{{ngraph::element::i64, ngraph::element::i32}});
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
|
@ -67,7 +67,7 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model) {
|
||||
|
||||
// In OV API 2.0(IRv10) default convertion to fp32 (inputs, outputs and weights) is disabled
|
||||
// and we need to run the ConvertPrecision transformation to support old networks.
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_array{{ngraph::element::f16, ngraph::element::f32}});
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_map{{ngraph::element::f16, ngraph::element::f32}});
|
||||
manager.register_pass<ov::pass::ConvertMVN1ToMVN6>();
|
||||
manager.register_pass<ov::intel_gna::pass::DecomposeMVN>();
|
||||
manager.register_pass<ov::pass::CommonOptimizations>();
|
||||
@ -148,9 +148,9 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model) {
|
||||
manager.register_pass<ov::intel_gna::pass::InsertCopyBeforeConcatLayer>();
|
||||
manager.register_pass<ov::intel_gna::pass::HandleMultiConnectedLayerToConcatAndMemory>();
|
||||
manager.register_pass<ov::intel_gna::pass::HandleNonFunctionalSubgraphs>();
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_array{{ov::element::i64, ov::element::i32},
|
||||
{ov::element::u64, ov::element::i32},
|
||||
{ov::element::u32, ov::element::i32}});
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_map{{ov::element::i64, ov::element::i32},
|
||||
{ov::element::u64, ov::element::i32},
|
||||
{ov::element::u32, ov::element::i32}});
|
||||
const auto& pass_config = manager.get_pass_config();
|
||||
|
||||
// Allowing FP16 Converts to be folded and FP16 constants to upgrade to FP32 data type
|
||||
|
@ -140,7 +140,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
manager.register_pass<EinsumDecomposition>();
|
||||
|
||||
precisions_array fp_convert_precision_list = {
|
||||
precisions_map fp_convert_precision_map = {
|
||||
{ov::element::f64, ov::element::f32}
|
||||
};
|
||||
|
||||
@ -171,7 +171,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
|
||||
for (auto& et : fp_element_types) {
|
||||
if (et != infer_precision) {
|
||||
fp_convert_precision_list.push_back({et, infer_precision});
|
||||
fp_convert_precision_map.insert({et, infer_precision});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -179,14 +179,9 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
// Add conversion from unsupported FP data types to f32 if we don't have a conversion to something valid already in the list
|
||||
for (auto& et : fp_element_types) {
|
||||
if (!fp_precision_supported(et)) {
|
||||
auto et_pair = std::make_pair(et, fallback_precision);
|
||||
bool has_valid_conversion = std::find_if(fp_convert_precision_list.begin(), fp_convert_precision_list.end(),
|
||||
[&](std::pair<ov::element::Type, ov::element::Type> v) -> bool {
|
||||
return v.first == et_pair.first && fp_precision_supported(v.second);
|
||||
}) != fp_convert_precision_list.end();
|
||||
|
||||
bool has_valid_conversion = fp_convert_precision_map.count(et) && fp_precision_supported(fp_convert_precision_map[et]);
|
||||
if (!has_valid_conversion) {
|
||||
fp_convert_precision_list.push_back(et_pair);
|
||||
fp_convert_precision_map.insert(std::make_pair(et, fallback_precision));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -194,7 +189,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
type_to_fuse_map empty_fuse_map = {};
|
||||
manager.register_pass<ov::pass::Validate>();
|
||||
// call ConvertPrecision with keep_precision_sensitive_in_fp32 = true
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(fp_convert_precision_list, empty_fuse_map, true);
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(fp_convert_precision_map, empty_fuse_map, true);
|
||||
|
||||
manager.register_pass<ov::pass::CommonOptimizations>();
|
||||
|
||||
@ -232,7 +227,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
manager.register_pass<ov::pass::ConvertPriorBox8To0, false>();
|
||||
manager.register_pass<ov::pass::ConvertMulticlassNmsToMulticlassNmsIE>();
|
||||
|
||||
precisions_array int_convert_precision_list {
|
||||
precisions_map int_convert_precision_map {
|
||||
{ngraph::element::i64, ngraph::element::i32},
|
||||
{ngraph::element::u64, ngraph::element::i32},
|
||||
{ngraph::element::u16, ngraph::element::i32},
|
||||
@ -243,7 +238,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
};
|
||||
|
||||
manager.register_pass<ngraph::pass::Validate>();
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(int_convert_precision_list);
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(int_convert_precision_map);
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->disable<ov::pass::EyeDecomposition>();
|
||||
|
@ -283,7 +283,7 @@ std::vector<ov::Tensor> SubgraphBaseTest::calculate_refs() {
|
||||
|
||||
auto functionToProcess = functionRefs->clone();
|
||||
//TODO: remove this conversions as soon as function interpreter fully support bf16 and f16
|
||||
precisions_array precisions = {
|
||||
precisions_map precisions = {
|
||||
{ ngraph::element::bf16, ngraph::element::f32 }
|
||||
};
|
||||
auto convert_added = false;
|
||||
@ -299,7 +299,7 @@ std::vector<ov::Tensor> SubgraphBaseTest::calculate_refs() {
|
||||
}
|
||||
}
|
||||
if (!convert_added) {
|
||||
precisions.push_back({ ngraph::element::f16, ngraph::element::f32});
|
||||
precisions.insert({ ngraph::element::f16, ngraph::element::f32});
|
||||
}
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions);
|
||||
|
Loading…
Reference in New Issue
Block a user