Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
junzis
GitHub Repository: junzis/openap
Path: blob/master/tests/test_ndarrayconvert.py
592 views
1
"""Tests for the ndarrayconvert decorator."""
2
3
import numpy as np
4
import pytest
5
6
from openap.extra import ndarrayconvert
7
8
9
class MockModel:
10
"""Mock model class for testing the decorator."""
11
12
@ndarrayconvert
13
def compute(self, x, y):
14
"""Simple computation that requires array inputs."""
15
return x + y
16
17
@ndarrayconvert
18
def compute_multi(self, x, y):
19
"""Return multiple values."""
20
return x + y, x * y
21
22
@ndarrayconvert(column=True)
23
def compute_column(self, x, y):
24
"""Computation with column reshaping."""
25
return x + y
26
27
28
class TestNdarrayConvertNumPy:
29
"""Test ndarrayconvert with NumPy inputs."""
30
31
def test_scalar_input_returns_scalar(self):
32
"""Scalar inputs should return scalar output."""
33
model = MockModel()
34
result = model.compute(1.0, 2.0)
35
assert isinstance(result, float)
36
assert result == 3.0
37
38
def test_int_input_returns_scalar(self):
39
"""Integer inputs should return scalar output."""
40
model = MockModel()
41
result = model.compute(1, 2)
42
assert isinstance(result, (int, float, np.integer, np.floating))
43
assert result == 3
44
45
def test_list_input_returns_array(self):
46
"""List inputs should return array output."""
47
model = MockModel()
48
result = model.compute([1, 2, 3], [4, 5, 6])
49
assert isinstance(result, np.ndarray)
50
np.testing.assert_array_equal(result, [5, 7, 9])
51
52
def test_array_input_returns_array(self):
53
"""Array inputs should return array output."""
54
model = MockModel()
55
x = np.array([1.0, 2.0, 3.0])
56
y = np.array([4.0, 5.0, 6.0])
57
result = model.compute(x, y)
58
assert isinstance(result, np.ndarray)
59
np.testing.assert_array_equal(result, [5.0, 7.0, 9.0])
60
61
def test_single_element_array_returns_scalar(self):
62
"""Single-element array should return scalar."""
63
model = MockModel()
64
result = model.compute(np.array([1.0]), np.array([2.0]))
65
assert isinstance(result, float)
66
assert result == 3.0
67
68
def test_multi_return_scalar(self):
69
"""Multiple return values with scalar inputs."""
70
model = MockModel()
71
sum_result, prod_result = model.compute_multi(2.0, 3.0)
72
assert isinstance(sum_result, float)
73
assert isinstance(prod_result, float)
74
assert sum_result == 5.0
75
assert prod_result == 6.0
76
77
def test_multi_return_array(self):
78
"""Multiple return values with array inputs."""
79
model = MockModel()
80
sum_result, prod_result = model.compute_multi([1, 2], [3, 4])
81
assert isinstance(sum_result, np.ndarray)
82
assert isinstance(prod_result, np.ndarray)
83
np.testing.assert_array_equal(sum_result, [4, 6])
84
np.testing.assert_array_equal(prod_result, [3, 8])
85
86
def test_column_mode(self):
87
"""Test column=True reshapes to column vectors."""
88
model = MockModel()
89
result = model.compute_column([1, 2, 3], [4, 5, 6])
90
assert isinstance(result, np.ndarray)
91
assert result.shape == (3, 1)
92
93
def test_kwargs(self):
94
"""Test with keyword arguments."""
95
model = MockModel()
96
result = model.compute(x=1.0, y=2.0)
97
assert isinstance(result, float)
98
assert result == 3.0
99
100
101
class TestNdarrayConvertCasadi:
102
"""Test ndarrayconvert with CasADi inputs."""
103
104
@pytest.fixture
105
def casadi(self):
106
return pytest.importorskip("casadi")
107
108
def test_symbolic_passthrough(self, casadi):
109
"""CasADi symbolic types should pass through unchanged."""
110
model = MockModel()
111
x = casadi.SX.sym("x")
112
y = casadi.SX.sym("y")
113
result = model.compute(x, y)
114
assert isinstance(result, casadi.SX)
115
116
def test_mx_passthrough(self, casadi):
117
"""CasADi MX types should pass through unchanged."""
118
model = MockModel()
119
x = casadi.MX.sym("x")
120
y = casadi.MX.sym("y")
121
result = model.compute(x, y)
122
assert isinstance(result, casadi.MX)
123
124
def test_dm_passthrough(self, casadi):
125
"""CasADi DM types should pass through unchanged."""
126
model = MockModel()
127
x = casadi.DM([1, 2, 3])
128
y = casadi.DM([4, 5, 6])
129
result = model.compute(x, y)
130
assert isinstance(result, casadi.DM)
131
132
133
class TestNdarrayConvertJax:
134
"""Test ndarrayconvert with JAX inputs."""
135
136
@pytest.fixture
137
def jax(self):
138
return pytest.importorskip("jax")
139
140
@pytest.fixture
141
def jnp(self, jax):
142
return jax.numpy
143
144
def test_jax_array_passthrough(self, jnp):
145
"""JAX arrays should pass through unchanged."""
146
model = MockModel()
147
x = jnp.array([1.0, 2.0, 3.0])
148
y = jnp.array([4.0, 5.0, 6.0])
149
result = model.compute(x, y)
150
# Result should be JAX array, not converted to NumPy
151
assert "jax" in type(result).__module__
152
153
def test_jax_scalar_passthrough(self, jnp):
154
"""JAX scalars should pass through unchanged."""
155
model = MockModel()
156
x = jnp.array(1.0)
157
y = jnp.array(2.0)
158
result = model.compute(x, y)
159
assert "jax" in type(result).__module__
160
161
162
class TestNdarrayConvertIntegration:
163
"""Integration tests with actual OpenAP models."""
164
165
def test_thrust_scalar(self):
166
"""Thrust model with scalar inputs."""
167
from openap import Thrust
168
169
thrust = Thrust("A320")
170
result = thrust.takeoff(tas=150, alt=0)
171
assert isinstance(result, float)
172
assert result > 0
173
174
def test_thrust_array(self):
175
"""Thrust model with array inputs."""
176
from openap import Thrust
177
178
thrust = Thrust("A320")
179
result = thrust.takeoff(tas=[150, 200, 250], alt=[0, 0, 0])
180
assert isinstance(result, np.ndarray)
181
assert result.shape == (3,)
182
assert all(r > 0 for r in result)
183
184
def test_drag_scalar(self):
185
"""Drag model with scalar inputs."""
186
from openap import Drag
187
188
drag = Drag("A320")
189
result = drag.clean(mass=65000, tas=250, alt=35000)
190
assert isinstance(result, float)
191
assert result > 0
192
193
def test_fuelflow_scalar(self):
194
"""FuelFlow model with scalar inputs."""
195
from openap import FuelFlow
196
197
ff = FuelFlow("A320")
198
result = ff.enroute(mass=65000, tas=250, alt=35000)
199
assert isinstance(result, float)
200
assert result > 0
201
202
203
if __name__ == "__main__":
204
pytest.main([__file__, "-v"])
205
206