Improvements for 2023.1 release (#19168)

* TorchFX caching bugfix and improvements

* Fixed inconsistent env variable for Backend device

* Identify PyTorch FrontEnd Decoder type

* Added import statement in init files

* Registered ts_openvino as a separate backend

* Added caching fix and removed extraneous code

* Changed the name of ts backend

* Fixed issue with local temporary object

* Removed import statement from init files

* Changed the documentation

* Added get_supported_ops method for decoders

---------

Co-authored-by: Cavus Mustafa <mustafa.cavus@intel.com>
Co-authored-by: ynimmaga <yamini.nimmagadda@intel.com>
This commit is contained in:
Surya Siddharth Pemmaraju 2023-08-16 06:13:21 -07:00 committed by GitHub
parent daa4f17a0a
commit a6719ef2be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 167 additions and 174 deletions

View File

@ -159,6 +159,9 @@ class TorchFXPythonDecoder (Decoder):
return 0
return len(self.get_subgraphs()) if hasattr(self.pt_module, 'blocks') else 1
def decoder_type_name(self) -> str:
return "fx"
def visit_subgraph(self, node_visitor):
# make sure topological order is satisfied
for node in self._nodes:
@ -374,4 +377,4 @@ class TorchFXPythonDecoder (Decoder):
return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index))
except:
# Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node
return False
return False

View File

