{ "cells": [ { "cell_type": "markdown", "id": "45b018f1-04fc-4b44-a785-faa0fdd9f6b7", "metadata": {}, "source": [ "# Keras integration\n", "\n", "This notebook describes the combination of bt_ocean with Keras.\n", "\n", "bt_ocean allows a neural network to be used to define a right-hand-side forcing term, which can then be trained using time-dependent data. Here, to demonstrate the principles, we will consider an extremely simple case of a Keras model consisting of a single layer which simply outputs the degrees of freedom for a function – so that training reduces to the problem of finding this function. Specifically we will try to find the wind forcing term $Q$ used on the right-hand-side of the barotropic vorticity equation. This toy problem demonstrates the key ideas – while remaining small enough to run quickly!" ] }, { "cell_type": "markdown", "id": "c755be30-a57c-4554-8d28-27da9052f4f9", "metadata": {}, "source": [ "## Training data\n", "\n", "For this simple toy problem we will run a low resolution model for a short amount of time." ] }, { "cell_type": "code", "execution_count": null, "id": "508b3675-1cd7-434d-989e-09055583b039", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "from bt_ocean.model import CNAB2Solver\n", "from bt_ocean.network import Dynamics\n", "from bt_ocean.parameters import parameters, tau_0, rho_0, D, Q\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import keras\n", "import matplotlib.pyplot as plt\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "keras.backend.set_floatx(\"float64\")\n", "\n", "n_hour = 2\n", "parameters = dict(parameters)\n", "parameters.update({\"N_x\": 64,\n", " \"N_y\": 64,\n", " \"dt\": 3600 / n_hour,\n", " \"nu\": 2500})\n", "model = CNAB2Solver(parameters)\n", "model.fields[\"Q\"] = Q(model.grid)\n", "\n", "n_day = 24 * n_hour\n", "N = 5 * n_day\n", "\n", "model.initialize()\n", "data = [model.fields[\"zeta\"]]\n", "for _ in range(N):\n", " model.step()\n", " data.append(model.fields[\"zeta\"])\n", " if model.n % n_day == 0:\n", " print(f\"{model.n // n_day=} {model.ke()=}\")\n", "data = jnp.array(data)" ] }, { "cell_type": "markdown", "id": "d387e29e-ea96-4ba7-8327-2ae19241b045", "metadata": {}, "source": [ "## A simple Keras model\n", "\n", "We now set up our Keras model. To do this we first define a `keras.models.Model` which defines a map from our state to a forcing term which we add on the right-hand-side of the barotropic vorticity equation, and a callable which uses this map." ] }, { "cell_type": "code", "execution_count": null, "id": "da7d0234-e47b-4449-8c44-89f7eff02437", "metadata": {}, "outputs": [], "source": [ "import keras\n", "\n", "\n", "class Term(keras.layers.Layer):\n", " def __init__(self, grid, **kwargs):\n", " super().__init__(**kwargs)\n", " self.__b = self.add_weight(shape=(grid.N_x + 1, grid.N_y + 1), dtype=grid.fdtype, initializer=\"zeros\")\n", "\n", " def compute_output_shape(self, input_shape):\n", " return input_shape[:-1] + self.__b.shape\n", "\n", " def call(self, inputs):\n", " return jnp.tensordot(jnp.ones_like(inputs, shape=inputs.shape[:-1]), self.__b, axes=0)\n", "\n", "\n", "Q_input_layer = keras.layers.Input((0,))\n", "Q_network = keras.models.Model(inputs=Q_input_layer, outputs=Term(model.grid)(Q_input_layer))\n", "Q_weight = tau_0 * jnp.pi / (D * rho_0 * model.grid.L_y)\n", "\n", "\n", "def update(dynamics, Q_network):\n", " dynamics.fields[\"Q\"] = Q_weight * Q_network(jnp.zeros(shape=(1, 0)))[0, :, :]" ] }, { "cell_type": "markdown", "id": "d23edf22-f916-41fc-8095-2bdd8d9be6c9", "metadata": {}, "source": [ "We now set up a `Dynamics` layer. This is a custom Keras layer which represents the mapping from an initial condition, in terms of the initial vorticity fields, to dynamical trajectories, in terms of the vorticity fields evaluated at later times. The trajectories are computed by solving the barotropic vorticity equation, while being forced with an extra right-hand-side term defined by the given `keras.models.Model`.\n", "\n", "In the following we use the AdamW optimizer, but here we have only one batch of size one. In fact here we are solving a standard variational optimization problem, but without an explicit regularization term, and so for this problem it might be better to use a deterministic optimizer. We do, however, increase the learning rate." ] }, { "cell_type": "code", "execution_count": null, "id": "a94eda1c-2556-4c99-be23-b57df4e73967", "metadata": {}, "outputs": [], "source": [ "output_weight = (model.grid.N_x + 1) * (model.grid.N_y + 1) * jnp.sqrt(model.grid.W / (4 * model.grid.L_x * model.grid.L_y)) / (model.beta * model.grid.L_y)\n", "dynamics_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1))\n", "dynamics_layer = Dynamics(\n", " model, update, Q_network, N=1, n_output=N, output_weight=output_weight)\n", "dynamics_network = keras.models.Model(\n", " inputs=dynamics_input_layer, outputs=dynamics_layer(dynamics_input_layer))\n", "dynamics_network.compile(optimizer=keras.optimizers.AdamW(learning_rate=0.1),\n", " loss=\"mean_squared_error\")" ] }, { "cell_type": "markdown", "id": "aaea48a7-2329-49d6-9c5f-6e0f7434c678", "metadata": {}, "source": [ "## Training\n", "\n", "We are now ready to train. In this simple problem we simply use the full trajectory as a single input-output pair, and since we know the answer we supply no validation data." ] }, { "cell_type": "code", "execution_count": null, "id": "327c0337-d2af-47d5-8047-cc3e9b2b0be0", "metadata": {}, "outputs": [], "source": [ "input_data = jnp.reshape(data[0, :, :], (1,) + data.shape[1:])\n", "output_data = jnp.reshape(data[1:, :, :], (1, data.shape[0] - 1) + data.shape[1:])\n", "_ = dynamics_network.fit(input_data, output_data * output_weight, epochs=40, verbose=2)" ] }, { "cell_type": "markdown", "id": "b49517ff-68d1-403b-ad5e-abc88a7386d4", "metadata": {}, "source": [ "Let's see how well we've done." ] }, { "cell_type": "code", "execution_count": null, "id": "c56ce9f5-406f-4a6a-a922-134d9726a00d", "metadata": {}, "outputs": [], "source": [ "result = Q_weight * dynamics_network.get_weights()[0]\n", "ref = model.fields[\"Q\"]\n", "\n", "plt.figure()\n", "m = abs(result).max()\n", "plt.contourf(model.grid.X, model.grid.Y, result,\n", " jnp.linspace(-m * (1 + 1.0e-10), m * (1 + 1.0e-10), 64), cmap=\"bwr\")\n", "plt.gca().set_aspect(1)\n", "plt.colorbar()\n", "plt.title(\"Result\")\n", "\n", "plt.figure()\n", "m = abs(ref).max()\n", "plt.contourf(model.grid.X, model.grid.Y, ref,\n", " jnp.linspace(-m * (1 + 1.0e-10), m * (1 + 1.0e-10), 64), cmap=\"bwr\")\n", "plt.gca().set_aspect(1)\n", "plt.colorbar()\n", "plt.title(\"Reference\")\n", "\n", "plt.figure()\n", "m = abs(result - ref).max()\n", "plt.contourf(model.grid.X, model.grid.Y, result - ref,\n", " jnp.linspace(-m * (1 + 1.0e-10), m * (1 + 1.0e-10), 64), cmap=\"bwr\")\n", "plt.gca().set_aspect(1)\n", "plt.colorbar()\n", "plt.title(\"Error\")\n", "\n", "assert abs((result - ref)[8:-8, 8:-8]).max() < 0.2 * abs(ref[8:-8, 8:-8]).max()" ] }, { "cell_type": "markdown", "id": "5619785b-1578-4f02-ab3a-a6a4c0797dd1", "metadata": {}, "source": [ "We have a reasonable result, except at the east and west boundaries.\n", "\n", "While we might expect to do much better for this problem if we applied a deterministic optimizer, there is also a more fundamental problem here. bt_ocean applies free slip boundary conditions, meaning that the vorticity satisfies homogeneous Dirichlet boundary conditions. This means that there are perturbations we can apply to the right-hand-side (contributions to $Q$) which do not affect the dynamics – and consequently derivatives of the loss with respect to the right-hand-side forcing, evaluated in directions associated with these perturbations, are zero. This non-regularized inverse problem is ill-posed, and we see this here by finding that, given only the trajectory of the numerical model, we cannot recover the wind stress curl term on the boundary." ] }, { "cell_type": "markdown", "id": "62bfc02f-a6f0-41d7-a535-7fef207ae5dd", "metadata": {}, "source": [ "## Increasing the complexity\n", "\n", "Here we have used Keras to solve a standard variational optimization problem, by defining a very simple Keras model. However we can make the Keras model `Q_network` much more complex, and can also use the `Dynamics` layer itself as part of a more complicated 'outer' Keras model. That is, we can embed neural networks within bt_ocean, and can also embed bt_ocean within neural networks. The main restriction is that use of the embedded neural network (by the `update` callable in this example) can only change the `dynamics` argument, but cannot change other arguments (here `Q_network`) or have other side effects. This means for example that evaluating the embedded neural network cannot change the neural network itself – as occurs e.g. with batch normalization." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }