Skip to content

Kalman Filter¤

This example optimizes the parameters of a Kalman-Filter.

This example is available as a Jupyter notebook here.

from types import SimpleNamespace
from typing import Optional

import diffrax as dfx
import equinox as eqx  #
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import matplotlib.pyplot as plt
import optax  #

We use Equinox to build the Kalman-Filter implementation and to represent linear, time-invariant systems (LTI systems).

We use Optax for optimisers (Adam etc.)

Problem Formulation¤

Assume that there exists some unknown dynamical system of the form

\(\frac{dx}{dt}(t)= f(x(t), u(t), t)\)

\(y(t) = g(x(t)) + \epsilon(t)\)

where - \(u(t)\) denotes the time-dependent input to the system - \(y(t)\) denotes the time-dependent output / measurement to the system - \(x(t)\) denotes the time-dependent state of the system (which is not directly measureable in general) - \(f,g\) denote the time-dependent dynamics- and measurement-function, respectively - \(\epsilon\) denotes random measurement uncertainty

The goal is to infer \(x\) from \(y\) even though \(f,g,\epsilon\) are unkown.

A Kalman-Filter represents a possible solution.


From Wikipedia:

Kalman filtering, also known as linear quadratic estimation (LQE), is an algorithm that uses a series of measurements observed over time, including statistical noise and other inaccuracies, and produces estimates of unknown variables that tend to be more accurate than those based on a single measurement alone, by estimating a joint probability distribution over the variables for each timeframe. The filter is named after Rudolf E. Kálmán, who was one of the primary developers of its theory.

The algorithm works by a two-phase process. For the prediction phase, the Kalman filter produces estimates of the current state variables, along with their uncertainties. Once the outcome of the next measurement (necessarily corrupted with some error, including random noise) is observed, these estimates are updated using a weighted average, with more weight being given to estimates with greater certainty. The algorithm is recursive. It can operate in real time, using only the present input measurements and the state calculated previously and its uncertainty matrix; no additional past information is required.

For the sake of simplicity, here we assume that the dynamical system takes the form

\(\frac{dx}{dt}(t) = Ax(t) + Bu(t)\)

\(y(t) = Cx(t) + \epsilon(t)\)

where \(A,B,C\) are constant matrices, i.e. the system is linear in its state \(x(t)\) and its input \(u(t)\). Further, the dynamics- and measurement-functions do not depend on time.

Hence, the above represents the general form of (physical) linear, time-invariant systems (LTI systems).

Here we define a container object for LTI systems.

class LTISystem(eqx.Module):
    A: jnp.ndarray
    B: jnp.ndarray
    C: jnp.ndarray

An harmonic oscillator is an LTI system, this function returns such an LTI system.

def harmonic_oscillator(damping: float = 0.0, time_scaling: float = 1.0) -> LTISystem:
    A = jnp.array([[0.0, time_scaling], [-time_scaling, -2 * damping]])
    B = jnp.array([[0.0], [1.0]])
    C = jnp.array([[0.0, 1.0]])
    return LTISystem(A, B, C)

Here we define some utility functions that allow us to simulate LTI systems.

def interpolate_us(ts, us, B):
    if us is None:
        m = B.shape[-1]
        u_t = SimpleNamespace(evaluate=lambda t: jnp.zeros((m,)))
        u_t = dfx.LinearInterpolation(ts=ts, ys=us)
    return u_t

