import os
import shutil
import sys
import tempfile
from importlib.metadata import version
from importlib.util import find_spec
from pathlib import Path
import pytest
import yaml
from packaging.version import Version
from yt.config import ytcfg
from yt.utilities.answer_testing.testing_utilities import (
_compare_raw_arrays,
_hash_results,
_save_raw_arrays,
_save_result,
_streamline_for_io,
data_dir_load,
)
NUMPY_VERSION = Version(version("numpy"))
PILLOW_VERSION = Version(version("Pillow"))
MATPLOTLIB_VERSION = Version(version("matplotlib"))
if find_spec("setuptools") is not None:
SETUPTOOLS_VERSION = Version(version("setuptools"))
else:
SETUPTOOLS_VERSION = None
if find_spec("pandas") is not None:
PANDAS_VERSION = Version(version("pandas"))
else:
PANDAS_VERSION = None
def pytest_addoption(parser):
"""
Lets options be passed to test functions.
"""
parser.addoption(
"--with-answer-testing",
action="store_true",
)
parser.addoption(
"--answer-store",
action="store_true",
)
parser.addoption(
"--answer-raw-arrays",
action="store_true",
)
parser.addoption(
"--raw-answer-store",
action="store_true",
)
parser.addoption(
"--force-overwrite",
action="store_true",
)
parser.addoption(
"--no-hash",
action="store_true",
)
parser.addoption("--local-dir", default=None, help="Where answers are saved.")
parser.addini(
"local-dir",
default=str(Path(__file__).parent / "answer-store"),
help="answer directory.",
)
parser.addini(
"test_data_dir",
default=ytcfg.get("yt", "test_data_dir"),
help="Directory where data for tests is stored.",
)
def pytest_configure(config):
r"""
Reads in the tests/tests.yaml file. This file contains a list of
each answer test's answer file (including the changeset number).
"""
config.addinivalue_line("markers", "answer_test: Run the answer tests.")
config.addinivalue_line(
"markers", "big_data: Run answer tests that require large data files."
)
for value in (
"error",
"ignore::pytest.PytestCollectionWarning",
"ignore:Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.:UserWarning",
r"ignore:tight_layout.+falling back to Agg renderer:UserWarning",
"ignore:invalid value encountered in log10:RuntimeWarning",
"ignore:divide by zero encountered in log10:RuntimeWarning",
"ignore:unclosed file.*:ResourceWarning",
):
config.addinivalue_line("filterwarnings", value)
if SETUPTOOLS_VERSION is not None and SETUPTOOLS_VERSION >= Version("67.3.0"):
config.addinivalue_line(
"filterwarnings",
r"ignore:(Deprecated call to `pkg_resources\.declare_namespace\('.*'\)`\.\n)?"
r"Implementing implicit namespace packages \(as specified in PEP 420\) "
r"is preferred to `pkg_resources\.declare_namespace`\.:DeprecationWarning",
)
if SETUPTOOLS_VERSION is not None and SETUPTOOLS_VERSION >= Version("67.5.0"):
config.addinivalue_line(
"filterwarnings",
"ignore:pkg_resources is deprecated as an API:DeprecationWarning",
)
if NUMPY_VERSION >= Version("1.25"):
if find_spec("h5py") is not None and (
Version(version("h5py")) < Version("3.9")
):
config.addinivalue_line(
"filterwarnings",
"ignore:`product` is deprecated as of NumPy 1.25.0:DeprecationWarning",
)
if PILLOW_VERSION >= Version("11.3.0") and MATPLOTLIB_VERSION <= Version("3.10.3"):
config.addinivalue_line(
"filterwarnings",
r"ignore:'mode' parameter is deprecated:DeprecationWarning",
)
if PANDAS_VERSION is not None and PANDAS_VERSION >= Version("2.2.0"):
config.addinivalue_line(
"filterwarnings",
r"ignore:\s*Pyarrow will become a required dependency of pandas:DeprecationWarning",
)
if sys.version_info >= (3, 12):
config.addinivalue_line(
"filterwarnings",
r"ignore:datetime\.datetime\.utcfromtimestamp\(\) is deprecated:DeprecationWarning",
)
if find_spec("ratarmount"):
config.addinivalue_line(
"filterwarnings",
r"ignore:This process \(pid=\d+\) is multi-threaded, use of fork\(\) "
r"may lead to deadlocks in the child\."
":DeprecationWarning",
)
if find_spec("datatree"):
config.addinivalue_line(
"filterwarnings",
"ignore:" r"Engine.*loading failed.*" ":RuntimeWarning",
)
def pytest_collection_modifyitems(config, items):
r"""
Decide which tests to skip based on command-line options.
"""
skip_answer = pytest.mark.skip(reason="--with-answer-testing not set.")
skip_unit = pytest.mark.skip(reason="Running answer tests, so skipping unit tests.")
skip_big = pytest.mark.skip(reason="--answer-big-data not set.")
for item in items:
if "answer_test" in item.keywords and not config.getoption(
"--with-answer-testing"
):
item.add_marker(skip_answer)
if (
"big_data" in item.keywords
and not config.getoption("--with-answer-testing")
and not config.getoption("--answer-big-data")
):
item.add_marker(skip_big)
if "answer_test" not in item.keywords and config.getoption(
"--with-answer-testing"
):
item.add_marker(skip_unit)
def pytest_itemcollected(item):
mpl_marker = item.get_closest_marker("mpl_image_compare")
if mpl_marker is not None:
mpl_marker.kwargs.setdefault("tolerance", 0.5)
def _param_list(request):
r"""
Saves the non-ds, non-fixture function arguments for saving to
the answer file.
"""
blacklist = [
"hashing",
"answer_file",
"request",
"answer_compare",
"temp_dir",
"orbit_traj",
"etc_traj",
]
test_params = {}
for key, val in request.node.funcargs.items():
if key not in blacklist:
if key == "callback":
val = val[0]
test_params[key] = str(val)
test_params = _streamline_for_io(test_params)
return test_params
def _get_answer_files(request):
"""
Gets the path to where the hashed and raw answers are saved.
"""
answer_file = f"{request.cls.__name__}_{request.cls.answer_version}.yaml"
raw_answer_file = f"{request.cls.__name__}_{request.cls.answer_version}.h5"
clLocalDir = request.config.getoption("--local-dir")
iniLocalDir = request.config.getini("local-dir")
if clLocalDir is not None:
answer_file = os.path.join(os.path.expanduser(clLocalDir), answer_file)
raw_answer_file = os.path.join(os.path.expanduser(clLocalDir), raw_answer_file)
else:
answer_file = os.path.join(os.path.expanduser(iniLocalDir), answer_file)
raw_answer_file = os.path.join(os.path.expanduser(iniLocalDir), raw_answer_file)
overwrite = request.config.getoption("--force-overwrite")
storing = request.config.getoption("--answer-store")
raw_storing = request.config.getoption("--raw-answer-store")
raw = request.config.getoption("--answer-raw-arrays")
if os.path.exists(answer_file) and storing and not overwrite:
raise FileExistsError(
"Use `--force-overwrite` to overwrite an existing answer file."
)
if os.path.exists(raw_answer_file) and raw_storing and raw and not overwrite:
raise FileExistsError(
"Use `--force-overwrite` to overwrite an existing raw answer file."
)
if os.path.exists(answer_file) and storing and overwrite:
os.remove(answer_file)
if os.path.exists(raw_answer_file) and raw_storing and raw and overwrite:
os.remove(raw_answer_file)
print(os.path.abspath(answer_file))
return answer_file, raw_answer_file
@pytest.fixture(scope="function")
def hashing(request):
r"""
Handles initialization, generation, and saving of answer test
result hashes.
"""
no_hash = request.config.getoption("--no-hash")
store_hash = request.config.getoption("--answer-store")
raw = request.config.getoption("--answer-raw-arrays")
raw_store = request.config.getoption("--raw-answer-store")
if request.cls.answer_file is None:
request.cls.answer_file, request.cls.raw_answer_file = _get_answer_files(
request
)
if not no_hash and not store_hash and request.cls.saved_hashes is None:
try:
with open(request.cls.answer_file) as fd:
request.cls.saved_hashes = yaml.safe_load(fd)
except FileNotFoundError:
module_filename = f"{request.function.__module__.replace('.', os.sep)}.py"
with open(f"generate_test_{os.getpid()}.txt", "a") as fp:
fp.write(f"{module_filename}::{request.cls.__name__}\n")
pytest.fail(msg="Answer file not found.", pytrace=False)
request.cls.hashes = {}
yield
params = _param_list(request)
hashes = _hash_results(request.cls.hashes)
hashes.update(params)
hashes = {request.node.name: hashes}
if not no_hash and store_hash:
_save_result(hashes, request.cls.answer_file)
elif not no_hash and not store_hash:
try:
for test_name, test_hash in hashes.items():
assert test_name in request.cls.saved_hashes
assert test_hash == request.cls.saved_hashes[test_name]
except AssertionError:
pytest.fail(f"Comparison failure: {request.node.name}", pytrace=False)
if raw and raw_store:
_save_raw_arrays(
request.cls.hashes, request.cls.raw_answer_file, request.node.name
)
if raw and not raw_store:
_compare_raw_arrays(
request.cls.hashes, request.cls.raw_answer_file, request.node.name
)
@pytest.fixture(scope="function")
def temp_dir():
r"""
Creates a temporary directory needed by certain tests.
"""
curdir = os.getcwd()
if int(os.environ.get("GENERATE_YTDATA", 0)):
tmpdir = os.getcwd()
else:
tmpdir = tempfile.mkdtemp()
os.chdir(tmpdir)
yield tmpdir
os.chdir(curdir)
if tmpdir != curdir:
shutil.rmtree(tmpdir)
@pytest.fixture(scope="class")
def ds(request):
if isinstance(request.param, str):
ds_fn = request.param
opts = {}
else:
ds_fn, opts = request.param
try:
return data_dir_load(
ds_fn, cls=opts.get("cls"), args=opts.get("args"), kwargs=opts.get("kwargs")
)
except FileNotFoundError:
return pytest.skip(f"Data file: `{request.param}` not found.")
@pytest.fixture(scope="class")
def field(request):
"""
Fixture for returning the field. Needed because indirect=True is
used for loading the datasets.
"""
return request.param
@pytest.fixture(scope="class")
def dobj(request):
"""
Fixture for returning the ds_obj. Needed because indirect=True is
used for loading the datasets.
"""
return request.param
@pytest.fixture(scope="class")
def axis(request):
"""
Fixture for returning the axis. Needed because indirect=True is
used for loading the datasets.
"""
return request.param
@pytest.fixture(scope="class")
def weight(request):
"""
Fixture for returning the weight_field. Needed because
indirect=True is used for loading the datasets.
"""
return request.param
@pytest.fixture(scope="class")
def ds_repr(request):
"""
Fixture for returning the string representation of a dataset.
Needed because indirect=True is used for loading the datasets.
"""
return request.param
@pytest.fixture(scope="class")
def Npart(request):
"""
Fixture for returning the number of particles in a dataset.
Needed because indirect=True is used for loading the datasets.
"""
return request.param