Skip to content

Commit bd3e17f

Browse files
committed
add remaining annotations
1 parent 6660cb7 commit bd3e17f

File tree

4 files changed

+67
-21
lines changed

4 files changed

+67
-21
lines changed

python/prophet/diagnostics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class _SupportsMap(Protocol):
4242
def map(self, __f: Callable[..., object], *its: Iterable[object]) -> Any: ...
4343

4444

45-
logger = logging.getLogger('prophet')
45+
logger: logging.Logger = logging.getLogger('prophet')
4646

4747

4848
def generate_cutoffs(

python/prophet/forecaster.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from prophet.models import StanBackendEnum, ModelInputData, ModelParams, TrendIndicator, IStanBackend
3030
from prophet.plot import (plot, plot_components)
3131

32-
logger = logging.getLogger('prophet')
32+
logger: logging.Logger = logging.getLogger('prophet')
3333
logger.setLevel(logging.INFO)
3434
NANOSECONDS_TO_SECONDS = 1000 * 1000 * 1000
3535

@@ -91,13 +91,14 @@ class Prophet:
9191
growth: Literal["linear", "logistic", "flat"]
9292
changepoints: pd.Series[pd.Timestamp] | None
9393
n_changepoints: int
94+
specified_changepoints: bool
9495
changepoint_range: float
9596
yearly_seasonality: Literal["auto"] | int
9697
weekly_seasonality: Literal["auto"] | int
9798
daily_seasonality: Literal["auto"] | int
9899
holidays: pd.DataFrame | None
99-
seasonality_mode: Literal["additive", "multiplicative"]
100-
holidays_mode: Literal["additive", "multiplicative"]
100+
seasonality_mode: _Mode
101+
holidays_mode: _Mode
101102
seasonality_prior_scale: float
102103
holidays_prior_scale: float
103104
changepoint_prior_scale: float
@@ -112,8 +113,8 @@ class Prophet:
112113
logistic_floor: bool
113114
t_scale: pd.Timedelta | None
114115
changepoints_t: npt.NDArray[np.float64] | None
115-
seasonalities: OrderedDict[str, dict]
116-
extra_regressors: OrderedDict[str, dict]
116+
seasonalities: OrderedDict[str, dict[str, Any]]
117+
extra_regressors: OrderedDict[str, dict[str, Any]]
117118
country_holidays: str | None
118119
stan_fit: Any | None
119120
params: dict[str, Any]
@@ -135,7 +136,7 @@ def __init__(
135136
weekly_seasonality: Literal["auto"] | int = "auto",
136137
daily_seasonality: Literal["auto"] | int = "auto",
137138
holidays: pd.DataFrame | None = None,
138-
seasonality_mode: Literal["additive", "multiplicative"] = "additive",
139+
seasonality_mode: _Mode = "additive",
139140
seasonality_prior_scale: SupportsFloat = 10.0,
140141
holidays_prior_scale: SupportsFloat = 10.0,
141142
changepoint_prior_scale: SupportsFloat = 0.05,
@@ -144,7 +145,7 @@ def __init__(
144145
uncertainty_samples: int = 1000,
145146
stan_backend: str | None = None,
146147
scaling: Literal["absmax", "minmax"] = "absmax",
147-
holidays_mode: Literal["additive", "multiplicative"] | None = None,
148+
holidays_mode: _Mode | None = None,
148149
) -> None:
149150
self.growth = growth
150151

@@ -690,7 +691,7 @@ def add_regressor(
690691
prior_scale: float | None = None,
691692
standardize: Literal['auto'] | bool = 'auto',
692693
mode: _Mode | None = None,
693-
):
694+
) -> Self:
694695
"""Add an additional regressor to be used for fitting and predicting.
695696
696697
The dataframe passed to `fit` and `predict` will have a column with the
@@ -1021,7 +1022,13 @@ def add_group_component(
10211022
components = pd.concat([components, new_comp], ignore_index=True)
10221023
return components
10231024

1024-
def parse_seasonality_args(self, name: str, arg, auto_disable: bool, default_order: int) -> int:
1025+
def parse_seasonality_args(
1026+
self,
1027+
name: str,
1028+
arg: Literal['auto'] | int,
1029+
auto_disable: bool,
1030+
default_order: int,
1031+
) -> int:
10251032
"""Get number of fourier components for built-in seasonalities.
10261033
10271034
Parameters
@@ -1201,7 +1208,7 @@ def flat_growth_init(df: pd.DataFrame) -> tuple[float, float]:
12011208
m = df['y_scaled'].mean()
12021209
return k, m
12031210

1204-
def preprocess(self, df: pd.DataFrame, **kwargs) -> ModelInputData:
1211+
def preprocess(self, df: pd.DataFrame, **kwargs: Any) -> ModelInputData:
12051212
"""
12061213
Reformats historical data, standardizes y and extra regressors, sets seasonalities and changepoints.
12071214
@@ -1985,7 +1992,7 @@ def predictive_samples(self, df: pd.DataFrame, vectorized: bool = True) -> dict[
19851992
df = self.setup_dataframe(df.copy())
19861993
return self.sample_posterior_predictive(df, vectorized)
19871994

1988-
def percentile(self, a, *args, **kwargs):
1995+
def percentile(self, a: npt.ArrayLike, *args: Any, **kwargs: Any) -> np.ndarray:
19891996
"""
19901997
We rely on np.nanpercentile in the rare instances where there
19911998
are a small number of bad samples with MCMC that contain NaNs.

python/prophet/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pandas as pd
2020

2121
import logging
22-
logger = logging.getLogger('prophet.models')
22+
logger: logging.Logger = logging.getLogger('prophet.models')
2323

2424
if TYPE_CHECKING:
2525
from cmdstanpy import CmdStanMCMC, CmdStanMLE, CmdStanModel

python/prophet/plot.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,24 @@
1616
from prophet.diagnostics import performance_metrics
1717

1818
if TYPE_CHECKING:
19-
from typing import Literal, Sequence, TypeVar
19+
from typing import Literal, Sequence, TypeVar, type_check_only
20+
from typing_extensions import TypedDict
21+
2022
from prophet.forecaster import Prophet
2123

2224
import matplotlib.pyplot as plt
2325
import plotly.graph_objs as go
2426

2527
_AxT = TypeVar('_AxT', bound=plt.Axes)
2628

29+
@type_check_only
30+
class _PlotlyProps(TypedDict):
31+
traces: list[go.Scatter]
32+
xaxis: go.layout.XAxis
33+
yaxis: go.layout.YAxis
34+
2735

28-
logger = logging.getLogger('prophet.plot')
36+
logger: logging.Logger = logging.getLogger('prophet.plot')
2937

3038

3139
try:
@@ -798,7 +806,12 @@ def plot_plotly(
798806

799807

800808
def plot_components_plotly(
801-
m, fcst, uncertainty=True, plot_cap=True, figsize=(900, 200)):
809+
m: Prophet,
810+
fcst: pd.DataFrame,
811+
uncertainty: bool = True,
812+
plot_cap: bool = True,
813+
figsize: tuple[int, int] = (900, 200),
814+
) -> go.Figure:
802815
"""Plot the Prophet forecast components using Plotly.
803816
See plot_plotly() for Plotly setup instructions
804817
@@ -860,7 +873,14 @@ def plot_components_plotly(
860873
return fig
861874

862875

863-
def plot_forecast_component_plotly(m, fcst, name, uncertainty=True, plot_cap=False, figsize=(900, 300)):
876+
def plot_forecast_component_plotly(
877+
m: Prophet,
878+
fcst: pd.DataFrame,
879+
name: str,
880+
uncertainty: bool = True,
881+
plot_cap: bool = False,
882+
figsize: tuple[int, int] = (900, 300)
883+
) -> go.Figure:
864884
"""Plot a particular component of the forecast using Plotly.
865885
See plot_plotly() for Plotly setup instructions
866886
@@ -891,7 +911,12 @@ def plot_forecast_component_plotly(m, fcst, name, uncertainty=True, plot_cap=Fal
891911
return fig
892912

893913

894-
def plot_seasonality_plotly(m, name, uncertainty=True, figsize=(900, 300)):
914+
def plot_seasonality_plotly(
915+
m: Prophet,
916+
name: str,
917+
uncertainty: bool = True,
918+
figsize: tuple[int, int] = (900, 300)
919+
) -> go.Figure:
895920
"""Plot a custom seasonal component using Plotly.
896921
See plot_plotly() for Plotly setup instructions
897922
@@ -919,7 +944,13 @@ def plot_seasonality_plotly(m, name, uncertainty=True, figsize=(900, 300)):
919944
return fig
920945

921946

922-
def get_forecast_component_plotly_props(m, fcst, name, uncertainty=True, plot_cap=False):
947+
def get_forecast_component_plotly_props(
948+
m: Prophet,
949+
fcst: pd.DataFrame,
950+
name: str,
951+
uncertainty: bool = True,
952+
plot_cap: bool = False,
953+
) -> _PlotlyProps:
923954
"""Prepares a dictionary for plotting the selected forecast component with Plotly
924955
925956
Parameters
@@ -956,8 +987,10 @@ def get_forecast_component_plotly_props(m, fcst, name, uncertainty=True, plot_ca
956987
holiday_features.columns = holiday_features.columns.str.replace('+0', '', regex=False)
957988
text = pd.Series(data='', index=holiday_features.index)
958989
for holiday_feature, idxs in holiday_features.items():
990+
# https://github.com/facebook/pyrefly/issues/2248
991+
# pyrefly:ignore[unsupported-operation]
959992
text[idxs.astype(bool) & (text != '')] += '<br>' # Add newline if additional holiday
960-
text[idxs.astype(bool)] += holiday_feature
993+
text[idxs.astype(bool)] += holiday_feature # pyrefly:ignore[unsupported-operation]
961994

962995
traces = []
963996
traces.append(go.Scatter(
@@ -1020,12 +1053,17 @@ def get_forecast_component_plotly_props(m, fcst, name, uncertainty=True, plot_ca
10201053
yaxis = go.layout.YAxis(rangemode='normal' if name == 'trend' else 'tozero',
10211054
title=go.layout.yaxis.Title(text=name),
10221055
zerolinecolor=zeroline_color)
1056+
assert m.component_modes
10231057
if name in m.component_modes['multiplicative']:
10241058
yaxis.update(tickformat='%', hoverformat='.2%')
10251059
return {'traces': traces, 'xaxis': xaxis, 'yaxis': yaxis}
10261060

10271061

1028-
def get_seasonality_plotly_props(m, name, uncertainty=True):
1062+
def get_seasonality_plotly_props(
1063+
m: Prophet,
1064+
name: str,
1065+
uncertainty: bool = True,
1066+
) -> _PlotlyProps:
10291067
"""Prepares a dictionary for plotting the selected seasonality with Plotly
10301068
10311069
Parameters
@@ -1048,6 +1086,7 @@ def get_seasonality_plotly_props(m, name, uncertainty=True):
10481086
start = pd.to_datetime('2017-01-01 0000')
10491087
period = m.seasonalities[name]['period']
10501088
end = start + pd.Timedelta(days=period)
1089+
assert m.history is not None
10511090
if (m.history['ds'].dt.hour == 0).all(): # Day Precision
10521091
plot_points = np.floor(period).astype(int)
10531092
elif (m.history['ds'].dt.minute == 0).all(): # Hour Precision

0 commit comments

Comments
 (0)