Skip to content

Latent ODE¤

This example trains a Latent ODE.

In this case, it's on a simple dataset of decaying oscillators. That is, 2-dimensional time series that look like:

xx    ***
    **   *
  x*      **
  *x
    x       *
 *           *                  xxxxx
*    x        *               xx    xx *******
                             x        x       **
      x        *            x        * x        *                  xxxxxxxx  ******
       x        *          x        *   x        *              xxx       *xx      *
                          x        *     xx       **           x        **   xx
        x        *       x        *        x        *        xx       **       xx
                  *     x        *          x        **     x        *           xxx
         x         *            *            x         *  xx       **
          x         *  x       *              xx        xx*     ***
           x         *x       *                 xxx  xxx   *****
            x        x*      *                     xx
             x     xx  ******
              xxxxx

The model is trained to generate samples that look like this.

What's really nice about this example is that we will take the underlying data to be irregularly sampled. We will have different observation times for different batch elements.

Most differential equation libraries will struggle with this, as they usually mandate that the differential equation be solved over the same timespan for all batch elements. Working around this can involve programming complexity like outputting at lots and lots of times (the union of all the observations times in the batch), or mathematical complexities like reparameterising the differentiating equation.

However Diffrax is capable of handling this without such issues! You can vmap over different integration times for different batch elements.

Reference:

@incollection{rubanova2019latent,
    title={{L}atent {O}rdinary {D}ifferential {E}quations for {I}rregularly-{S}ampled
           {T}ime {S}eries},
    author={Rubanova, Yulia and Chen, Ricky T. Q. and Duvenaud, David K.},
    booktitle={Advances in Neural Information Processing Systems},
    publisher={Curran Associates, Inc.},
    year={2019},
}

This example is available as a Jupyter notebook here.

import time

import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax


matplotlib.rcParams.update({"font.size": 30})

The vector field. Note its overall structure of scalar * tanh(mlp(y)) which is a good structure for Latent ODEs. (Here the tanh is part of self.mlp.)

class Func(eqx.Module):
    scale: jnp.ndarray
    mlp: eqx.nn.MLP

    def __call__(self, t, y, args):
        return self.scale * self.mlp(y)

Wrap up the differential equation solve into a model.

class LatentODE(eqx.Module):
    func: Func
    rnn_cell: eqx.nn.GRUCell

    hidden_to_latent: eqx.nn.Linear
    latent_to_hidden: eqx.nn.MLP
    hidden_to_data: eqx.nn.Linear

    hidden_size: int
    latent_size: int

    def __init__(
        self, *, data_size, hidden_size, latent_size, width_size, depth, key, **kwargs
    ):
        super().__init__(**kwargs)

        mkey, gkey, hlkey, lhkey, hdkey = jr.split(key, 5)

        scale = jnp.ones(())
        mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            final_activation=jnn.tanh,
            key=mkey,
        )
        self.func = Func(scale, mlp)
        self.rnn_cell = eqx.nn.GRUCell(data_size + 1, hidden_size, key=gkey)

        self.hidden_to_latent = eqx.nn.Linear(hidden_size, 2 * latent_size, key=hlkey)
        self.latent_to_hidden = eqx.nn.MLP(
            latent_size, hidden_size, width_size=width_size, depth=depth, key=lhkey
        )
        self.hidden_to_data = eqx.nn.Linear(hidden_size, data_size, key=hdkey)

        self.hidden_size = hidden_size
        self.latent_size = latent_size

    # Encoder of the VAE
    def _latent(self, ts, ys, key):
        data = jnp.concatenate([ts[:, None], ys], axis=1)
        hidden = jnp.zeros((self.hidden_size,))
        for data_i in reversed(data):
            hidden = self.rnn_cell(data_i, hidden)
        context = self.hidden_to_latent(hidden)
        mean, logstd = context[: self.latent_size], context[self.latent_size :]
        std = jnp.exp(logstd)
        latent = mean + jr.normal(key, (self.latent_size,)) * std
        return latent, mean, std

    # Decoder of the VAE
    def _sample(self, ts, latent):
        dt0 = 0.4  # selected as a reasonable choice for this problem
        y0 = self.latent_to_hidden(latent)
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            ts[0],
            ts[-1],
            dt0,
            y0,
            saveat=diffrax.SaveAt(ts=ts),
        )
        return jax.vmap(self.hidden_to_data)(sol.ys)

    @staticmethod
    def _loss(ys, pred_ys, mean, std):
        # -log p_θ with Gaussian p_θ
        reconstruction_loss = 0.5 * jnp.sum((ys - pred_ys) ** 2)
        # KL(N(mean, std^2) || N(0, 1))
        variational_loss = 0.5 * jnp.sum(mean**2 + std**2 - 2 * jnp.log(std) - 1)
        return reconstruction_loss + variational_loss

    # Run both encoder and decoder during training.
    def train(self, ts, ys, *, key):
        latent, mean, std = self._latent(ts, ys, key)
        pred_ys = self._sample(ts, latent)
        return self._loss(ys, pred_ys, mean, std)

    # Run just the decoder during inference.
    def sample(self, ts, *, key):
        latent = jr.normal(key, (self.latent_size,))
        return self._sample(ts, latent)

