{"id":20590429,"url":"https://github.com/probml/sts-jax","last_synced_at":"2026-03-03T19:03:06.378Z","repository":{"id":75182357,"uuid":"558966300","full_name":"probml/sts-jax","owner":"probml","description":"Structural Time Series in JAX","archived":false,"fork":false,"pushed_at":"2024-05-08T20:08:28.000Z","size":11419,"stargazers_count":190,"open_issues_count":3,"forks_count":9,"subscribers_count":18,"default_branch":"main","last_synced_at":"2025-05-07T15:09:49.525Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/probml.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-10-28T17:48:56.000Z","updated_at":"2025-04-27T08:40:19.000Z","dependencies_parsed_at":"2024-05-08T21:27:18.287Z","dependency_job_id":"e4c8e389-5a2b-48a2-a9a1-936288e87f6b","html_url":"https://github.com/probml/sts-jax","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/probml/sts-jax","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/probml%2Fsts-jax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/probml%2Fsts-jax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/probml%2Fsts-jax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/probml%2Fsts-jax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/probml","download_url":"https://codeload.github.com/probml/sts-jax/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/probml%2Fsts-jax/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":30056056,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-03-03T18:21:05.932Z","status":"ssl_error","status_checked_at":"2026-03-03T18:20:59.341Z","response_time":61,"last_error":"SSL_connect returned=1 errno=0 peeraddr=140.82.121.5:443 state=error: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-11-16T07:36:35.949Z","updated_at":"2026-03-03T19:03:06.362Z","avatar_url":"https://github.com/probml.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"# sts-jax\nStructural Time Series (STS) in JAX\n\nThis library has a similar to design to [tfp.sts](https://www.tensorflow.org/probability/api_docs/python/tfp/sts),\nbut is built entirely in JAX,\nand uses the [Dynamax](https://github.com/probml/dynamax/tree/main/dynamax) library\nfor state-space models.\nWe also include an implementation of the\n[causal impact](https://google.github.io/CausalImpact/) method.\nThis has a similar to design to [tfcausalimpact](https://github.com/WillianFuks/tfcausalimpact),\nbut is built entirely in JAX.\n\n## Installation\n\nTo install the latest development branch:\n\n``` {.console}\npip install git+https://github.com/probml/sts-jax\n```\nor use\n``` {.console}\ngit clone git@github.com:probml/sts-jax.git\ncd sts-jax\npip install -e .\n```\n\n## What are structural time series (STS) models?\n\nThe STS model is a linear state space model with a specific structure. In particular,\nthe latent state $z_t$ is a composition of states of all latent components:\n\n$$z_t = [c_{1, t}, c_{2, t}, ...]$$\n\nwhere $c_{i,t}$ is the state of latent component $c_i$ at time step $t$.\n\nThe STS model (with scalar Gaussian observations) takes the form:\n\n$$y_t = H_t z_t + u_t + \\epsilon_t, \\qquad  \\epsilon_t \\sim \\mathcal{N}(0, \\sigma^2_t)$$\n\n$$z_{t+1} = F_t z_t + R_t \\eta_t, \\qquad \\eta_t \\sim \\mathcal{N}(0, Q_t)$$\n\nwhere\n\n* $y_t$: observation (emission) at time $t$.\n* $\\sigma^2_t$: variance of the observation noise.\n* $H_t$: emission matrix, which sums up the contributions of all latent components.\n* $u_t = x_t^T \\beta$: regression component from external inputs.\n* $F_t$: fixed transition matrix of the latent dynamics.\n* $R_t$: the selection matrix, which is a subset of columns of base vector $e_i$, converting\n    the non-singular covariance matrix into the (possibly singular) covariance matrix of\n    the latent state $z_t$.\n* $Q_t$: non-singular covariance matrix of the latent state, so the dimension of $Q_t$\n        can be smaller than the dimension of $z_t$.\n\nThe covariance matrix of the latent dynamics model takes the form $R Q R^T$, where $Q$ is\na non-singular matrix (block diagonal), and $R$ is the selecting matrix.\n\nMore information of STS models can be found in these books:\n\n\u003e -   \\\"Machine Learning: Advanced Topics\\\", K. Murphy, MIT Press 2023.\n\u003e     Available at \u003chttps://probml.github.io/pml-book/book2.html\u003e.\n\u003e -   \\\"Time Series Analysis by State Space Methods (2nd edn)\\\", James Durbin, Siem Jan Koopman,\n\u003e     Oxford University Press, 2012.\n\n## Usage\n\nIn this library, an STS model is constructed by providing the observed time series and specifying a list of\ncomponents and the distribution family of the observation. This library implements\ncommon STS components including **local linear trend** component, **seasonal** component, \n**cycle** component, **autoregressive** component, and **regression** component.\nThe observed time series can follow either the **Gaussian**\ndistribution or the **Poisson** distribution. (Other likelihood functions can also be added.)\n\nInternally, the STS model is converted to the corresponding state space model (SSM) and inference\nand learning of parameters are performed on the SSM.\nIf the observation $Y_t$ follows a Gaussian distribution, the inference of latent variables\n$Z_{1:T}$ (gven the parameters) is based on the \n[Kalman filter](https://github.com/probml/dynamax/tree/main/dynamax/linear_gaussian_ssm).\nAlternatively, if the observation $Y_t$ follows Poisson distribution, with\na mean given by $E[Y_t|Z_t] = e^{H_t Z_t + u_t}$, the inference of the\nlatent variables $Z_{1:t}$ is based on a generalization of the extended\nKalman filter, which we call the\n[conditional moment Gaussian filter](https://github.com/probml/dynamax/tree/main/dynamax/generalized_gaussian_ssm),\nbased on [Tronarp 2018](https://acris.aalto.fi/ws/portalfiles/portal/17669270/cm_parapub.pdf).\n\nThe marginal likelihood of $Y_{1:T}$ conditioned on parameters can be evaluated as a \nbyproduct of the forwards filtering process.\nThis can then be used to learn  the parameters of the STS model,\nusing  **MLE** (based on SGD implemented in the library [optax](https://github.com/deepmind/optax)),\n**ADVI** (using a Gaussian posterior approximation on the unconstrained parameter space),\nor **HMC** (from the library [blackjax](https://github.com/blackjax-devs/blackjax)).\nThe parameter estimation is done offline, given one or more historical timeseries.\nThese parameters can then be used for forecasting the future.\n\nBelow we illustrate the API applied to some example datasets.\n\n## Electricity demand\n\nThis example is adapted from the [TFP blog](https://blog.tensorflow.org/2019/03/structural-time-series-modeling-in.html).\nSee [this file](./sts_jax/structural_time_series/demos/sts_electric_demo.ipynb) for a runnable version\nof this demo.\n\nThe problem of interest is to forecast electricity demand in Victoria, Australia.\nThe dataset contains hourly record of electricity demand and temperature measurements \nfrom the first 8 weeks of 2014.  The following plot is the training\nset of the data, which contains measurements in the first 6 weeks.\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"./sts_jax/figures/electr_obs.png\" alt=\"drawing\" style=\"width:600px;\"/\u003e\n\u003cp\u003e\n\nWe now build a model where the demand linearly depends on the temperature,\nbut also has two seasonal components, and an auto-regressive component.\n\n```python\nimport sts_jax.structural_time_series.sts_model as sts\n\nhour_of_day_effect = sts.SeasonalDummy(num_seasons=24,\n                                       name='hour_of_day_effect')\nday_of_week_effect = sts.SeasonalTrig(num_seasons=7, num_steps_per_season=24,\n                                      name='day_of_week_effect')\ntemperature_effect = sts.LinearRegression(dim_covariates=1, add_bias=True,\n                                          name='temperature_effect')\nautoregress_effect = sts.Autoregressive(order=1,\n                                        name='autoregress_effect')\n\n# The STS model is constructed by providing the observed time series,\n# specifying a list of components and the distribution family of the observations.\nmodel = sts.StructuralTimeSeries(\n    [hour_of_day_effect, day_of_week_effect, temperature_effect, autoregress_effect],\n    obs_time_series,\n    obs_distribution='Gaussian',\n    covariates=temperature_training_data)\n\n```\nIn this case, we choose to fit the model using MLE.\n\n```python\n# Perform the MLE estimation of parameters via SGD implemented in dynamax library.\nopt_param, _losses = model.fit_mle(obs_time_series,\n                                   covariates=temperature_training_data,\n                                   num_steps=2000)\n```\n\nWe can now plug in the parameters and the future inputs,\nand use ancestral sampling from the\nfiltered posterior to forecast future observations.\n\n```python\n# The 'forecast' method samples the future means and future observations from the\n# predictive distribution, conditioned on the parameters of the model. \nforecast_means, forecasts = model.forecast(opt_param,\n                                           obs_time_series,\n                                           num_forecast_steps,\n                                           past_covariates=temperature_training_data,\n                                           forecast_covariates=temperature_predict_data)\n```\n\nThe following plot shows the mean and 95\\% probability interval  of the forecast.\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"./sts_jax/figures/electr_forecast.png\" alt=\"drawing\" style=\"width:600px;\"/\u003e\n\u003cp\u003e\n\n## CO2 levels \n\nThis example is adapted from the [TFP blog](https://blog.tensorflow.org/2019/03/structural-time-series-modeling-in.html).\nSee [this file](./sts_jax/structural_time_series/demos/sts_co2_demo.ipynb) for a runnable version\nof the demo, which is similar to the electricity example.\n\n## Time series with Poisson observations\n\nWe can also fit STS models with discrete observations following the Poisson \ndistribution. Internally, the inference of the latent states $Z_{1:T}$ in the corresponding SSM\nis based on the (generalized) extended Kalman filter implemented\nin the library dynamax. An STS model for a Poisson-distributed time series can be constructed\nsimply by specifying observation distribution to be 'Poisson'. Everything else is the same\nas the Gaussian case.\n\nBelow we  create a synthetic dataset, following [this TFP example](https://www.tensorflow.org/probability/examples/STS_approximate_inference_for_models_with_non_Gaussian_observations).\nSee [this file](./sts_jax/structural_time_series/demos/sts_poisson_demo.ipynb) for a runnable version\nof this demo.\n\n\n```python\nimport sts_jax.structural_time_series.sts_model as sts\n\n# This example uses a synthetic dataset and the STS model contains only a\n# local linear trend component.\ntrend = sts.LocalLinearTrend()\nmodel = sts.StructuralTimeSeries([trend],\n                                 obs_distribution='Poisson',\n                                 obs_time_series=counts_training)\n\n# Fit the model using HMC algorithm\nparam_samples, _log_probs = model.fit_hmc(num_samples=200,\n                                          obs_time_series=counts_training)\n\n# Forecast into the future given samples of parameters returned by the HMC algorithm.\nforecasts = model.forecast(param_samples, obs_time_series, num_forecast_steps)[1]\n```\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"./sts_jax/figures/poisson_forecast.png\" alt=\"drawing\" style=\"width:600px;\"/\u003e\n\u003cp\u003e\n\n### Comparison to TFP\n\nThe TFP approach to STS with non-conjugate likelihoods is to perform\nHMC on the joint distribution of the latent states $Z_{1:T}$ and the parameters, conditioned \non the observations $Y_{1:T}$. Since the dimension of the state space grows linearly\nwith the length of the time series to be fitted, the implementation will be inefficient\nwhen $T$ is relatively large.  By contrast, we (approximately) marginalize out $Z_{1:T}$,\nusing a generalized extended Kalman filter,\nand just perform HMC in the collapsed parameter space. This is much faster, but yields\ncomparable error, as we show below. (The burnin steps of HMC in the TFP-STS\nimplementation is adjusted such that the forecast error of the two implementations\nare comparable.)\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"./sts_jax/figures/comparison.png\" alt=\"drawing\" style=\"width:600px;\"/\u003e\n\u003cp\u003e\n\n## Causal Impact\n\nThe [causal impact](https://google.github.io/CausalImpact/CausalImpact.html)\nmethod is implemented on top of the STS-JAX package.\n\nBelow we show an example, where Y is the output time series and X is a parallel \nset of input covariates. We notice a sudden change in the response variable at time $t=70$,\ncaused by some kind of intervention (e.g., launching an ad campaign).\nWe define the causal impact of this intervention\nto be the change in the observed output compared to what we would have\nexpected had the intervention not happened.\nSee [this file](./sts_jax/causal_impact/causal_impact_demo.ipynb)\nfor a runnable version of this demo.\n(See also the [CausalPy](https://www.pymc-labs.io/blog-posts/causalpy-a-new-package-for-bayesian-causal-inference-for-quasi-experiments/)\npackage for some related methods.)\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"./sts_jax/figures/causal_obs.png\" alt=\"drawing\" style=\"width:600px;\"/\u003e\n\u003cp\u003e\n\nThis is how we run inference:\n\n```python\nfrom sts_jax.causal_impact.causal_impact import causal_impact\n\n# The causal impact is inferred by providing the target time series and covariates,\n# specifying the intervention time and the distribution family of the observation.\n# If the STS model is not given, an STS model with only a local linear trend component\n# in addition to the regression component is constructed by default internally.\nimpact = causal_impact(obs_time_series,\n                       intervention_timepoint,\n                       'Gaussian',\n                       covariates,\n                       sts_model=None)\n\n```\n\n\nThe format of the output from our\ncausal impact code follows that of the R package\n[CausalImpact](https://google.github.io/CausalImpact/CausalImpact.html),\nand is shown below.\n\n```python\nimpact.plot()\n```\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"./sts_jax/figures/causal_forecast.png\" alt=\"drawing\" style=\"width:600px;\"/\u003e\n\u003cp\u003e\n\n```python\nimpact.print_summary()\n\nPosterior inference of the causal impact:\n\n                               Average            Cumulative     \nActual                          129.93             3897.88       \n\nPrediction (s.d.)           120.01 (2.04)      3600.42 (61.31)   \n95% CI                     [114.82, 123.07]   [3444.72, 3692.09] \n\nAbsolute effect (s.d.)       9.92 (2.04)        297.45 (61.31)   \n95% CI                      [6.86, 15.11]      [205.78, 453.16]  \n\nRelative effect (s.d.)      8.29% (1.89%)       8.29% (1.89%)    \n95% CI                     [5.57%, 13.16%]     [5.57%, 13.16%]   \n\nPosterior tail-area probability p: 0.0050\nPosterior prob of a causal effect: 99.50%\n```\n\n\n\n## About\n\nAuthors: [Xinlong Xi](https://www.stat.ubc.ca/users/xinglong-li),\n[Kevin Murphy](https://www.cs.ubc.ca/~murphyk/).\n\nMIT License. 2022\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fprobml%2Fsts-jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fprobml%2Fsts-jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fprobml%2Fsts-jax/lists"}