[Python] Enhance Func1 API

This commit is contained in:
Ingmar Schoegl 2023-06-29 20:47:28 -06:00 committed by Ray Speth
parent c496697ee0
commit 267ebb9022
2 changed files with 112 additions and 38 deletions

View File

@ -6,13 +6,14 @@
from .ctcxx cimport *
cdef extern from "cantera/numerics/Func1.h":
cdef cppclass CxxFunc1 "Cantera::Func1":
double eval(double) except +translate_exception
string type()
string type_name()
string write(string)
cdef cppclass CxxTabulated1 "Cantera::Tabulated1" (CxxFunc1):
CxxTabulated1(int, double*, double*, string) except +translate_exception
double eval(double) except +translate_exception
cdef extern from "cantera/cython/funcWrapper.h":
ctypedef double (*callback_wrapper)(double, void*, void**) except? 0.0
@ -31,12 +32,20 @@ cdef extern from "cantera/cython/funcWrapper.h":
void setExceptionValue(PyObject*)
cdef extern from "cantera/numerics/Func1Factory.h":
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
string, double) except +translate_exception
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
string, vector[double]&) except +translate_exception
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
string, shared_ptr[CxxFunc1], shared_ptr[CxxFunc1]) except +translate_exception
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
string, shared_ptr[CxxFunc1], double) except +translate_exception
cdef class Func1:
cdef shared_ptr[CxxFunc1] _func
cdef CxxFunc1* func
cdef object callable
cdef object exception
cpdef void _set_callback(self, object) except *
cdef class TabulatedFunction(Func1):
cpdef void _set_tables(self, object, object, string) except *

View File

@ -7,6 +7,7 @@ import numpy as np
from ._utils cimport *
cdef double func_callback(double t, void* obj, void** err) except? 0.0:
"""
This function is called from C/C++ to evaluate a `Func1` object ``obj``,
@ -65,34 +66,108 @@ cdef class Func1:
self.exception = None
self.callable = None
def __init__(self, c):
def __init__(self, c, *, init=True):
if init is False:
# used by 'create' classmethod
return
if hasattr(c, '__call__'):
# callback function
self._set_callback(c)
else:
arr = np.array(c)
try:
if arr.ndim == 0:
# handle constants or unsized numpy arrays
k = float(c)
self._set_callback(lambda t: k)
elif arr.size == 1:
# handle lists, tuples or numpy arrays with a single element
k = float(c[0])
self._set_callback(lambda t: k)
else:
raise TypeError
return
except TypeError:
raise TypeError(
"'Func1' objects must be constructed from a number or "
"a callable object") from None
cdef Func1 func
try:
arr = np.array(c)
if arr.ndim == 0:
# handle constants or unsized numpy arrays
k = float(c)
elif arr.size == 1:
# handle lists, tuples or numpy arrays with a single element
k = float(c[0])
else:
raise TypeError
func = Func1.cxx_functor("constant", k)
self._func = func._func
self.func = self._func.get()
except TypeError:
raise TypeError(
"'Func1' objects must be constructed from a number or "
"a callable object") from None
cpdef void _set_callback(self, c) except *:
self.callable = c
self._func.reset(new CxxFunc1Py(func_callback, <void*>self))
self.func = self._func.get()
@property
def type(self):
"""
Return the type of the underlying C++ functor object.
.. versionadded:: 3.0
"""
return pystr(self.func.type())
@classmethod
def cxx_functor(cls, functor_type, *args):
"""
Retrieve a C++ `Func1` functor (advanced feature).
For implemented functor types, see the Cantera C++ ``Func1`` documentation.
.. versionadded:: 3.0
"""
cdef shared_ptr[CxxFunc1] func
cdef Func1 f0
cdef Func1 f1
cdef string cxx_string = stringify(functor_type)
cdef vector[double] arr
if len(args) == 0:
# simple functor with no parameter
func = CxxNewFunc1(cxx_string, 1.)
elif len(args) == 1:
if hasattr(args[0], "__len__"):
# advanced functor with array and no parameter
for v in args[0]:
arr.push_back(v)
func = CxxNewFunc1(cxx_string, arr)
else:
# simple functor with scalar parameter
func = CxxNewFunc1(cxx_string, float(args[0]))
elif len(args) == 2:
if isinstance(args[0], Func1) and isinstance(args[1], Func1):
# compound functor
f0 = args[0]
f1 = args[1]
func = CxxNewFunc1(cxx_string, f0._func, f1._func)
elif isinstance(args[0], Func1):
# modified functor
f0 = args[0]
func = CxxNewFunc1(cxx_string, f0._func, float(args[1]))
else:
raise ValueError("Invalid arguments")
else:
raise ValueError("Invalid arguments")
cls_name = pystr(func.get().type_name()).split("::")[-1]
cdef Func1 out = type(
cls_name, (cls, ), {"__module__": cls.__module__})(None, init=False)
out._func = func
out.func = out._func.get()
return out
def write(self, name="t"):
"""
Write a :math:`LaTeX` expression representing a functor.
:param name:
Name of the variable to be used.
.. versionadded:: 3.0
"""
return pystr(self.func.write(stringify(name)))
def __call__(self, t):
return self.func.eval(t)
@ -134,17 +209,7 @@ cdef class TabulatedFunction(Func1):
"""
def __init__(self, time, fval, method='linear'):
self._set_tables(time, fval, stringify(method))
cpdef void _set_tables(self, time, fval, string method) except *:
tt = np.asarray(time, dtype=np.double)
ff = np.asarray(fval, dtype=np.double)
if tt.size != ff.size:
raise ValueError("Sizes of arrays do not match "
"({} vs {})".format(tt.size, ff.size))
elif tt.size == 0:
raise ValueError("Arrays must not be empty.")
cdef np.ndarray[np.double_t, ndim=1] tvec = tt
cdef np.ndarray[np.double_t, ndim=1] fvec = ff
self.func = <CxxFunc1*>(new CxxTabulated1(tt.size, &tvec[0], &fvec[0],
method))
arr = np.hstack([np.array(time), np.array(fval)])
cdef Func1 func = Func1.cxx_functor(f"tabulated-{method}", arr)
self._func = func._func
self.func = self._func.get()