def diffeqsolve(
    ts: jnp.ndarray,
    y0: jnp.ndarray,
    solver: dfx.AbstractSolver = dfx.Dopri5(),
    stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
    dt0: float = 0.01,
) -> jnp.ndarray:
    return dfx.diffeqsolve(

def simulate_lti_system(
    sys: LTISystem,
    y0: jnp.ndarray,
    ts: jnp.ndarray,
    us: Optional[jnp.ndarray] = None,
    std_measurement_noise: float = 0.0,
    u_t = interpolate_us(ts, us, sys.B)

    def rhs(t, y, args):
        return sys.A @ y + sys.B @ u_t.evaluate(t)

    xs = diffeqsolve(rhs, ts, y0)
    # noisy measurements
    ys = xs @ sys.C.transpose()
    ys = ys + jr.normal(key, shape=ys.shape) * std_measurement_noise
    return xs, ys
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Here we define the Kalman-Filter.

Note how we use equinox to combine the Kalman-Filter logic in __call__ and the Kalman-Filter parameters Q , R in one object.

class KalmanFilter(eqx.Module):
    """Continuous-time Kalman Filter

        [1] Optimal and robust estimation. 2nd edition. Page 154.

    sys: LTISystem
    x0: jnp.ndarray
    P0: jnp.ndarray
    Q: jnp.ndarray
    R: jnp.ndarray

    def __call__(self, ts, ys, us: Optional[jnp.ndarray] = None):
        A, B, C = self.sys.A, self.sys.B, self.sys.C

        y_t = dfx.LinearInterpolation(ts=ts, ys=ys)
        u_t = interpolate_us(ts, us, B)

        y0 = (self.x0, self.P0)

        def rhs(t, y, args):
            x, P = y

            # eq 3.22 of Ref [1]
            K = P @ C.transpose() @ jnp.linalg.inv(self.R)

            # eq 3.21 of Ref [1]
            dPdt = (
                A @ P
                + P @ A.transpose()
                + self.Q
                - P @ C.transpose() @ jnp.linalg.inv(self.R) @ C @ P

            # eq 3.23 of Ref [1]
            dxdt = A @ x + B @ u_t.evaluate(t) + K @ (y_t.evaluate(t) - C @ x)

            return (dxdt, dPdt)

        return diffeqsolve(rhs, ts, y0)[0]

Main entry point. Try runnning main().

def main(
    # evaluate at these timepoints
    ts=jnp.arange(0, 5.0, 0.01),
    # system that generates data
    # initial state of our data generating system
    sys_true_x0=jnp.array([1.0, 0.0]),
    # standard deviation of measurement noise
    # our model for system `true`, it's not perfect
    # initial state guess, it's not perfect
    sys_model_x0=jnp.array([0.0, 0.0]),
    # weighs how much we trust our model of the system
    Q=jnp.diag(jnp.ones((2,))) * 0.1,
    # weighs how much we trust in the measurements of the system
    # weighs how much we trust our initial guess
    P0=jnp.diag(jnp.ones((2,))) * 10.0,
    xs, ys = simulate_lti_system(
        sys_true, sys_true_x0, ts, std_measurement_noise=sys_true_std_measurement_noise

    kmf = KalmanFilter(sys_model, sys_model_x0, P0, Q, R)

    print(f"Initial Q: \n{kmf.Q}\n Initial R: \n{kmf.R}")

    # gradients should only be able to change Q/R parameters
    # *not* the model (well at least not in this example :)
    filter_spec = jtu.tree_map(lambda arr: False, kmf)
    filter_spec = eqx.tree_at(
        lambda tree: (tree.Q, tree.R), filter_spec, replace=(True, True)

    opt = optax.adam(1e-2)
    opt_state = opt.init(kmf)

    def loss_fn(dynamic_kmf, static_kmf, ts, ys, xs):
        kmf = eqx.combine(dynamic_kmf, static_kmf)
        xhats = kmf(ts, ys)
        return jnp.mean((xs - xhats) ** 2)

    def make_step(kmf, opt_state, ts, ys, xs):
        dynamic_kmf, static_kmf = eqx.partition(kmf, filter_spec)
        value, grads = loss_fn(dynamic_kmf, static_kmf, ts, ys, xs)
        updates, opt_state = opt.update(grads, opt_state)
        kmf = eqx.apply_updates(kmf, updates)
        return value, kmf, opt_state

    for step in range(n_gradient_steps):
        value, kmf, opt_state = make_step(kmf, opt_state, ts, ys, xs)
        if step % print_every == 0:
            print("Current MSE: ", value)

    print(f"Final Q: \n{kmf.Q}\n Final R: \n{kmf.R}")

    if plot:
        xhats = kmf(ts, ys)
        plt.plot(ts, xs[:, 0], label="true position", color="orange")
            xhats[:, 0],
            label="estimated position",
        plt.plot(ts, xs[:, 1], label="true velocity", color="blue")
            xhats[:, 1],
            label="estimated velocity",
        plt.ylabel("position / velocity")
        plt.title("Kalman-Filter optimization w.r.t Q/R")
Initial Q: 
[[0.1 0. ]
 [0.  0.1]]
 Initial R: 
Final Q: 
[[0.1 0. ]
 [0.  0.1]]
 Final R: 

Initial Q: 
[[0.1 0. ]
 [0.  0.1]]
 Initial R: 
Current MSE:  0.10257154
Current MSE:  0.09800266
Current MSE:  0.09304534
Current MSE:  0.087538235
Current MSE:  0.08122826
Current MSE:  0.07371251
Current MSE:  0.06441843
Current MSE:  0.05347546
Current MSE:  0.046111725
Current MSE:  0.03786327
Final Q: 
[[-0.44275677  1.3142775 ]
 [-1.1867669   0.9120258 ]]
 Final R: 

We can see that the MSE is smaller after optimization.

After optimization we trust the measurements more as there is a significant modeling error.

We can observe this nicely through the added noise in our state estimate. Recall that the measurements are noisy after all.