Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
junzis
GitHub Repository: junzis/openap
Path: blob/master/openap/extra/__init__.py
592 views
1
"""Extra utilities for OpenAP."""
2
3
import functools
4
5
import numpy as np
6
7
8
def _is_symbolic_type(arg):
9
"""Check if argument is a symbolic type (CasADi or JAX).
10
11
Returns True for CasADi SX/MX/DM and JAX Array types.
12
"""
13
if arg is None:
14
return False
15
16
type_name = type(arg).__module__
17
18
# Fast path: check module name prefix
19
if type_name.startswith("casadi"):
20
return True
21
if type_name.startswith("jax"):
22
return True
23
24
return False
25
26
27
def ndarrayconvert(func=None, column=False):
28
"""Decorator to convert inputs to NumPy arrays and handle scalar outputs.
29
30
This decorator:
31
- Converts scalar inputs to 1-element arrays
32
- Converts list inputs to arrays
33
- Converts 1-element array outputs back to scalars
34
35
For symbolic types (CasADi, JAX), the decorator passes through
36
without conversion to allow symbolic computation and autodiff.
37
38
Args:
39
func: Function to decorate.
40
column: If True, reshape arrays to column vectors.
41
"""
42
43
def _decorator(func):
44
@functools.wraps(func)
45
def wrapper(self, *args, **kwargs):
46
# If any argument is symbolic (CasADi/JAX), skip conversion
47
if any(_is_symbolic_type(arg) for arg in args):
48
return func(self, *args, **kwargs)
49
if any(_is_symbolic_type(v) for v in kwargs.values()):
50
return func(self, *args, **kwargs)
51
52
# NumPy path: convert inputs to arrays
53
new_args = []
54
for arg in args:
55
if isinstance(arg, str):
56
new_args.append(arg)
57
elif np.ndim(arg) == 0:
58
arr = np.array([arg])
59
new_args.append(arr.reshape(-1, 1) if column else arr)
60
else:
61
arr = np.asarray(arg)
62
new_args.append(arr.reshape(-1, 1) if column else arr)
63
64
new_kwargs = {}
65
for k, arg in kwargs.items():
66
if isinstance(arg, str):
67
new_kwargs[k] = arg
68
elif np.ndim(arg) == 0:
69
arr = np.array([arg])
70
new_kwargs[k] = arr.reshape(-1, 1) if column else arr
71
else:
72
arr = np.asarray(arg)
73
new_kwargs[k] = arr.reshape(-1, 1) if column else arr
74
75
result = func(self, *new_args, **new_kwargs)
76
77
# Convert 1-element arrays back to scalars
78
def to_scalar(value):
79
if not isinstance(value, np.ndarray):
80
return value
81
if value.size == 1:
82
return value.item()
83
if not column and value.ndim > 1:
84
return value.squeeze()
85
return value
86
87
if isinstance(result, tuple):
88
return tuple(to_scalar(r) for r in result)
89
return to_scalar(result)
90
91
wrapper.orig_func = func
92
return wrapper
93
94
if func is not None:
95
return _decorator(func)
96
return _decorator
97
98