ConvertPrecision transformation: handle Assign, ReadValue and Variable (#21266)

* ConvertPresicion transformation: handle Assign, ReadValue and Variable

* revert debug code

* use correct commit for onednn_gpu

* codestyle
This commit is contained in:
Ivan Tikhonov 2023-11-27 18:04:31 +03:30 committed by GitHub
parent eaa3098920
commit 57d794c810
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 172 additions and 12 deletions

View File

@ -26,7 +26,6 @@ int main(int argc, char* argv[]) {
// Creating ov::Model // Creating ov::Model
auto read = std::make_shared<ov::opset11::ReadValue>(init_const, variable); auto read = std::make_shared<ov::opset11::ReadValue>(init_const, variable);
std::vector<std::shared_ptr<ov::Node>> args = {arg, read};
auto add = std::make_shared<ov::opset11::Add>(arg, read); auto add = std::make_shared<ov::opset11::Add>(arg, read);
auto assign = std::make_shared<ov::opset11::Assign>(add, variable); auto assign = std::make_shared<ov::opset11::Assign>(add, variable);
auto add2 = std::make_shared<ov::opset11::Add>(add, read); auto add2 = std::make_shared<ov::opset11::Add>(add, read);

View File

@ -663,9 +663,7 @@ def test_infer_queue_get_idle_handle(device):
@pytest.mark.parametrize("data_type", @pytest.mark.parametrize("data_type",
[np.float32, [np.float32,
np.int32, np.int32,
# issue after ConvertPrecision transformation, ticket: TBA np.float16])
# np.float16
])
@pytest.mark.parametrize("mode", ["set_init_memory_state", "reset_memory_state", "normal"]) @pytest.mark.parametrize("mode", ["set_init_memory_state", "reset_memory_state", "normal"])
@pytest.mark.parametrize("input_shape", [[10], [10, 10], [10, 10, 10], [2, 10, 10, 10]]) @pytest.mark.parametrize("input_shape", [[10], [10, 10], [10, 10, 10], [2, 10, 10, 10]])
@pytest.mark.skipif( @pytest.mark.skipif(

View File

@ -101,10 +101,13 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
input_shapes[param.get()] = param->get_partial_shape(); input_shapes[param.get()] = param->get_partial_shape();
param->set_partial_shape(PartialShape::dynamic(param->get_partial_shape().rank())); param->set_partial_shape(PartialShape::dynamic(param->get_partial_shape().rank()));
} }
// After setting dynamic ranks into Parameters, the initializing subgraph of ReadValue operation might
// also have a dynamic rank. The shape consistency check between this subgraph and Variable might fail.
// We have to set dynamic rank to Variables to keep the ov::Model consistent.
for (const auto& variable : f->get_variables()) { for (const auto& variable : f->get_variables()) {
const auto& var_info = variable->get_info(); const auto& var_info = variable->get_info();
variable_shapes[variable.get()] = var_info.data_shape; variable_shapes[variable.get()] = var_info.data_shape;
variable->update_partial_shape(PartialShape::dynamic(var_info.data_shape.rank())); variable->update_data_shape(PartialShape::dynamic(var_info.data_shape.rank()));
} }
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
} }
@ -271,7 +274,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
} }
for (const auto& variable : f->get_variables()) { for (const auto& variable : f->get_variables()) {
variable->update_partial_shape(variable_shapes.at(variable.get())); variable->update_data_shape(variable_shapes.at(variable.get()));
} }
} }
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();

View File

@ -35,6 +35,12 @@ bool fuse_type_to_parameter(const std::shared_ptr<ov::Node>& node,
const precisions_map& precisions, const precisions_map& precisions,
bool convert_input_precision); bool convert_input_precision);
// this function inserts Convert operations to 'data' input and outputs of `node`
// to execute 'node' with the original type.
bool wrap_into_original_type(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_variable(const std::shared_ptr<op::util::Variable>& variable, const precisions_map& precisions);
bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node, bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node,
const precisions_map& precisions, const precisions_map& precisions,
const std::vector<ov::Input<ov::Node>>& consumers); const std::vector<ov::Input<ov::Node>>& consumers);
@ -207,6 +213,12 @@ bool convert_function_precision(const std::shared_ptr<Model>& f,
is_changed |= fuse_type_to_parameter(param, precisions, convert_input_output_precision); is_changed |= fuse_type_to_parameter(param, precisions, convert_input_output_precision);
} }
if (convert_input_output_precision) {
for (const auto& variable : f->get_variables()) {
is_changed |= fuse_type_to_variable(variable, precisions);
}
}
if (is_changed) if (is_changed)
ops = f->get_ordered_ops(); ops = f->get_ordered_ops();
@ -245,6 +257,14 @@ bool convert_function_precision(const std::shared_ptr<Model>& f,
true); true);
} }
} }
// if convert_input_output_precision flag is set, we don't need to preserve the original precision
// for Assign/ReadValue ops, we have already changed the type in Variable.
// Otherwise, we have insert Convert ops to inputs/outputs of ReadValue/Assign
if ((as_type_ptr<op::util::AssignBase>(node) || as_type_ptr<op::util::ReadValueBase>(node)) &&
convert_input_output_precision) {
node->revalidate_and_infer_types();
continue;
}
is_output_precision_changed |= convert_node_output_precision(node, is_output_precision_changed |= convert_node_output_precision(node,
precisions, precisions,
type_to_fuse, type_to_fuse,
@ -380,6 +400,8 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
type_to_fuse_map type_to_fuse{ type_to_fuse_map type_to_fuse{
{opset4::Convert::get_type_info_static(), fuse_type_to_convert}, {opset4::Convert::get_type_info_static(), fuse_type_to_convert},
{opset4::ShapeOf::get_type_info_static(), fuse_type_to_shapeof}, {opset4::ShapeOf::get_type_info_static(), fuse_type_to_shapeof},
{opset6::Assign::get_type_info_static(), wrap_into_original_type},
{opset6::ReadValue::get_type_info_static(), wrap_into_original_type},
{opset3::NonMaxSuppression::get_type_info_static(), fuse_type_to_nms3}, {opset3::NonMaxSuppression::get_type_info_static(), fuse_type_to_nms3},
{opset4::NonMaxSuppression::get_type_info_static(), fuse_type_to_nms4}, {opset4::NonMaxSuppression::get_type_info_static(), fuse_type_to_nms4},
{opset5::NonMaxSuppression::get_type_info_static(), fuse_type_to_nms5}, {opset5::NonMaxSuppression::get_type_info_static(), fuse_type_to_nms5},
@ -557,6 +579,38 @@ bool fuse_type_to_parameter(const std::shared_ptr<ov::Node>& node,
return changed; return changed;
} }
bool wrap_into_original_type(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
auto it = precisions.find(node->get_output_element_type(0));
if (it == precisions.end())
return false;
const auto& to = it->second;
const auto& from = it->first;
auto convert_before = std::make_shared<opset4::Convert>(node->input_value(0), from);
node->input(0).replace_source_output(convert_before);
auto consumers = node->output(0).get_target_inputs();
auto convert_after = std::make_shared<opset4::Convert>(node, to);
for (auto& input : consumers) {
const auto consumer = input.get_node();
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer)) {
continue;
}
input.replace_source_output(convert_after);
}
return true;
}
bool fuse_type_to_variable(const std::shared_ptr<op::util::Variable>& variable, const precisions_map& precisions) {
auto it = precisions.find(variable->get_info().data_type);
if (it == precisions.end())
return false;
const auto& to = it->second;
variable->update_data_type(to);
return true;
}
bool fuse_type_to_convert(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) { bool fuse_type_to_convert(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
auto it = precisions.find(node->get_output_element_type(0)); auto it = precisions.find(node->get_output_element_type(0));
if (it == precisions.end()) if (it == precisions.end())

View File

@ -2340,3 +2340,106 @@ TEST(TransformationTests, align_mixed_fp16_fp32_with_parameter_for_shape_2) {
FunctionsComparator::Result result = func_comparator(model_ref, model); FunctionsComparator::Result result = func_comparator(model_ref, model);
ASSERT_TRUE(result.valid) << result.message; ASSERT_TRUE(result.valid) << result.message;
} }
TEST(TransformationTests, ConvertPrecision_assign_read_value_preserve_orig_types) {
shared_ptr<Model> model, model_ref;
pass::Manager manager;
{
auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape{10, 10}, ov::element::f32, "variable_name"});
auto input = make_shared<opset10::Parameter>(element::f32, Shape{10, 10});
auto read_value = make_shared<opset10::ReadValue>(input, variable);
auto some_value = opset10::Constant::create(element::f32, Shape{1}, {2});
auto mul = make_shared<opset10::Multiply>(read_value, some_value);
auto res = make_shared<opset10::Result>(mul);
auto assign = make_shared<opset10::Assign>(mul, variable);
model = make_shared<Model>(ResultVector{res}, SinkVector{assign}, ParameterVector{input});
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
bool convert_input_output_precision = false;
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32,
convert_input_output_precision);
manager.run_passes(model);
}
{
auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape{10, 10}, ov::element::f32, "variable_name"});
auto input = make_shared<opset10::Parameter>(element::f32, Shape{10, 10});
auto convert_1 = make_shared<opset10::Convert>(input, element::f16);
auto convert_2 = make_shared<opset10::Convert>(convert_1, element::f32);
auto read_value = make_shared<opset10::ReadValue>(convert_2, variable);
auto convert_3 = make_shared<opset10::Convert>(read_value, element::f16);
auto some_value = opset10::Constant::create(element::f16, Shape{1}, {2});
auto mul = make_shared<opset10::Multiply>(convert_3, some_value);
auto convert_4 = make_shared<opset10::Convert>(mul, element::f32);
auto res = make_shared<opset10::Result>(convert_4);
auto convert_5 = make_shared<opset10::Convert>(mul, element::f32);
auto assign = make_shared<opset10::Assign>(convert_5, variable);
model_ref = make_shared<Model>(ResultVector{res}, SinkVector{assign}, ParameterVector{input});
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
FunctionsComparator::Result result = func_comparator(model_ref, model);
ASSERT_TRUE(result.valid) << result.message;
}
TEST(TransformationTests, ConvertPrecision_assign_read_value_change_variable_type) {
shared_ptr<Model> model, model_ref;
pass::Manager manager;
{
auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape{10, 10}, ov::element::f32, "variable_name"});
auto input = make_shared<opset10::Parameter>(element::f32, Shape{10, 10});
auto read_value = make_shared<opset10::ReadValue>(input, variable);
auto some_value = opset10::Constant::create(element::f32, Shape{1}, {2});
auto mul = make_shared<opset10::Multiply>(read_value, some_value);
auto res = make_shared<opset10::Result>(mul);
auto assign = make_shared<opset10::Assign>(mul, variable);
model = make_shared<Model>(ResultVector{res}, SinkVector{assign}, ParameterVector{input});
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
bool convert_input_output_precision = true;
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32,
convert_input_output_precision);
manager.run_passes(model);
}
{
auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape{10, 10}, ov::element::f16, "variable_name"});
auto input = make_shared<opset10::Parameter>(element::f16, Shape{10, 10});
auto read_value = make_shared<opset10::ReadValue>(input, variable);
auto some_value = opset10::Constant::create(element::f16, Shape{1}, {2});
auto mul = make_shared<opset10::Multiply>(read_value, some_value);
auto res = make_shared<opset10::Result>(mul);
auto assign = make_shared<opset10::Assign>(mul, variable);
model_ref = make_shared<Model>(ResultVector{res}, SinkVector{assign}, ParameterVector{input});
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
FunctionsComparator::Result result = func_comparator(model_ref, model);
ASSERT_TRUE(result.valid) << result.message;
}

