Add explicit converts for Parameter and Result in ConvertPrecision tr… (#18183)
* Add explicit converts for Parameter and Result in ConvertPrecision transformation * set friendly name for convert on output * tests
This commit is contained in:
parent
82c65c25da
commit
0296008c7e
@ -81,17 +81,21 @@ public:
|
||||
ConvertPrecision(ov::element::Type_t from,
|
||||
ov::element::Type_t to,
|
||||
type_to_fuse_map additional_type_to_fuse_map = {},
|
||||
bool keep_precision_sensitive_in_fp32 = false)
|
||||
bool keep_precision_sensitive_in_fp32 = false,
|
||||
bool convert_input_output_precision = true)
|
||||
: m_precisions(precisions_map{{from, to}}),
|
||||
m_additional_type_to_fuse_map(additional_type_to_fuse_map),
|
||||
m_keep_precision_sensitive_in_fp32(keep_precision_sensitive_in_fp32) {}
|
||||
m_keep_precision_sensitive_in_fp32(keep_precision_sensitive_in_fp32),
|
||||
m_convert_input_output_precision(convert_input_output_precision) {}
|
||||
|
||||
ConvertPrecision(const precisions_map& precisions,
|
||||
const type_to_fuse_map& additional_type_to_fuse_map = {},
|
||||
bool keep_precision_sensitive_in_fp32 = false)
|
||||
bool keep_precision_sensitive_in_fp32 = false,
|
||||
bool convert_input_output_precision = true)
|
||||
: m_precisions(precisions),
|
||||
m_additional_type_to_fuse_map(additional_type_to_fuse_map),
|
||||
m_keep_precision_sensitive_in_fp32(keep_precision_sensitive_in_fp32) {}
|
||||
m_keep_precision_sensitive_in_fp32(keep_precision_sensitive_in_fp32),
|
||||
m_convert_input_output_precision(convert_input_output_precision) {}
|
||||
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
|
||||
@ -99,4 +103,5 @@ private:
|
||||
precisions_map m_precisions;
|
||||
type_to_fuse_map m_additional_type_to_fuse_map;
|
||||
bool m_keep_precision_sensitive_in_fp32;
|
||||
bool m_convert_input_output_precision;
|
||||
};
|
||||
|
@ -27,9 +27,14 @@
|
||||
#include "transformations/rt_info/decompression.hpp"
|
||||
#include "transformations/rt_info/disable_fp16_compression.hpp"
|
||||
#include "transformations/rt_info/keep_fp16_const.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
|
||||
bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node,
|
||||
const precisions_map& precisions,
|
||||
bool convert_input_precision);
|
||||
|
||||
bool fuse_type_to_constant(const std::shared_ptr<ngraph::Node>& node,
|
||||
const precisions_map& precisions,
|
||||
const std::vector<ngraph::Input<ngraph::Node>>& consumers);
|
||||
@ -39,7 +44,6 @@ bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ngraph::Node>& node, c
|
||||
bool fuse_type_to_unique_v10(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_range_v4(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_eye_v9(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
|
||||
@ -147,9 +151,11 @@ bool convert_node_output_precision(
|
||||
if (t2f_it != type_to_fuse.end()) {
|
||||
node_changed = t2f_it->second(node, precisions);
|
||||
}
|
||||
|
||||
if ((function_changed || node_changed) && !node_is_replaced(node)) {
|
||||
node->revalidate_and_infer_types();
|
||||
}
|
||||
|
||||
return node_changed;
|
||||
}
|
||||
|
||||
@ -173,20 +179,33 @@ bool convert_function_precision(
|
||||
bool has_fp16_compression,
|
||||
bool skip_precision_sensitive,
|
||||
bool is_changed,
|
||||
bool is_subgraph) {
|
||||
bool is_subgraph,
|
||||
bool convert_input_output_precision) {
|
||||
bool is_output_precision_changed = false;
|
||||
|
||||
auto ops = f->get_ordered_ops();
|
||||
ov::element::TypeVector orig_result_types;
|
||||
if (!convert_input_output_precision) {
|
||||
const auto& results = f->get_results();
|
||||
orig_result_types.reserve(results.size());
|
||||
for (const auto& result : results) {
|
||||
orig_result_types.push_back(result->get_input_element_type(0));
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate over all nodes in topological order and then iterate over node outputs.
|
||||
// If output type mismatch given type we try to fuse type into this operation
|
||||
// otherwise we insert Convert operation.
|
||||
auto ops = f->get_ordered_ops();
|
||||
for (auto& node : ops) {
|
||||
if (skip_precision_sensitive && fp16_compression_is_disabled(node) && has_fp16_compression)
|
||||
continue;
|
||||
is_changed |= convert_node_input_precision(node, precisions, type_to_extend);
|
||||
}
|
||||
|
||||
for (const auto& param : f->get_parameters()) {
|
||||
is_changed |= fuse_type_to_parameter(param, precisions, convert_input_output_precision);
|
||||
}
|
||||
|
||||
if (is_changed)
|
||||
ops = f->get_ordered_ops();
|
||||
|
||||
@ -221,6 +240,7 @@ bool convert_function_precision(
|
||||
has_fp16_compression,
|
||||
skip_precision_sensitive,
|
||||
is_changed || is_output_precision_changed,
|
||||
true,
|
||||
true);
|
||||
}
|
||||
}
|
||||
@ -252,6 +272,37 @@ bool convert_function_precision(
|
||||
}
|
||||
}
|
||||
|
||||
if (is_changed && !convert_input_output_precision) {
|
||||
auto& results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++) {
|
||||
auto& result = results[i];
|
||||
if (result->get_input_element_type(0) != orig_result_types[i]) {
|
||||
auto result_input = result->input_value(0);
|
||||
const auto convert = std::make_shared<ov::op::v0::Convert>(result_input, orig_result_types[i]);
|
||||
if (result_input.get_node()->get_output_size() > 1) {
|
||||
convert->set_friendly_name(result_input.get_node()->get_friendly_name() + "." +
|
||||
std::to_string(result_input.get_index()));
|
||||
} else {
|
||||
convert->set_friendly_name(result_input.get_node()->get_friendly_name());
|
||||
result_input.get_node()->set_friendly_name("");
|
||||
}
|
||||
|
||||
auto& convert_output_tensor = convert->get_output_tensor(0);
|
||||
convert_output_tensor.set_names(result_input.get_names());
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const auto& legacy_name = ov::descriptor::get_ov_tensor_legacy_name(result_input.get_tensor());
|
||||
if (!legacy_name.empty()) {
|
||||
ov::descriptor::set_ov_tensor_legacy_name(convert_output_tensor, legacy_name);
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
result_input.set_names({});
|
||||
result->input(0).replace_source_output(convert->output(0));
|
||||
result->revalidate_and_infer_types();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return is_changed;
|
||||
}
|
||||
|
||||
@ -261,7 +312,8 @@ bool convert_precision(ov::pass::PassBase& pass,
|
||||
const type_to_fuse_map& type_to_extend,
|
||||
const precisions_map& precisions,
|
||||
bool has_fp16_compression,
|
||||
bool skip_precision_sensitive = false) {
|
||||
bool skip_precision_sensitive,
|
||||
bool convert_input_output_precision) {
|
||||
// As Constant operations can be shared between multiple nGraph Functions so before
|
||||
// changing precision we need to understand which Constant consumers belongs
|
||||
// to the current nGraph Function
|
||||
@ -274,7 +326,8 @@ bool convert_precision(ov::pass::PassBase& pass,
|
||||
has_fp16_compression,
|
||||
skip_precision_sensitive,
|
||||
false,
|
||||
false);
|
||||
false,
|
||||
convert_input_output_precision);
|
||||
}
|
||||
|
||||
using precisions_set_t = std::unordered_set<ngraph::element::Type_t, EnumClassHash>;
|
||||
@ -324,7 +377,6 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ngraph::Func
|
||||
}
|
||||
|
||||
type_to_fuse_map type_to_fuse{
|
||||
{opset4::Parameter::get_type_info_static(), fuse_type_to_parameter},
|
||||
{opset4::Convert::get_type_info_static(), fuse_type_to_convert},
|
||||
{opset4::ShapeOf::get_type_info_static(), fuse_type_to_shapeof},
|
||||
{opset3::NonMaxSuppression::get_type_info_static(), fuse_type_to_nms3},
|
||||
@ -378,7 +430,8 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ngraph::Func
|
||||
type_to_extend,
|
||||
used_precisions,
|
||||
has_fp16_compression,
|
||||
m_keep_precision_sensitive_in_fp32);
|
||||
m_keep_precision_sensitive_in_fp32,
|
||||
m_convert_input_output_precision);
|
||||
|
||||
// to remove extra converts
|
||||
if (m_keep_precision_sensitive_in_fp32) {
|
||||
@ -470,17 +523,33 @@ bool fuse_type_to_eye_v9(const std::shared_ptr<ngraph::Node>& node, const precis
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
bool fuse_type_to_parameter(const std::shared_ptr<ngraph::Node>& node,
|
||||
const precisions_map& precisions,
|
||||
bool convert_input_precision) {
|
||||
auto it = precisions.find(node->get_output_element_type(0));
|
||||
if (it == precisions.end())
|
||||
return false;
|
||||
bool changed = false;
|
||||
const auto& to = it->second;
|
||||
if (auto param = ov::as_type_ptr<opset4::Parameter>(node)) {
|
||||
if (convert_input_precision) {
|
||||
param->set_element_type(to);
|
||||
param->validate_and_infer_types();
|
||||
return true;
|
||||
changed = true;
|
||||
} else {
|
||||
auto param_consumers = param->output(0).get_target_inputs();
|
||||
auto convert = std::make_shared<opset4::Convert>(param, to);
|
||||
for (auto& input : param_consumers) {
|
||||
const auto consumer = input.get_node();
|
||||
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer)) {
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
input.replace_source_output(convert);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||
|
@ -1840,3 +1840,298 @@ TEST(TransformationTests, ConvertPrecision_disable_for_quantized_nodes_2) {
|
||||
FunctionsComparator::Result result = func_comparator(model_ref, model);
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecisionExplicitConvertsForParameterAndResult) {
|
||||
shared_ptr<Model> model, model_ref;
|
||||
pass::Manager manager;
|
||||
{
|
||||
auto param_1 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto sin = make_shared<opset10::Sin>(param_1);
|
||||
sin->set_friendly_name("sine");
|
||||
sin->get_output_tensor(0).add_names({"sine:0"});
|
||||
auto result_sin = make_shared<opset10::Result>(sin);
|
||||
model = make_shared<Model>(result_sin, ParameterVector{param_1});
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = false;
|
||||
bool convert_input_output_precision = false;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f64, element::f32}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32,
|
||||
convert_input_output_precision);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
{
|
||||
auto param_1 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto converted_param = make_shared<opset10::Convert>(param_1, element::f32);
|
||||
auto sin = make_shared<opset10::Sin>(converted_param);
|
||||
auto converted_sin = make_shared<opset10::Convert>(sin, element::f64);
|
||||
converted_sin->get_output_tensor(0).add_names({"sine:0"});
|
||||
auto result_sin = make_shared<opset10::Result>(converted_sin);
|
||||
model_ref = make_shared<Model>(result_sin, ParameterVector{param_1});
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
|
||||
FunctionsComparator::Result result = func_comparator(model_ref, model);
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
|
||||
const auto& results = model->get_results();
|
||||
ASSERT_EQ("sine", results[0]->get_input_node_ptr(0)->get_friendly_name());
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecisionExplicitConvertsMultiParam) {
|
||||
shared_ptr<Model> model, model_ref;
|
||||
pass::Manager manager;
|
||||
{
|
||||
auto param_1 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto convert_1 = make_shared<opset10::Convert>(param_1, element::f32);
|
||||
|
||||
auto param_2 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto convert_2 = make_shared<opset10::Convert>(param_2, element::i64);
|
||||
|
||||
auto param_3 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto param_4 = make_shared<opset10::Parameter>(element::i64, Shape{3});
|
||||
|
||||
auto add = make_shared<opset10::Add>(convert_2, param_4);
|
||||
auto mul = make_shared<opset10::Multiply>(param_1, param_3);
|
||||
auto sin = make_shared<opset10::Sin>(convert_1);
|
||||
|
||||
add->set_friendly_name("add");
|
||||
add->get_output_tensor(0).add_names({"add:0"});
|
||||
mul->set_friendly_name("mul");
|
||||
mul->get_output_tensor(0).add_names({"mul:0"});
|
||||
sin->set_friendly_name("sine");
|
||||
sin->get_output_tensor(0).add_names({"sine:0"});
|
||||
|
||||
auto result_add = make_shared<opset10::Result>(add);
|
||||
auto result_mul = make_shared<opset10::Result>(mul);
|
||||
auto result_sin = make_shared<opset10::Result>(sin);
|
||||
|
||||
model = make_shared<Model>(ResultVector{result_add, result_mul, result_sin},
|
||||
ParameterVector{param_1, param_2, param_3, param_4});
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = false;
|
||||
bool convert_input_output_precision = false;
|
||||
manager.register_pass<pass::ConvertPrecision>(
|
||||
precisions_map{{element::f64, element::f32}, {element::i64, element::i32}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32,
|
||||
convert_input_output_precision);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
{
|
||||
auto param_1 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto convert_1 = make_shared<opset10::Convert>(param_1, element::f32);
|
||||
|
||||
auto param_2 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto convert_2 = make_shared<opset10::Convert>(param_2, element::i32);
|
||||
|
||||
auto param_3 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto convert_3 = make_shared<opset10::Convert>(param_3, element::f32);
|
||||
auto param_4 = make_shared<opset10::Parameter>(element::i64, Shape{3});
|
||||
auto convert_4 = make_shared<opset10::Convert>(param_4, element::i32);
|
||||
|
||||
auto add = make_shared<opset10::Add>(convert_2, convert_4);
|
||||
auto converted_add = make_shared<opset10::Convert>(add, element::i64);
|
||||
auto convert_1_2 = make_shared<opset10::Convert>(param_1, element::f32);
|
||||
auto mul = make_shared<opset10::Multiply>(convert_1_2, convert_3);
|
||||
auto converted_mul = make_shared<opset10::Convert>(mul, element::f64);
|
||||
auto sin = make_shared<opset10::Sin>(convert_1);
|
||||
|
||||
converted_add->get_output_tensor(0).add_names({"add:0"});
|
||||
converted_mul->get_output_tensor(0).add_names({"mul:0"});
|
||||
sin->get_output_tensor(0).add_names({"sine:0"});
|
||||
|
||||
auto result_add = make_shared<opset10::Result>(converted_add);
|
||||
auto result_mul = make_shared<opset10::Result>(converted_mul);
|
||||
auto result_sin = make_shared<opset10::Result>(sin);
|
||||
|
||||
model_ref = make_shared<Model>(ResultVector{result_add, result_mul, result_sin},
|
||||
ParameterVector{param_1, param_2, param_3, param_4});
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
|
||||
FunctionsComparator::Result result = func_comparator(model_ref, model);
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
|
||||
const auto& results = model->get_results();
|
||||
ASSERT_EQ("add", results[0]->get_input_node_ptr(0)->get_friendly_name());
|
||||
ASSERT_EQ("mul", results[1]->get_input_node_ptr(0)->get_friendly_name());
|
||||
ASSERT_EQ("sine", results[2]->get_input_node_ptr(0)->get_friendly_name());
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecisionExplicitConvertsSingleNodeMultipleOutputs) {
|
||||
shared_ptr<Model> model, model_ref;
|
||||
pass::Manager manager;
|
||||
{
|
||||
auto param_1 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
auto split = make_shared<opset10::Split>(param_1, axis, 3);
|
||||
split->set_friendly_name("split");
|
||||
split->get_output_tensor(0).add_names({"split:0"});
|
||||
split->get_output_tensor(1).add_names({"split:1"});
|
||||
split->get_output_tensor(2).add_names({"split:2"});
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
ov::descriptor::set_ov_tensor_legacy_name(split->get_output_tensor(0), "legacy_split:0");
|
||||
ov::descriptor::set_ov_tensor_legacy_name(split->get_output_tensor(1), "legacy_split:1");
|
||||
ov::descriptor::set_ov_tensor_legacy_name(split->get_output_tensor(2), "legacy_split:2");
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
model = make_shared<Model>(split->outputs(), ParameterVector{param_1});
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = false;
|
||||
bool convert_input_output_precision = false;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f64, element::f32}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32,
|
||||
convert_input_output_precision);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
{
|
||||
auto param_1 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto convert_1 = make_shared<opset10::Convert>(param_1, element::f32);
|
||||
auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
auto split = make_shared<opset10::Split>(convert_1, axis, 3);
|
||||
|
||||
auto convert_split_0 = make_shared<opset10::Convert>(split->output(0), element::f64);
|
||||
auto convert_split_1 = make_shared<opset10::Convert>(split->output(1), element::f64);
|
||||
auto convert_split_2 = make_shared<opset10::Convert>(split->output(2), element::f64);
|
||||
convert_split_0->get_output_tensor(0).add_names({"split:0"});
|
||||
convert_split_1->get_output_tensor(0).add_names({"split:1"});
|
||||
convert_split_2->get_output_tensor(0).add_names({"split:2"});
|
||||
model_ref =
|
||||
make_shared<Model>(NodeVector{convert_split_0, convert_split_1, convert_split_2}, ParameterVector{param_1});
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
|
||||
FunctionsComparator::Result result = func_comparator(model_ref, model);
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
|
||||
const auto& results = model->get_results();
|
||||
ASSERT_EQ("split.0", results[0]->get_input_node_ptr(0)->get_friendly_name());
|
||||
ASSERT_EQ("split.1", results[1]->get_input_node_ptr(0)->get_friendly_name());
|
||||
ASSERT_EQ("split.2", results[2]->get_input_node_ptr(0)->get_friendly_name());
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
ASSERT_EQ("legacy_split:0", ov::descriptor::get_ov_tensor_legacy_name(results[0]->get_input_tensor(0)));
|
||||
ASSERT_EQ("legacy_split:1", ov::descriptor::get_ov_tensor_legacy_name(results[1]->get_input_tensor(0)));
|
||||
ASSERT_EQ("legacy_split:2", ov::descriptor::get_ov_tensor_legacy_name(results[2]->get_input_tensor(0)));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecisionExplicitConvertsMultiSubgraphs) {
|
||||
shared_ptr<Model> model, model_ref;
|
||||
pass::Manager manager;
|
||||
{
|
||||
auto cond = make_shared<opset10::Parameter>(element::boolean, Shape{});
|
||||
auto param_1 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto param_2 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
|
||||
auto if_op = make_shared<opset10::If>(cond);
|
||||
|
||||
auto param_1_then = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto param_2_then = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto add = make_shared<opset10::Add>(param_1_then, param_2_then);
|
||||
auto result_then = make_shared<opset10::Result>(add);
|
||||
auto then_body = make_shared<Model>(result_then, ParameterVector{param_1_then, param_2_then});
|
||||
|
||||
auto param_1_else = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto param_2_else = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
|
||||
auto trip_count = op::v0::Constant::create(element::i32, Shape{}, {2});
|
||||
auto term_cond = op::v0::Constant::create(element::boolean, Shape{}, {true});
|
||||
auto loop = make_shared<opset10::Loop>(trip_count, term_cond);
|
||||
|
||||
auto param_1_loop = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto param_2_loop = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto mul = make_shared<opset10::Multiply>(param_1_loop, param_2_loop);
|
||||
auto result_mul = make_shared<opset10::Result>(mul);
|
||||
auto result_cond = make_shared<opset10::Result>(term_cond);
|
||||
auto loop_body =
|
||||
make_shared<Model>(ResultVector{result_cond, result_mul}, ParameterVector{param_1_loop, param_2_loop});
|
||||
|
||||
loop->set_function(loop_body);
|
||||
loop->set_special_body_ports({-1, 0});
|
||||
loop->set_merged_input(param_1_loop, param_1_else, result_mul);
|
||||
|
||||
auto result_else = make_shared<opset10::Result>(loop->get_iter_value(result_mul));
|
||||
auto else_body = make_shared<Model>(result_else, ParameterVector{param_1_else, param_2_else});
|
||||
|
||||
if_op->set_then_body(then_body);
|
||||
if_op->set_else_body(else_body);
|
||||
if_op->set_input(param_1, param_1_then, param_1_else);
|
||||
if_op->set_input(param_2, param_2_then, param_2_else);
|
||||
auto result = if_op->set_output(result_then, result_else);
|
||||
|
||||
result.get_node()->set_friendly_name("if_result");
|
||||
result.add_names({"if_result:0"});
|
||||
model = make_shared<Model>(OutputVector{result}, ParameterVector{cond, param_1, param_2});
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = false;
|
||||
bool convert_input_output_precision = false;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f64, element::f32}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32,
|
||||
convert_input_output_precision);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
{
|
||||
auto cond = make_shared<opset10::Parameter>(element::boolean, Shape{});
|
||||
auto param_1 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
auto param_2 = make_shared<opset10::Parameter>(element::f64, Shape{3});
|
||||
|
||||
auto if_op = make_shared<opset10::If>(cond);
|
||||
|
||||
auto param_1_then = make_shared<opset10::Parameter>(element::f32, Shape{3});
|
||||
auto param_2_then = make_shared<opset10::Parameter>(element::f32, Shape{3});
|
||||
auto add = make_shared<opset10::Add>(param_1_then, param_2_then);
|
||||
auto result_then = make_shared<opset10::Result>(add);
|
||||
auto then_body = make_shared<Model>(result_then, ParameterVector{param_1_then, param_2_then});
|
||||
|
||||
auto param_1_else = make_shared<opset10::Parameter>(element::f32, Shape{3});
|
||||
auto param_2_else = make_shared<opset10::Parameter>(element::f32, Shape{3});
|
||||
|
||||
auto trip_count = op::v0::Constant::create(element::i32, Shape{}, {2});
|
||||
auto term_cond = op::v0::Constant::create(element::boolean, Shape{}, {true});
|
||||
auto loop = make_shared<opset10::Loop>(trip_count, term_cond);
|
||||
|
||||
auto param_1_loop = make_shared<opset10::Parameter>(element::f32, Shape{3});
|
||||
auto param_2_loop = make_shared<opset10::Parameter>(element::f32, Shape{3});
|
||||
auto mul = make_shared<opset10::Multiply>(param_1_loop, param_2_loop);
|
||||
auto result_mul = make_shared<opset10::Result>(mul);
|
||||
auto result_cond = make_shared<opset10::Result>(term_cond);
|
||||
auto loop_body =
|
||||
make_shared<Model>(ResultVector{result_cond, result_mul}, ParameterVector{param_1_loop, param_2_loop});
|
||||
|
||||
loop->set_function(loop_body);
|
||||
loop->set_special_body_ports({-1, 0});
|
||||
loop->set_merged_input(param_1_loop, param_1_else, result_mul);
|
||||
|
||||
auto result_else = make_shared<opset10::Result>(loop->get_iter_value(result_mul));
|
||||
auto else_body = make_shared<Model>(result_else, ParameterVector{param_1_else, param_2_else});
|
||||
|
||||
if_op->set_then_body(then_body);
|
||||
if_op->set_else_body(else_body);
|
||||
auto convert_1 = make_shared<opset10::Convert>(param_1, element::f32);
|
||||
auto convert_2 = make_shared<opset10::Convert>(param_2, element::f32);
|
||||
if_op->set_input(convert_1, param_1_then, param_1_else);
|
||||
if_op->set_input(convert_2, param_2_then, param_2_else);
|
||||
auto result = if_op->set_output(result_then, result_else);
|
||||
auto converted_result = make_shared<opset10::Convert>(result, element::f64);
|
||||
converted_result->get_output_tensor(0).add_names({"if_result:0"});
|
||||
|
||||
model_ref = make_shared<Model>(converted_result, ParameterVector{cond, param_1, param_2});
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
|
||||
FunctionsComparator::Result result = func_comparator(model_ref, model);
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
|
||||
const auto& results = model->get_results();
|
||||
ASSERT_EQ("if_result", results[0]->get_input_node_ptr(0)->get_friendly_name());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user