Toy dataset of decaying oscillators.

By way of illustration we set this up as a differential equation and solve this using Diffrax as well. (Despite this being an autonomous linear ODE, for which a closed-form solution is actually available.)

def get_data(dataset_size, *, key):
    ykey, tkey1, tkey2 = jr.split(key, 3)

    y0 = jr.normal(ykey, (dataset_size, 2))

    t0 = 0
    t1 = 2 + jr.uniform(tkey1, (dataset_size,))
    ts = jr.uniform(tkey2, (dataset_size, 20)) * (t1[:, None] - t0) + t0
    ts = jnp.sort(ts)
    dt0 = 0.1

    def func(t, y, args):
        return jnp.array([[-0.1, 1.3], [-1, -0.1]]) @ y

    def solve(ts, y0):
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(func),
            diffrax.Tsit5(),
            ts[0],
            ts[-1],
            dt0,
            y0,
            saveat=diffrax.SaveAt(ts=ts),
        )
        return sol.ys

    ys = jax.vmap(solve)(ts, y0)

    return ts, ys
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while start < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

The main entry point. Try running main() to train a model.

def main(
    dataset_size=10000,
    batch_size=256,
    lr=1e-2,
    steps=250,
    save_every=50,
    hidden_size=16,
    latent_size=16,
    width_size=16,
    depth=2,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key, train_key, sample_key = jr.split(key, 5)

    ts, ys = get_data(dataset_size, key=data_key)

    model = LatentODE(
        data_size=ys.shape[-1],
        hidden_size=hidden_size,
        latent_size=latent_size,
        width_size=width_size,
        depth=depth,
        key=model_key,
    )

    @eqx.filter_value_and_grad
    def loss(model, ts_i, ys_i, key_i):
        batch_size, _ = ts_i.shape
        key_i = jr.split(key_i, batch_size)
        loss = jax.vmap(model.train)(ts_i, ys_i, key=key_i)
        return jnp.mean(loss)

    @eqx.filter_jit
    def make_step(model, opt_state, ts_i, ys_i, key_i):
        value, grads = loss(model, ts_i, ys_i, key_i)
        key_i = jr.split(key_i, 1)[0]
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return value, model, opt_state, key_i

    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

    # Plot results
    num_plots = 1 + (steps - 1) // save_every
    if ((steps - 1) % save_every) != 0:
        num_plots += 1
    fig, axs = plt.subplots(1, num_plots, figsize=(num_plots * 8, 8))
    axs[0].set_ylabel("x")
    axs = iter(axs)
    for step, (ts_i, ys_i) in zip(
        range(steps), dataloader((ts, ys), batch_size, key=loader_key)
    ):
        start = time.time()
        value, model, opt_state, train_key = make_step(
            model, opt_state, ts_i, ys_i, train_key
        )
        end = time.time()
        print(f"Step: {step}, Loss: {value}, Computation time: {end - start}")

        if (step % save_every) == 0 or step == steps - 1:
            ax = next(axs)
            # Sample over a longer time interval than we trained on. The model will be
            # sufficiently good that it will correctly extrapolate!
            sample_t = jnp.linspace(0, 12, 300)
            sample_y = model.sample(sample_t, key=sample_key)
            sample_t = np.asarray(sample_t)
            sample_y = np.asarray(sample_y)
            ax.plot(sample_t, sample_y[:, 0])
            ax.plot(sample_t, sample_y[:, 1])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlabel("t")

    plt.savefig("latent_ode.png")
    plt.show()
main()
Step: 0, Loss: 19.934764862060547, Computation time: 27.07537531852722
Step: 1, Loss: 17.945302963256836, Computation time: 0.1743943691253662
Step: 2, Loss: 16.862319946289062, Computation time: 0.16676902770996094
Step: 3, Loss: 17.838266372680664, Computation time: 0.1676805019378662
Step: 4, Loss: 15.913865089416504, Computation time: 0.16959643363952637
Step: 5, Loss: 15.387907028198242, Computation time: 0.16565966606140137
Step: 6, Loss: 16.50263214111328, Computation time: 0.16969871520996094
Step: 7, Loss: 17.307086944580078, Computation time: 0.17042207717895508
Step: 8, Loss: 15.414609909057617, Computation time: 0.16952204704284668
Step: 9, Loss: 16.912670135498047, Computation time: 0.16579079627990723
Step: 10, Loss: 17.230003356933594, Computation time: 0.16723251342773438
Step: 11, Loss: 18.290681838989258, Computation time: 0.16434955596923828
Step: 12, Loss: 15.541263580322266, Computation time: 0.16330623626708984
Step: 13, Loss: 15.520601272583008, Computation time: 0.16518783569335938
Step: 14, Loss: 14.719974517822266, Computation time: 0.16350150108337402
Step: 15, Loss: 15.513769149780273, Computation time: 0.16359448432922363
Step: 16, Loss: 16.30827522277832, Computation time: 0.1634058952331543
Step: 17, Loss: 14.704435348510742, Computation time: 0.16392016410827637
Step: 18, Loss: 14.534599304199219, Computation time: 0.16302919387817383
Step: 19, Loss: 14.99282455444336, Computation time: 0.1640028953552246
Step: 20, Loss: 15.04023551940918, Computation time: 0.16433429718017578
Step: 21, Loss: 15.750327110290527, Computation time: 0.16364169120788574
Step: 22, Loss: 14.745054244995117, Computation time: 0.163421630859375
Step: 23, Loss: 15.654170989990234, Computation time: 0.16426348686218262
Step: 24, Loss: 14.102017402648926, Computation time: 0.16342639923095703
Step: 25, Loss: 13.730924606323242, Computation time: 0.16349434852600098
Step: 26, Loss: 14.454326629638672, Computation time: 0.162459135055542
Step: 27, Loss: 16.074562072753906, Computation time: 0.16372108459472656
Step: 28, Loss: 14.457178115844727, Computation time: 0.16365718841552734
Step: 29, Loss: 14.899832725524902, Computation time: 0.16407418251037598
Step: 30, Loss: 14.21741771697998, Computation time: 0.16400694847106934
Step: 31, Loss: 12.896212577819824, Computation time: 0.16325831413269043
Step: 32, Loss: 13.572277069091797, Computation time: 0.16397356986999512
Step: 33, Loss: 14.58654499053955, Computation time: 0.1686105728149414
Step: 34, Loss: 14.236112594604492, Computation time: 0.1673274040222168
Step: 35, Loss: 13.96904182434082, Computation time: 0.16666364669799805
Step: 36, Loss: 13.717779159545898, Computation time: 0.16426467895507812
Step: 37, Loss: 13.212942123413086, Computation time: 0.16362261772155762
Step: 38, Loss: 13.356792449951172, Computation time: 0.16526198387145996
Step: 39, Loss: 13.750845909118652, Computation time: 26.91799235343933
Step: 40, Loss: 15.398611068725586, Computation time: 0.1675868034362793
Step: 41, Loss: 11.830371856689453, Computation time: 0.16466093063354492
Step: 42, Loss: 12.59495735168457, Computation time: 0.16176891326904297
Step: 43, Loss: 13.213092803955078, Computation time: 0.16349530220031738
Step: 44, Loss: 12.40422534942627, Computation time: 0.16125273704528809
Step: 45, Loss: 13.30964469909668, Computation time: 0.16145730018615723
Step: 46, Loss: 12.55689811706543, Computation time: 0.16156625747680664
Step: 47, Loss: 11.785927772521973, Computation time: 0.1622486114501953
Step: 48, Loss: 11.325067520141602, Computation time: 0.16244864463806152
Step: 49, Loss: 11.61506462097168, Computation time: 0.1624457836151123
Step: 50, Loss: 10.890422821044922, Computation time: 0.16366934776306152
Step: 51, Loss: 13.305912017822266, Computation time: 0.16304707527160645
Step: 52, Loss: 11.54366397857666, Computation time: 0.16243696212768555
Step: 53, Loss: 11.796025276184082, Computation time: 0.16330742835998535
Step: 54, Loss: 12.504520416259766, Computation time: 0.16342830657958984
Step: 55, Loss: 11.736138343811035, Computation time: 0.16159415245056152
Step: 56, Loss: 11.351236343383789, Computation time: 0.16047382354736328
Step: 57, Loss: 11.916851997375488, Computation time: 0.16179728507995605
Step: 58, Loss: 11.83980655670166, Computation time: 0.16157770156860352
Step: 59, Loss: 11.1612548828125, Computation time: 0.16280055046081543
Step: 60, Loss: 11.311992645263672, Computation time: 0.1631929874420166
Step: 61, Loss: 11.657142639160156, Computation time: 0.16200017929077148
Step: 62, Loss: 10.814916610717773, Computation time: 0.16182494163513184
Step: 63, Loss: 10.638484001159668, Computation time: 0.16114020347595215
Step: 64, Loss: 9.871231079101562, Computation time: 0.16211938858032227
Step: 65, Loss: 10.842245101928711, Computation time: 0.16185402870178223
Step: 66, Loss: 11.241954803466797, Computation time: 0.16134214401245117
Step: 67, Loss: 10.528236389160156, Computation time: 0.16387319564819336
Step: 68, Loss: 10.252235412597656, Computation time: 0.16159725189208984
Step: 69, Loss: 10.343666076660156, Computation time: 0.16295313835144043
Step: 70, Loss: 9.838155746459961, Computation time: 0.16141152381896973
Step: 71, Loss: 10.129756927490234, Computation time: 0.16135191917419434
Step: 72, Loss: 10.172172546386719, Computation time: 0.16157221794128418
Step: 73, Loss: 9.98276424407959, Computation time: 0.16115164756774902
Step: 74, Loss: 9.925966262817383, Computation time: 0.16163945198059082
Step: 75, Loss: 9.98451042175293, Computation time: 0.16181254386901855
Step: 76, Loss: 10.033723831176758, Computation time: 0.1613597869873047
Step: 77, Loss: 9.620193481445312, Computation time: 0.1607823371887207
Step: 78, Loss: 9.448945045471191, Computation time: 0.1607818603515625
Step: 79, Loss: 7.9748687744140625, Computation time: 0.1488492488861084
Step: 80, Loss: 9.215356826782227, Computation time: 0.16275405883789062
Step: 81, Loss: 9.691690444946289, Computation time: 0.1624891757965088
Step: 82, Loss: 8.748353958129883, Computation time: 0.16045212745666504
Step: 83, Loss: 8.528343200683594, Computation time: 0.16178536415100098
Step: 84, Loss: 8.34644889831543, Computation time: 0.16109156608581543
Step: 85, Loss: 9.200542449951172, Computation time: 0.16094589233398438
Step: 86, Loss: 8.57141399383545, Computation time: 0.1619279384613037
Step: 87, Loss: 7.508444786071777, Computation time: 0.1600663661956787
Step: 88, Loss: 7.279205322265625, Computation time: 0.16137266159057617
Step: 89, Loss: 7.090503215789795, Computation time: 0.16118311882019043
Step: 90, Loss: 7.453930377960205, Computation time: 0.16112112998962402
Step: 91, Loss: 7.0916032791137695, Computation time: 0.16120529174804688
Step: 92, Loss: 7.136333465576172, Computation time: 0.16111302375793457
Step: 93, Loss: 7.14594841003418, Computation time: 0.16206598281860352
Step: 94, Loss: 6.871617317199707, Computation time: 0.19673919677734375
Step: 95, Loss: 7.352797031402588, Computation time: 0.16296100616455078
Step: 96, Loss: 6.726633548736572, Computation time: 0.16156458854675293
Step: 97, Loss: 6.9557905197143555, Computation time: 0.16250896453857422
Step: 98, Loss: 7.102599143981934, Computation time: 0.1620466709136963
Step: 99, Loss: 7.049860954284668, Computation time: 0.16131353378295898
Step: 100, Loss: 6.750383377075195, Computation time: 0.16186952590942383
Step: 101, Loss: 7.038060188293457, Computation time: 0.16181278228759766
Step: 102, Loss: 7.034355640411377, Computation time: 0.16237926483154297
Step: 103, Loss: 6.82716178894043, Computation time: 0.16185402870178223
Step: 104, Loss: 6.787952423095703, Computation time: 0.16224908828735352
Step: 105, Loss: 6.880023002624512, Computation time: 0.16243886947631836
Step: 106, Loss: 6.616780757904053, Computation time: 0.1620333194732666
Step: 107, Loss: 6.402748107910156, Computation time: 0.16213607788085938
Step: 108, Loss: 6.7207746505737305, Computation time: 0.16174864768981934
Step: 109, Loss: 5.961440563201904, Computation time: 0.16174983978271484
Step: 110, Loss: 6.086441993713379, Computation time: 0.16232728958129883
Step: 111, Loss: 5.67965030670166, Computation time: 0.1625194549560547
Step: 112, Loss: 5.820930480957031, Computation time: 0.1604611873626709
Step: 113, Loss: 6.119414329528809, Computation time: 0.16963505744934082
Step: 114, Loss: 6.096449851989746, Computation time: 0.16268205642700195
Step: 115, Loss: 5.988513469696045, Computation time: 0.1606006622314453
Step: 116, Loss: 6.118512153625488, Computation time: 0.16241216659545898
Step: 117, Loss: 5.241769790649414, Computation time: 0.16131067276000977
Step: 118, Loss: 6.166355609893799, Computation time: 0.16092491149902344
Step: 119, Loss: 6.842771530151367, Computation time: 0.1441802978515625
Step: 120, Loss: 6.375185489654541, Computation time: 0.16277027130126953
Step: 121, Loss: 5.80587100982666, Computation time: 0.1614992618560791
Step: 122, Loss: 5.733676433563232, Computation time: 0.16245174407958984
Step: 123, Loss: 5.918340682983398, Computation time: 0.16118144989013672
Step: 124, Loss: 5.5885467529296875, Computation time: 0.16121363639831543
Step: 125, Loss: 5.8133063316345215, Computation time: 0.16047954559326172
Step: 126, Loss: 5.448032379150391, Computation time: 0.1612851619720459
Step: 127, Loss: 5.919766902923584, Computation time: 0.16178321838378906
Step: 128, Loss: 5.811756610870361, Computation time: 0.16073966026306152
Step: 129, Loss: 5.2886857986450195, Computation time: 0.16239547729492188
Step: 130, Loss: 5.062446594238281, Computation time: 0.1623084545135498
Step: 131, Loss: 5.370600700378418, Computation time: 0.16302895545959473
Step: 132, Loss: 5.032846450805664, Computation time: 0.16185617446899414
Step: 133, Loss: 5.3186492919921875, Computation time: 0.16357207298278809
Step: 134, Loss: 4.988264083862305, Computation time: 0.16092920303344727
Step: 135, Loss: 5.364264488220215, Computation time: 0.16193294525146484
Step: 136, Loss: 5.038562774658203, Computation time: 0.16143488883972168
Step: 137, Loss: 5.195552825927734, Computation time: 0.16141676902770996
Step: 138, Loss: 4.877957344055176, Computation time: 0.16106271743774414
Step: 139, Loss: 4.971206188201904, Computation time: 0.15976953506469727
Step: 140, Loss: 4.850249767303467, Computation time: 0.16672515869140625
Step: 141, Loss: 5.053151607513428, Computation time: 0.16182613372802734
Step: 142, Loss: 4.553808212280273, Computation time: 0.16060352325439453
Step: 143, Loss: 4.6004109382629395, Computation time: 0.16107678413391113
Step: 144, Loss: 4.889383316040039, Computation time: 0.1608583927154541
Step: 145, Loss: 4.736492156982422, Computation time: 0.16157317161560059
Step: 146, Loss: 4.708489894866943, Computation time: 0.16304683685302734
Step: 147, Loss: 4.679104804992676, Computation time: 0.1609785556793213
Step: 148, Loss: 4.689470291137695, Computation time: 0.16070127487182617
Step: 149, Loss: 4.528751850128174, Computation time: 0.16136622428894043
Step: 150, Loss: 4.48677396774292, Computation time: 0.1604769229888916
Step: 151, Loss: 4.637646675109863, Computation time: 0.16101288795471191
Step: 152, Loss: 4.762913703918457, Computation time: 0.16133403778076172
Step: 153, Loss: 4.44551944732666, Computation time: 0.1619107723236084
Step: 154, Loss: 4.5776472091674805, Computation time: 0.1616075038909912
Step: 155, Loss: 4.562440395355225, Computation time: 0.16150236129760742
Step: 156, Loss: 4.409887313842773, Computation time: 0.16173315048217773
Step: 157, Loss: 4.46767520904541, Computation time: 0.16112399101257324
Step: 158, Loss: 4.25125789642334, Computation time: 0.16138744354248047
Step: 159, Loss: 4.785336971282959, Computation time: 0.1468524932861328
Step: 160, Loss: 5.054254055023193, Computation time: 0.16128849983215332
Step: 161, Loss: 4.8799567222595215, Computation time: 0.1611628532409668
Step: 162, Loss: 4.688265800476074, Computation time: 0.16042160987854004
Step: 163, Loss: 4.51352596282959, Computation time: 0.1602628231048584
Step: 164, Loss: 4.331615447998047, Computation time: 0.1609640121459961
Step: 165, Loss: 4.137004852294922, Computation time: 0.16290068626403809
Step: 166, Loss: 4.654952049255371, Computation time: 0.16114187240600586
Step: 167, Loss: 4.4677629470825195, Computation time: 0.16231393814086914
Step: 168, Loss: 4.510952949523926, Computation time: 0.16344356536865234
Step: 169, Loss: 4.258943557739258, Computation time: 0.16016602516174316
Step: 170, Loss: 4.283701419830322, Computation time: 0.1614704132080078
Step: 171, Loss: 4.368310451507568, Computation time: 0.1617722511291504
Step: 172, Loss: 4.095067024230957, Computation time: 0.16355204582214355
Step: 173, Loss: 4.290921211242676, Computation time: 0.16144156455993652
Step: 174, Loss: 4.135052680969238, Computation time: 0.16065239906311035
Step: 175, Loss: 4.188730239868164, Computation time: 0.16092491149902344
Step: 176, Loss: 3.9966931343078613, Computation time: 0.16103458404541016
Step: 177, Loss: 4.127541542053223, Computation time: 0.16103053092956543
Step: 178, Loss: 4.2538557052612305, Computation time: 0.1615607738494873
Step: 179, Loss: 4.453568458557129, Computation time: 0.1603102684020996
Step: 180, Loss: 4.0408525466918945, Computation time: 0.16083049774169922
Step: 181, Loss: 4.516185760498047, Computation time: 0.1609797477722168
Step: 182, Loss: 4.250395774841309, Computation time: 0.1612706184387207
Step: 183, Loss: 4.046529769897461, Computation time: 0.16176581382751465
Step: 184, Loss: 4.198785781860352, Computation time: 0.16283583641052246
Step: 185, Loss: 3.9407706260681152, Computation time: 0.16234254837036133
Step: 186, Loss: 4.026411056518555, Computation time: 0.1624460220336914
Step: 187, Loss: 4.224530220031738, Computation time: 0.16072320938110352
Step: 188, Loss: 4.028736591339111, Computation time: 0.16074919700622559
Step: 189, Loss: 3.837322950363159, Computation time: 0.16036534309387207
Step: 190, Loss: 4.123674392700195, Computation time: 0.16191387176513672
Step: 191, Loss: 3.9622178077697754, Computation time: 0.16129708290100098
Step: 192, Loss: 3.969315528869629, Computation time: 0.16092944145202637
Step: 193, Loss: 3.7825825214385986, Computation time: 0.16073131561279297
Step: 194, Loss: 3.9199018478393555, Computation time: 0.16074514389038086
Step: 195, Loss: 4.052471160888672, Computation time: 0.16427040100097656
Step: 196, Loss: 3.7691221237182617, Computation time: 0.16066265106201172
Step: 197, Loss: 3.937032699584961, Computation time: 0.16099143028259277
Step: 198, Loss: 4.042672634124756, Computation time: 0.16167831420898438
Step: 199, Loss: 3.7281570434570312, Computation time: 0.14007043838500977
Step: 200, Loss: 4.159261226654053, Computation time: 0.16143798828125
Step: 201, Loss: 4.408998489379883, Computation time: 0.16060853004455566
Step: 202, Loss: 4.1045427322387695, Computation time: 0.16067767143249512
Step: 203, Loss: 4.352884292602539, Computation time: 0.1615588665008545
Step: 204, Loss: 4.170437335968018, Computation time: 0.16057705879211426
Step: 205, Loss: 3.970756769180298, Computation time: 0.1603851318359375
Step: 206, Loss: 4.299739837646484, Computation time: 0.16051793098449707
Step: 207, Loss: 4.127477645874023, Computation time: 0.16169023513793945
Step: 208, Loss: 4.360357761383057, Computation time: 0.1614537239074707
Step: 209, Loss: 3.9281232357025146, Computation time: 0.16314291954040527
Step: 210, Loss: 3.9255576133728027, Computation time: 0.16143369674682617
Step: 211, Loss: 4.089841842651367, Computation time: 0.162628173828125
Step: 212, Loss: 4.131923675537109, Computation time: 0.1637284755706787
Step: 213, Loss: 4.047548294067383, Computation time: 0.16175484657287598
Step: 214, Loss: 4.078159809112549, Computation time: 0.1614534854888916
Step: 215, Loss: 4.092671871185303, Computation time: 0.16064238548278809
Step: 216, Loss: 4.069928169250488, Computation time: 0.16089081764221191
Step: 217, Loss: 3.7901744842529297, Computation time: 0.16229534149169922
Step: 218, Loss: 4.05171012878418, Computation time: 0.16241717338562012
Step: 219, Loss: 4.072657585144043, Computation time: 0.16231489181518555
Step: 220, Loss: 4.119385719299316, Computation time: 0.16376709938049316
Step: 221, Loss: 3.946767568588257, Computation time: 0.16153383255004883
Step: 222, Loss: 3.8579845428466797, Computation time: 0.16051745414733887
Step: 223, Loss: 3.955892324447632, Computation time: 0.16411495208740234
Step: 224, Loss: 4.090612411499023, Computation time: 0.16119980812072754
Step: 225, Loss: 3.871494770050049, Computation time: 0.1633768081665039
Step: 226, Loss: 4.001490116119385, Computation time: 0.1612398624420166
Step: 227, Loss: 3.856689453125, Computation time: 0.16136479377746582
Step: 228, Loss: 3.854506254196167, Computation time: 0.16175079345703125
Step: 229, Loss: 3.920146942138672, Computation time: 0.16027593612670898
Step: 230, Loss: 3.8486571311950684, Computation time: 0.16107869148254395
Step: 231, Loss: 4.150424003601074, Computation time: 0.161329984664917
Step: 232, Loss: 4.034335613250732, Computation time: 0.16145658493041992
Step: 233, Loss: 3.862642288208008, Computation time: 0.16074752807617188
Step: 234, Loss: 3.879786491394043, Computation time: 0.16097068786621094
Step: 235, Loss: 3.9150876998901367, Computation time: 0.1610715389251709
Step: 236, Loss: 3.6582045555114746, Computation time: 0.16137981414794922
Step: 237, Loss: 4.022642612457275, Computation time: 0.16101980209350586
Step: 238, Loss: 3.920273780822754, Computation time: 0.16168999671936035
Step: 239, Loss: 4.942720890045166, Computation time: 0.139939546585083
Step: 240, Loss: 3.820035457611084, Computation time: 0.16096997261047363
Step: 241, Loss: 4.027595520019531, Computation time: 0.1608715057373047
Step: 242, Loss: 3.9767158031463623, Computation time: 0.16132664680480957
Step: 243, Loss: 3.927661895751953, Computation time: 0.16009283065795898
Step: 244, Loss: 4.054908275604248, Computation time: 0.16004633903503418
Step: 245, Loss: 4.072584629058838, Computation time: 0.1604931354522705
Step: 246, Loss: 4.165594100952148, Computation time: 0.16080093383789062
Step: 247, Loss: 3.9277215003967285, Computation time: 0.16055607795715332
Step: 248, Loss: 4.001946449279785, Computation time: 0.1610417366027832
Step: 249, Loss: 3.9720990657806396, Computation time: 0.16184639930725098