Clean up imports and apply suggested changes in drawnetwork.py

This commit is contained in:
Niclas Garan 2023-10-12 03:28:11 +02:00 committed by Ray Speth
parent 7b20d8e27c
commit 7e795389f6
2 changed files with 24 additions and 20 deletions

View File

@ -1,8 +1,8 @@
# This file is part of Cantera. See License.txt in the top-level directory or
# at https://cantera.org/license.txt for license and copyright information.
import importlib.metadata
from functools import wraps
import importlib.metadata as _metadata
from functools import wraps as _wraps
_graphviz = None
def _import_graphviz():
@ -11,16 +11,18 @@ def _import_graphviz():
if _graphviz is not None:
return
try:
importlib.metadata.version("graphviz")
except importlib.metadata.PackageNotFoundError:
raise ImportError("This requires the graphviz package.")
_metadata.version("graphviz")
except _metadata.PackageNotFoundError:
raise ImportError("This requires a python interface to graphviz.\n"
"It can be installed using conda (``conda install "
"python-graphviz``) or pip (``pip install graphviz``)")
else:
import graphviz as _graphviz
def needs_graphviz(func):
def _needs_graphviz(func):
# decorator function to load graphviz when needed
@wraps(func)
@_wraps(func)
def inner(*args, **kwargs):
if not _graphviz:
_import_graphviz()
@ -29,7 +31,7 @@ def needs_graphviz(func):
return inner
@needs_graphviz
@_needs_graphviz
def draw_reactor(r, dot=None, print_state=False, species=None, **kwargs):
"""
Draw `ReactorBase` object as ``graphviz`` ``dot`` node.
@ -72,11 +74,11 @@ def draw_reactor(r, dot=None, print_state=False, species=None, **kwargs):
s_label = ""
if species == "X":
X = r.thermo.mole_fraction_dict()
X = r.thermo.mole_fraction_dict(1e-4)
s_percents = "\\n".join([f"{s}: {v*100:.2f}" for s, v in X.items()])
s_label += "X (%)\\n" + s_percents
elif species == "Y":
Y = r.thermo.mass_fraction_dict()
Y = r.thermo.mass_fraction_dict(1e-4)
s_percents = "\\n".join([f"{s}: {v*100:.2f}" for s, v in Y.items()])
s_label += "Y (%)\\n" + s_percents
else:
@ -87,11 +89,11 @@ def draw_reactor(r, dot=None, print_state=False, species=None, **kwargs):
s_label += "X (%)\\n" + s_percents
# For full state output, shape must be 'Mrecord'
node_attr = {k:v for k,v in node_attr.items() if k != "shape"}
node_attr.pop("shape", None)
dot.node(r.name, shape="Mrecord",
label="{"+ T_label +"|"+ P_label +"}"+"|"+ s_label,
xlabel=r.name,
**node_attr)
label=f"{{{T_label}|{P_label}}}|{s_label}",
xlabel=r.name,
**node_attr)
else:
dot.node(r.name, **node_attr)
@ -99,7 +101,7 @@ def draw_reactor(r, dot=None, print_state=False, species=None, **kwargs):
return dot
@needs_graphviz
@_needs_graphviz
def draw_reactor_net(n, **kwargs):
"""
Draw `ReactorNet` object as ``graphviz.graphs.DiGraph``. Connecting flow
@ -130,7 +132,7 @@ def draw_reactor_net(n, **kwargs):
draw_surface(surface, dot, **kwargs)
# some Reactors or Reservoirs only exist as connecting nodes
connected_reactors = get_connected_reactors(connections)
connected_reactors = _get_connected_reactors(connections)
# remove already drawn reactors and draw new reactors
connected_reactors.difference_update(reactors)
@ -142,7 +144,7 @@ def draw_reactor_net(n, **kwargs):
return dot
def get_connected_reactors(connections):
def _get_connected_reactors(connections):
"""
Collect and returned all connected reactors.
@ -198,7 +200,9 @@ def draw_surface(surface, dot=None, **kwargs):
dot.edge(r.name, name, **edge_attr)
return dot
def draw_connections(connections, dot=None, **kwargs):
@_needs_graphviz
def draw_connections(connections, dot=None, show_wall_velocity=True, **kwargs):
"""
Draw connections between reactors and reservoirs. This includes flow
@ -209,7 +213,7 @@ def draw_connections(connections, dot=None, show_wall_velocity=True, **kwargs):
`FlowDevice` or `WallBase`.
:param dot:
``graphviz.graphs.BaseGraph`` object to which the connection is added.
If not provided, a new ``DiGraph`` is created. Defaults to ``None``
If not provided, a new ``DiGraph`` is created. Defaults to ``None``.
:param **kwargs:
Keyword options can contain ``graph_attr`` and general ``node_attr``,
``edge_attr``, ``heat_flow_attr``, and ``mass_flow_attr`` to be passed

View File

@ -10,7 +10,7 @@ from .thermo cimport *
from ._utils cimport pystr, stringify, comp_map, py_to_anymap, anymap_to_py
from ._utils import *
from .delegator cimport *
from .drawnetwork import draw_reactor, draw_reactor_net, draw_connections
from .drawnetwork import *
_reactor_counts = _defaultdict(int)