-
Notifications
You must be signed in to change notification settings - Fork 371
Expand file tree
/
Copy pathconftest.py
More file actions
92 lines (80 loc) · 3.26 KB
/
conftest.py
File metadata and controls
92 lines (80 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import warnings
import pytest
from tslearn.backend import check_keras_backend
from tslearn.datasets import UCR_UEA_datasets
try:
check_keras_backend()
import keras
except ImportError:
keras = None
def pytest_ignore_collect(collection_path, *args, **kwargs):
if keras is None and "shapelets" in collection_path.parts:
return True
def pytest_collection_modifyitems(config, items):
try:
import pandas
except ImportError:
pandas = None
try:
import cesium
except ImportError:
cesium = None
try:
import torch
except:
torch = None
if torch is None:
skip_marker = pytest.mark.skip(reason="torch not installed!")
for item in items:
if item.name in [
"tslearn.metrics._dtw.dtw",
"tslearn.metrics.dtw_variants.dtw",
"tslearn.metrics.softdtw_variants.cdist_soft_dtw_normalized",
"tslearn.metrics.softdtw_variants.soft_dtw",
"tslearn.metrics.softdtw_variants.soft_dtw_alignment",
"tslearn.metrics.softdtw_variants.cdist_soft_dtw",
"tslearn.metrics._frechet.frechet"
]:
item.add_marker(skip_marker)
if pandas is None:
skip_marker = pytest.mark.skip(reason="pandas not installed!")
for item in items:
if item.name in [
"tslearn.utils.cast.from_tsfresh_dataset",
"tslearn.utils.cast.to_tsfresh_dataset",
"tslearn.utils.cast.from_sktime_dataset",
"tslearn.utils.cast.to_sktime_dataset",
"tslearn.utils.cast.from_pyflux_dataset",
"tslearn.utils.cast.to_pyflux_dataset",
"tslearn.utils.cast.from_cesium_dataset",
"tslearn.utils.cast.to_cesium_dataset",
]:
item.add_marker(skip_marker)
if cesium is None:
skip_marker = pytest.mark.skip(reason="cesium not installed!")
for item in items:
if item.name in [
"tslearn.utils.cast.to_cesium_dataset",
"tslearn.utils.cast.from_cesium_dataset",
]:
item.add_marker(skip_marker)
# Skip related doctests if UCR UEA datasets cannot be fetched
try:
datasets = UCR_UEA_datasets()
ucr_uea_datasets = bool(datasets.list_datasets())
except Exception as exc:
ucr_uea_datasets = False
warnings.warn("Error listing UCR UEA datasets: {}".format(exc))
if not ucr_uea_datasets:
warnings.warn("Skipping doctests requiring UCR UEA dataset download")
skip_marker = pytest.mark.skip(reason="Datasets not cached!")
for item in items:
if item.name in [
"tslearn.datasets.ucr_uea.UCR_UEA_datasets.list_datasets",
"tslearn.datasets.ucr_uea.UCR_UEA_datasets.list_multivariate_datasets",
"tslearn.datasets.ucr_uea.UCR_UEA_datasets.list_univariate_datasets",
"tslearn.datasets.ucr_uea.UCR_UEA_datasets.load_dataset",
"tslearn.datasets.ucr_uea.UCR_UEA_datasets.baseline_accuracy",
"tslearn.datasets.ucr_uea.UCR_UEA_datasets.list_cached_datasets"
]:
item.add_marker(skip_marker)