{ "cells": [ { "cell_type": "markdown", "id": "71577fda-4a7c-4ba2-aaf1-c49c3962e5de", "metadata": {}, "source": [ "# Steady-state problems\n", "\n", "This notebook outlines the solution and differentiation of steady-state problems using bt_ocean. The key approach used is *implicit* differentiation. For further details on implicit differentiation see\n", "\n", " - Andreas Griewank and Andrea Walther, 'Evaluating derivatives', second edition, Society for Industrial and Applied Mathematics, 2008, ISBN: 978-0-898716-59-7, chapter 15\n", " - Bruce Christianson, 'Reverse accumulation and attractive fixed points', Optimization Methods and Software 3(4), pp. 311–326 1994, doi: https://doi.org/10.1080/10556789408805572\n", " - Zico Kolter, David Duvenaud, and Matt Johnson, 'Deep implicit layers - neural ODEs, deep equilibirum models, and beyond', https://implicit-layers-tutorial.org/ [accessed 2024-08-26], chapter 2" ] }, { "cell_type": "markdown", "id": "50e03bdb-de66-4384-9182-529d8f9cd176", "metadata": {}, "source": [ "## Forward problem\n", "\n", "Here we consider a non-linear Stommel problem with a single gyre. We start by configuring the model." ] }, { "cell_type": "code", "execution_count": null, "id": "ce9eaeba-924a-4e5b-947f-e8b839860e7a", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "from bt_ocean.model import CNAB2Solver\n", "from bt_ocean.parameters import parameters, tau_0, rho_0, D\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "\n", "from functools import partial\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "n_hour = 2\n", "parameters = dict(parameters)\n", "parameters.update({\"N_x\": 64,\n", " \"N_y\": 64,\n", " \"dt\": 3600 / n_hour,\n", " \"r\": 0.1 * parameters[\"beta\"] * parameters[\"L_y\"],\n", " \"nu\": 0})\n", "\n", "\n", "def Q(grid):\n", " return -((tau_0 * jnp.pi / (D * rho_0 * 2 * grid.L_y))\n", " * jnp.cos(jnp.pi * grid.Y / (2 * grid.L_y)))\n", "\n", "\n", "model = CNAB2Solver(parameters)\n", "model.fields[\"Q\"] = Q(model.grid)" ] }, { "cell_type": "markdown", "id": "17ef64b5-2b1c-49e6-b36f-1118848834c7", "metadata": {}, "source": [ "We now solve to steady-state. Rather than writing our own time loop we can use the `steady_state_solve` method. Under-the-hood this timesteps the model until an approximate steady-state is reached." ] }, { "cell_type": "code", "execution_count": null, "id": "cf5aec0d-163e-4e3a-9752-25d6d3d55bb4", "metadata": {}, "outputs": [], "source": [ "model.steady_state_solve(tol=1.0e-5)\n", "\n", "plt.figure()\n", "plt.contourf(model.grid.X, model.grid.Y, D * model.fields[\"psi\"] / 1.0e6, 64, cmap=\"magma\")\n", "plt.gca().set_aspect(1)\n", "plt.colorbar()\n", "\n", "print(f\"{(D * model.fields['psi'] / 1.0e6).max()=}\")" ] }, { "cell_type": "markdown", "id": "0947c56b-6321-4401-a923-4d7de65640ea", "metadata": {}, "source": [ "We see a classic Stommel gyre solution, with a transport of approximately $10$ Sv." ] }, { "cell_type": "markdown", "id": "53697c00-c5ca-44a4-a78e-9259d9e592e1", "metadata": {}, "source": [ "## Autodiff\n", "\n", "We now differentiate the transport with respect to the wind forcing profile $Q$ appearing on the right-hand-side of the barotropic vorticity equation. We write a single function which computes the diagnostic. Under-the-hood this applies *implicit* differentiation, constructing a fixed-point iteration to solve an associated adjoint problem." ] }, { "cell_type": "code", "execution_count": null, "id": "8bcb169f-112b-46c3-8816-f61b552cf451", "metadata": {}, "outputs": [], "source": [ "def forward(Q):\n", " model = CNAB2Solver(parameters)\n", " model.fields[\"Q\"] = Q\n", " model.steady_state_solve(tol=1.0e-6)\n", " return D * model.fields[\"psi\"].max() / 1.0e6\n", "\n", "\n", "transport, vjp = jax.vjp(forward, model.fields[\"Q\"])\n", "print(f\"{transport=}\")\n", "\n", "dtransport_dual, = vjp(1.0)\n", "dtransport_primal = dtransport_dual / model.grid.W\n", "\n", "plt.figure()\n", "plt.contourf(model.grid.X, model.grid.Y, dtransport_primal, 64, cmap=\"magma\")\n", "plt.gca().set_aspect(1)\n", "plt.colorbar()" ] }, { "cell_type": "markdown", "id": "906c5fef-bdbb-4857-812a-ca66551dbf18", "metadata": {}, "source": [ "Let's perform a simple check against finite differencing." ] }, { "cell_type": "code", "execution_count": null, "id": "c0910989-3471-4df5-a06d-89dd04cdef73", "metadata": {}, "outputs": [], "source": [ "eta = ((tau_0 * jnp.pi / (D * rho_0 * 2 * model.grid.L_y))\n", " * jnp.sin(jnp.pi * (model.grid.X + model.grid.L_x) / (2 * model.grid.L_x))\n", " * jnp.sin(jnp.pi * (model.grid.Y + model.grid.L_y) / (2 * model.grid.L_y)))\n", "dtransport_dual_eta = jnp.tensordot(dtransport_dual, eta)\n", "print(f\"{dtransport_dual_eta=}\")\n", "\n", "eps = 1.0e-5\n", "dtransport_dual_eta_fd = (forward(model.fields[\"Q\"] + eps * eta) - forward(model.fields[\"Q\"] - eps * eta)) / (2 * eps)\n", "print(f\"{dtransport_dual_eta_fd=}\")\n", "assert abs(dtransport_dual_eta - dtransport_dual_eta_fd) < 1.0e-3 * abs(dtransport_dual_eta_fd)" ] }, { "cell_type": "markdown", "id": "4a976a0b-b65c-4275-a67b-8a018d607f50", "metadata": {}, "source": [ "## Custom terms\n", "\n", "We can also add further terms, and differentiate with respect to extra parameters. Here we differentiate with respect to a perturbation to the bottom drag parameter." ] }, { "cell_type": "code", "execution_count": null, "id": "f27744e7-7da1-4e21-bfcd-08d805f26e38", "metadata": {}, "outputs": [], "source": [ "def forward(Q, r):\n", " def update(model, r):\n", " model.fields[\"Q\"] = Q - (r - model.r) * model.fields[\"zeta\"]\n", "\n", " model = CNAB2Solver(parameters)\n", " model.steady_state_solve(r, update=update, tol=1.0e-6)\n", " return D * model.fields[\"psi\"].max() / 1.0e6\n", "\n", "\n", "r = jnp.full_like(model.fields[\"Q\"], model.r)\n", "transport, vjp = jax.vjp(partial(forward, model.fields[\"Q\"]), r)\n", "print(f\"{transport=}\")\n", "\n", "dtransport_dual, = vjp(1.0)\n", "dtransport_primal = dtransport_dual / model.grid.W\n", "\n", "plt.figure()\n", "plt.contourf(model.grid.X, model.grid.Y, dtransport_primal, 64, cmap=\"magma\")\n", "plt.gca().set_aspect(1)\n", "plt.colorbar()" ] }, { "cell_type": "markdown", "id": "ae804da7-01b9-493e-8875-0e28cec56c44", "metadata": {}, "source": [ "We can again compare against the result from finite differencing." ] }, { "cell_type": "code", "execution_count": null, "id": "346cb5f3-85e0-4050-8361-4d3c2078d262", "metadata": {}, "outputs": [], "source": [ "eta = (model.r\n", " * jnp.sin(jnp.pi * (model.grid.X + model.grid.L_x) / (2 * model.grid.L_x))\n", " * jnp.sin(jnp.pi * (model.grid.Y + model.grid.L_y) / (2 * model.grid.L_y)))\n", "dtransport_dual_eta = jnp.tensordot(dtransport_dual, eta)\n", "print(f\"{dtransport_dual_eta=}\")\n", "\n", "eps = 1.0e-5\n", "dtransport_dual_eta_fd = (forward(model.fields[\"Q\"], r + eps * eta) - forward(model.fields[\"Q\"], r - eps * eta)) / (2 * eps)\n", "print(f\"{dtransport_dual_eta_fd=}\")\n", "assert abs(dtransport_dual_eta - dtransport_dual_eta_fd) < 1.0e-3 * abs(dtransport_dual_eta_fd)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }