Decompose/flatten tuple inputs (#18092)

* prim::TupleUnpack and prim::ListUnpack removing transformation in PT FE to flatten input list and tuples

* Enabled tuples and lists as items in example_inputs

* Applied code style

* Added tests for tuples as inputs and extended test infrastructure to support it

* Negligible performance optimizations

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fixed duplicated names of test classes

* Added description for tuple flattening transformation

* Removed any support for list flattening on inputs; fixed layer tests

* Fixed style

* Fixed order of new Parameters and Results while flattening tuples

* Fixed style

* Better diagnostics when not all prim::TupleUnpack ops after Parameters are decomposed

* Small fix in diagnostics message

---------

Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com>
Co-authored-by: Andrei Kochin <andrei.kochin@intel.com>
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
Co-authored-by: Alina Kladieva <alina.kladieva@intel.com>
This commit is contained in:
Sergey Lyalin 2023-07-06 11:05:26 +04:00 committed by GitHub
parent 553dab43b4
commit dfba702c74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 328 additions and 30 deletions

View File

@ -34,6 +34,7 @@
#include "transforms/prim_list_construct_pad.hpp"
#include "transforms/prim_list_tuple_construct_replacer.hpp"
#include "transforms/prim_list_unpack_replacer.hpp"
#include "transforms/prim_tuple_unpack_parameter_replacer.hpp"
#include "transforms/rfftn_complex_replacer.hpp"
#include "transforms/string_equality_replacer.hpp"
#include "transforms/tuple_unpack_replacer.hpp"
@ -174,6 +175,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeTupleParameters>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();

View File

@ -3,7 +3,7 @@
//
#include "prim_list_tuple_construct_replacer.hpp"
#include <queue>
#include <deque>
#include "openvino/frontend/pytorch/decoder.hpp"
#include "openvino/op/result.hpp"
@ -17,21 +17,24 @@ namespace pass {
bool DecomposeListTupleResults::run_on_model(const std::shared_ptr<Model>& model) {
bool at_least_one_decomposed = false;
std::queue<std::shared_ptr<ov::op::v0::Result>> results;
for (auto res : model->get_results()) {
results.push(res);
}
const auto& orig_results = model->get_results();
std::deque<std::shared_ptr<ov::op::v0::Result>> results(orig_results.begin(), orig_results.end());
ov::ResultVector updated_results; // will hold final fully unpacked results list
while (!results.empty()) {
auto result = results.front();
results.pop();
results.pop_front();
auto input_node = result->get_input_node_shared_ptr(0);
auto tuple_construct = cast_fw_node(input_node, "prim::TupleConstruct");
auto list_construct = cast_fw_node(input_node, "prim::ListConstruct");
if (!tuple_construct && !list_construct) {
updated_results.push_back(result);
continue;
}
for (const auto& input : input_node->inputs()) {
const auto& out = input.get_source_output();
const auto& inputs = input_node->inputs();
// enumerating inputs in reverse order because of results.push_front below
for (auto pinput = inputs.rbegin(); pinput != inputs.rend(); ++pinput) {
const auto& out = pinput->get_source_output();
if (const auto& fw_node = cast_fw_node(out.get_node_shared_ptr(), "prim::Constant")) {
const auto& attrs = fw_node->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
@ -42,13 +45,19 @@ bool DecomposeListTupleResults::run_on_model(const std::shared_ptr<Model>& model
}
}
auto new_result = std::make_shared<ov::op::v0::Result>(out);
model->add_results({new_result});
results.push(new_result);
model->remove_result(result);
results.push_front(new_result);
at_least_one_decomposed = true;
}
}
if (at_least_one_decomposed) {
// remove all results
while (!model->get_results().empty())
model->remove_result(model->get_results()[0]);
// and replace them all by updated list of results
model->add_results(updated_results);
}
return at_least_one_decomposed;
};
} // namespace pass

View File

@ -0,0 +1,122 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "prim_tuple_unpack_parameter_replacer.hpp"
#include <deque>
#include <sstream>
#include "openvino/frontend/pytorch/decoder.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
bool DecomposeTupleParameters::run_on_model(const std::shared_ptr<Model>& model) {
bool at_least_one_decomposed = false;
const auto& orig_parameters = model->get_parameters();
std::deque<std::shared_ptr<ov::op::v0::Parameter>> parameters(orig_parameters.begin(), orig_parameters.end());
ov::ParameterVector updated_parameters; // will hold final fully unpacked parameters list
while (!parameters.empty()) {
auto parameter = parameters.front();
parameters.pop_front();
auto consumers = parameter->get_output_target_inputs(0);
size_t num_outputs = 0; // number of outputs in each unpack consumer should match
bool all_unpacks = true;
// collects all outputs per each consumer operation for this tuple Parameter
std::vector<OutputVector> consumer_outputs;
// The following vector track consumer nodes having prim::TupleUnpack type to form a detailed
// error message in case when parameter replacement is required but not possible.
std::vector<std::shared_ptr<Node>> consumer_unpacks;
for (const auto& consumer : consumers) {
auto node = consumer.get_node()->shared_from_this();
auto tuple_unpack = cast_fw_node(node, "prim::TupleUnpack");
if (!tuple_unpack) {
all_unpacks = false;
continue; // need to look at all consumers to form good diagnostics
}
consumer_unpacks.push_back(node);
if (num_outputs == 0) {
num_outputs = node->get_output_size();
} else if (num_outputs != node->get_output_size()) {
std::stringstream message;
message << "Unpack node " << node
<< " as one of the consumers of a tuple, which is introduced by parameter "
<< parameter->output(0) << ", has number of outputs " << node->get_output_size()
<< " not matching number of outputs " << num_outputs << " for other consumer(s) found earlier.";
add_exception_to_fw_node(node, message.str());
all_unpacks = false;
break;
}
consumer_outputs.push_back(node->outputs());
}
if (!all_unpacks || consumer_outputs.empty()) {
// if at least one consumer is not an unpack-like op or there are not matching number of unpacked objects,
// we cannot replace other unpacks even if they exist, leaving Unpack-op(s) in the graph for this Parameter
updated_parameters.push_back(parameter);
// In case if at least one Unpack exists there is an opportinity to attach diagnostics
for (const auto& consumer : consumer_unpacks) {
std::stringstream message;
message << "Not prim::TupleUnpack operations exist except this one: " << consumer
<< " found as one of the consumers of a tuple, which is introduced by parameter "
<< parameter->output(0) << ".";
add_exception_to_fw_node(consumer, message.str());
}
continue;
}
// enumerating outputs in reverse order because of parameters.push_front below
for (size_t i = num_outputs; i--;) {
// Merged partial shape and element type among all the consumers of i-th result of unpack ops
PartialShape ps = PartialShape::dynamic();
element::Type et = element::dynamic;
std::set<Input<Node>> inputs;
for (const auto& outputs : consumer_outputs) {
auto output = outputs[i];
OPENVINO_ASSERT(PartialShape::merge_into(ps, output.get_partial_shape()),
"Consumers for unpack op have incompatible shape");
OPENVINO_ASSERT(element::Type::merge(et, et, output.get_element_type()),
"Consumers for unpack op have incompatible types");
auto target_inputs = output.get_target_inputs();
inputs.insert(target_inputs.begin(), target_inputs.end());
}
auto new_parameter = std::make_shared<ov::op::v0::Parameter>(et, ps);
for (auto input : inputs) {
auto names = input.get_tensor().get_names();
input.replace_source_output(new_parameter->output(0));
new_parameter->output(0).add_names(names);
}
// TODO: Assign correct names
parameters.push_front(new_parameter);
at_least_one_decomposed = true;
}
}
if (at_least_one_decomposed) {
// remove all parameters
while (!model->get_parameters().empty())
model->remove_parameter(model->get_parameters()[0]);
// and replace them by updated list of parameters
model->add_parameters(updated_parameters);
}
return at_least_one_decomposed;
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,37 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
// This transformation replaces all prim::TupleUnpack operations coming after Parameters with
// more Parameters -- one new parameter for each prim::TupleUnpack output. The original Parameter
// is replaced with these new Parameters preserving the order relative to other Parameters in a model.
// Order of new parameters is the same as the order of prim::TupleUnpack outputs.
// If prim::TupleUnpack has a consumer that is also prim::TupleUnpack, the transformation applies
// the replacement recursively until all prim::TupleUnpacks that take a Parameter output are eliminated.
//
// For example, if a model has the following signature: a, (b, (c, d)), e, where a, b, c, d, and e are
// tensors, and (x1, x2) means tuple consisting two elements x1 and x2, then the resulting model
// after the transformation will have a, b, c, d, e as inputs (without tuples, flattened).
// Note, that there is no special 'tuple' type of an input, tuple structure is restored by
// following prim::TupleUnpack operations in the graph only assuming that they can be applied on
// tuples only and the most nested objects in those tuples are tensors.
class DecomposeTupleParameters : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::DecomposeTupleParameters");
bool run_on_model(const std::shared_ptr<Model>& model) override;
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -50,8 +50,15 @@ class PytorchLayerTest:
else:
inputs = self._prepare_input()
torch_inputs = [torch.from_numpy(inp) if isinstance(
inp, np.ndarray) else inp for inp in inputs]
def numpy_to_torch_recursively(x):
if isinstance(x, tuple):
return tuple(numpy_to_torch_recursively(y) for y in x)
elif isinstance(x, np.ndarray):
return torch.from_numpy(x)
else:
return x
torch_inputs = [numpy_to_torch_recursively(inp) for inp in inputs]
if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None:
custom_eps = kwargs['custom_eps']
@ -61,6 +68,8 @@ class PytorchLayerTest:
def use_ts_backend():
return(os.environ.get('USE_TS_BACKEND', False))
ov_inputs = flattenize_inputs(inputs)
if use_ts_backend():
self.ts_backend_test(model, torch_inputs, custom_eps)
else:
@ -68,7 +77,7 @@ class PytorchLayerTest:
model.eval()
trace_model = kwargs.get('trace_model', False)
freeze_model = kwargs.get('freeze_model', True)
model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, inputs, freeze_model)
model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
graph = model.inlined_graph
if kind is not None and not isinstance(kind, (tuple, list)):
@ -80,7 +89,7 @@ class PytorchLayerTest:
# OV infer:
core = Core()
compiled = core.compile_model(converted_model, ie_device)
infer_res = compiled(deepcopy(inputs))
infer_res = compiled(deepcopy(ov_inputs))
if hasattr(self, 'skip_framework') and self.skip_framework:
warnings.warn('Framework is skipped')
@ -266,25 +275,33 @@ def get_params(ie_device=None, precision=None):
return test_args
def flattenize_dict_outputs(res):
def flattenize_dict_outputs(res, types):
if isinstance(res, dict):
return flattenize_outputs(res.values())
return flattenize(res.values(), types)
def flattenize_outputs(res):
def flattenize(res, types: list):
results = []
for res_item in res:
# if None is at output we skip it
if res_item is None:
continue
# If input is list or tuple flattenize it
if isinstance(res_item, (list, tuple)):
decomposed_res = flattenize_outputs(res_item)
if isinstance(res_item, (list, tuple)) and type(res_item) in types:
decomposed_res = flattenize(res_item, types)
results.extend(decomposed_res)
continue
if isinstance(res_item, dict):
decomposed_res = flattenize_dict_outputs(res_item)
if isinstance(res_item, dict) and type(res_item) in types:
decomposed_res = flattenize_dict_outputs(res_item, types)
results.extend(decomposed_res)
continue
results.append(res_item)
return results
def flattenize_outputs(res):
return flattenize(res, [list, tuple, dict])
def flattenize_inputs(res):
return flattenize(res, [tuple])

View File

@ -33,6 +33,11 @@ class TestTupleConstruct(PytorchLayerTest):
def forward(self, x):
return (x, [None, x + x], None)
class prim_tuple_construct_with_tensor_tail(torch.nn.Module):
def forward(self, x):
return ((x, x + x), x + x + x)
class prim_tuple_construct_with_list_and_tuple(torch.nn.Module):
def forward(self, x):
@ -43,6 +48,7 @@ class TestTupleConstruct(PytorchLayerTest):
"multiple": prim_tuple_construct,
"none": prim_tuple_construct_with_none,
"list": prim_tuple_construct_with_list,
"tensor_tail": prim_tuple_construct_with_tensor_tail,
"list_and_tuple": prim_tuple_construct_with_list_and_tuple
}
@ -51,11 +57,11 @@ class TestTupleConstruct(PytorchLayerTest):
return model(), ref_net, "prim::TupleConstruct"
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "list_and_tuple"])
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "tensor_tail", "list_and_tuple"])
@pytest.mark.nightly
def test_tuple_construct(self, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version)
class TestTupleConstructTupleUnpack(PytorchLayerTest):
def _prepare_input(self):
@ -69,7 +75,7 @@ class TestTupleConstructTupleUnpack(PytorchLayerTest):
def forward(self, x):
x1, x2, x3, x4, x5 = self.prepare_input(x)
return x1, x2, x3, x4, x5
def prepare_input(self, x):
return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32)
@ -80,4 +86,106 @@ class TestTupleConstructTupleUnpack(PytorchLayerTest):
@pytest.mark.nightly
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
class TestTupleUnpackParameterSingle(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( (tensor_gen(), tensor_gen()), )
def create_model(self):
import torch
from typing import Tuple
class model(torch.nn.Module):
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
x1, x2 = x
return x1, x2
return model(), None, ["prim::TupleUnpack"]
@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)
class TestTupleUnpackParameterSingleMixed(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
# generate tensor with a different shape for easier mismatch detection in case of mixed input order
def tensor_gen_2():
return np.random.uniform(0, 50, (2, 3)).astype(np.float32)
return (tensor_gen_2(), (tensor_gen(), tensor_gen()), tensor_gen_2())
def create_model(self):
import torch
from typing import Tuple
class model(torch.nn.Module):
def forward(self, y1, x: Tuple[torch.Tensor, torch.Tensor], y2):
x1, x2 = x
return x1, x2, y1, y2
return model(), None, ["prim::TupleUnpack"]
@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)
class TestTupleUnpackParameterNested(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( ((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen())), )
def create_model(self):
import torch
from typing import Tuple
class model(torch.nn.Module):
def forward(self, x: Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]):
x1, x2 = x
y1, y2 = x1
y3, y4 = x2
return y1, y2, y3, y4
return model(), None, ["prim::TupleUnpack"]
@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)
class TestTupleUnpackParameterMultiple(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( (tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen()) )
def create_model(self):
import torch
from typing import Tuple
class model(torch.nn.Module):
def forward(self, x: Tuple[torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor]):
z1, z2 = x
z3, z4 = y
return z1, z2, z3, z4
return model(), None, ["prim::TupleUnpack"]
@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)

View File

@ -43,7 +43,7 @@ def update_list_or_dict(container, name, idx, value):
container[idx] = value
return
def get_value_from_list_or_dict(container, name, idx):
if isinstance(container, dict):
if name is None:
@ -87,8 +87,8 @@ def extract_input_info_from_example(args, inputs):
example_dtype = pt_to_ov_type_map.get(str(dtype))
user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id)
if user_dtype is not None and example_dtype.to_dtype() != user_dtype:
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
data_rank = getattr(example_input, "ndim", 0)
user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id)
if user_input_shape.rank.get_length() != data_rank:
@ -108,7 +108,7 @@ def extract_input_info_from_example(args, inputs):
input_name = input_names[input_id] if input_names else None
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype())
args.placeholder_data_types = data_types
args.placeholder_shapes = input_shapes
if not args.input and input_names:
@ -126,6 +126,9 @@ def to_torch_tensor(tensor):
return torch.tensor(tensor.data)
if isinstance(tensor, (float, int, bool)):
return tensor
if isinstance(tensor, tuple):
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
return tuple(to_torch_tensor(x) for x in tensor)
else:
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
"Got {}".format(type(tensor)))