1616from prophet .diagnostics import performance_metrics
1717
1818if TYPE_CHECKING :
19- from typing import Literal , Sequence , TypeVar
19+ from typing import Literal , Sequence , TypeVar , type_check_only
20+ from typing_extensions import TypedDict
21+
2022 from prophet .forecaster import Prophet
2123
2224 import matplotlib .pyplot as plt
2325 import plotly .graph_objs as go
2426
2527 _AxT = TypeVar ('_AxT' , bound = plt .Axes )
2628
29+ @type_check_only
30+ class _PlotlyProps (TypedDict ):
31+ traces : list [go .Scatter ]
32+ xaxis : go .layout .XAxis
33+ yaxis : go .layout .YAxis
34+
2735
28- logger = logging .getLogger ('prophet.plot' )
36+ logger : logging . Logger = logging .getLogger ('prophet.plot' )
2937
3038
3139try :
@@ -798,7 +806,12 @@ def plot_plotly(
798806
799807
800808def plot_components_plotly (
801- m , fcst , uncertainty = True , plot_cap = True , figsize = (900 , 200 )):
809+ m : Prophet ,
810+ fcst : pd .DataFrame ,
811+ uncertainty : bool = True ,
812+ plot_cap : bool = True ,
813+ figsize : tuple [int , int ] = (900 , 200 ),
814+ ) -> go .Figure :
802815 """Plot the Prophet forecast components using Plotly.
803816 See plot_plotly() for Plotly setup instructions
804817
@@ -860,7 +873,14 @@ def plot_components_plotly(
860873 return fig
861874
862875
863- def plot_forecast_component_plotly (m , fcst , name , uncertainty = True , plot_cap = False , figsize = (900 , 300 )):
876+ def plot_forecast_component_plotly (
877+ m : Prophet ,
878+ fcst : pd .DataFrame ,
879+ name : str ,
880+ uncertainty : bool = True ,
881+ plot_cap : bool = False ,
882+ figsize : tuple [int , int ] = (900 , 300 )
883+ ) -> go .Figure :
864884 """Plot a particular component of the forecast using Plotly.
865885 See plot_plotly() for Plotly setup instructions
866886
@@ -891,7 +911,12 @@ def plot_forecast_component_plotly(m, fcst, name, uncertainty=True, plot_cap=Fal
891911 return fig
892912
893913
894- def plot_seasonality_plotly (m , name , uncertainty = True , figsize = (900 , 300 )):
914+ def plot_seasonality_plotly (
915+ m : Prophet ,
916+ name : str ,
917+ uncertainty : bool = True ,
918+ figsize : tuple [int , int ] = (900 , 300 )
919+ ) -> go .Figure :
895920 """Plot a custom seasonal component using Plotly.
896921 See plot_plotly() for Plotly setup instructions
897922
@@ -919,7 +944,13 @@ def plot_seasonality_plotly(m, name, uncertainty=True, figsize=(900, 300)):
919944 return fig
920945
921946
922- def get_forecast_component_plotly_props (m , fcst , name , uncertainty = True , plot_cap = False ):
947+ def get_forecast_component_plotly_props (
948+ m : Prophet ,
949+ fcst : pd .DataFrame ,
950+ name : str ,
951+ uncertainty : bool = True ,
952+ plot_cap : bool = False ,
953+ ) -> _PlotlyProps :
923954 """Prepares a dictionary for plotting the selected forecast component with Plotly
924955
925956 Parameters
@@ -956,8 +987,10 @@ def get_forecast_component_plotly_props(m, fcst, name, uncertainty=True, plot_ca
956987 holiday_features .columns = holiday_features .columns .str .replace ('+0' , '' , regex = False )
957988 text = pd .Series (data = '' , index = holiday_features .index )
958989 for holiday_feature , idxs in holiday_features .items ():
990+ # https://github.com/facebook/pyrefly/issues/2248
991+ # pyrefly:ignore[unsupported-operation]
959992 text [idxs .astype (bool ) & (text != '' )] += '<br>' # Add newline if additional holiday
960- text [idxs .astype (bool )] += holiday_feature
993+ text [idxs .astype (bool )] += holiday_feature # pyrefly:ignore[unsupported-operation]
961994
962995 traces = []
963996 traces .append (go .Scatter (
@@ -1020,12 +1053,17 @@ def get_forecast_component_plotly_props(m, fcst, name, uncertainty=True, plot_ca
10201053 yaxis = go .layout .YAxis (rangemode = 'normal' if name == 'trend' else 'tozero' ,
10211054 title = go .layout .yaxis .Title (text = name ),
10221055 zerolinecolor = zeroline_color )
1056+ assert m .component_modes
10231057 if name in m .component_modes ['multiplicative' ]:
10241058 yaxis .update (tickformat = '%' , hoverformat = '.2%' )
10251059 return {'traces' : traces , 'xaxis' : xaxis , 'yaxis' : yaxis }
10261060
10271061
1028- def get_seasonality_plotly_props (m , name , uncertainty = True ):
1062+ def get_seasonality_plotly_props (
1063+ m : Prophet ,
1064+ name : str ,
1065+ uncertainty : bool = True ,
1066+ ) -> _PlotlyProps :
10291067 """Prepares a dictionary for plotting the selected seasonality with Plotly
10301068
10311069 Parameters
@@ -1048,6 +1086,7 @@ def get_seasonality_plotly_props(m, name, uncertainty=True):
10481086 start = pd .to_datetime ('2017-01-01 0000' )
10491087 period = m .seasonalities [name ]['period' ]
10501088 end = start + pd .Timedelta (days = period )
1089+ assert m .history is not None
10511090 if (m .history ['ds' ].dt .hour == 0 ).all (): # Day Precision
10521091 plot_points = np .floor (period ).astype (int )
10531092 elif (m .history ['ds' ].dt .minute == 0 ).all (): # Hour Precision
0 commit comments