@ -19,12 +19,15 @@ from openvino.frontend import FrontEndManager
from openvino.runtime import Core, Type, PartialShape
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
from openvino.frontend.pytorch.torchdynamo.execute import execute
from openvino.frontend.pytorch.torchdynamo.execute import execute, execute_cached
from openvino.frontend.pytorch.torchdynamo.compile import cached_model_name, cache_root_path, get_device, openvino_compile_cached_model
from openvino.runtime import Core, Type, PartialShape
log = logging.getLogger(__name__)
"""
This is a preview feature in OpenVINO. Torchscript backend
This is a preview feature in OpenVINO. This feature
enables users to compile PyTorch models using torch.compile
with OpenVINO as a target backend in PyTorch applications
@ -42,11 +45,12 @@ log = logging.getLogger(__name__)
@register_backend
@fake_tensor_unsupported
def openvino(subgraph, example_inputs):
if (os.getenv("PYTORCH_TRACING_MODE") is not None):
if (os.getenv("PYTORCH_TRACING_MODE") == "TORCHFX"):
return fx_openvino(subgraph, example_inputs)
return ts_openvino(subgraph, example_inputs)
return fx_openvino(subgraph, example_inputs)
@register_backend
@fake_tensor_unsupported
def openvino_ts(subgraph, example_inputs):
return ts_openvino(subgraph, example_inputs)
def ts_openvino(subgraph, example_inputs):
try:
@ -79,8 +83,8 @@ def ts_openvino(subgraph, example_inputs):
om.validate_nodes_and_infer_types()
device = "CPU"
if (os.getenv("OPENVINO_TS_BACKEND_DEVICE") is not None):
device = os.getenv("OPENVINO_TS_BACKEND_DEVICE")
if (os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None):
device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE")
assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices"
compiled_model = core.compile_model(om, device)
@ -110,15 +114,36 @@ def ts_openvino(subgraph, example_inputs):
def fx_openvino(subgraph, example_inputs):
try:
executor_parameters = None
inputs_reversed = False
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
# Create a hash to be used for caching
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
executor_parameters = {"model_hash_str": model_hash_str}
# Check if the model was fully supported and already cached
example_inputs.reverse()
inputs_reversed = True
maybe_fs_cached_name = cached_model_name(model_hash_str + "_fs", get_device(), example_inputs, cache_root_path())
if os.path.isfile(maybe_fs_cached_name + ".xml") and os.path.isfile(maybe_fs_cached_name + ".bin"):
# Model is fully supported and already cached. Run the cached OV model directly.
compiled_model = openvino_compile_cached_model(maybe_fs_cached_name, *example_inputs)
def _call(*args):
res = execute_cached(compiled_model, *args)
return res
return _call
if inputs_reversed:
example_inputs.reverse()
model = make_fx(subgraph)(*example_inputs)
with torch.no_grad():
model.eval()
partitioner = Partitioner()
compiled_model = partitioner.make_partitions(model)
if executor_parameters is not None and 'model_hash_str' in executor_parameters:
# Check if the model is fully supported.
fully_supported = partitioner.check_fully_supported(compiled_model)
if fully_supported:
executor_parameters["model_hash_str"] += "_fs"
def _call(*args):
res = execute(compiled_model, *args, executor="openvino",
executor_parameters=executor_parameters)
@ -128,6 +153,5 @@ def fx_openvino(subgraph, example_inputs):
log.debug(f"Failed in OpenVINO execution: {e}")
return compile_fx(subgraph, example_inputs)
def reset():
clear_caches()

View File

@ -9,6 +9,7 @@ import os
import torch
import torch.overrides
from hashlib import sha256
from torch.fx import GraphModule
from openvino.frontend import FrontEndManager
@ -17,29 +18,77 @@ from openvino.runtime import Core, Type, PartialShape, serialize
from typing import Callable, Optional
def cached_model_name(model_hash_str, device, args, cache_root, reversed = False):
if model_hash_str is None:
return None
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None):
core = Core()
model_cache_dir = cache_root + "/model/"
try:
os.makedirs(model_cache_dir, exist_ok=True)
file_name = model_cache_dir + model_hash_str + "_" + device
except OSError as error:
print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error)
return None
inputs_str = ""
for idx, input_data in enumerate(args):
if reversed:
inputs_str = "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "") + inputs_str
else:
inputs_str += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
inputs_str = sha256(inputs_str.encode('utf-8')).hexdigest()
file_name += inputs_str
return file_name
def cache_root_path():
cache_root = "./cache/"
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
return cache_root
def get_device():
device = "CPU"
if os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None:
device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE")
assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices"
file_name = None
cache_root = "./cache/"
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
if model_hash_str is not None:
model_cache_dir = cache_root + "/model/"
try:
os.makedirs(model_cache_dir, exist_ok=True)
file_name = model_cache_dir + model_hash_str + "_" + device
except OSError as error:
print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error)
file_name = None
model_hash_str = None
return device
def openvino_compile_cached_model(cached_model_path, *example_inputs):
core = Core()
om = core.read_model(cached_model_path + ".xml")
dtype_mapping = {
torch.float32: Type.f32,
torch.float64: Type.f64,
torch.float16: Type.f16,
torch.int64: Type.i64,
torch.int32: Type.i32,
torch.uint8: Type.u8,
torch.int8: Type.i8,
torch.bool: Type.boolean
}
for idx, input_data in enumerate(example_inputs):
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
om.validate_nodes_and_infer_types()
core.set_property({'CACHE_DIR': cache_root_path() + '/blob'})
compiled_model = core.compile_model(om, get_device())
return compiled_model
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None):
core = Core()
device = get_device()
cache_root = cache_root_path()
file_name = cached_model_name(model_hash_str, device, args, cache_root)
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
om = core.read_model(file_name + ".xml")
@ -49,11 +98,9 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None):
input_shapes = []
input_types = []
for idx, input_data in enumerate(args): # subgraph.example_inputs):
for idx, input_data in enumerate(args):
input_types.append(input_data.type())
input_shapes.append(input_data.size())
if file_name is not None:
file_name += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
decoder = TorchFXPythonDecoder(gm, gm, input_shapes=input_shapes, input_types=input_types)

View File

@ -42,7 +42,7 @@ partitioned_modules = {}
def execute(
gm: GraphModule,
*args,
executor: str = "aten",
executor: str = "openvino",
executor_parameters: Optional[dict] = None,
):
if executor == "openvino":
@ -57,6 +57,14 @@ def execute(
import numpy as np
def execute_cached(compiled_model, *args):
ov_inputs = [a.detach().cpu().numpy() for a in args]
ov_inputs.reverse()
res = compiled_model(ov_inputs)
result = [torch.from_numpy(res[out]) for out in compiled_model.outputs]
return result
def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition_id):
executor_parameters = executor_parameters or DEFAULT_OPENVINO_PYTHON_CONFIG
@ -69,7 +77,11 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition
model_hash_str = executor_parameters.get("model_hash_str", None)
if model_hash_str is not None:
model_hash_str = model_hash_str + str(partition_id)
fully_supported = False
if len(model_hash_str) > 3 and model_hash_str[-3:] == "_fs":
fully_supported = True
if not fully_supported:
model_hash_str = model_hash_str + "_p" + str(partition_id)
if use_cache and (partition_id in compiled_cache):
compiled = compiled_cache[partition_id]

View File

@ -10,15 +10,10 @@ import torch
from torch.nn import Module
from torch._ops import OpOverload
from torch.fx import GraphModule
from torch.fx.node import Node, _get_qualified_name
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.fx.experimental.proxy_tensor import DecompositionInterpreter
from torch._decomp import decomposition_table
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
import typing as t
import logging
@ -26,135 +21,12 @@ import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
def aten_to_dtype(self, dtype: torch.dtype, **kwargs):
if len(kwargs) > 0 or not dtype:
raise RuntimeError(
"No support for other to.dtype() formats other than to.dtype(self, dtype)"
)
return torch._prims.convert_element_type(self, dtype)
# decomposition_table currently contains both aten2aten and aten2prim decomposition
# this is a hack to separate them, as we only need aten2prim decomposition for nvfuser-supported aten graph lowering
aten2aten_decomp = {}
aten2prim_decomp = {}
for op, decomp_fn in decomposition_table.items():
if "torch._refs" in decomp_fn.__module__:
aten2prim_decomp[op] = decomp_fn
else:
aten2aten_decomp[op] = decomp_fn
aten2aten_decomp_skips = {
"aten.native_layer_norm_backward.default",
"aten.embedding_dense_backward.default", # This is hurting nvfuser's perf
"aten.addmm.default",
}
for op, decomp_fn in decomposition_table.items():
if "torch._refs" in decomp_fn.__module__:
aten2prim_decomp[op] = decomp_fn
else:
if str(op) not in aten2aten_decomp_skips:
aten2aten_decomp[op] = decomp_fn
aten2prim_decomp[torch.ops.aten.to.dtype] = aten_to_dtype
class OperatorSupport(OperatorSupport):
"""
Operator support for OpenVINO backend.
Currently, partitioning is based on FX ATen graph. The fused subgraph will latter be decomposed into prims.
To determine if an ATen ops is supported by nvFuser, we shall check the prim ops used in its ref decomposition.
Only if all the prim ops in the ref has a nvfuser_impl, we say this Aten op is suppported by nvFuser.
Note: When adding a rule, please add it to the corresponding section and follow the
alphabetical order.
"""
def __init__(self):
# TODO: current list copied from torch/csrc/jit/codegen/cuda/parser.cpp is incorrect,
# as that file is solely for TorchScript and doesn't represent the actual status
# whether operation would be runnable by primTorch+nvFuser.
# We will iterate on this list to reflect the the reality.
"""
support_dict = {
# ===============================================================
# call_function aten
# ===============================================================
# Following supported aten ops is copied from torch/csrc/jit/codegen/cuda/parser.cpp
# TODO: might need to update according to supported input types
"torch.ops.aten.relu": None,
"torch.ops.aten.relu_": None,
"torch.ops.aten.conv2d": None,
"torch.ops.aten._convolution": None,
"torch.ops.aten.convolution": None,
"torch.ops.aten.batch_norm": None,
"torch.ops.aten.layer_norm": None,
"torch.ops.aten.add": None,
"torch.ops.aten.add_": None,
"torch.ops.aten.mul": None,
"torch.ops.aten.mul_": None,
"torch.ops.aten.div": None,
"torch.ops.aten.floordiv": None,
"torch.ops.aten.tanh": None,
"torch.ops.aten.elu": None,
"torch.ops.aten.sigmoid": None,
"torch.ops.aten.gelu": None,
"torch.ops.aten.sqrt": None,
"torch.ops.aten.abs": None,
"torch.ops.aten.square": None,
"torch.ops.aten.hardtanh": None,
"torch.ops.aten.hardtanh_": None,
"torch.ops.aten.hardsigmoid": None,
"torch.ops.aten.hardswish": None,
"torch.ops.aten.hardswish_": None,
"torch.ops.aten.silu_": None,
"torch.ops.aten.relu6": None,
"torch.ops.aten.softmax": None,
"torch.ops.aten.matmul": None,
"torch.ops.aten.mm": None,
"torch.ops.aten.linear": None,
"torch.ops.aten.max_pool2d": None,
"torch.ops.aten.avg_pool2d": None,
"torch.ops.aten.adaptive_avg_pool2d": None,
"torch.ops.aten.adaptive_max_pool2d": None,
#"torch.ops.aten.max_pool2d_with_indices": None,
"torch.ops.aten.mean": None,
"torch.ops.aten.flatten": None,
#"torch.ops.prim.NumToTensor": None,
"torch.ops.aten.contiguous": None,
"torch.ops.aten.as_tensor": None,
"torch.ops.aten.Int": None,
"torch.ops.aten.to": None,
"torch.ops.aten.permute": None,
"torch.ops.aten.embedding": None,
"torch.ops.aten.transpose": None,
"torch.ops.aten.size": None,
"torch.ops.aten.view": None,
"torch.ops.aten.unsqueeze": None,
"torch.ops.aten.rsub": None,
"torch.ops.aten.slice": None,
#"torch.ops.prim.Loop": None,
#"torch.ops.prim.If": None,
#"torch.ops.prim.Constant": None,
"torch.ops.aten.dim": None,
"torch.ops.aten.reciprocal": None,
"torch.ops.aten.sub": None,
"torch.ops.aten.eq": None,
"torch.ops.aten.ne": None,
"torch.ops.aten.gt": None,
"torch.ops.aten.lt": None,
"torch.ops.aten.neg": None,
#"torch.ops.prim.TupleConstruct": None,
"torch.ops.aten.append": None,
"getattr": None,
"_operator.getitem": None,
}
"""
# Just added Resnet50 supported iterations
support_dict = {
"_operator.getitem": None,
"torch.ops.aten._adaptive_avg_pool2d.default": None,

View File

@ -30,7 +30,7 @@ class Partitioner:
def fx_serialize(self, graph_module: GraphModule, *args, **kwargs):
fx_gm = make_fx(graph_module)(*args)
return fx_gm # prim_module
return fx_gm
def add_get_attr_inputs(self, partitions: t.List[Partition]):
# TODO: Find a more efficient way to include input
@ -44,9 +44,18 @@ class Partitioner:
for getattr_node, getattr_part in getattr_to_merge.items():
getattr_part.add_node(getattr_node)
def check_fully_supported(self, graph_module: GraphModule) -> bool:
num_fused = 0
for node in graph_module.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
num_fused += 1
elif node.op != "placeholder" and node.op != "output":
return False
if num_fused == 1:
return True
return False
def make_partitions(self, graph_module: GraphModule) -> GraphModule:
# entry function for nvFuser backend
# FX graph based partitioning based on nvfuser supported ops
partitioner = CapabilityBasedPartitioner(
graph_module, self.supported_ops, allows_single_node_partition=False)
partitions = partitioner.propose_partitions()

View File

@ -274,6 +274,9 @@ class TorchScriptPythonDecoder (Decoder):
self.m_decoders.append(decoder)
node_visitor(decoder)
def decoder_type_name(self) -> str:
return "ts"
def get_subgraphs(self) -> list:
if self.graph_element.kind() == "prim::PythonOp":
if "Subgraph" in self.graph_element.attributeNames():

View File

@ -108,6 +108,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
ov::OutputVector inlined_inputs(size_t start_index) const override {
PYBIND11_OVERRIDE_PURE(ov::OutputVector, TorchDecoder, inlined_inputs, start_index); }
const std::string& decoder_type_name() const override {
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, decoder_type_name);
}
};
void regclass_frontend_pytorch_decoder(py::module m);

View File

@ -107,6 +107,9 @@ public:
/// Returns new nodes for inputs inlined in the op itself
// Used in Torch.FX decoder
virtual OutputVector inlined_inputs(size_t start_index) const = 0;
/// Returns the id of the deccoder type (0: TorchFX, 1: TorchScript)
virtual const std::string& decoder_type_name() const = 0;
};
} // namespace pytorch

View File

@ -61,8 +61,9 @@ public:
protected:
bool supported_impl(const std::vector<ov::Any>& variants) const override;
ov::frontend::InputModel::Ptr load_impl(const std::vector<ov::Any>& variants) const override;
std::map<std::string, CreatorFunction> get_supported_ops(const ov::frontend::InputModel::Ptr& model) const;
std::map<std::string, CreatorFunction> m_op_translators;
std::map<std::string, CreatorFunction> m_op_extension_translators;
std::vector<ConversionExtensionBase::Ptr> m_conversion_extensions;
TelemetryExtension::Ptr m_telemetry;
};

View File

@ -110,20 +110,14 @@ std::string pack_detailed_failure_report(const std::map<std::string, std::string
}
} // namespace
FrontEnd::FrontEnd() {
const char* torch_tracing_mode = std::getenv("PYTORCH_TRACING_MODE");
if ((torch_tracing_mode != nullptr) && std::strcmp(torch_tracing_mode, "TORCHFX") == 0) {
m_op_translators = get_supported_ops_fx();
} else {
m_op_translators = get_supported_ops_ts();
}
}
FrontEnd::FrontEnd() {}
std::shared_ptr<Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr& model) const {
FRONT_END_GENERAL_CHECK(std::dynamic_pointer_cast<pytorch::InputModel>(model), "Invalid input model");
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops(model);
std::shared_ptr<Model> converted_model;
{
TranslateSession translate_session(model, m_op_translators, m_telemetry);
TranslateSession translate_session(model, supported_ops, m_telemetry);
converted_model = translate_session.get_converted_model();
}
@ -151,9 +145,10 @@ void FrontEnd::convert(const std::shared_ptr<Model>& partiallyConverted) const {
std::shared_ptr<Model> FrontEnd::convert_partially(const ov::frontend::InputModel::Ptr& model) const {
FRONT_END_GENERAL_CHECK(std::dynamic_pointer_cast<pytorch::InputModel>(model), "Invalid input model");
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops(model);
std::shared_ptr<Model> partial_model;
{
TranslateSession translate_session(model, m_op_translators, m_telemetry);
TranslateSession translate_session(model, supported_ops, m_telemetry);
partial_model = translate_session.get_converted_model();
}
try {
@ -231,12 +226,12 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
if (auto conv_ext = std::dynamic_pointer_cast<ov::frontend::ConversionExtension>(extension)) {
m_conversion_extensions.push_back(conv_ext);
m_op_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
m_op_extension_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
return conv_ext->get_converter()(context);
};
} else if (auto conv_ext = std::dynamic_pointer_cast<ov::frontend::pytorch::ConversionExtension>(extension)) {
m_conversion_extensions.push_back(conv_ext);
m_op_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
m_op_extension_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
return conv_ext->get_converter()(context);
};
} else if (const auto& so_ext = std::dynamic_pointer_cast<ov::detail::SOExtension>(extension)) {
@ -273,6 +268,17 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
return std::make_shared<pytorch::InputModel>(tdecoder);
}
std::map<std::string, CreatorFunction> FrontEnd::get_supported_ops(const ov::frontend::InputModel::Ptr& model) const {
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops_fx();
if (std::dynamic_pointer_cast<pytorch::InputModel>(model)->decoder_type_name() == "fx")
supported_ops = get_supported_ops_fx();
else
supported_ops = get_supported_ops_ts();
for (auto i = m_op_extension_translators.begin(); i != m_op_extension_translators.end(); i++)
supported_ops[i->first] = i->second;
return supported_ops;
}
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -139,6 +139,10 @@ void InputModel::set_tensor_value(const Place::Ptr& place, const void* value) {
}
}
const std::string& InputModel::decoder_type_name() const {
return m_model_decoder->decoder_type_name();
}
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -41,6 +41,7 @@ public:
void set_element_type(const frontend::Place::Ptr& place, const ov::element::Type& type) override;
ov::element::Type get_element_type(const frontend::Place::Ptr& place) const override;
void set_tensor_value(const frontend::Place::Ptr& place, const void* value) override;
const std::string& decoder_type_name() const;
private:
std::shared_ptr<TorchDecoder> m_model_decoder;

View File

@ -29,11 +29,15 @@ public:
virtual size_t get_subgraph_size() const override {
return 0;
}
virtual const std::string& decoder_type_name() const override {
return m_decoder_type;
}
private:
const Output<Node> m_qinput;
const std::string m_op_type = "QuantizedPtNode";
const std::string m_schema = "NONE";
const std::string m_decoder_type = "qt";
};
enum QuantizedPtNodeType { QUANTIZE_PER_TENSOR, QUANTIZE_PER_CHANNEL };