Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
wiseplat
GitHub Repository: wiseplat/python-code
Path: blob/master/ invest-robot-contest_TinkoffBotTwitch-main/venv/lib/python3.8/site-packages/numpy/ma/testutils.py
7757 views
1
"""Miscellaneous functions for testing masked arrays and subclasses
2
3
:author: Pierre Gerard-Marchant
4
:contact: pierregm_at_uga_dot_edu
5
:version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $
6
7
"""
8
import operator
9
10
import numpy as np
11
from numpy import ndarray, float_
12
import numpy.core.umath as umath
13
import numpy.testing
14
from numpy.testing import (
15
assert_, assert_allclose, assert_array_almost_equal_nulp,
16
assert_raises, build_err_msg
17
)
18
from .core import mask_or, getmask, masked_array, nomask, masked, filled
19
20
__all__masked = [
21
'almost', 'approx', 'assert_almost_equal', 'assert_array_almost_equal',
22
'assert_array_approx_equal', 'assert_array_compare',
23
'assert_array_equal', 'assert_array_less', 'assert_close',
24
'assert_equal', 'assert_equal_records', 'assert_mask_equal',
25
'assert_not_equal', 'fail_if_array_equal',
26
]
27
28
# Include some normal test functions to avoid breaking other projects who
29
# have mistakenly included them from this file. SciPy is one. That is
30
# unfortunate, as some of these functions are not intended to work with
31
# masked arrays. But there was no way to tell before.
32
from unittest import TestCase
33
__some__from_testing = [
34
'TestCase', 'assert_', 'assert_allclose', 'assert_array_almost_equal_nulp',
35
'assert_raises'
36
]
37
38
__all__ = __all__masked + __some__from_testing
39
40
41
def approx(a, b, fill_value=True, rtol=1e-5, atol=1e-8):
42
"""
43
Returns true if all components of a and b are equal to given tolerances.
44
45
If fill_value is True, masked values considered equal. Otherwise,
46
masked values are considered unequal. The relative error rtol should
47
be positive and << 1.0 The absolute error atol comes into play for
48
those elements of b that are very small or zero; it says how small a
49
must be also.
50
51
"""
52
m = mask_or(getmask(a), getmask(b))
53
d1 = filled(a)
54
d2 = filled(b)
55
if d1.dtype.char == "O" or d2.dtype.char == "O":
56
return np.equal(d1, d2).ravel()
57
x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
58
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
59
d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y))
60
return d.ravel()
61
62
63
def almost(a, b, decimal=6, fill_value=True):
64
"""
65
Returns True if a and b are equal up to decimal places.
66
67
If fill_value is True, masked values considered equal. Otherwise,
68
masked values are considered unequal.
69
70
"""
71
m = mask_or(getmask(a), getmask(b))
72
d1 = filled(a)
73
d2 = filled(b)
74
if d1.dtype.char == "O" or d2.dtype.char == "O":
75
return np.equal(d1, d2).ravel()
76
x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
77
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
78
d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal)
79
return d.ravel()
80
81
82
def _assert_equal_on_sequences(actual, desired, err_msg=''):
83
"""
84
Asserts the equality of two non-array sequences.
85
86
"""
87
assert_equal(len(actual), len(desired), err_msg)
88
for k in range(len(desired)):
89
assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}')
90
return
91
92
93
def assert_equal_records(a, b):
94
"""
95
Asserts that two records are equal.
96
97
Pretty crude for now.
98
99
"""
100
assert_equal(a.dtype, b.dtype)
101
for f in a.dtype.names:
102
(af, bf) = (operator.getitem(a, f), operator.getitem(b, f))
103
if not (af is masked) and not (bf is masked):
104
assert_equal(operator.getitem(a, f), operator.getitem(b, f))
105
return
106
107
108
def assert_equal(actual, desired, err_msg=''):
109
"""
110
Asserts that two items are equal.
111
112
"""
113
# Case #1: dictionary .....
114
if isinstance(desired, dict):
115
if not isinstance(actual, dict):
116
raise AssertionError(repr(type(actual)))
117
assert_equal(len(actual), len(desired), err_msg)
118
for k, i in desired.items():
119
if k not in actual:
120
raise AssertionError(f"{k} not in {actual}")
121
assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}')
122
return
123
# Case #2: lists .....
124
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
125
return _assert_equal_on_sequences(actual, desired, err_msg='')
126
if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
127
msg = build_err_msg([actual, desired], err_msg,)
128
if not desired == actual:
129
raise AssertionError(msg)
130
return
131
# Case #4. arrays or equivalent
132
if ((actual is masked) and not (desired is masked)) or \
133
((desired is masked) and not (actual is masked)):
134
msg = build_err_msg([actual, desired],
135
err_msg, header='', names=('x', 'y'))
136
raise ValueError(msg)
137
actual = np.asanyarray(actual)
138
desired = np.asanyarray(desired)
139
(actual_dtype, desired_dtype) = (actual.dtype, desired.dtype)
140
if actual_dtype.char == "S" and desired_dtype.char == "S":
141
return _assert_equal_on_sequences(actual.tolist(),
142
desired.tolist(),
143
err_msg='')
144
return assert_array_equal(actual, desired, err_msg)
145
146
147
def fail_if_equal(actual, desired, err_msg='',):
148
"""
149
Raises an assertion error if two items are equal.
150
151
"""
152
if isinstance(desired, dict):
153
if not isinstance(actual, dict):
154
raise AssertionError(repr(type(actual)))
155
fail_if_equal(len(actual), len(desired), err_msg)
156
for k, i in desired.items():
157
if k not in actual:
158
raise AssertionError(repr(k))
159
fail_if_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}')
160
return
161
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
162
fail_if_equal(len(actual), len(desired), err_msg)
163
for k in range(len(desired)):
164
fail_if_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}')
165
return
166
if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
167
return fail_if_array_equal(actual, desired, err_msg)
168
msg = build_err_msg([actual, desired], err_msg)
169
if not desired != actual:
170
raise AssertionError(msg)
171
172
173
assert_not_equal = fail_if_equal
174
175
176
def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
177
"""
178
Asserts that two items are almost equal.
179
180
The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal).
181
182
"""
183
if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
184
return assert_array_almost_equal(actual, desired, decimal=decimal,
185
err_msg=err_msg, verbose=verbose)
186
msg = build_err_msg([actual, desired],
187
err_msg=err_msg, verbose=verbose)
188
if not round(abs(desired - actual), decimal) == 0:
189
raise AssertionError(msg)
190
191
192
assert_close = assert_almost_equal
193
194
195
def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
196
fill_value=True):
197
"""
198
Asserts that comparison between two masked arrays is satisfied.
199
200
The comparison is elementwise.
201
202
"""
203
# Allocate a common mask and refill
204
m = mask_or(getmask(x), getmask(y))
205
x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False)
206
y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False)
207
if ((x is masked) and not (y is masked)) or \
208
((y is masked) and not (x is masked)):
209
msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
210
header=header, names=('x', 'y'))
211
raise ValueError(msg)
212
# OK, now run the basic tests on filled versions
213
return np.testing.assert_array_compare(comparison,
214
x.filled(fill_value),
215
y.filled(fill_value),
216
err_msg=err_msg,
217
verbose=verbose, header=header)
218
219
220
def assert_array_equal(x, y, err_msg='', verbose=True):
221
"""
222
Checks the elementwise equality of two masked arrays.
223
224
"""
225
assert_array_compare(operator.__eq__, x, y,
226
err_msg=err_msg, verbose=verbose,
227
header='Arrays are not equal')
228
229
230
def fail_if_array_equal(x, y, err_msg='', verbose=True):
231
"""
232
Raises an assertion error if two masked arrays are not equal elementwise.
233
234
"""
235
def compare(x, y):
236
return (not np.alltrue(approx(x, y)))
237
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
238
header='Arrays are not equal')
239
240
241
def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True):
242
"""
243
Checks the equality of two masked arrays, up to given number odecimals.
244
245
The equality is checked elementwise.
246
247
"""
248
def compare(x, y):
249
"Returns the result of the loose comparison between x and y)."
250
return approx(x, y, rtol=10. ** -decimal)
251
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
252
header='Arrays are not almost equal')
253
254
255
def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
256
"""
257
Checks the equality of two masked arrays, up to given number odecimals.
258
259
The equality is checked elementwise.
260
261
"""
262
def compare(x, y):
263
"Returns the result of the loose comparison between x and y)."
264
return almost(x, y, decimal)
265
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
266
header='Arrays are not almost equal')
267
268
269
def assert_array_less(x, y, err_msg='', verbose=True):
270
"""
271
Checks that x is smaller than y elementwise.
272
273
"""
274
assert_array_compare(operator.__lt__, x, y,
275
err_msg=err_msg, verbose=verbose,
276
header='Arrays are not less-ordered')
277
278
279
def assert_mask_equal(m1, m2, err_msg=''):
280
"""
281
Asserts the equality of two masks.
282
283
"""
284
if m1 is nomask:
285
assert_(m2 is nomask)
286
if m2 is nomask:
287
assert_(m1 is nomask)
288
assert_array_equal(m1, m2, err_msg=err_msg)
289
290