Files
Time-Series-Analysis/venv/lib/python3.11/site-packages/prophet/tests/test_serialize.py
2025-08-01 04:33:03 -04:00

143 lines
5.4 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
from pathlib import Path
import numpy as np
import pandas as pd
import pytest
from prophet import Prophet
from prophet.serialize import PD_DATAFRAME, PD_SERIES, model_from_json, model_to_json
class TestSerialize:
def test_simple_serialize(self, daily_univariate_ts, backend):
m = Prophet(stan_backend=backend)
df = daily_univariate_ts.head(daily_univariate_ts.shape[0] - 30)
m.fit(df)
future = m.make_future_dataframe(2, include_history=False)
fcst = m.predict(future)
model_str = model_to_json(m)
# Make sure json doesn't get too large in the future
assert len(model_str) < 200000
m2 = model_from_json(model_str)
# Check that m and m2 are equal
assert m.__dict__.keys() == m2.__dict__.keys()
for k, v in m.__dict__.items():
if k in ["stan_fit", "stan_backend"]:
continue
if k == "params":
assert v.keys() == m2.params.keys()
for kk, vv in v.items():
assert np.array_equal(vv, m2.params[kk])
elif k in PD_SERIES and v is not None:
assert v.equals(m2.__dict__[k])
elif k in PD_DATAFRAME and v is not None:
pd.testing.assert_frame_equal(v, m2.__dict__[k], check_index_type=False)
elif k == "changepoints_t":
assert np.array_equal(v, m.__dict__[k])
else:
assert v == m2.__dict__[k]
assert m2.stan_fit is None
assert m2.stan_backend is None
# Check that m2 makes the same forecast
future2 = m2.make_future_dataframe(2, include_history=False)
fcst2 = m2.predict(future2)
assert np.array_equal(fcst["yhat"].values, fcst2["yhat"].values)
def test_full_serialize(self, daily_univariate_ts, backend):
# Construct a model with all attributes
holidays = pd.DataFrame(
{
"ds": pd.to_datetime(["2012-06-06", "2013-06-06"]),
"holiday": ["seans-bday"] * 2,
"lower_window": [0] * 2,
"upper_window": [1] * 2,
}
)
# Test with holidays and country_holidays
m = Prophet(
holidays=holidays,
seasonality_mode="multiplicative",
changepoints=["2012-07-01", "2012-10-01", "2013-01-01"],
stan_backend=backend,
)
m.add_country_holidays(country_name="US")
m.add_seasonality(
name="conditional_weekly",
period=7,
fourier_order=3,
prior_scale=2.0,
condition_name="is_conditional_week",
)
m.add_seasonality(name="normal_monthly", period=30.5, fourier_order=5, prior_scale=2.0)
df = daily_univariate_ts.copy()
df["is_conditional_week"] = [0] * 255 + [1] * 255
m.add_regressor("binary_feature", prior_scale=0.2)
m.add_regressor("numeric_feature", prior_scale=0.5)
m.add_regressor("numeric_feature2", prior_scale=0.5, mode="multiplicative")
m.add_regressor("binary_feature2", standardize=True)
df["binary_feature"] = ["0"] * 255 + ["1"] * 255
df["numeric_feature"] = range(510)
df["numeric_feature2"] = range(510)
df["binary_feature2"] = [1] * 100 + [0] * 410
train = df.head(400)
test = df.tail(100)
m.fit(train)
fcst = m.predict(test)
# Serialize!
m2 = model_from_json(model_to_json(m))
# Check that m and m2 are equal
assert m.__dict__.keys() == m2.__dict__.keys()
for k, v in m.__dict__.items():
if k in ["stan_fit", "stan_backend"]:
continue
if k == "params":
assert v.keys() == m2.params.keys()
for kk, vv in v.items():
assert np.array_equal(vv, m2.params[kk])
elif k in PD_SERIES and v is not None:
assert v.equals(m2.__dict__[k])
elif k in PD_DATAFRAME and v is not None:
pd.testing.assert_frame_equal(v, m2.__dict__[k], check_index_type=False)
elif k == "changepoints_t":
assert np.array_equal(v, m.__dict__[k])
else:
assert v == m2.__dict__[k]
assert m2.stan_fit is None
assert m2.stan_backend is None
# Check that m2 makes the same forecast
fcst2 = m2.predict(test)
assert np.array_equal(fcst["yhat"].values, fcst2["yhat"].values)
def test_backwards_compatibility(self):
old_versions = {
"0.6.1.dev0": (29.3669923968994, "fb"),
"0.7.1": (29.282810844704414, "fb"),
"1.0.1": (29.282810844704414, ""),
}
for v, (pred_val, v_str) in old_versions.items():
fname = Path(__file__).parent / f"serialized_model_v{v}.json"
with open(fname, "r") as fin:
model_str = json.load(fin)
# Check that deserializes
m = model_from_json(model_str)
assert json.loads(model_str)[f"__{v_str}prophet_version"] == v
# Predict
future = m.make_future_dataframe(10)
fcst = m.predict(future)
assert fcst["yhat"].values[-1] == pytest.approx(pred_val)