[PT FE] Support set/get element type, shape and value in PyTorch FE InputModel (#16100)

* Support setting and getting element type, shape and value in PyTorch FE InputModel

* Fix code style

* Fix code style

* Fix rsub layer test

* Fix py style

* Apply review feedback

* Fix code style

* Fix initial values of input and output flags in Place
This commit is contained in:
Maxim Vafin
2023-03-07 15:45:29 +01:00
committed by GitHub
parent e6ad0a5154
commit 82584543ba
8 changed files with 388 additions and 40 deletions

View File

@@ -0,0 +1,136 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "input_model.hpp"
#include "place.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
InputModel::InputModel(std::shared_ptr<TorchDecoder> model_decoder) : m_model_decoder(model_decoder) {
const auto& inputs = m_model_decoder->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
auto in_place = std::make_shared<pytorch::Place>(*this, inputs[i]);
for (const auto& name : in_place->get_names()) {
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(in_place));
}
auto type_any = simplified_type_interpret(m_model_decoder->get_input_type(i));
auto dtype = element::dynamic;
if (type_any.is<element::Type>()) {
dtype = type_any.as<element::Type>();
}
m_descriptors.emplace(inputs[i], PlaceDesc(dtype, m_model_decoder->get_input_shape(i)));
}
const auto& outputs = m_model_decoder->outputs();
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_place = std::make_shared<pytorch::Place>(*this, outputs[i]);
for (const auto& name : out_place->get_names()) {
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(out_place));
}
auto type_any = simplified_type_interpret(m_model_decoder->get_output_type(i));
auto dtype = element::dynamic;
if (type_any.is<element::Type>()) {
dtype = type_any.as<element::Type>();
}
m_descriptors.emplace(outputs[i], PlaceDesc(dtype, m_model_decoder->get_output_shape(i)));
}
}
std::vector<ov::frontend::Place::Ptr> InputModel::get_inputs() const {
std::vector<ov::frontend::Place::Ptr> res;
for (const auto& input_idx : m_model_decoder->inputs()) {
auto place_it = m_name_to_place.find(std::to_string(input_idx));
FRONT_END_GENERAL_CHECK(place_it != m_name_to_place.end(), "Couldn't find Place for input.");
res.push_back(place_it->second);
}
return res;
}
std::vector<ov::frontend::Place::Ptr> InputModel::get_outputs() const {
std::vector<ov::frontend::Place::Ptr> res;
for (const auto& output_idx : m_model_decoder->outputs()) {
auto place_it = m_name_to_place.find(std::to_string(output_idx));
FRONT_END_GENERAL_CHECK(place_it != m_name_to_place.end(), "Couldn't find Place for output.");
res.push_back(place_it->second);
}
return res;
}
Place::Ptr InputModel::get_place_by_tensor_name(const std::string& tensor_name) const {
auto place_it = m_name_to_place.find(tensor_name);
if (place_it != m_name_to_place.end()) {
return place_it->second;
}
return nullptr;
}
void InputModel::set_partial_shape(const Place::Ptr& place, const ov::PartialShape& shape) {
FRONT_END_GENERAL_CHECK(place && place->is_input(),
"Provided place is invalid, only inputs are supported for setting shape.");
auto pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(place);
FRONT_END_GENERAL_CHECK(pytorch_place, "Only place produced by PyTorch Frontend is supported");
auto it = m_descriptors.find(pytorch_place->get_tensor_index());
if (it != m_descriptors.end()) {
it->second.m_pshape = shape;
}
}
ov::PartialShape InputModel::get_partial_shape(const Place::Ptr& place) const {
auto pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(place);
FRONT_END_GENERAL_CHECK(
pytorch_place,
"Provided place is invalid. Only place of input or output is supported by PyTorch Frontend.");
auto it = m_descriptors.find(pytorch_place->get_tensor_index());
if (it != m_descriptors.end()) {
return it->second.m_pshape;
}
return PartialShape::dynamic();
}
void InputModel::set_element_type(const Place::Ptr& place, const ov::element::Type& type) {
FRONT_END_GENERAL_CHECK(place && place->is_input(),
"Provided place is invalid, only inputs are supported for setting element type.");
auto pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(place);
FRONT_END_GENERAL_CHECK(pytorch_place, "Only place produced by PyTorch Frontend is supported");
auto it = m_descriptors.find(pytorch_place->get_tensor_index());
if (it != m_descriptors.end()) {
it->second.m_type = type;
}
}
ov::element::Type InputModel::get_element_type(const Place::Ptr& place) const {
auto pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(place);
FRONT_END_GENERAL_CHECK(
pytorch_place,
"Provided place is invalid. Only place of input or output is supported by PyTorch Frontend.");
auto it = m_descriptors.find(pytorch_place->get_tensor_index());
if (it != m_descriptors.end()) {
return it->second.m_type;
}
return element::dynamic;
}
void InputModel::set_tensor_value(const Place::Ptr& place, const void* value) {
FRONT_END_GENERAL_CHECK(place && place->is_input(),
"Provided place is invalid, only inputs are supported for setting tensor value.");
auto pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(place);
FRONT_END_GENERAL_CHECK(pytorch_place, "Only place produced by PyTorch Frontend is supported");
auto it = m_descriptors.find(pytorch_place->get_tensor_index());
if (it != m_descriptors.end()) {
auto el_type = it->second.m_type;
auto p_shape = it->second.m_pshape;
FRONT_END_GENERAL_CHECK(el_type.is_static() && p_shape.is_static(),
"Shape and type must be statically defined before calling set_tensor_value");
it->second.m_value = ov::op::v0::Constant::create(el_type, p_shape.to_shape(), value);
} else {
FRONT_END_GENERAL_CHECK(false, "Place is not known.");
}
}
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -4,20 +4,49 @@
#pragma once
#include "openvino/frontend/pytorch/decoder.hpp"
#include "translate_session.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/frontend/input_model.hpp"
#include "openvino/frontend/place.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
class TranslateSession;
class Place;
class TorchDecoder;
struct PlaceDesc {
PlaceDesc(const element::Type& type, const PartialShape& pshape)
: m_type(type),
m_pshape(pshape),
m_value(nullptr) {}
element::Type m_type;
PartialShape m_pshape;
std::shared_ptr<Node> m_value;
};
class InputModel : public ov::frontend::InputModel {
friend class ::ov::frontend::pytorch::TranslateSession;
std::shared_ptr<TorchDecoder> m_model;
friend class ::ov::frontend::pytorch::Place;
public:
explicit InputModel(std::shared_ptr<TorchDecoder> model) : m_model(model) {}
// TODO: pass telemetry extension to this ctor
explicit InputModel(std::shared_ptr<TorchDecoder> model_decoder);
std::vector<frontend::Place::Ptr> get_inputs() const override;
std::vector<frontend::Place::Ptr> get_outputs() const override;
frontend::Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const override;
void set_partial_shape(const frontend::Place::Ptr& place, const ov::PartialShape& shape) override;
ov::PartialShape get_partial_shape(const frontend::Place::Ptr& place) const override;
void set_element_type(const frontend::Place::Ptr& place, const ov::element::Type& type) override;
ov::element::Type get_element_type(const frontend::Place::Ptr& place) const override;
void set_tensor_value(const frontend::Place::Ptr& place, const void* value) override;
private:
std::shared_ptr<TorchDecoder> m_model_decoder;
std::unordered_map<std::string, std::shared_ptr<frontend::Place>> m_name_to_place;
std::unordered_map<size_t, PlaceDesc> m_descriptors;
};
} // namespace pytorch

