Improvements for 2023.1 release (#19168)
* TorchFX caching bugfix and improvements * Fixed inconsistent env variable for Backend device * Identify PyTorch FrontEnd Decoder type * Added import statement in init files * Registered ts_openvino as a separate backend * Added caching fix and removed extraneous code * Changed the name of ts backend * Fixed issue with local temporary object * Removed import statement from init files * Changed the documentation * Added get_supported_ops method for decoders --------- Co-authored-by: Cavus Mustafa <mustafa.cavus@intel.com> Co-authored-by: ynimmaga <yamini.nimmagadda@intel.com>
This commit is contained in:
parent
daa4f17a0a
commit
a6719ef2be
@ -159,6 +159,9 @@ class TorchFXPythonDecoder (Decoder):
|
||||
return 0
|
||||
return len(self.get_subgraphs()) if hasattr(self.pt_module, 'blocks') else 1
|
||||
|
||||
def decoder_type_name(self) -> str:
|
||||
return "fx"
|
||||
|
||||
def visit_subgraph(self, node_visitor):
|
||||
# make sure topological order is satisfied
|
||||
for node in self._nodes:
|
||||
@ -374,4 +377,4 @@ class TorchFXPythonDecoder (Decoder):
|
||||
return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index))
|
||||
except:
|
||||
# Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node
|
||||
return False
|
||||
return False
|
||||
|
@ -19,12 +19,15 @@ from openvino.frontend import FrontEndManager
|
||||
from openvino.runtime import Core, Type, PartialShape
|
||||
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
|
||||
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
|
||||
from openvino.frontend.pytorch.torchdynamo.execute import execute
|
||||
from openvino.frontend.pytorch.torchdynamo.execute import execute, execute_cached
|
||||
from openvino.frontend.pytorch.torchdynamo.compile import cached_model_name, cache_root_path, get_device, openvino_compile_cached_model
|
||||
|
||||
from openvino.runtime import Core, Type, PartialShape
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
This is a preview feature in OpenVINO. Torchscript backend
|
||||
This is a preview feature in OpenVINO. This feature
|
||||
enables users to compile PyTorch models using torch.compile
|
||||
with OpenVINO as a target backend in PyTorch applications
|
||||
|
||||
@ -42,11 +45,12 @@ log = logging.getLogger(__name__)
|
||||
@register_backend
|
||||
@fake_tensor_unsupported
|
||||
def openvino(subgraph, example_inputs):
|
||||
if (os.getenv("PYTORCH_TRACING_MODE") is not None):
|
||||
if (os.getenv("PYTORCH_TRACING_MODE") == "TORCHFX"):
|
||||
return fx_openvino(subgraph, example_inputs)
|
||||
return ts_openvino(subgraph, example_inputs)
|
||||
return fx_openvino(subgraph, example_inputs)
|
||||
|
||||
@register_backend
|
||||
@fake_tensor_unsupported
|
||||
def openvino_ts(subgraph, example_inputs):
|
||||
return ts_openvino(subgraph, example_inputs)
|
||||
|
||||
def ts_openvino(subgraph, example_inputs):
|
||||
try:
|
||||
@ -79,8 +83,8 @@ def ts_openvino(subgraph, example_inputs):
|
||||
om.validate_nodes_and_infer_types()
|
||||
|
||||
device = "CPU"
|
||||
if (os.getenv("OPENVINO_TS_BACKEND_DEVICE") is not None):
|
||||
device = os.getenv("OPENVINO_TS_BACKEND_DEVICE")
|
||||
if (os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None):
|
||||
device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE")
|
||||
assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices"
|
||||
|
||||
compiled_model = core.compile_model(om, device)
|
||||
@ -110,15 +114,36 @@ def ts_openvino(subgraph, example_inputs):
|
||||
def fx_openvino(subgraph, example_inputs):
|
||||
try:
|
||||
executor_parameters = None
|
||||
inputs_reversed = False
|
||||
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
|
||||
# Create a hash to be used for caching
|
||||
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
|
||||
executor_parameters = {"model_hash_str": model_hash_str}
|
||||
# Check if the model was fully supported and already cached
|
||||
example_inputs.reverse()
|
||||
inputs_reversed = True
|
||||
maybe_fs_cached_name = cached_model_name(model_hash_str + "_fs", get_device(), example_inputs, cache_root_path())
|
||||
if os.path.isfile(maybe_fs_cached_name + ".xml") and os.path.isfile(maybe_fs_cached_name + ".bin"):
|
||||
# Model is fully supported and already cached. Run the cached OV model directly.
|
||||
compiled_model = openvino_compile_cached_model(maybe_fs_cached_name, *example_inputs)
|
||||
def _call(*args):
|
||||
res = execute_cached(compiled_model, *args)
|
||||
return res
|
||||
return _call
|
||||
if inputs_reversed:
|
||||
example_inputs.reverse()
|
||||
model = make_fx(subgraph)(*example_inputs)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
partitioner = Partitioner()
|
||||
compiled_model = partitioner.make_partitions(model)
|
||||
|
||||
if executor_parameters is not None and 'model_hash_str' in executor_parameters:
|
||||
# Check if the model is fully supported.
|
||||
fully_supported = partitioner.check_fully_supported(compiled_model)
|
||||
if fully_supported:
|
||||
executor_parameters["model_hash_str"] += "_fs"
|
||||
|
||||
def _call(*args):
|
||||
res = execute(compiled_model, *args, executor="openvino",
|
||||
executor_parameters=executor_parameters)
|
||||
@ -128,6 +153,5 @@ def fx_openvino(subgraph, example_inputs):
|
||||
log.debug(f"Failed in OpenVINO execution: {e}")
|
||||
return compile_fx(subgraph, example_inputs)
|
||||
|
||||
|
||||
def reset():
|
||||
clear_caches()
|
||||
|
@ -9,6 +9,7 @@ import os
|
||||
import torch
|
||||
import torch.overrides
|
||||
|
||||
from hashlib import sha256
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from openvino.frontend import FrontEndManager
|
||||
@ -17,29 +18,77 @@ from openvino.runtime import Core, Type, PartialShape, serialize
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
def cached_model_name(model_hash_str, device, args, cache_root, reversed = False):
|
||||
if model_hash_str is None:
|
||||
return None
|
||||
|
||||
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None):
|
||||
core = Core()
|
||||
model_cache_dir = cache_root + "/model/"
|
||||
|
||||
try:
|
||||
os.makedirs(model_cache_dir, exist_ok=True)
|
||||
file_name = model_cache_dir + model_hash_str + "_" + device
|
||||
except OSError as error:
|
||||
print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error)
|
||||
return None
|
||||
|
||||
inputs_str = ""
|
||||
for idx, input_data in enumerate(args):
|
||||
if reversed:
|
||||
inputs_str = "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "") + inputs_str
|
||||
else:
|
||||
inputs_str += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
|
||||
inputs_str = sha256(inputs_str.encode('utf-8')).hexdigest()
|
||||
file_name += inputs_str
|
||||
|
||||
return file_name
|
||||
|
||||
def cache_root_path():
|
||||
cache_root = "./cache/"
|
||||
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
|
||||
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
|
||||
return cache_root
|
||||
|
||||
def get_device():
|
||||
device = "CPU"
|
||||
|
||||
if os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None:
|
||||
device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE")
|
||||
assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices"
|
||||
|
||||
file_name = None
|
||||
cache_root = "./cache/"
|
||||
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
|
||||
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
|
||||
if model_hash_str is not None:
|
||||
model_cache_dir = cache_root + "/model/"
|
||||
try:
|
||||
os.makedirs(model_cache_dir, exist_ok=True)
|
||||
file_name = model_cache_dir + model_hash_str + "_" + device
|
||||
except OSError as error:
|
||||
print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error)
|
||||
file_name = None
|
||||
model_hash_str = None
|
||||
return device
|
||||
|
||||
def openvino_compile_cached_model(cached_model_path, *example_inputs):
|
||||
core = Core()
|
||||
om = core.read_model(cached_model_path + ".xml")
|
||||
|
||||
dtype_mapping = {
|
||||
torch.float32: Type.f32,
|
||||
torch.float64: Type.f64,
|
||||
torch.float16: Type.f16,
|
||||
torch.int64: Type.i64,
|
||||
torch.int32: Type.i32,
|
||||
torch.uint8: Type.u8,
|
||||
torch.int8: Type.i8,
|
||||
torch.bool: Type.boolean
|
||||
}
|
||||
|
||||
for idx, input_data in enumerate(example_inputs):
|
||||
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
|
||||
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
|
||||
om.validate_nodes_and_infer_types()
|
||||
|
||||
core.set_property({'CACHE_DIR': cache_root_path() + '/blob'})
|
||||
|
||||
compiled_model = core.compile_model(om, get_device())
|
||||
|
||||
return compiled_model
|
||||
|
||||
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None):
|
||||
core = Core()
|
||||
|
||||
device = get_device()
|
||||
cache_root = cache_root_path()
|
||||
file_name = cached_model_name(model_hash_str, device, args, cache_root)
|
||||
|
||||
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
|
||||
om = core.read_model(file_name + ".xml")
|
||||
@ -49,11 +98,9 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None):
|
||||
|
||||
input_shapes = []
|
||||
input_types = []
|
||||
for idx, input_data in enumerate(args): # subgraph.example_inputs):
|
||||
for idx, input_data in enumerate(args):
|
||||
input_types.append(input_data.type())
|
||||
input_shapes.append(input_data.size())
|
||||
if file_name is not None:
|
||||
file_name += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
|
||||
|
||||
decoder = TorchFXPythonDecoder(gm, gm, input_shapes=input_shapes, input_types=input_types)
|
||||
|
||||
|
@ -42,7 +42,7 @@ partitioned_modules = {}
|
||||
def execute(
|
||||
gm: GraphModule,
|
||||
*args,
|
||||
executor: str = "aten",
|
||||
executor: str = "openvino",
|
||||
executor_parameters: Optional[dict] = None,
|
||||
):
|
||||
if executor == "openvino":
|
||||
@ -57,6 +57,14 @@ def execute(
|
||||
import numpy as np
|
||||
|
||||
|
||||
def execute_cached(compiled_model, *args):
|
||||
ov_inputs = [a.detach().cpu().numpy() for a in args]
|
||||
ov_inputs.reverse()
|
||||
res = compiled_model(ov_inputs)
|
||||
result = [torch.from_numpy(res[out]) for out in compiled_model.outputs]
|
||||
return result
|
||||
|
||||
|
||||
def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition_id):
|
||||
|
||||
executor_parameters = executor_parameters or DEFAULT_OPENVINO_PYTHON_CONFIG
|
||||
@ -69,7 +77,11 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition
|
||||
|
||||
model_hash_str = executor_parameters.get("model_hash_str", None)
|
||||
if model_hash_str is not None:
|
||||
model_hash_str = model_hash_str + str(partition_id)
|
||||
fully_supported = False
|
||||
if len(model_hash_str) > 3 and model_hash_str[-3:] == "_fs":
|
||||
fully_supported = True
|
||||
if not fully_supported:
|
||||
model_hash_str = model_hash_str + "_p" + str(partition_id)
|
||||
|
||||
if use_cache and (partition_id in compiled_cache):
|
||||
compiled = compiled_cache[partition_id]
|
||||
|
@ -10,15 +10,10 @@ import torch
|
||||
from torch.nn import Module
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node, _get_qualified_name
|
||||
from torch.fx.passes.operator_support import OperatorSupport
|
||||
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import DecompositionInterpreter
|
||||
from torch._decomp import decomposition_table
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
import typing as t
|
||||
|
||||
import logging
|
||||
@ -26,135 +21,12 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def aten_to_dtype(self, dtype: torch.dtype, **kwargs):
|
||||
if len(kwargs) > 0 or not dtype:
|
||||
raise RuntimeError(
|
||||
"No support for other to.dtype() formats other than to.dtype(self, dtype)"
|
||||
)
|
||||
return torch._prims.convert_element_type(self, dtype)
|
||||
|
||||
|
||||
# decomposition_table currently contains both aten2aten and aten2prim decomposition
|
||||
# this is a hack to separate them, as we only need aten2prim decomposition for nvfuser-supported aten graph lowering
|
||||
aten2aten_decomp = {}
|
||||
aten2prim_decomp = {}
|
||||
|
||||
for op, decomp_fn in decomposition_table.items():
|
||||
if "torch._refs" in decomp_fn.__module__:
|
||||
aten2prim_decomp[op] = decomp_fn
|
||||
else:
|
||||
aten2aten_decomp[op] = decomp_fn
|
||||
|
||||
aten2aten_decomp_skips = {
|
||||
"aten.native_layer_norm_backward.default",
|
||||
"aten.embedding_dense_backward.default", # This is hurting nvfuser's perf
|
||||
"aten.addmm.default",
|
||||
}
|
||||
|
||||
for op, decomp_fn in decomposition_table.items():
|
||||
if "torch._refs" in decomp_fn.__module__:
|
||||
aten2prim_decomp[op] = decomp_fn
|
||||
else:
|
||||
if str(op) not in aten2aten_decomp_skips:
|
||||
aten2aten_decomp[op] = decomp_fn
|
||||
|
||||
|
||||
aten2prim_decomp[torch.ops.aten.to.dtype] = aten_to_dtype
|
||||
|
||||
|
||||
class OperatorSupport(OperatorSupport):
|
||||
"""
|
||||
Operator support for OpenVINO backend.
|
||||
Currently, partitioning is based on FX ATen graph. The fused subgraph will latter be decomposed into prims.
|
||||
To determine if an ATen ops is supported by nvFuser, we shall check the prim ops used in its ref decomposition.
|
||||
Only if all the prim ops in the ref has a nvfuser_impl, we say this Aten op is suppported by nvFuser.
|
||||
Note: When adding a rule, please add it to the corresponding section and follow the
|
||||
alphabetical order.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# TODO: current list copied from torch/csrc/jit/codegen/cuda/parser.cpp is incorrect,
|
||||
# as that file is solely for TorchScript and doesn't represent the actual status
|
||||
# whether operation would be runnable by primTorch+nvFuser.
|
||||
# We will iterate on this list to reflect the the reality.
|
||||
"""
|
||||
support_dict = {
|
||||
# ===============================================================
|
||||
# call_function aten
|
||||
# ===============================================================
|
||||
# Following supported aten ops is copied from torch/csrc/jit/codegen/cuda/parser.cpp
|
||||
# TODO: might need to update according to supported input types
|
||||
|
||||
"torch.ops.aten.relu": None,
|
||||
"torch.ops.aten.relu_": None,
|
||||
"torch.ops.aten.conv2d": None,
|
||||
"torch.ops.aten._convolution": None,
|
||||
"torch.ops.aten.convolution": None,
|
||||
"torch.ops.aten.batch_norm": None,
|
||||
"torch.ops.aten.layer_norm": None,
|
||||
"torch.ops.aten.add": None,
|
||||
"torch.ops.aten.add_": None,
|
||||
"torch.ops.aten.mul": None,
|
||||
"torch.ops.aten.mul_": None,
|
||||
"torch.ops.aten.div": None,
|
||||
"torch.ops.aten.floordiv": None,
|
||||
"torch.ops.aten.tanh": None,
|
||||
"torch.ops.aten.elu": None,
|
||||
"torch.ops.aten.sigmoid": None,
|
||||
"torch.ops.aten.gelu": None,
|
||||
"torch.ops.aten.sqrt": None,
|
||||
"torch.ops.aten.abs": None,
|
||||
"torch.ops.aten.square": None,
|
||||
"torch.ops.aten.hardtanh": None,
|
||||
"torch.ops.aten.hardtanh_": None,
|
||||
"torch.ops.aten.hardsigmoid": None,
|
||||
"torch.ops.aten.hardswish": None,
|
||||
"torch.ops.aten.hardswish_": None,
|
||||
"torch.ops.aten.silu_": None,
|
||||
"torch.ops.aten.relu6": None,
|
||||
"torch.ops.aten.softmax": None,
|
||||
"torch.ops.aten.matmul": None,
|
||||
"torch.ops.aten.mm": None,
|
||||
"torch.ops.aten.linear": None,
|
||||
"torch.ops.aten.max_pool2d": None,
|
||||
"torch.ops.aten.avg_pool2d": None,
|
||||
"torch.ops.aten.adaptive_avg_pool2d": None,
|
||||
"torch.ops.aten.adaptive_max_pool2d": None,
|
||||
#"torch.ops.aten.max_pool2d_with_indices": None,
|
||||
"torch.ops.aten.mean": None,
|
||||
"torch.ops.aten.flatten": None,
|
||||
#"torch.ops.prim.NumToTensor": None,
|
||||
"torch.ops.aten.contiguous": None,
|
||||
"torch.ops.aten.as_tensor": None,
|
||||
"torch.ops.aten.Int": None,
|
||||
"torch.ops.aten.to": None,
|
||||
"torch.ops.aten.permute": None,
|
||||
"torch.ops.aten.embedding": None,
|
||||
"torch.ops.aten.transpose": None,
|
||||
"torch.ops.aten.size": None,
|
||||
"torch.ops.aten.view": None,
|
||||
"torch.ops.aten.unsqueeze": None,
|
||||
"torch.ops.aten.rsub": None,
|
||||
"torch.ops.aten.slice": None,
|
||||
#"torch.ops.prim.Loop": None,
|
||||
#"torch.ops.prim.If": None,
|
||||
#"torch.ops.prim.Constant": None,
|
||||
"torch.ops.aten.dim": None,
|
||||
"torch.ops.aten.reciprocal": None,
|
||||
"torch.ops.aten.sub": None,
|
||||
"torch.ops.aten.eq": None,
|
||||
"torch.ops.aten.ne": None,
|
||||
"torch.ops.aten.gt": None,
|
||||
"torch.ops.aten.lt": None,
|
||||
"torch.ops.aten.neg": None,
|
||||
#"torch.ops.prim.TupleConstruct": None,
|
||||
"torch.ops.aten.append": None,
|
||||
"getattr": None,
|
||||
"_operator.getitem": None,
|
||||
}
|
||||
"""
|
||||
# Just added Resnet50 supported iterations
|
||||
support_dict = {
|
||||
"_operator.getitem": None,
|
||||
"torch.ops.aten._adaptive_avg_pool2d.default": None,
|
||||
|
@ -30,7 +30,7 @@ class Partitioner:
|
||||
|
||||
def fx_serialize(self, graph_module: GraphModule, *args, **kwargs):
|
||||
fx_gm = make_fx(graph_module)(*args)
|
||||
return fx_gm # prim_module
|
||||
return fx_gm
|
||||
|
||||
def add_get_attr_inputs(self, partitions: t.List[Partition]):
|
||||
# TODO: Find a more efficient way to include input
|
||||
@ -44,9 +44,18 @@ class Partitioner:
|
||||
for getattr_node, getattr_part in getattr_to_merge.items():
|
||||
getattr_part.add_node(getattr_node)
|
||||
|
||||
def check_fully_supported(self, graph_module: GraphModule) -> bool:
|
||||
num_fused = 0
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "call_module" and "fused_" in node.name:
|
||||
num_fused += 1
|
||||
elif node.op != "placeholder" and node.op != "output":
|
||||
return False
|
||||
if num_fused == 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def make_partitions(self, graph_module: GraphModule) -> GraphModule:
|
||||
# entry function for nvFuser backend
|
||||
# FX graph based partitioning based on nvfuser supported ops
|
||||
partitioner = CapabilityBasedPartitioner(
|
||||
graph_module, self.supported_ops, allows_single_node_partition=False)
|
||||
partitions = partitioner.propose_partitions()
|
||||
|
@ -274,6 +274,9 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
self.m_decoders.append(decoder)
|
||||
node_visitor(decoder)
|
||||
|
||||
def decoder_type_name(self) -> str:
|
||||
return "ts"
|
||||
|
||||
def get_subgraphs(self) -> list:
|
||||
if self.graph_element.kind() == "prim::PythonOp":
|
||||
if "Subgraph" in self.graph_element.attributeNames():
|
||||
|
@ -108,6 +108,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
|
||||
|
||||
ov::OutputVector inlined_inputs(size_t start_index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(ov::OutputVector, TorchDecoder, inlined_inputs, start_index); }
|
||||
|
||||
const std::string& decoder_type_name() const override {
|
||||
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, decoder_type_name);
|
||||
}
|
||||
};
|
||||
|
||||
void regclass_frontend_pytorch_decoder(py::module m);
|
||||
|
@ -107,6 +107,9 @@ public:
|
||||
/// Returns new nodes for inputs inlined in the op itself
|
||||
// Used in Torch.FX decoder
|
||||
virtual OutputVector inlined_inputs(size_t start_index) const = 0;
|
||||
|
||||
/// Returns the id of the deccoder type (0: TorchFX, 1: TorchScript)
|
||||
virtual const std::string& decoder_type_name() const = 0;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
|
@ -61,8 +61,9 @@ public:
|
||||
protected:
|
||||
bool supported_impl(const std::vector<ov::Any>& variants) const override;
|
||||
ov::frontend::InputModel::Ptr load_impl(const std::vector<ov::Any>& variants) const override;
|
||||
std::map<std::string, CreatorFunction> get_supported_ops(const ov::frontend::InputModel::Ptr& model) const;
|
||||
|
||||
std::map<std::string, CreatorFunction> m_op_translators;
|
||||
std::map<std::string, CreatorFunction> m_op_extension_translators;
|
||||
std::vector<ConversionExtensionBase::Ptr> m_conversion_extensions;
|
||||
TelemetryExtension::Ptr m_telemetry;
|
||||
};
|
||||
|
@ -110,20 +110,14 @@ std::string pack_detailed_failure_report(const std::map<std::string, std::string
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FrontEnd::FrontEnd() {
|
||||
const char* torch_tracing_mode = std::getenv("PYTORCH_TRACING_MODE");
|
||||
if ((torch_tracing_mode != nullptr) && std::strcmp(torch_tracing_mode, "TORCHFX") == 0) {
|
||||
m_op_translators = get_supported_ops_fx();
|
||||
} else {
|
||||
m_op_translators = get_supported_ops_ts();
|
||||
}
|
||||
}
|
||||
FrontEnd::FrontEnd() {}
|
||||
|
||||
std::shared_ptr<Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr& model) const {
|
||||
FRONT_END_GENERAL_CHECK(std::dynamic_pointer_cast<pytorch::InputModel>(model), "Invalid input model");
|
||||
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops(model);
|
||||
std::shared_ptr<Model> converted_model;
|
||||
{
|
||||
TranslateSession translate_session(model, m_op_translators, m_telemetry);
|
||||
TranslateSession translate_session(model, supported_ops, m_telemetry);
|
||||
converted_model = translate_session.get_converted_model();
|
||||
}
|
||||
|
||||
@ -151,9 +145,10 @@ void FrontEnd::convert(const std::shared_ptr<Model>& partiallyConverted) const {
|
||||
|
||||
std::shared_ptr<Model> FrontEnd::convert_partially(const ov::frontend::InputModel::Ptr& model) const {
|
||||
FRONT_END_GENERAL_CHECK(std::dynamic_pointer_cast<pytorch::InputModel>(model), "Invalid input model");
|
||||
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops(model);
|
||||
std::shared_ptr<Model> partial_model;
|
||||
{
|
||||
TranslateSession translate_session(model, m_op_translators, m_telemetry);
|
||||
TranslateSession translate_session(model, supported_ops, m_telemetry);
|
||||
partial_model = translate_session.get_converted_model();
|
||||
}
|
||||
try {
|
||||
@ -231,12 +226,12 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
|
||||
if (auto conv_ext = std::dynamic_pointer_cast<ov::frontend::ConversionExtension>(extension)) {
|
||||
m_conversion_extensions.push_back(conv_ext);
|
||||
m_op_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
|
||||
m_op_extension_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
|
||||
return conv_ext->get_converter()(context);
|
||||
};
|
||||
} else if (auto conv_ext = std::dynamic_pointer_cast<ov::frontend::pytorch::ConversionExtension>(extension)) {
|
||||
m_conversion_extensions.push_back(conv_ext);
|
||||
m_op_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
|
||||
m_op_extension_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
|
||||
return conv_ext->get_converter()(context);
|
||||
};
|
||||
} else if (const auto& so_ext = std::dynamic_pointer_cast<ov::detail::SOExtension>(extension)) {
|
||||
@ -273,6 +268,17 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
||||
return std::make_shared<pytorch::InputModel>(tdecoder);
|
||||
}
|
||||
|
||||
std::map<std::string, CreatorFunction> FrontEnd::get_supported_ops(const ov::frontend::InputModel::Ptr& model) const {
|
||||
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops_fx();
|
||||
if (std::dynamic_pointer_cast<pytorch::InputModel>(model)->decoder_type_name() == "fx")
|
||||
supported_ops = get_supported_ops_fx();
|
||||
else
|
||||
supported_ops = get_supported_ops_ts();
|
||||
for (auto i = m_op_extension_translators.begin(); i != m_op_extension_translators.end(); i++)
|
||||
supported_ops[i->first] = i->second;
|
||||
return supported_ops;
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
@ -139,6 +139,10 @@ void InputModel::set_tensor_value(const Place::Ptr& place, const void* value) {
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& InputModel::decoder_type_name() const {
|
||||
return m_model_decoder->decoder_type_name();
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
@ -41,6 +41,7 @@ public:
|
||||
void set_element_type(const frontend::Place::Ptr& place, const ov::element::Type& type) override;
|
||||
ov::element::Type get_element_type(const frontend::Place::Ptr& place) const override;
|
||||
void set_tensor_value(const frontend::Place::Ptr& place, const void* value) override;
|
||||
const std::string& decoder_type_name() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<TorchDecoder> m_model_decoder;
|
||||
|
@ -29,11 +29,15 @@ public:
|
||||
virtual size_t get_subgraph_size() const override {
|
||||
return 0;
|
||||
}
|
||||
virtual const std::string& decoder_type_name() const override {
|
||||
return m_decoder_type;
|
||||
}
|
||||
|
||||
private:
|
||||
const Output<Node> m_qinput;
|
||||
const std::string m_op_type = "QuantizedPtNode";
|
||||
const std::string m_schema = "NONE";
|
||||
const std::string m_decoder_type = "qt";
|
||||
};
|
||||
|
||||
enum QuantizedPtNodeType { QUANTIZE_PER_TENSOR, QUANTIZE_PER_CHANNEL };
|
||||
|
Loading…
Reference in New Issue
Block a user