View File

@ -38,10 +38,14 @@ public:
m_info = variable_info; m_info = variable_info;
} }
void update_partial_shape(const PartialShape& new_pshape) { void update_data_shape(const PartialShape& new_pshape) {
m_info.data_shape = new_pshape; m_info.data_shape = new_pshape;
} }
void update_data_type(const element::Type& new_type) {
m_info.data_type = new_type;
}
private: private:
VariableInfo m_info; VariableInfo m_info;
}; };

View File

@ -10,8 +10,6 @@
#include "openvino/core/model.hpp" #include "openvino/core/model.hpp"
#include "openvino/op/read_value.hpp" #include "openvino/op/read_value.hpp"
#include "openvino/op/util/variable.hpp" #include "openvino/op/util/variable.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/serialize.hpp"
using namespace std; using namespace std;
using namespace ov; using namespace ov;

View File

@ -1,12 +1,13 @@
# Copyright (C) 2018-2023 Intel Corporation # Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import numpy as np import numpy as np
from openvino.runtime import PartialShape
from openvino.tools.mo.ops.op import Op
from openvino.tools.mo.middle.passes.convert_data_type import np_data_type_to_destination_type
from openvino.tools.mo.front.common.partial_infer.utils import unmask_shape from openvino.tools.mo.front.common.partial_infer.utils import unmask_shape
from openvino.tools.mo.graph.graph import Graph, Node from openvino.tools.mo.graph.graph import Graph, Node
from openvino.runtime import PartialShape from openvino.tools.mo.middle.passes.convert_data_type import np_data_type_to_destination_type
from openvino.tools.mo.ops.op import Op
class ReadValue(Op): class ReadValue(Op):
op = 'ReadValue' op = 'ReadValue'