mirror of
https://github.com/Cantera/cantera.git
synced 2025-02-25 18:55:29 -06:00
[Python] Enhance Func1 API
This commit is contained in:
parent
c496697ee0
commit
267ebb9022
@ -6,13 +6,14 @@
|
||||
|
||||
from .ctcxx cimport *
|
||||
|
||||
|
||||
cdef extern from "cantera/numerics/Func1.h":
|
||||
cdef cppclass CxxFunc1 "Cantera::Func1":
|
||||
double eval(double) except +translate_exception
|
||||
string type()
|
||||
string type_name()
|
||||
string write(string)
|
||||
|
||||
cdef cppclass CxxTabulated1 "Cantera::Tabulated1" (CxxFunc1):
|
||||
CxxTabulated1(int, double*, double*, string) except +translate_exception
|
||||
double eval(double) except +translate_exception
|
||||
|
||||
cdef extern from "cantera/cython/funcWrapper.h":
|
||||
ctypedef double (*callback_wrapper)(double, void*, void**) except? 0.0
|
||||
@ -31,12 +32,20 @@ cdef extern from "cantera/cython/funcWrapper.h":
|
||||
void setExceptionValue(PyObject*)
|
||||
|
||||
|
||||
cdef extern from "cantera/numerics/Func1Factory.h":
|
||||
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
|
||||
string, double) except +translate_exception
|
||||
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
|
||||
string, vector[double]&) except +translate_exception
|
||||
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
|
||||
string, shared_ptr[CxxFunc1], shared_ptr[CxxFunc1]) except +translate_exception
|
||||
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
|
||||
string, shared_ptr[CxxFunc1], double) except +translate_exception
|
||||
|
||||
|
||||
cdef class Func1:
|
||||
cdef shared_ptr[CxxFunc1] _func
|
||||
cdef CxxFunc1* func
|
||||
cdef object callable
|
||||
cdef object exception
|
||||
cpdef void _set_callback(self, object) except *
|
||||
|
||||
cdef class TabulatedFunction(Func1):
|
||||
cpdef void _set_tables(self, object, object, string) except *
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
|
||||
from ._utils cimport *
|
||||
|
||||
|
||||
cdef double func_callback(double t, void* obj, void** err) except? 0.0:
|
||||
"""
|
||||
This function is called from C/C++ to evaluate a `Func1` object ``obj``,
|
||||
@ -65,34 +66,108 @@ cdef class Func1:
|
||||
self.exception = None
|
||||
self.callable = None
|
||||
|
||||
def __init__(self, c):
|
||||
def __init__(self, c, *, init=True):
|
||||
if init is False:
|
||||
# used by 'create' classmethod
|
||||
return
|
||||
if hasattr(c, '__call__'):
|
||||
# callback function
|
||||
self._set_callback(c)
|
||||
else:
|
||||
arr = np.array(c)
|
||||
try:
|
||||
if arr.ndim == 0:
|
||||
# handle constants or unsized numpy arrays
|
||||
k = float(c)
|
||||
self._set_callback(lambda t: k)
|
||||
elif arr.size == 1:
|
||||
# handle lists, tuples or numpy arrays with a single element
|
||||
k = float(c[0])
|
||||
self._set_callback(lambda t: k)
|
||||
else:
|
||||
raise TypeError
|
||||
return
|
||||
|
||||
except TypeError:
|
||||
raise TypeError(
|
||||
"'Func1' objects must be constructed from a number or "
|
||||
"a callable object") from None
|
||||
cdef Func1 func
|
||||
try:
|
||||
arr = np.array(c)
|
||||
if arr.ndim == 0:
|
||||
# handle constants or unsized numpy arrays
|
||||
k = float(c)
|
||||
elif arr.size == 1:
|
||||
# handle lists, tuples or numpy arrays with a single element
|
||||
k = float(c[0])
|
||||
else:
|
||||
raise TypeError
|
||||
func = Func1.cxx_functor("constant", k)
|
||||
self._func = func._func
|
||||
self.func = self._func.get()
|
||||
|
||||
except TypeError:
|
||||
raise TypeError(
|
||||
"'Func1' objects must be constructed from a number or "
|
||||
"a callable object") from None
|
||||
|
||||
cpdef void _set_callback(self, c) except *:
|
||||
self.callable = c
|
||||
self._func.reset(new CxxFunc1Py(func_callback, <void*>self))
|
||||
self.func = self._func.get()
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""
|
||||
Return the type of the underlying C++ functor object.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return pystr(self.func.type())
|
||||
|
||||
@classmethod
|
||||
def cxx_functor(cls, functor_type, *args):
|
||||
"""
|
||||
Retrieve a C++ `Func1` functor (advanced feature).
|
||||
|
||||
For implemented functor types, see the Cantera C++ ``Func1`` documentation.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
cdef shared_ptr[CxxFunc1] func
|
||||
cdef Func1 f0
|
||||
cdef Func1 f1
|
||||
cdef string cxx_string = stringify(functor_type)
|
||||
cdef vector[double] arr
|
||||
if len(args) == 0:
|
||||
# simple functor with no parameter
|
||||
func = CxxNewFunc1(cxx_string, 1.)
|
||||
elif len(args) == 1:
|
||||
if hasattr(args[0], "__len__"):
|
||||
# advanced functor with array and no parameter
|
||||
for v in args[0]:
|
||||
arr.push_back(v)
|
||||
func = CxxNewFunc1(cxx_string, arr)
|
||||
else:
|
||||
# simple functor with scalar parameter
|
||||
func = CxxNewFunc1(cxx_string, float(args[0]))
|
||||
elif len(args) == 2:
|
||||
if isinstance(args[0], Func1) and isinstance(args[1], Func1):
|
||||
# compound functor
|
||||
f0 = args[0]
|
||||
f1 = args[1]
|
||||
func = CxxNewFunc1(cxx_string, f0._func, f1._func)
|
||||
elif isinstance(args[0], Func1):
|
||||
# modified functor
|
||||
f0 = args[0]
|
||||
func = CxxNewFunc1(cxx_string, f0._func, float(args[1]))
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
cls_name = pystr(func.get().type_name()).split("::")[-1]
|
||||
cdef Func1 out = type(
|
||||
cls_name, (cls, ), {"__module__": cls.__module__})(None, init=False)
|
||||
out._func = func
|
||||
out.func = out._func.get()
|
||||
return out
|
||||
|
||||
def write(self, name="t"):
|
||||
"""
|
||||
Write a :math:`LaTeX` expression representing a functor.
|
||||
|
||||
:param name:
|
||||
Name of the variable to be used.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return pystr(self.func.write(stringify(name)))
|
||||
|
||||
def __call__(self, t):
|
||||
return self.func.eval(t)
|
||||
|
||||
@ -134,17 +209,7 @@ cdef class TabulatedFunction(Func1):
|
||||
"""
|
||||
|
||||
def __init__(self, time, fval, method='linear'):
|
||||
self._set_tables(time, fval, stringify(method))
|
||||
|
||||
cpdef void _set_tables(self, time, fval, string method) except *:
|
||||
tt = np.asarray(time, dtype=np.double)
|
||||
ff = np.asarray(fval, dtype=np.double)
|
||||
if tt.size != ff.size:
|
||||
raise ValueError("Sizes of arrays do not match "
|
||||
"({} vs {})".format(tt.size, ff.size))
|
||||
elif tt.size == 0:
|
||||
raise ValueError("Arrays must not be empty.")
|
||||
cdef np.ndarray[np.double_t, ndim=1] tvec = tt
|
||||
cdef np.ndarray[np.double_t, ndim=1] fvec = ff
|
||||
self.func = <CxxFunc1*>(new CxxTabulated1(tt.size, &tvec[0], &fvec[0],
|
||||
method))
|
||||
arr = np.hstack([np.array(time), np.array(fval)])
|
||||
cdef Func1 func = Func1.cxx_functor(f"tabulated-{method}", arr)
|
||||
self._func = func._func
|
||||
self.func = self._func.get()
|
||||
|
Loading…
Reference in New Issue
Block a user