Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
junzis
GitHub Repository: junzis/openap
Path: blob/master/tests/test_backends.py
592 views
1
"""Tests for the backend abstraction pattern.
2
3
This module tests that all three backends (NumPy, CasADi, JAX) work
4
correctly and produce consistent results.
5
"""
6
7
import numpy as np
8
import pytest
9
10
from openap import Aero, Drag, Emission, FuelFlow, Thrust
11
from openap.backends import CasadiBackend, JaxBackend, NumpyBackend
12
13
14
# Expected values computed with NumPy backend (reference)
15
EXPECTED = {
16
"thrust_takeoff": 185981.10, # N, tas=150kt, alt=0ft
17
"thrust_climb": 72317.87, # N, tas=280kt, alt=20000ft, roc=2000fpm
18
"drag_clean": 47722.51, # N, mass=65000kg, tas=250kt, alt=35000ft
19
"fuelflow_enroute": 0.988612, # kg/s, mass=65000kg, tas=250kt, alt=35000ft
20
"fuelflow_at_thrust": 1.030994, # kg/s, thrust=50000N
21
"emission_nox": 16.2334, # g/s, ffac=1.0kg/s, tas=250kt, alt=35000ft
22
"emission_co2": 3160.0, # g/s, ffac=1.0kg/s
23
"emission_h2o": 1230.0, # g/s, ffac=1.0kg/s
24
"aero_temperature": 223.15, # K, h=10000m
25
"aero_density": 0.412604, # kg/m³, h=10000m
26
"aero_pressure": 26429.70, # Pa, h=10000m
27
"thrust_array": [185981.10, 141161.74, 96559.57], # N
28
}
29
30
# Tolerance for floating point comparisons
31
RTOL = 1e-4 # 0.01% relative tolerance
32
33
34
class TestNumpyBackend:
35
"""Tests for NumpyBackend (default)."""
36
37
def test_thrust_takeoff(self):
38
"""Test thrust calculation at takeoff."""
39
thrust = Thrust("A320")
40
assert type(thrust.backend).__name__ == "NumpyBackend"
41
42
T = thrust.takeoff(tas=150, alt=0)
43
assert isinstance(T, (float, np.floating))
44
assert T == pytest.approx(EXPECTED["thrust_takeoff"], rel=RTOL)
45
46
def test_thrust_climb(self):
47
"""Test thrust calculation during climb."""
48
thrust = Thrust("A320")
49
T = thrust.climb(tas=280, alt=20000, roc=2000)
50
assert isinstance(T, (float, np.floating))
51
assert T == pytest.approx(EXPECTED["thrust_climb"], rel=RTOL)
52
53
def test_drag_clean(self):
54
"""Test drag calculation in clean configuration."""
55
drag = Drag("A320")
56
assert type(drag.backend).__name__ == "NumpyBackend"
57
58
D = drag.clean(mass=65000, tas=250, alt=35000)
59
assert isinstance(D, (float, np.floating))
60
assert D == pytest.approx(EXPECTED["drag_clean"], rel=RTOL)
61
62
def test_fuelflow_enroute(self):
63
"""Test fuel flow calculation."""
64
ff = FuelFlow("A320")
65
assert type(ff.backend).__name__ == "NumpyBackend"
66
67
fuel = ff.enroute(mass=65000, tas=250, alt=35000)
68
assert isinstance(fuel, (float, np.floating))
69
assert fuel == pytest.approx(EXPECTED["fuelflow_enroute"], rel=RTOL)
70
71
def test_fuelflow_at_thrust(self):
72
"""Test fuel flow at given thrust."""
73
ff = FuelFlow("A320")
74
fuel = ff.at_thrust(50000)
75
assert isinstance(fuel, (float, np.floating))
76
assert fuel == pytest.approx(EXPECTED["fuelflow_at_thrust"], rel=RTOL)
77
78
def test_emission_nox(self):
79
"""Test NOx emission calculation."""
80
em = Emission("A320")
81
assert type(em.backend).__name__ == "NumpyBackend"
82
83
nox = em.nox(ffac=1.0, tas=250, alt=35000)
84
assert isinstance(nox, (float, np.floating))
85
assert nox == pytest.approx(EXPECTED["emission_nox"], rel=RTOL)
86
87
def test_emission_co2(self):
88
"""Test CO2 emission calculation."""
89
em = Emission("A320")
90
co2 = em.co2(ffac=1.0)
91
assert co2 == pytest.approx(EXPECTED["emission_co2"], rel=RTOL)
92
93
def test_emission_h2o(self):
94
"""Test H2O emission calculation."""
95
em = Emission("A320")
96
h2o = em.h2o(ffac=1.0)
97
assert h2o == pytest.approx(EXPECTED["emission_h2o"], rel=RTOL)
98
99
def test_aero_temperature(self):
100
"""Test temperature calculation."""
101
aero = Aero()
102
T = aero.temperature(10000) # 10km
103
assert isinstance(T, (float, np.floating))
104
assert T == pytest.approx(EXPECTED["aero_temperature"], rel=RTOL)
105
106
def test_aero_density(self):
107
"""Test density calculation."""
108
aero = Aero()
109
rho = aero.density(10000) # 10km
110
assert rho == pytest.approx(EXPECTED["aero_density"], rel=RTOL)
111
112
def test_aero_pressure(self):
113
"""Test pressure calculation."""
114
aero = Aero()
115
p = aero.pressure(10000) # 10km
116
assert p == pytest.approx(EXPECTED["aero_pressure"], rel=RTOL)
117
118
def test_array_inputs(self):
119
"""Test that array inputs work correctly."""
120
thrust = Thrust("A320")
121
tas = np.array([150, 200, 250])
122
alt = np.array([0, 10000, 20000])
123
124
T = thrust.takeoff(tas, alt)
125
assert isinstance(T, np.ndarray)
126
assert T.shape == (3,)
127
np.testing.assert_allclose(T, EXPECTED["thrust_array"], rtol=RTOL)
128
129
130
class TestCasadiBackend:
131
"""Tests for CasadiBackend."""
132
133
@pytest.fixture
134
def casadi(self):
135
"""Import casadi if available."""
136
casadi = pytest.importorskip("casadi")
137
return casadi
138
139
def test_thrust_symbolic(self, casadi):
140
"""Test thrust with symbolic inputs."""
141
thrust = Thrust("A320", backend=CasadiBackend())
142
assert type(thrust.backend).__name__ == "CasadiBackend"
143
144
tas = casadi.SX.sym("tas")
145
alt = casadi.SX.sym("alt")
146
T = thrust.takeoff(tas, alt)
147
148
assert isinstance(T, casadi.SX)
149
150
# Evaluate at numeric values
151
f = casadi.Function("f", [tas, alt], [T])
152
result = float(f(150, 0))
153
assert result == pytest.approx(EXPECTED["thrust_takeoff"], rel=RTOL)
154
155
def test_drag_symbolic(self, casadi):
156
"""Test drag with symbolic inputs."""
157
drag = Drag("A320", backend=CasadiBackend())
158
159
mass = casadi.SX.sym("mass")
160
tas = casadi.SX.sym("tas")
161
alt = casadi.SX.sym("alt")
162
D = drag.clean(mass, tas, alt)
163
164
assert isinstance(D, casadi.SX)
165
166
# Evaluate
167
f = casadi.Function("f", [mass, tas, alt], [D])
168
result = float(f(65000, 250, 35000))
169
assert result == pytest.approx(EXPECTED["drag_clean"], rel=RTOL)
170
171
def test_fuelflow_symbolic(self, casadi):
172
"""Test fuel flow with symbolic inputs."""
173
ff = FuelFlow("A320", backend=CasadiBackend())
174
175
mass = casadi.SX.sym("mass")
176
tas = casadi.SX.sym("tas")
177
alt = casadi.SX.sym("alt")
178
fuel = ff.enroute(mass, tas, alt)
179
180
assert isinstance(fuel, casadi.SX)
181
182
# Evaluate
183
f = casadi.Function("f", [mass, tas, alt], [fuel])
184
result = float(f(65000, 250, 35000))
185
assert result == pytest.approx(EXPECTED["fuelflow_enroute"], rel=RTOL)
186
187
def test_emission_symbolic(self, casadi):
188
"""Test emission with symbolic inputs."""
189
em = Emission("A320", backend=CasadiBackend())
190
191
ffac = casadi.SX.sym("ffac")
192
tas = casadi.SX.sym("tas")
193
alt = casadi.SX.sym("alt")
194
nox = em.nox(ffac, tas, alt)
195
196
assert isinstance(nox, casadi.SX)
197
198
# Evaluate
199
f = casadi.Function("f", [ffac, tas, alt], [nox])
200
result = float(f(1.0, 250, 35000))
201
assert result == pytest.approx(EXPECTED["emission_nox"], rel=RTOL)
202
203
def test_jacobian(self, casadi):
204
"""Test that Jacobian can be computed."""
205
thrust = Thrust("A320", backend=CasadiBackend())
206
207
tas = casadi.SX.sym("tas")
208
alt = casadi.SX.sym("alt")
209
T = thrust.takeoff(tas, alt)
210
211
# Compute Jacobian
212
jacobian = casadi.jacobian(T, tas)
213
assert isinstance(jacobian, casadi.SX)
214
215
# Evaluate (note: 'jac' is a reserved name in CasADi)
216
jac_fn = casadi.Function("thrust_jacobian", [tas, alt], [jacobian])
217
result = jac_fn(150, 0)
218
assert result.shape == (1, 1)
219
220
# dT/dtas should be negative (thrust decreases with speed at takeoff)
221
assert float(result) < 0
222
assert float(result) == pytest.approx(-276.19, rel=0.01)
223
224
def test_aero_symbolic(self, casadi):
225
"""Test aero functions with symbolic inputs."""
226
aero = Aero(backend=CasadiBackend())
227
228
h = casadi.SX.sym("h")
229
T = aero.temperature(h)
230
assert isinstance(T, casadi.SX)
231
232
# Evaluate
233
f = casadi.Function("f", [h], [T])
234
result = float(f(10000))
235
assert result == pytest.approx(EXPECTED["aero_temperature"], rel=RTOL)
236
237
238
class TestJaxBackend:
239
"""Tests for JaxBackend."""
240
241
@pytest.fixture
242
def jax(self):
243
"""Import jax if available."""
244
jax = pytest.importorskip("jax")
245
return jax
246
247
@pytest.fixture
248
def jnp(self, jax):
249
"""Import jax.numpy."""
250
return jax.numpy
251
252
def test_thrust_jax(self, jnp):
253
"""Test thrust with JAX arrays."""
254
thrust = Thrust("A320", backend=JaxBackend())
255
assert type(thrust.backend).__name__ == "JaxBackend"
256
257
T = thrust.takeoff(jnp.array(150.0), jnp.array(0.0))
258
assert float(T) == pytest.approx(EXPECTED["thrust_takeoff"], rel=RTOL)
259
260
def test_drag_jax(self, jnp):
261
"""Test drag with JAX arrays."""
262
drag = Drag("A320", backend=JaxBackend())
263
264
D = drag.clean(
265
jnp.array(65000.0), jnp.array(250.0), jnp.array(35000.0)
266
)
267
assert float(D) == pytest.approx(EXPECTED["drag_clean"], rel=RTOL)
268
269
def test_fuelflow_jax(self, jnp):
270
"""Test fuel flow with JAX arrays."""
271
ff = FuelFlow("A320", backend=JaxBackend())
272
273
fuel = ff.enroute(
274
jnp.array(65000.0), jnp.array(250.0), jnp.array(35000.0)
275
)
276
assert float(fuel) == pytest.approx(EXPECTED["fuelflow_enroute"], rel=RTOL)
277
278
def test_emission_jax(self, jnp):
279
"""Test emission with JAX arrays."""
280
em = Emission("A320", backend=JaxBackend())
281
282
nox = em.nox(jnp.array(1.0), jnp.array(250.0), jnp.array(35000.0))
283
assert float(nox) == pytest.approx(EXPECTED["emission_nox"], rel=RTOL)
284
285
def test_jit_compilation(self, jax, jnp):
286
"""Test that JIT compilation works."""
287
thrust = Thrust("A320", backend=JaxBackend())
288
289
@jax.jit
290
def compute_thrust(tas, alt):
291
return thrust.takeoff(tas, alt)
292
293
# First call compiles
294
result1 = compute_thrust(jnp.array(150.0), jnp.array(0.0))
295
# Second call uses compiled version
296
result2 = compute_thrust(jnp.array(200.0), jnp.array(0.0))
297
298
assert float(result1) == pytest.approx(EXPECTED["thrust_takeoff"], rel=RTOL)
299
# Different input should give different output
300
assert float(result2) == pytest.approx(173103.59, rel=RTOL)
301
302
def test_gradient(self, jax, jnp):
303
"""Test that gradients can be computed."""
304
thrust = Thrust("A320", backend=JaxBackend())
305
306
def thrust_fn(tas):
307
return thrust.takeoff(tas, 0.0)
308
309
grad_fn = jax.grad(thrust_fn)
310
dT_dtas = grad_fn(150.0)
311
312
# Gradient should match CasADi result
313
assert not jnp.isnan(dT_dtas)
314
assert float(dT_dtas) == pytest.approx(-276.19, rel=0.01)
315
316
def test_aero_jax(self, jnp):
317
"""Test aero functions with JAX."""
318
aero = Aero(backend=JaxBackend())
319
320
T = aero.temperature(jnp.array(10000.0))
321
assert float(T) == pytest.approx(EXPECTED["aero_temperature"], rel=RTOL)
322
323
rho = aero.density(jnp.array(10000.0))
324
assert float(rho) == pytest.approx(EXPECTED["aero_density"], rel=RTOL)
325
326
p = aero.pressure(jnp.array(10000.0))
327
assert float(p) == pytest.approx(EXPECTED["aero_pressure"], rel=RTOL)
328
329
330
class TestBackendConsistency:
331
"""Tests that all backends produce consistent results."""
332
333
@pytest.fixture
334
def casadi(self):
335
return pytest.importorskip("casadi")
336
337
@pytest.fixture
338
def jax(self):
339
return pytest.importorskip("jax")
340
341
def test_thrust_consistency(self, casadi, jax):
342
"""Test that all backends give same thrust."""
343
jnp = jax.numpy
344
345
# NumPy
346
thrust_np = Thrust("A320", backend=NumpyBackend())
347
T_np = thrust_np.takeoff(tas=150, alt=0)
348
349
# CasADi
350
thrust_ca = Thrust("A320", backend=CasadiBackend())
351
tas = casadi.SX.sym("tas")
352
alt = casadi.SX.sym("alt")
353
T_ca_sym = thrust_ca.takeoff(tas, alt)
354
f = casadi.Function("f", [tas, alt], [T_ca_sym])
355
T_ca = float(f(150, 0))
356
357
# JAX
358
thrust_jax = Thrust("A320", backend=JaxBackend())
359
T_jax = float(thrust_jax.takeoff(jnp.array(150.0), jnp.array(0.0)))
360
361
# All should match expected value
362
assert T_np == pytest.approx(EXPECTED["thrust_takeoff"], rel=RTOL)
363
assert T_ca == pytest.approx(EXPECTED["thrust_takeoff"], rel=RTOL)
364
assert T_jax == pytest.approx(EXPECTED["thrust_takeoff"], rel=RTOL)
365
366
def test_drag_consistency(self, casadi, jax):
367
"""Test that all backends give same drag."""
368
jnp = jax.numpy
369
370
# NumPy
371
drag_np = Drag("A320", backend=NumpyBackend())
372
D_np = drag_np.clean(mass=65000, tas=250, alt=35000)
373
374
# CasADi
375
drag_ca = Drag("A320", backend=CasadiBackend())
376
mass = casadi.SX.sym("mass")
377
tas = casadi.SX.sym("tas")
378
alt = casadi.SX.sym("alt")
379
D_ca_sym = drag_ca.clean(mass, tas, alt)
380
f = casadi.Function("f", [mass, tas, alt], [D_ca_sym])
381
D_ca = float(f(65000, 250, 35000))
382
383
# JAX
384
drag_jax = Drag("A320", backend=JaxBackend())
385
D_jax = float(
386
drag_jax.clean(
387
jnp.array(65000.0), jnp.array(250.0), jnp.array(35000.0)
388
)
389
)
390
391
# All should match expected value
392
assert D_np == pytest.approx(EXPECTED["drag_clean"], rel=RTOL)
393
assert D_ca == pytest.approx(EXPECTED["drag_clean"], rel=RTOL)
394
assert D_jax == pytest.approx(EXPECTED["drag_clean"], rel=RTOL)
395
396
def test_fuelflow_consistency(self, casadi, jax):
397
"""Test that all backends give same fuel flow."""
398
jnp = jax.numpy
399
400
# NumPy
401
ff_np = FuelFlow("A320", backend=NumpyBackend())
402
fuel_np = ff_np.enroute(mass=65000, tas=250, alt=35000)
403
404
# CasADi
405
ff_ca = FuelFlow("A320", backend=CasadiBackend())
406
mass = casadi.SX.sym("mass")
407
tas = casadi.SX.sym("tas")
408
alt = casadi.SX.sym("alt")
409
fuel_ca_sym = ff_ca.enroute(mass, tas, alt)
410
f = casadi.Function("f", [mass, tas, alt], [fuel_ca_sym])
411
fuel_ca = float(f(65000, 250, 35000))
412
413
# JAX
414
ff_jax = FuelFlow("A320", backend=JaxBackend())
415
fuel_jax = float(
416
ff_jax.enroute(
417
jnp.array(65000.0), jnp.array(250.0), jnp.array(35000.0)
418
)
419
)
420
421
# All should match expected value
422
assert fuel_np == pytest.approx(EXPECTED["fuelflow_enroute"], rel=RTOL)
423
assert fuel_ca == pytest.approx(EXPECTED["fuelflow_enroute"], rel=RTOL)
424
assert fuel_jax == pytest.approx(EXPECTED["fuelflow_enroute"], rel=RTOL)
425
426
def test_emission_consistency(self, casadi, jax):
427
"""Test that all backends give same emissions."""
428
jnp = jax.numpy
429
430
# NumPy
431
em_np = Emission("A320", backend=NumpyBackend())
432
nox_np = em_np.nox(ffac=1.0, tas=250, alt=35000)
433
434
# CasADi
435
em_ca = Emission("A320", backend=CasadiBackend())
436
ffac = casadi.SX.sym("ffac")
437
tas = casadi.SX.sym("tas")
438
alt = casadi.SX.sym("alt")
439
nox_ca_sym = em_ca.nox(ffac, tas, alt)
440
f = casadi.Function("f", [ffac, tas, alt], [nox_ca_sym])
441
nox_ca = float(f(1.0, 250, 35000))
442
443
# JAX
444
em_jax = Emission("A320", backend=JaxBackend())
445
nox_jax = float(
446
em_jax.nox(jnp.array(1.0), jnp.array(250.0), jnp.array(35000.0))
447
)
448
449
# All should match expected value
450
assert nox_np == pytest.approx(EXPECTED["emission_nox"], rel=RTOL)
451
assert nox_ca == pytest.approx(EXPECTED["emission_nox"], rel=RTOL)
452
assert nox_jax == pytest.approx(EXPECTED["emission_nox"], rel=RTOL)
453
454
def test_aero_consistency(self, casadi, jax):
455
"""Test that all backends give same aero values."""
456
jnp = jax.numpy
457
458
# NumPy
459
aero_np = Aero(backend=NumpyBackend())
460
T_np = aero_np.temperature(10000)
461
rho_np = aero_np.density(10000)
462
p_np = aero_np.pressure(10000)
463
464
# CasADi
465
aero_ca = Aero(backend=CasadiBackend())
466
h = casadi.SX.sym("h")
467
T_ca_sym = aero_ca.temperature(h)
468
rho_ca_sym = aero_ca.density(h)
469
p_ca_sym = aero_ca.pressure(h)
470
f_T = casadi.Function("f", [h], [T_ca_sym])
471
f_rho = casadi.Function("f", [h], [rho_ca_sym])
472
f_p = casadi.Function("f", [h], [p_ca_sym])
473
T_ca = float(f_T(10000))
474
rho_ca = float(f_rho(10000))
475
p_ca = float(f_p(10000))
476
477
# JAX
478
aero_jax = Aero(backend=JaxBackend())
479
T_jax = float(aero_jax.temperature(jnp.array(10000.0)))
480
rho_jax = float(aero_jax.density(jnp.array(10000.0)))
481
p_jax = float(aero_jax.pressure(jnp.array(10000.0)))
482
483
# All should match expected values
484
assert T_np == pytest.approx(EXPECTED["aero_temperature"], rel=RTOL)
485
assert T_ca == pytest.approx(EXPECTED["aero_temperature"], rel=RTOL)
486
assert T_jax == pytest.approx(EXPECTED["aero_temperature"], rel=RTOL)
487
488
assert rho_np == pytest.approx(EXPECTED["aero_density"], rel=RTOL)
489
assert rho_ca == pytest.approx(EXPECTED["aero_density"], rel=RTOL)
490
assert rho_jax == pytest.approx(EXPECTED["aero_density"], rel=RTOL)
491
492
assert p_np == pytest.approx(EXPECTED["aero_pressure"], rel=RTOL)
493
assert p_ca == pytest.approx(EXPECTED["aero_pressure"], rel=RTOL)
494
assert p_jax == pytest.approx(EXPECTED["aero_pressure"], rel=RTOL)
495
496
497
class TestConvenienceModules:
498
"""Tests for the convenience modules (openap.casadi, openap.jax)."""
499
500
def test_casadi_module(self):
501
"""Test openap.casadi convenience module."""
502
casadi = pytest.importorskip("casadi")
503
504
from openap.casadi import Drag, Emission, FuelFlow, Thrust, aero, prop
505
506
# Check classes use CasadiBackend
507
thrust = Thrust("A320")
508
assert type(thrust.backend).__name__ == "CasadiBackend"
509
510
# Check prop is available
511
ac = prop.aircraft("A320")
512
assert "mtow" in ac
513
assert ac["mtow"] == pytest.approx(78000, rel=0.01)
514
515
# Check aero works symbolically and gives correct values
516
h = casadi.SX.sym("h")
517
T = aero.temperature(h)
518
assert isinstance(T, casadi.SX)
519
520
f = casadi.Function("f", [h], [T])
521
result = float(f(10000))
522
assert result == pytest.approx(EXPECTED["aero_temperature"], rel=RTOL)
523
524
def test_jax_module(self):
525
"""Test openap.jax convenience module."""
526
jax = pytest.importorskip("jax")
527
jnp = jax.numpy
528
529
from openap.jax import Drag, Emission, FuelFlow, Thrust, aero
530
531
# Check classes use JaxBackend
532
thrust = Thrust("A320")
533
assert type(thrust.backend).__name__ == "JaxBackend"
534
535
# Check JIT works and gives correct values
536
@jax.jit
537
def compute(tas, alt):
538
return thrust.takeoff(tas, alt)
539
540
result = compute(jnp.array(150.0), jnp.array(0.0))
541
assert float(result) == pytest.approx(EXPECTED["thrust_takeoff"], rel=RTOL)
542
543
# Check aero gives correct values
544
T = aero.temperature(jnp.array(10000.0))
545
assert float(T) == pytest.approx(EXPECTED["aero_temperature"], rel=RTOL)
546
547
548
if __name__ == "__main__":
549
pytest.main([__file__, "-v"])
550
551