View File

@@ -0,0 +1,49 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "place.hpp"
#include "input_model.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/frontend/pytorch/decoder.hpp"
#include "openvino/util/log.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index)
: m_input_model(input_model),
m_tensor_index(tensor_index),
m_is_input(false),
m_is_output(false) {
m_names.push_back(std::to_string(tensor_index));
const auto im = dynamic_cast<const ov::frontend::pytorch::InputModel*>(&input_model);
FRONT_END_GENERAL_CHECK(im, "PyTorch Place requires PyTorch InputModel class.");
const auto& inputs = im->m_model_decoder->inputs();
const auto& outputs = im->m_model_decoder->outputs();
auto in_it = std::find(inputs.begin(), inputs.end(), tensor_index);
if (in_it != inputs.end()) {
m_is_input = true;
const auto& debug_name = im->m_model_decoder->get_input_debug_name(std::distance(inputs.begin(), in_it));
if (debug_name != m_names.at(0)) {
m_names.push_back(debug_name);
}
}
auto out_it = std::find(outputs.begin(), outputs.end(), tensor_index);
if (out_it != outputs.end()) {
m_is_output = true;
const auto& debug_name = im->m_model_decoder->get_output_debug_name(std::distance(outputs.begin(), out_it));
if (debug_name != m_names.at(0)) {
m_names.push_back(debug_name);
}
}
if (m_is_input && m_is_output) {
OPENVINO_DEBUG << "[WARNING] Place " << tensor_index << " is input and output at a same time.";
}
}
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -0,0 +1,46 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/frontend/input_model.hpp"
#include "openvino/frontend/place.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
class Place : public ov::frontend::Place {
public:
Place(const ov::frontend::InputModel& input_model, size_t tensor_index);
~Place() override = default;
bool is_input() const override {
return m_is_input;
}
bool is_output() const override {
return m_is_output;
}
bool is_equal(const Ptr& another) const override {
return this == another.get();
}
std::vector<std::string> get_names() const override {
return m_names;
}
size_t get_tensor_index() const {
return m_tensor_index;
}
private:
const ov::frontend::InputModel& m_input_model;
const size_t m_tensor_index;
std::vector<std::string> m_names;
bool m_is_input;
bool m_is_output;
};
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -36,11 +36,13 @@ std::shared_ptr<ov::Model> TranslateSession::get_converted_model() {
std::shared_ptr<ov::Model> TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model) {
auto pytorch_model = std::dynamic_pointer_cast<pytorch::InputModel>(input_model);
FRONT_END_GENERAL_CHECK(pytorch_model != nullptr, "Invalid input model");
return convert_pytorch_model(pytorch_model->m_model);
return convert_pytorch_model(pytorch_model->m_model_decoder, {}, pytorch_model->m_descriptors);
}
std::shared_ptr<Model> TranslateSession::convert_pytorch_model(std::shared_ptr<TorchDecoder> pytorch_model,
const TensorMap& external_tensor_map) {
std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
std::shared_ptr<TorchDecoder> pytorch_model,
const TensorMap& external_tensor_map,
const std::unordered_map<size_t, PlaceDesc>& external_descriptors) {
std::shared_ptr<Model> resulting_model; // define here to make a conversion in a nested scope
{
ParameterVector parameters;
@@ -50,28 +52,46 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(std::shared_ptr<T
// Go over all pytorch_model inputs and register them in the tensor map:
auto inputs = pytorch_model->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
PartialShape ps = pytorch_model->get_input_shape(i);
auto type = simplified_type_interpret(pytorch_model->get_input_type(i));
// TODO: Use special API to set custom type detalization
auto parameter = std::make_shared<v0::Parameter>(element::dynamic, ps);
encode_tensor_name(parameter->output(0), inputs.at(i), pytorch_model->get_input_debug_name(i));
parameters.push_back(parameter);
auto order = pytorch_model->get_input_transpose_order(i);
if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) {
FRONT_END_GENERAL_CHECK(ps.is_static(), "Shape must be static."); // TODO: make dynamic
auto sh = ps.get_shape();
Shape new_shape(sh.size());
for (size_t i = 0; i < sh.size(); i++) {
new_shape[order[i]] = sh[i];
std::shared_ptr<Node> input_node;
element::Type type = element::dynamic;
PartialShape pshape;
auto desc = external_descriptors.find(inputs[i]);
if (desc != external_descriptors.end()) {
if (desc->second.m_value) {
input_node = desc->second.m_value;
} else {
pshape = desc->second.m_pshape;
type = desc->second.m_type;
}
auto shape_const = v0::Constant::create(element::i32, {new_shape.size()}, new_shape);
auto reshape = std::make_shared<v1::Reshape>(parameter, shape_const, false);
auto order_const = v0::Constant::create(element::i32, {order.size()}, order);
auto transpose = std::make_shared<v1::Transpose>(reshape, order_const);
tensor_map[inputs.at(i)] = transpose;
} else {
tensor_map[inputs.at(i)] = parameter;
pshape = pytorch_model->get_input_shape(i);
auto type_any = simplified_type_interpret(pytorch_model->get_input_type(i));
// TODO: Use special API to set custom type specification
if (type_any.is<element::Type>()) {
type = type_any.as<element::Type>();
}
}
if (!input_node) {
auto parameter = std::make_shared<v0::Parameter>(type, pshape);
encode_tensor_name(parameter->output(0), inputs.at(i), pytorch_model->get_input_debug_name(i));
parameters.push_back(parameter);
input_node = parameter;
auto order = pytorch_model->get_input_transpose_order(i);
if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) {
FRONT_END_GENERAL_CHECK(pshape.is_static(), "Shape must be static."); // TODO: make dynamic
auto sh = pshape.get_shape();
Shape new_shape(sh.size());
for (size_t i = 0; i < sh.size(); i++) {
new_shape[order[i]] = sh[i];
}
auto shape_const = v0::Constant::create(element::i32, {new_shape.size()}, new_shape);
auto reshape = std::make_shared<v1::Reshape>(parameter, shape_const, false);
auto order_const = v0::Constant::create(element::i32, {order.size()}, order);
auto transpose = std::make_shared<v1::Transpose>(reshape, order_const);
input_node = transpose;
}
}
tensor_map[inputs.at(i)] = input_node;
}
auto node_visitor = [&](std::shared_ptr<TorchDecoder> node) {
@@ -88,7 +108,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(std::shared_ptr<T
// TODO: Eliminate duplication with the main code for Parameters creation
PartialShape ps = node->get_input_shape(i);
auto type = simplified_type_interpret(node->get_input_type(i));
// TODO: Use special API to set custom type detalization
// TODO: Use special API to set custom type specification
auto parameter = std::make_shared<v0::Parameter>(element::dynamic, ps);
// TODO: Missing get_input_transpose_order handling for not trivial layouts
tensor_map[input] = parameter;

View File

@@ -4,8 +4,7 @@
#pragma once
#include "openvino/frontend/input_model.hpp"
#include "openvino/frontend/pytorch/frontend.hpp"
#include "input_model.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
namespace ov {
@@ -28,8 +27,10 @@ public:
/// context which is visible from nested model. Empty external_tensor_map is used as an indication that this is a
/// main body conversion.
/// \return fully converted OV Model
std::shared_ptr<Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> pytorch_model,
const TensorMap& external_tensor_map = {});
std::shared_ptr<Model> convert_pytorch_model(
std::shared_ptr<TorchDecoder> pytorch_model,
const TensorMap& external_tensor_map = {},
const std::unordered_map<size_t, PlaceDesc>& external_descriptors = {});
void encode_tensor_name(Output<Node> tensor_desc, size_t tensor_idx, std::string debug_name = "");
size_t decode_tensor_name(const Output<Node>& tensor_desc);

View File

@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
import numpy as np
from openvino.frontend import FrontEndManager
from openvino.runtime import PartialShape, Type
class aten_relu(torch.nn.Module):
def forward(self, x):
return x, torch.nn.functional.relu(x)
def get_scripted_model(model):
with torch.no_grad():
model = torch.jit.script(model)
model.eval()
model = torch.jit.freeze(model)
print(model.inlined_graph) # will help debugging
return model
def test_pytorch_fe_set_input_shape():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
model = get_scripted_model(aten_relu())
decoder = TorchScriptPythonDecoder(model)
fe_manager = FrontEndManager()
fe = fe_manager.load_by_framework("pytorch")
im = fe.load(decoder)
place = im.get_place_by_tensor_name("x.1")
im.set_partial_shape(place, PartialShape([1, 2, 3, 4]))
om = fe.convert(im)
assert om.get_parameters()[0].get_partial_shape(
) == PartialShape([1, 2, 3, 4])
def test_pytorch_fe_set_input_type():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
model = get_scripted_model(aten_relu())
decoder = TorchScriptPythonDecoder(model)
fe_manager = FrontEndManager()
fe = fe_manager.load_by_framework("pytorch")
im = fe.load(decoder)
place = im.get_place_by_tensor_name("x.1")
im.set_element_type(place, Type.f32)
om = fe.convert(im)
assert om.get_parameters()[0].get_element_type() == Type.f32
def test_pytorch_fe_set_input_value():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
model = get_scripted_model(aten_relu())
decoder = TorchScriptPythonDecoder(model)
fe_manager = FrontEndManager()
fe = fe_manager.load_by_framework("pytorch")
im = fe.load(decoder)
place = im.get_place_by_tensor_name("x.1")
im.set_partial_shape(place, PartialShape([1, 2, 3, 4]))
im.set_element_type(place, Type.f32)
im.set_tensor_value(place, np.random.randn(1, 2, 3, 4).astype(np.float32))
om = fe.convert(im)
assert len(om.get_parameters()) == 0

View File

@@ -55,7 +55,7 @@ class TestRsub(PytorchLayerTest):
self._test(*self.create_model(second_type="int"), ie_device, precision, ir_version)
class TestRSubTypes(PytorchLayerTest):
class TestRsubTypes(PytorchLayerTest):
def _prepare_input(self):
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
@@ -67,7 +67,7 @@ class TestRSubTypes(PytorchLayerTest):
def __init__(self, lhs_type, rhs_type):
super().__init__()
self.lhs_type = lhs_type
if rhs_type == "int":
if rhs_type == np.int32:
self.forward = self.forward2
else:
self.forward = self.forward1
@@ -83,17 +83,17 @@ class TestRSubTypes(PytorchLayerTest):
return aten_rsub(lhs_type, rhs_type), ref_net, "aten::rsub"
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
[[torch.int32, "int"],
# [torch.int32, "float"], fp64 produce ov error of eltwise constant fold
[torch.int64, "int"],
# [torch.int64, "float"], fp64 produce ov error of eltwise constant fold
[torch.float32, "int"],
[torch.float32, "float"],
[[torch.int32, np.int32],
[torch.int32, np.float32],
[torch.int64, np.int32],
[torch.int64, np.float32],
[torch.float32, np.int32],
[torch.float32, np.float32],
])
@pytest.mark.parametrize(("lhs_shape"), [[2, 3], [3], [2, 3, 4]])
@pytest.mark.nightly
@pytest.mark.precommit
def test_sub_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type):
def test_rsub_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type):
self.lhs_type = lhs_type
self.lhs_shape = lhs_shape
self.rhs_type = rhs_type