diff --git a/README.md b/README.md index c59044cf..2596691e 100644 --- a/README.md +++ b/README.md @@ -1,68 +1,42 @@ -# GraphCast: Learning skillful medium-range global weather forecasting +# Google DeepMind GraphCast and GenCast -This package contains example code to run and train [GraphCast](https://www.science.org/doi/10.1126/science.adi2336). -It also provides three pretrained models: +This package contains example code to run and train the weather models used in the research papers [GraphCast](https://www.science.org/doi/10.1126/science.adi2336) and [GenCast](https://arxiv.org/abs/2312.15796). -1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree -resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017, - -2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree -resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from -1979 to 2015, useful to run a model with lower memory and compute constraints, - -3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 -pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on -HRES data from 2016 to 2021. This model can be initialized from HRES data (does -not require precipitation inputs). - -The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast). +It also provides pretrained model weights, normalization statistics and example input data on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast). Full model training requires downloading the [ERA5](https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5) dataset, available from [ECMWF](https://www.ecmwf.int/). This can best be -accessed as Zarr from [Weatherbench2's ERA5 data](https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5) (see the 6h downsampled versions). +accessed as Zarr from [Weatherbench2's ERA5 data](https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5). -## Overview of files +Data for operational fine-tuning can similarly be accessed at [Weatherbench2's HRES 0th frame data](https://weatherbench2.readthedocs.io/en/latest/data-guide.html#ifs-hres-t-0-analysis). -The best starting point is to open `graphcast_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb), which gives an -example of loading data, generating random weights or load a pre-trained -snapshot, generating predictions, computing the loss and computing gradients. -The one-step implementation of GraphCast architecture, is provided in -`graphcast.py`. +These datasets may be governed by separate terms and conditions or license provisions. Your use of such third-party materials is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use. -### Brief description of library files: +## Overview of files common to models -* `autoregressive.py`: Wrapper used to run (and train) the one-step GraphCast +* `autoregressive.py`: Wrapper used to run (and train) the one-step predictions to produce a sequence of predictions by auto-regressively feeding the outputs back as inputs at each step, in JAX a differentiable way. -* `casting.py`: Wrapper used around GraphCast to make it work using - BFloat16 precision. * `checkpoint.py`: Utils to serialize and deserialize trees. * `data_utils.py`: Utils for data preprocessing. * `deep_typed_graph_net.py`: General purpose deep graph neural network (GNN) that operates on `TypedGraph`'s where both inputs and outputs are flat - vectors of features for each of the nodes and edges. `graphcast.py` uses - three of these for the Grid2Mesh GNN, the Multi-mesh GNN and the Mesh2Grid - GNN, respectively. -* `graphcast.py`: The main GraphCast model architecture for one-step of - predictions. + vectors of features for each of the nodes and edges. * `grid_mesh_connectivity.py`: Tools for converting between regular grids on a sphere and triangular meshes. * `icosahedral_mesh.py`: Definition of an icosahedral multi-mesh. * `losses.py`: Loss computations, including latitude-weighting. +* `mlp.py`: Utils for building MLPs with norm conditioning layers. * `model_utils.py`: Utilities to produce flat node and edge vector features from input grid data, and to manipulate the node output vectors back into a multilevel grid data. -* `normalization.py`: Wrapper for the one-step GraphCast used to normalize - inputs according to historical values, and targets according to historical - time differences. -* `predictor_base.py`: Defines the interface of the predictor, which GraphCast +* `normalization.py`: Wrapper used to normalize inputs according to historical + values, and targets according to historical time differences. +* `predictor_base.py`: Defines the interface of the predictor, which models and all of the wrappers implement. * `rollout.py`: Similar to `autoregressive.py` but used only at inference time using a python loop to produce longer, but non-differentiable trajectories. -* `solar_radiation.py`: Computes Top-Of-the-Atmosphere (TOA) incident solar - radiation compatible with ERA5. This is used as a forcing variable and thus - needs to be computed for target lead times in an operational setting. * `typed_graph.py`: Definition of `TypedGraph`'s. * `typed_graph_net.py`: Implementation of simple graph neural network building blocks defined over `TypedGraph`'s that can be combined to build @@ -71,11 +45,115 @@ The one-step implementation of GraphCast architecture, is provided in * `xarray_tree.py`: An implementation of tree.map_structure that works with `xarray`s. +## GenCast: Diffusion-based ensemble forecasting for medium-range weather + +This package provides four pretrained models: + +1. `GenCast 0p25deg <2019`, GenCast model at 0.25deg resolution with 13 +pressure levels and a 6 times refined icosahedral mesh. This model is trained on +ERA5 data from 1979 to 2018 (inclusive), and can be causally evaluated on 2019 +and later years. This model was described in the paper +`GenCast: Diffusion-based ensemble forecasting for medium-range weather` +(https://arxiv.org/abs/2312.15796) + +2. `GenCast 0p25deg Operational <2019`, GenCast model at 0.25deg resolution, with 13 pressure levels and a 6 +times refined icosahedral mesh. This model is trained on ERA5 data from +1979 to 2018, and fine-tuned on HRES-fc0 data from +2016 to 2021 and can be causally evaluated on 2022 and later years. +This model can make predictions in an operational setting (i.e., initialised +from HRES-fc0) + +3. `GenCast 1p0deg <2019`, GenCast model at 1deg resolution, with 13 pressure +levels and a 5 times refined icosahedral mesh. This model is +trained on ERA5 data from 1979 to 2018, and can be causally evaluated on 2019 and later years. +This model has a smaller memory footprint than the 0.25deg models + +4. `GenCast 1p0deg Mini <2019`, GenCast model at 1deg resolution, with 13 pressure levels and a +4 times refined icosahedral mesh. This model is trained on ERA5 data +from 1979 to 2018, and can be causally evaluated on 2019 and later years. +This model has the smallest memory footprint of those provided and has been +provided to enable low cost demonstrations (for example, it is runnable in a free Colab notebook). +While its performance is reasonable, it is not representative of the performance +of the GenCast models (1-3) above. For reference, a scorecard comparing its performance to ENS can be found in [docs/](https://github.com/google-deepmind/graphcast/docs/GenCast_1p0deg_Mini_ENS_scorecard.png). Note that in this scorecard, +GenCast Mini only uses 8 member ensembles (vs. ENS' 50) so we use the fair (unbiased) +CRPS to allow for fair comparison. + +The best starting point is to open `gencast_mini_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/gencast_mini_demo.ipynb), which gives an +example of loading data, generating random weights or loading a `GenCast 1p0deg Mini <2019` +snapshot, generating predictions, computing the loss and computing gradients. +The one-step implementation of GenCast architecture is provided in +`gencast.py` and the relevant data, weights and statistics are in the `gencast/` +subdir of the Google Cloud Bucket. + +### Instructions for running GenCast on Google Cloud compute + +[cloud_vm_setup.md](https://github.com/google-deepmind/graphcast/blob/main/docs/cloud_vm_setup.md) +contains detailed instructions on launching a Google Cloud TPU VM. This provides +a means of running models (1-3) in the separate `gencast_demo_cloud_vm.ipynb` through [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/gencast_demo_cloud_vm.ipynb). + +### Brief description of relevant library files + +* `denoiser.py`: The GenCast denoiser for one step predictions. +* `denoisers_base.py`: Defines the interface of the denoiser. +* `dpm_solver_plus_plus_2s.py`: Sampler using DPM-Solver++ 2S from [1]. +* `gencast.py`: Combines the GenCast model architecture, wrapped as a + denoiser, with a sampler to generate predictions. +* `nan_cleaning.py`: Wraps a predictor to allow it to work with data + cleaned of NaNs. Used to remove NaNs from sea surface temperature. +* `samplers_base.py`: Defines the interface of the sampler. +* `samplers_utils.py`: Utility methods for the sampler. +* `sparse_transformer.py`: General purpose sparse transformer that + operates on `TypedGraph`'s where both inputs and outputs are flat vectors of + features for each of the nodes and edges. `predictor.py` uses one of these + for the mesh GNN. +* `sparse_transformer_utils.py`: Utility methods for the sparse + transformer. +* `transformer.py`: Wraps the mesh transformer, swapping the leading + two axes of the nodes in the input graph. + +[1] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic + Models, https://arxiv.org/abs/2211.01095 + +## GraphCast: Learning skillful medium-range global weather forecasting + +This package provides three pretrained models: + +1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree +resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017, + +2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree +resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from +1979 to 2015, useful to run a model with lower memory and compute constraints, + +3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 +pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on +HRES data from 2016 to 2021. This model can be initialized from HRES data (does +not require precipitation inputs). + +The best starting point is to open `graphcast_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb), which gives an +example of loading data, generating random weights or load a pre-trained +snapshot, generating predictions, computing the loss and computing gradients. +The one-step implementation of GraphCast architecture, is provided in +`graphcast.py` and the relevant data, weights and statistics are in the `graphcast/` +subdir of the Google Cloud Bucket. + +WARNING: For backwards compatibility, we have also left GraphCast data in the top level of the bucket. These will eventually be deleted in favour of the `graphcast/` subdir. + +### Brief description of relevant library files: + +* `casting.py`: Wrapper used around GraphCast to make it work using + BFloat16 precision. +* `graphcast.py`: The main GraphCast model architecture for one-step of + predictions. +* `solar_radiation.py`: Computes Top-Of-the-Atmosphere (TOA) incident solar + radiation compatible with ERA5. This is used as a forcing variable and thus + needs to be computed for target lead times in an operational setting. -### Dependencies. +## Dependencies. [Chex](https://github.com/deepmind/chex), [Dask](https://github.com/dask/dask), +[Dinosaur](https://github.com/google-research/dinosaur), [Haiku](https://github.com/deepmind/dm-haiku), [JAX](https://github.com/google/jax), [JAXline](https://github.com/deepmind/jaxline), @@ -85,39 +163,29 @@ The one-step implementation of GraphCast architecture, is provided in [Python](https://www.python.org/), [SciPy](https://scipy.org/), [Tree](https://github.com/deepmind/tree), -[Trimesh](https://github.com/mikedh/trimesh) and -[XArray](https://github.com/pydata/xarray). +[Trimesh](https://github.com/mikedh/trimesh) +[XArray](https://github.com/pydata/xarray) and +[XArray-TensorStore](https://github.com/google/xarray-tensorstore). -### License and attribution +## License and Disclaimers -The Colab notebook and the associated code are licensed under the Apache -License, Version 2.0. You may obtain a copy of the License at: -https://www.apache.org/licenses/LICENSE-2.0. +The Colab notebooks and the associated code are licensed under the Apache License, Version 2.0. You may obtain a copy of the License at: https://www.apache.org/licenses/LICENSE-2.0. -The model weights are made available for use under the terms of the Creative -Commons Attribution-NonCommercial-ShareAlike 4.0 International -(CC BY-NC-SA 4.0). You may obtain a copy of the License at: -https://creativecommons.org/licenses/by-nc-sa/4.0/. +The model weights are made available for use under the terms of the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You may obtain a copy of the License at: https://creativecommons.org/licenses/by-nc-sa/4.0/. -The weights were trained on ECMWF's ERA5 and HRES data. The colab includes a few -examples of ERA5 and HRES data that can be used as inputs to the models. -ECMWF data product are subject to the following terms: +This is not an officially supported Google product. -1. Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)". -2. Source www.ecmwf.int -3. Licence Statement: ECMWF data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/ -4. Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use. +Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY-NC-SA 4.0 licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses. -### Disclaimer +GenCast and GraphCast are part of an experimental research project. You are solely responsible for determining the appropriateness of using or distributing GenCast, GraphCast or any outputs generated and assume all risks associated with your use or distribution of GenCast, GraphCast and outputs and your exercise of rights and permissions granted by Google to you under the relevant License. Use discretion before relying on, publishing, downloading or otherwise using GenCast, GraphCast or any outputs generated. GenCast, GraphCast or any outputs generated (i) are not based on data published by; (ii) have not been produced in collaboration with; and (iii) have not been endorsed by any government meteorological agency or department and in no way replaces official alerts, warnings or notices published by such agencies. -This is not an officially supported Google product. +Copyright 2024 DeepMind Technologies Limited. -Copyright 2023 DeepMind Technologies Limited. -### Citation +## Citations -If you use this work, consider citing our paper ([blog post](https://deepmind.google/discover/blog/graphcast-ai-model-for-faster-and-more-accurate-global-weather-forecasting/), [Science](https://www.science.org/doi/10.1126/science.adi2336), [arXiv](https://arxiv.org/abs/2212.12794)): +If you use this work, consider citing our papers ([blog post](https://deepmind.google/discover/blog/graphcast-ai-model-for-faster-and-more-accurate-global-weather-forecasting/), [Science](https://www.science.org/doi/10.1126/science.adi2336), [arXiv](https://arxiv.org/abs/2212.12794), [arxiv GenCast](https://arxiv.org/abs/2312.15796)): ```latex @article{lam2023learning, @@ -131,3 +199,31 @@ If you use this work, consider citing our paper ([blog post](https://deepmind.go publisher={American Association for the Advancement of Science} } ``` + + +```latex +@article{price2023gencast, + title={GenCast: Diffusion-based ensemble forecasting for medium-range weather}, + author={Price, Ilan and Sanchez-Gonzalez, Alvaro and Alet, Ferran and Andersson, Tom R and El-Kadi, Andrew and Masters, Dominic and Ewalds, Timo and Stott, Jacklynn and Mohamed, Shakir and Battaglia, Peter and Lam, Remi and Willson, Matthew}, + journal={arXiv preprint arXiv:2312.15796}, + year={2023} +} +``` + +## Acknowledgements + +The (i) GenCast and GraphCast communicate with and/or reference with the following separate libraries and packages and the colab notebooks include a few examples of ECMWF’s ERA5 and HRES data that can be used as input to the models. +Data and products of the European Centre for Medium-range Weather Forecasts (ECMWF), as modified by Google. +Modified Copernicus Climate Change Service information 2023. Neither the European Commission nor ECMWF is responsible for any use that may be made of the Copernicus information or data it contains. +ECMWF HRES datasets +Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)". +Source: www.ecmwf.int +License Statement: ECMWF open data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/ +Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use. + +Use of the third-party materials referred to above may be governed by separate terms and conditions or license provisions. Your use of the third-party materials is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use. + + +## Contact + +For feedback and questions, contact us at gencast@google.com. diff --git a/gencast_demo_cloud_vm.ipynb b/gencast_demo_cloud_vm.ipynb new file mode 100644 index 00000000..122baa79 --- /dev/null +++ b/gencast_demo_cloud_vm.ipynb @@ -0,0 +1,690 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9KHpdDlzYNDI" + }, + "source": [ + "\u003e Copyright 2024 DeepMind Technologies Limited.\n", + "\u003e\n", + "\u003e Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "\u003e you may not use this file except in compliance with the License.\n", + "\u003e You may obtain a copy of the License at\n", + "\u003e\n", + "\u003e http://www.apache.org/licenses/LICENSE-2.0\n", + "\u003e\n", + "\u003e Unless required by applicable law or agreed to in writing, software\n", + "\u003e distributed under the License is distributed on an \"AS-IS\" BASIS,\n", + "\u003e WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "\u003e See the License for the specific language governing permissions and\n", + "\u003e limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GnMHywhUhlgJ" + }, + "source": [ + "# GenCast Demo\n", + "\n", + "This notebook demonstrates running all GenCast models provided in the repository:\n", + "\n", + "1. `GenCast 0p25deg \u003c2019`\n", + "2. `GenCast 0p25deg Operational \u003c2019`\n", + "3. `GenCast 1p0deg \u003c2019`\n", + "4. `GenCast 1p0deg Mini \u003c2019`\n", + "\n", + "While `GenCast 1p0deg Mini \u003c2019` is runnable with the freely provided TPUv2-8 configuration in Colab, the other models require compute that can be accessed via Google Cloud.\n", + "\n", + "See [cloud_vm_setup.md](https://github.com/google-deepmind/graphcast/blob/main/docs/cloud_vm_setup.md) for detailed instructions on launching a Google Cloud TPU VM and connecting to it via this notebook. This document also provides some more information on the memory requirements of the models.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yMbbXFl4msJw" + }, + "source": [ + "# Installation and Initialization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "233zaiZYqCnc" + }, + "outputs": [], + "source": [ + "# @title Pip install repo and dependencies\n", + "\n", + "%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z_j8ej4Pyg1L" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import dataclasses\n", + "import datetime\n", + "import math\n", + "from typing import Optional\n", + "import haiku as hk\n", + "from IPython.display import HTML\n", + "from IPython import display\n", + "import ipywidgets as widgets\n", + "import jax\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import animation\n", + "import numpy as np\n", + "import xarray\n", + "\n", + "from graphcast import rollout\n", + "from graphcast import xarray_jax\n", + "from graphcast import normalization\n", + "from graphcast import checkpoint\n", + "from graphcast import data_utils\n", + "from graphcast import xarray_tree\n", + "from graphcast import gencast\n", + "from graphcast import denoiser\n", + "from graphcast import nan_cleaning\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OzYgQ0QN-kn8" + }, + "outputs": [], + "source": [ + "# @title Plotting functions\n", + "\n", + "def select(\n", + " data: xarray.Dataset,\n", + " variable: str,\n", + " level: Optional[int] = None,\n", + " max_steps: Optional[int] = None\n", + " ) -\u003e xarray.Dataset:\n", + " data = data[variable]\n", + " if \"batch\" in data.dims:\n", + " data = data.isel(batch=0)\n", + " if max_steps is not None and \"time\" in data.sizes and max_steps \u003c data.sizes[\"time\"]:\n", + " data = data.isel(time=range(0, max_steps))\n", + " if level is not None and \"level\" in data.coords:\n", + " data = data.sel(level=level)\n", + " return data\n", + "\n", + "def scale(\n", + " data: xarray.Dataset,\n", + " center: Optional[float] = None,\n", + " robust: bool = False,\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + " vmin = np.nanpercentile(data, (2 if robust else 0))\n", + " vmax = np.nanpercentile(data, (98 if robust else 100))\n", + " if center is not None:\n", + " diff = max(vmax - center, center - vmin)\n", + " vmin = center - diff\n", + " vmax = center + diff\n", + " return (data, matplotlib.colors.Normalize(vmin, vmax),\n", + " (\"RdBu_r\" if center is not None else \"viridis\"))\n", + "\n", + "def plot_data(\n", + " data: dict[str, xarray.Dataset],\n", + " fig_title: str,\n", + " plot_size: float = 5,\n", + " robust: bool = False,\n", + " cols: int = 4\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + "\n", + " first_data = next(iter(data.values()))[0]\n", + " max_steps = first_data.sizes.get(\"time\", 1)\n", + " assert all(max_steps == d.sizes.get(\"time\", 1) for d, _, _ in data.values())\n", + "\n", + " cols = min(cols, len(data))\n", + " rows = math.ceil(len(data) / cols)\n", + " figure = plt.figure(figsize=(plot_size * 2 * cols,\n", + " plot_size * rows))\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " figure.subplots_adjust(wspace=0, hspace=0)\n", + " figure.tight_layout()\n", + "\n", + " images = []\n", + " for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):\n", + " ax = figure.add_subplot(rows, cols, i+1)\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.set_title(title)\n", + " im = ax.imshow(\n", + " plot_data.isel(time=0, missing_dims=\"ignore\"), norm=norm,\n", + " origin=\"lower\", cmap=cmap)\n", + " plt.colorbar(\n", + " mappable=im,\n", + " ax=ax,\n", + " orientation=\"vertical\",\n", + " pad=0.02,\n", + " aspect=16,\n", + " shrink=0.75,\n", + " cmap=cmap,\n", + " extend=(\"both\" if robust else \"neither\"))\n", + " images.append(im)\n", + "\n", + " def update(frame):\n", + " if \"time\" in first_data.dims:\n", + " td = datetime.timedelta(microseconds=first_data[\"time\"][frame].item() / 1000)\n", + " figure.suptitle(f\"{fig_title}, {td}\", fontsize=16)\n", + " else:\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " for im, (plot_data, norm, cmap) in zip(images, data.values()):\n", + " im.set_data(plot_data.isel(time=frame, missing_dims=\"ignore\"))\n", + "\n", + " ani = animation.FuncAnimation(\n", + " fig=figure, func=update, frames=max_steps, interval=250)\n", + " plt.close(figure.number)\n", + " return HTML(ani.to_jshtml())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rQWk0RRuCjDN" + }, + "source": [ + "# Load the Data and initialize the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jTRvMoexzjYm" + }, + "outputs": [], + "source": [ + "# @title Set paths\n", + "\n", + "MODEL_PATH = \"\" # E.g. \"GenCast 1p0deg _2019.npz\"\n", + "DATA_PATH = \"\" # E.g. \"source-era5_date-2019-03-29_res-1.0_levels-13_steps-04.nc\"\n", + "STATS_DIR = \"\" # E.g. \"stats/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cgfYjE1YhALA" + }, + "outputs": [], + "source": [ + "# @title Load the model\n", + "\n", + "with open(MODEL_PATH, \"rb\") as f:\n", + " ckpt = checkpoint.load(f, gencast.CheckPoint)\n", + "params = ckpt.params\n", + "state = {}\n", + "\n", + "task_config = ckpt.task_config\n", + "sampler_config = ckpt.sampler_config\n", + "noise_config = ckpt.noise_config\n", + "noise_encoder_config = ckpt.noise_encoder_config\n", + "denoiser_architecture_config = ckpt.denoiser_architecture_config\n", + "print(\"Model description:\\n\", ckpt.description, \"\\n\")\n", + "print(\"Model license:\\n\", ckpt.license, \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z2AqgxUgiALy" + }, + "source": [ + "## Load the example data\n", + "\n", + "Example ERA5 datasets are available at 0.25 degree and 1 degree resolution.\n", + "\n", + "Example HRES-fc0 datasets are available at 0.25 degree resolution.\n", + "\n", + "Some transformations were done from the base datasets:\n", + "- We accumulated precipitation over 12 hours instead of the default 1 hour.\n", + "- For HRES-fc0 sea surface temperature, we assigned NaNs to grid cells in which sea surface temperature was NaN in the ERA5 dataset (this remains fixed at all times).\n", + "\n", + "The data resolution must match the model that is loaded.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5XGzOww0y_BC" + }, + "outputs": [], + "source": [ + "# @title Check example dataset matches model\n", + "\n", + "def parse_file_parts(file_name):\n", + " return dict(part.split(\"-\", 1) for part in file_name.split(\"_\"))\n", + "\n", + "def data_valid_for_model(file_name: str, params_file_name: str):\n", + " \"\"\"Check data type and resolution matches.\"\"\"\n", + " data_file_parts = parse_file_parts(file_name.removesuffix(\".nc\"))\n", + " res_matches = data_file_parts[\"res\"].replace(\".\", \"p\") in params_file_name.lower()\n", + " source_matches = \"Operational\" in params_file_name\n", + " if data_file_parts[\"source\"] == \"era5\":\n", + " source_matches = not source_matches\n", + " return res_matches and source_matches\n", + "\n", + "assert data_valid_for_model(DATA_PATH, MODEL_PATH)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yz-ekISoJxeZ" + }, + "outputs": [], + "source": [ + "# @title Load weather data\n", + "\n", + "with open(DATA_PATH, \"rb\") as f:\n", + " example_batch = xarray.load_dataset(f).compute()\n", + "\n", + "assert example_batch.dims[\"time\"] \u003e= 3 # 2 for input, \u003e=1 for targets\n", + "\n", + "print(\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(DATA_PATH.removesuffix(\".nc\")).items()]))\n", + "\n", + "example_batch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iqzXVpn9_b15" + }, + "outputs": [], + "source": [ + "# @title Plot example data\n", + "\n", + "plot_size = 7\n", + "variable = \"geopotential\"\n", + "level = 500\n", + "steps = example_batch.dims[\"time\"]\n", + "\n", + "\n", + "data = {\n", + " \" \": scale(select(example_batch, variable, level, steps), robust=True),\n", + "}\n", + "fig_title = variable\n", + "if \"level\" in example_batch[variable].coords:\n", + " fig_title += f\" at {level} hPa\"\n", + "\n", + "plot_data(data, fig_title, plot_size, robust=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "njD4jsPTPKvJ" + }, + "outputs": [], + "source": [ + "# @title Extract training and eval data\n", + "\n", + "train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"12h\", \"12h\"), # Only 1AR training.\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"12h\", f\"{(example_batch.dims['time']-2)*12}h\"), # All but 2 input frames.\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "print(\"All Examples: \", example_batch.dims.mapping)\n", + "print(\"Train Inputs: \", train_inputs.dims.mapping)\n", + "print(\"Train Targets: \", train_targets.dims.mapping)\n", + "print(\"Train Forcings:\", train_forcings.dims.mapping)\n", + "print(\"Eval Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Eval Targets: \", eval_targets.dims.mapping)\n", + "print(\"Eval Forcings: \", eval_forcings.dims.mapping)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-DJzie5me2-H" + }, + "outputs": [], + "source": [ + "# @title Load normalization data\n", + "\n", + "with open(STATS_DIR +\"diffs_stddev_by_level.nc\", \"rb\") as f:\n", + " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", + "with open(STATS_DIR +\"mean_by_level.nc\", \"rb\") as f:\n", + " mean_by_level = xarray.load_dataset(f).compute()\n", + "with open(STATS_DIR +\"stddev_by_level.nc\", \"rb\") as f:\n", + " stddev_by_level = xarray.load_dataset(f).compute()\n", + "with open(STATS_DIR +\"min_by_level.nc\", \"rb\") as f:\n", + " min_by_level = xarray.load_dataset(f).compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ke2zQyuT_sMA" + }, + "outputs": [], + "source": [ + "# @title Build jitted functions, and possibly initialize random weights\n", + "\n", + "\n", + "def construct_wrapped_gencast():\n", + " \"\"\"Constructs and wraps the GenCast Predictor.\"\"\"\n", + " predictor = gencast.GenCast(\n", + " sampler_config=sampler_config,\n", + " task_config=task_config,\n", + " denoiser_architecture_config=denoiser_architecture_config,\n", + " noise_config=noise_config,\n", + " noise_encoder_config=noise_encoder_config,\n", + " )\n", + "\n", + " predictor = normalization.InputsAndResiduals(\n", + " predictor,\n", + " diffs_stddev_by_level=diffs_stddev_by_level,\n", + " mean_by_level=mean_by_level,\n", + " stddev_by_level=stddev_by_level,\n", + " )\n", + "\n", + " predictor = nan_cleaning.NaNCleaner(\n", + " predictor=predictor,\n", + " reintroduce_nans=True,\n", + " fill_value=min_by_level,\n", + " var_to_clean='sea_surface_temperature',\n", + " )\n", + "\n", + " return predictor\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def run_forward(inputs, targets_template, forcings):\n", + " predictor = construct_wrapped_gencast()\n", + " return predictor(inputs, targets_template=targets_template, forcings=forcings)\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def loss_fn(inputs, targets, forcings):\n", + " predictor = construct_wrapped_gencast()\n", + " loss, diagnostics = predictor.loss(inputs, targets, forcings)\n", + " return xarray_tree.map_structure(\n", + " lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),\n", + " (loss, diagnostics),\n", + " )\n", + "\n", + "\n", + "def grads_fn(params, state, inputs, targets, forcings):\n", + " def _aux(params, state, i, t, f):\n", + " (loss, diagnostics), next_state = loss_fn.apply(\n", + " params, state, jax.random.PRNGKey(0), i, t, f\n", + " )\n", + " return loss, (diagnostics, next_state)\n", + "\n", + " (loss, (diagnostics, next_state)), grads = jax.value_and_grad(\n", + " _aux, has_aux=True\n", + " )(params, state, inputs, targets, forcings)\n", + " return loss, diagnostics, next_state, grads\n", + "\n", + "\n", + "if params is None:\n", + " init_jitted = jax.jit(loss_fn.init)\n", + " params, state = init_jitted(\n", + " rng=jax.random.PRNGKey(0),\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings,\n", + " )\n", + "\n", + "\n", + "loss_fn_jitted = jax.jit(\n", + " lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]\n", + ")\n", + "grads_fn_jitted = jax.jit(grads_fn)\n", + "run_forward_jitted = jax.jit(\n", + " lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VBNutliiCyqA" + }, + "source": [ + "# Run the model\n", + "\n", + "The `chunked_prediction_generator_multiple_runs` iterates over forecast steps, where the 1 step forecast is jitted and samples are pmapped across the chips.\n", + "This allows us to make efficient use of all devices and parallelise generating an ensemble across them. We then combine the chunks at the end to form our final forecast.\n", + "\n", + "Note that the cell will take longer than the standard inference time to run when executed for the first time, as this will include code compilation time. This cost does not increase with the number of devices, it is a fixed-cost one time operation whose result can be reused across any number of devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t-6ik5tU1yr7" + }, + "outputs": [], + "source": [ + "# The number of ensemble members should be a multiple of the number of devices.\n", + "len(jax.local_devices())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7obeY9i9oTtD" + }, + "outputs": [], + "source": [ + "# @title Autoregressive rollout (loop in python)\n", + "\n", + "print(\"Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Targets: \", eval_targets.dims.mapping)\n", + "print(\"Forcings:\", eval_forcings.dims.mapping)\n", + "\n", + "num_ensemble_members = 8 # @param int\n", + "rng = jax.random.PRNGKey(0)\n", + "# We fold-in the ensemble member, this way the first N members should always\n", + "# match across different runs which use take the same inputs\n", + "# regardless of total ensemble size.\n", + "rngs = np.stack(\n", + " [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)\n", + "\n", + "chunks = []\n", + "for chunk in rollout.chunked_prediction_generator_multiple_runs(\n", + " xarray_jax.pmap(run_forward_jitted, dim=\"sample\"),\n", + " rngs=rngs,\n", + " inputs=eval_inputs,\n", + " targets_template=eval_targets * np.nan,\n", + " forcings=eval_forcings,\n", + " num_steps_per_chunk = 1,\n", + " num_samples = num_ensemble_members,\n", + " pmap_devices=jax.local_devices()\n", + " ):\n", + " chunks.append(chunk)\n", + "predictions = xarray.combine_by_coords(chunks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wn7dccXO5R7C" + }, + "outputs": [], + "source": [ + "# @title Plot prediction samples and diffs\n", + "\n", + "plot_size = 5\n", + "variable = \"2m_temperature\"\n", + "level = None\n", + "steps = predictions.dims[\"time\"]\n", + "\n", + "fig_title = variable\n", + "if \"level\" in predictions[variable].coords:\n", + " fig_title += f\" at {level} hPa\"\n", + "\n", + "for sample_idx in range(num_ensemble_members):\n", + " data = {\n", + " \"Targets\": scale(select(eval_targets, variable, level, steps), robust=True),\n", + " \"Predictions\": scale(select(predictions.isel(sample=sample_idx), variable, level, steps), robust=True),\n", + " \"Diff\": scale((select(eval_targets, variable, level, steps) -\n", + " select(predictions.isel(sample=sample_idx), variable, level, steps)),\n", + " robust=True, center=0),\n", + " }\n", + " display.display(plot_data(data, fig_title + f\", Sample {sample_idx}\", plot_size, robust=True))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X3m9lW5fN4oL" + }, + "outputs": [], + "source": [ + "# @title Plot ensemble mean and CRPS\n", + "\n", + "def crps(targets, predictions, bias_corrected = True):\n", + " if predictions.sizes.get(\"sample\", 1) \u003c 2:\n", + " raise ValueError(\n", + " \"predictions must have dim 'sample' with size at least 2.\")\n", + " sum_dims = [\"sample\", \"sample2\"]\n", + " preds2 = predictions.rename({\"sample\": \"sample2\"})\n", + " num_samps = predictions.sizes[\"sample\"]\n", + " num_samps2 = (num_samps - 1) if bias_corrected else num_samps\n", + " mean_abs_diff = np.abs(\n", + " predictions - preds2).sum(\n", + " dim=sum_dims, skipna=False) / (num_samps * num_samps2)\n", + " mean_abs_err = np.abs(targets - predictions).sum(dim=\"sample\", skipna=False) / num_samps\n", + " return mean_abs_err - 0.5 * mean_abs_diff\n", + "\n", + "\n", + "plot_size = 5\n", + "variable = \"2m_temperature\"\n", + "level = None\n", + "steps = predictions.dims[\"time\"]\n", + "\n", + "fig_title = variable\n", + "if \"level\" in predictions[variable].coords:\n", + " fig_title += f\" at {level} hPa\"\n", + "\n", + "data = {\n", + " \"Targets\": scale(select(eval_targets, variable, level, steps), robust=True),\n", + " \"Ensemble Mean\": scale(select(predictions.mean(dim=[\"sample\"]), variable, level, steps), robust=True),\n", + " \"Ensemble CRPS\": scale(crps((select(eval_targets, variable, level, steps)),\n", + " select(predictions, variable, level, steps)),\n", + " robust=True, center=0),\n", + "}\n", + "display.display(plot_data(data, fig_title, plot_size, robust=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZLI0DhWog3Rg" + }, + "outputs": [], + "source": [ + "# @title (Optional) Save the predictions.\n", + "predictions.to_zarr(\"predictions.zarr\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O6ZhRFBPD0kq" + }, + "source": [ + "# Train the model\n", + "\n", + "The following operations requires larger amounts of memory than running inference.\n", + "\n", + "The first time executing the cell takes more time, as it includes the time to jit the function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nv-u3dAP7IRZ" + }, + "outputs": [], + "source": [ + "# @title Loss computation\n", + "loss, diagnostics = loss_fn_jitted(\n", + " jax.random.PRNGKey(0),\n", + " train_inputs,\n", + " train_targets,\n", + " train_forcings)\n", + "print(\"Loss:\", float(loss))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mBNFq1IGZNLz" + }, + "outputs": [], + "source": [ + "# @title Gradient computation\n", + "loss, diagnostics, next_state, grads = grads_fn_jitted(\n", + " params=params,\n", + " state=state,\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings)\n", + "mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", + "print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//gdm/weather/colab_base:weather_notebook", + "kind": "private" + }, + "name": "GenCast Cloud VM", + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/gencast_mini_demo.ipynb b/gencast_mini_demo.ipynb new file mode 100644 index 00000000..4d0959ac --- /dev/null +++ b/gencast_mini_demo.ipynb @@ -0,0 +1,875 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9KHpdDlzYNDI" + }, + "source": [ + "\u003e Copyright 2024 DeepMind Technologies Limited.\n", + "\u003e\n", + "\u003e Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "\u003e you may not use this file except in compliance with the License.\n", + "\u003e You may obtain a copy of the License at\n", + "\u003e\n", + "\u003e http://www.apache.org/licenses/LICENSE-2.0\n", + "\u003e\n", + "\u003e Unless required by applicable law or agreed to in writing, software\n", + "\u003e distributed under the License is distributed on an \"AS-IS\" BASIS,\n", + "\u003e WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "\u003e See the License for the specific language governing permissions and\n", + "\u003e limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GnMHywhUhlgJ" + }, + "source": [ + "# GenCast Mini Demo\n", + "\n", + "This notebook demonstrates running `GenCast 1p0deg Mini \u003c2019`.\n", + "\n", + "`GenCast 1p0deg Mini \u003c2019` is a GenCast model at 1deg resolution, with 13 pressure levels and a 4 times refined icosahedral mesh. It is trained on ERA5 data from 1979 to 2018, and can be causally evaluated on 2019 and later years.\n", + "\n", + "While other GenCast models are [available](https://github.com/google-deepmind/graphcast/blob/main/README.md), this model has the smallest memory footprint of those provided and is the only one runnable with the freely provided TPUv2-8 configuration in Colab. You can select this configuration in `Runtime\u003eChange Runtime Type`.\n", + "\n", + "**N.B.** The performance of `GenCast 1p0deg Mini \u003c2019` is reasonable but is not representative of the performance of the other GenCast models described in the [README](https://github.com/google-deepmind/graphcast/blob/main/README.md).\n", + "\n", + "To run the other models using Google Cloud Compute, refer to [gencast_demo_cloud_vm.ipynb](https://colab.research.google.com/github/deepmind/graphcast/blob/master/gencast_demo_cloud_vm.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yMbbXFl4msJw" + }, + "source": [ + "# Installation and Initialization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-gAH79SRwp9G" + }, + "outputs": [], + "source": [ + "# @title Upgrade packages (kernel needs to be restarted after running this cell).\n", + "\n", + "%pip install -U importlib_metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "233zaiZYqCnc" + }, + "outputs": [], + "source": [ + "# @title Pip install repo and dependencies\n", + "\n", + "%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z_j8ej4Pyg1L" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import dataclasses\n", + "import datetime\n", + "import math\n", + "from google.cloud import storage\n", + "from typing import Optional\n", + "import haiku as hk\n", + "from IPython.display import HTML\n", + "from IPython import display\n", + "import ipywidgets as widgets\n", + "import jax\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import animation\n", + "import numpy as np\n", + "import xarray\n", + "\n", + "from graphcast import rollout\n", + "from graphcast import xarray_jax\n", + "from graphcast import normalization\n", + "from graphcast import checkpoint\n", + "from graphcast import data_utils\n", + "from graphcast import xarray_tree\n", + "from graphcast import gencast\n", + "from graphcast import denoiser\n", + "from graphcast import nan_cleaning\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "OzYgQ0QN-kn8" + }, + "outputs": [], + "source": [ + "# @title Plotting functions\n", + "\n", + "def select(\n", + " data: xarray.Dataset,\n", + " variable: str,\n", + " level: Optional[int] = None,\n", + " max_steps: Optional[int] = None\n", + " ) -\u003e xarray.Dataset:\n", + " data = data[variable]\n", + " if \"batch\" in data.dims:\n", + " data = data.isel(batch=0)\n", + " if max_steps is not None and \"time\" in data.sizes and max_steps \u003c data.sizes[\"time\"]:\n", + " data = data.isel(time=range(0, max_steps))\n", + " if level is not None and \"level\" in data.coords:\n", + " data = data.sel(level=level)\n", + " return data\n", + "\n", + "def scale(\n", + " data: xarray.Dataset,\n", + " center: Optional[float] = None,\n", + " robust: bool = False,\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + " vmin = np.nanpercentile(data, (2 if robust else 0))\n", + " vmax = np.nanpercentile(data, (98 if robust else 100))\n", + " if center is not None:\n", + " diff = max(vmax - center, center - vmin)\n", + " vmin = center - diff\n", + " vmax = center + diff\n", + " return (data, matplotlib.colors.Normalize(vmin, vmax),\n", + " (\"RdBu_r\" if center is not None else \"viridis\"))\n", + "\n", + "def plot_data(\n", + " data: dict[str, xarray.Dataset],\n", + " fig_title: str,\n", + " plot_size: float = 5,\n", + " robust: bool = False,\n", + " cols: int = 4\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + "\n", + " first_data = next(iter(data.values()))[0]\n", + " max_steps = first_data.sizes.get(\"time\", 1)\n", + " assert all(max_steps == d.sizes.get(\"time\", 1) for d, _, _ in data.values())\n", + "\n", + " cols = min(cols, len(data))\n", + " rows = math.ceil(len(data) / cols)\n", + " figure = plt.figure(figsize=(plot_size * 2 * cols,\n", + " plot_size * rows))\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " figure.subplots_adjust(wspace=0, hspace=0)\n", + " figure.tight_layout()\n", + "\n", + " images = []\n", + " for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):\n", + " ax = figure.add_subplot(rows, cols, i+1)\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.set_title(title)\n", + " im = ax.imshow(\n", + " plot_data.isel(time=0, missing_dims=\"ignore\"), norm=norm,\n", + " origin=\"lower\", cmap=cmap)\n", + " plt.colorbar(\n", + " mappable=im,\n", + " ax=ax,\n", + " orientation=\"vertical\",\n", + " pad=0.02,\n", + " aspect=16,\n", + " shrink=0.75,\n", + " cmap=cmap,\n", + " extend=(\"both\" if robust else \"neither\"))\n", + " images.append(im)\n", + "\n", + " def update(frame):\n", + " if \"time\" in first_data.dims:\n", + " td = datetime.timedelta(microseconds=first_data[\"time\"][frame].item() / 1000)\n", + " figure.suptitle(f\"{fig_title}, {td}\", fontsize=16)\n", + " else:\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " for im, (plot_data, norm, cmap) in zip(images, data.values()):\n", + " im.set_data(plot_data.isel(time=frame, missing_dims=\"ignore\"))\n", + "\n", + " ani = animation.FuncAnimation(\n", + " fig=figure, func=update, frames=max_steps, interval=250)\n", + " plt.close(figure.number)\n", + " return HTML(ani.to_jshtml())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rQWk0RRuCjDN" + }, + "source": [ + "# Load the Data and initialize the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ttMHeiGCppjB" + }, + "outputs": [], + "source": [ + "# @title Authenticate with Google Cloud Storage\n", + "\n", + "# Gives you an authenticated client, in case you want to use a private bucket.\n", + "gcs_client = storage.Client.create_anonymous_client()\n", + "gcs_bucket = gcs_client.get_bucket(\"dm_graphcast\")\n", + "dir_prefix = \"gencast/\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ty5WSDRjhDBF" + }, + "source": [ + "## Load the model params\n", + "\n", + "Choose one of the two ways of getting model params:\n", + "- **random**: You'll get random predictions, but you can change the model architecture, which may run faster or fit on your device.\n", + "- **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PMoFXuZXs-xg" + }, + "outputs": [], + "source": [ + "# @title Choose the model\n", + "\n", + "params_file_options = [\n", + " name for blob in gcs_bucket.list_blobs(prefix=(dir_prefix+\"params/\"))\n", + " if (name := blob.name.removeprefix(dir_prefix+\"params/\"))] # Drop empty string.\n", + "\n", + "latent_value_options = [int(2**i) for i in range(4, 10)]\n", + "random_latent_size = widgets.Dropdown(\n", + " options=latent_value_options, value=512,description=\"Latent size:\")\n", + "random_attention_type = widgets.Dropdown(\n", + " options=[\"splash_mha\", \"triblockdiag_mha\", \"mha\"], value=\"splash_mha\", description=\"Attention:\")\n", + "random_mesh_size = widgets.IntSlider(\n", + " value=4, min=4, max=6, description=\"Mesh size:\")\n", + "random_num_heads = widgets.Dropdown(\n", + " options=[int(2**i) for i in range(0, 3)], value=4,description=\"Num heads:\")\n", + "random_attention_k_hop = widgets.Dropdown(\n", + " options=[int(2**i) for i in range(2, 5)], value=16,description=\"Attn k hop:\")\n", + "\n", + "def update_latent_options(*args):\n", + " def _latent_valid_for_attn(attn, latent, heads):\n", + " head_dim, rem = divmod(latent, heads)\n", + " if rem != 0:\n", + " return False\n", + " # Required for splash attn.\n", + " if head_dim % 128 != 0:\n", + " return attn != \"splash_mha\"\n", + " return True\n", + " attn = random_attention_type.value\n", + " heads = random_num_heads.value\n", + " random_latent_size.options = [\n", + " latent for latent in latent_value_options\n", + " if _latent_valid_for_attn(attn, latent, heads)]\n", + "\n", + "# Observe changes to only allow for valid combinations.\n", + "random_attention_type.observe(update_latent_options, \"value\")\n", + "random_latent_size.observe(update_latent_options, \"value\")\n", + "random_num_heads.observe(update_latent_options, \"value\")\n", + "\n", + "params_file = widgets.Dropdown(\n", + " options=[f for f in params_file_options if \"Mini\" in f],\n", + " description=\"Params file:\",\n", + " layout={\"width\": \"max-content\"})\n", + "\n", + "source_tab = widgets.Tab([\n", + " widgets.VBox([\n", + " random_attention_type,\n", + " random_mesh_size,\n", + " random_num_heads,\n", + " random_latent_size,\n", + " random_attention_k_hop\n", + " ]),\n", + " params_file,\n", + "])\n", + "source_tab.set_title(0, \"Random\")\n", + "source_tab.set_title(1, \"Checkpoint\")\n", + "widgets.VBox([\n", + " source_tab,\n", + " widgets.Label(value=\"Run the next cell to load the model. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "cgfYjE1YhALA" + }, + "outputs": [], + "source": [ + "# @title Load the model\n", + "\n", + "source = source_tab.get_title(source_tab.selected_index)\n", + "\n", + "if source == \"Random\":\n", + " params = None # Filled in below\n", + " state = {}\n", + " task_config = gencast.TASK\n", + " # Use default values.\n", + " sampler_config = gencast.SamplerConfig()\n", + " noise_config = gencast.NoiseConfig()\n", + " noise_encoder_config = denoiser.NoiseEncoderConfig()\n", + " # Configure, otherwise use default values.\n", + " denoiser_architecture_config = denoiser.DenoiserArchitectureConfig(\n", + " sparse_transformer_config = denoiser.SparseTransformerConfig(\n", + " attention_k_hop=random_attention_k_hop.value,\n", + " attention_type=random_attention_type.value,\n", + " d_model=random_latent_size.value,\n", + " num_heads=random_num_heads.value\n", + " ),\n", + " mesh_size=random_mesh_size.value,\n", + " latent_size=random_latent_size.value,\n", + " )\n", + "else:\n", + " assert source == \"Checkpoint\"\n", + " with gcs_bucket.blob(dir_prefix + f\"params/{params_file.value}\").open(\"rb\") as f:\n", + " ckpt = checkpoint.load(f, gencast.CheckPoint)\n", + " params = ckpt.params\n", + " state = {}\n", + "\n", + " task_config = ckpt.task_config\n", + " sampler_config = ckpt.sampler_config\n", + " noise_config = ckpt.noise_config\n", + " noise_encoder_config = ckpt.noise_encoder_config\n", + " denoiser_architecture_config = ckpt.denoiser_architecture_config\n", + " print(\"Model description:\\n\", ckpt.description, \"\\n\")\n", + " print(\"Model license:\\n\", ckpt.license, \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z2AqgxUgiALy" + }, + "source": [ + "## Load the example data\n", + "\n", + "Example ERA5 datasets are available at 0.25 degree and 1 degree resolution.\n", + "\n", + "Example HRES-fc0 datasets are available at 0.25 degree resolution.\n", + "\n", + "Some transformations were done from the base datasets:\n", + "- We accumulated precipitation over 12 hours instead of the default 1 hour.\n", + "- For HRES-fc0 sea surface temperature, we assigned NaNs to grid cells in which sea surface temperature was NaN in the ERA5 dataset (this remains fixed at all times).\n", + "\n", + "The data resolution must match the model that is loaded. Since we are running GenCast Mini, this will be 1 degree.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "5XGzOww0y_BC" + }, + "outputs": [], + "source": [ + "# @title Get and filter the list of available example datasets\n", + "\n", + "dataset_file_options = [\n", + " name for blob in gcs_bucket.list_blobs(prefix=(dir_prefix + \"dataset/\"))\n", + " if (name := blob.name.removeprefix(dir_prefix+\"dataset/\"))] # Drop empty string.\n", + "\n", + "def parse_file_parts(file_name):\n", + " return dict(part.split(\"-\", 1) for part in file_name.split(\"_\"))\n", + "\n", + "\n", + "def data_valid_for_model(file_name: str, params_file_name: str):\n", + " \"\"\"Check data type and resolution matches.\"\"\"\n", + " if source == \"Random\":\n", + " return True\n", + " data_file_parts = parse_file_parts(file_name.removesuffix(\".nc\"))\n", + " res_matches = data_file_parts[\"res\"].replace(\".\", \"p\") in params_file_name.lower()\n", + " source_matches = \"Operational\" in params_file_name\n", + " if data_file_parts[\"source\"] == \"era5\":\n", + " source_matches = not source_matches\n", + " return res_matches and source_matches\n", + "\n", + "dataset_file = widgets.Dropdown(\n", + " options=[\n", + " (\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(option.removesuffix(\".nc\")).items()]), option)\n", + " for option in dataset_file_options\n", + " if data_valid_for_model(option, params_file.value)\n", + " ],\n", + " description=\"Dataset file:\",\n", + " layout={\"width\": \"max-content\"})\n", + "widgets.VBox([\n", + " dataset_file,\n", + " widgets.Label(value=\"Run the next cell to load the dataset. Rerunning this cell clears your selection and refilters the datasets that match your model.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Yz-ekISoJxeZ" + }, + "outputs": [], + "source": [ + "# @title Load weather data\n", + "\n", + "with gcs_bucket.blob(dir_prefix+f\"dataset/{dataset_file.value}\").open(\"rb\") as f:\n", + " example_batch = xarray.load_dataset(f).compute()\n", + "\n", + "assert example_batch.dims[\"time\"] \u003e= 3 # 2 for input, \u003e=1 for targets\n", + "\n", + "print(\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(dataset_file.value.removesuffix(\".nc\")).items()]))\n", + "\n", + "example_batch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "lXjFvdE6qStr" + }, + "outputs": [], + "source": [ + "# @title Choose data to plot\n", + "\n", + "plot_example_variable = widgets.Dropdown(\n", + " options=example_batch.data_vars.keys(),\n", + " value=\"2m_temperature\",\n", + " description=\"Variable\")\n", + "plot_example_level = widgets.Dropdown(\n", + " options=example_batch.coords[\"level\"].values,\n", + " value=500,\n", + " description=\"Level\")\n", + "plot_example_robust = widgets.Checkbox(value=True, description=\"Robust\")\n", + "plot_example_max_steps = widgets.IntSlider(\n", + " min=1, max=example_batch.dims[\"time\"], value=example_batch.dims[\"time\"],\n", + " description=\"Max steps\")\n", + "\n", + "widgets.VBox([\n", + " plot_example_variable,\n", + " plot_example_level,\n", + " plot_example_robust,\n", + " plot_example_max_steps,\n", + " widgets.Label(value=\"Run the next cell to plot the data. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "iqzXVpn9_b15" + }, + "outputs": [], + "source": [ + "# @title Plot example data\n", + "\n", + "plot_size = 7\n", + "\n", + "data = {\n", + " \" \": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),\n", + " robust=plot_example_robust.value),\n", + "}\n", + "fig_title = plot_example_variable.value\n", + "if \"level\" in example_batch[plot_example_variable.value].coords:\n", + " fig_title += f\" at {plot_example_level.value} hPa\"\n", + "\n", + "plot_data(data, fig_title, plot_size, plot_example_robust.value)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "njD4jsPTPKvJ" + }, + "outputs": [], + "source": [ + "# @title Extract training and eval data\n", + "\n", + "train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"12h\", \"12h\"), # Only 1AR training.\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"12h\", f\"{(example_batch.dims['time']-2)*12}h\"), # All but 2 input frames.\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "print(\"All Examples: \", example_batch.dims.mapping)\n", + "print(\"Train Inputs: \", train_inputs.dims.mapping)\n", + "print(\"Train Targets: \", train_targets.dims.mapping)\n", + "print(\"Train Forcings:\", train_forcings.dims.mapping)\n", + "print(\"Eval Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Eval Targets: \", eval_targets.dims.mapping)\n", + "print(\"Eval Forcings: \", eval_forcings.dims.mapping)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-DJzie5me2-H" + }, + "outputs": [], + "source": [ + "# @title Load normalization data\n", + "\n", + "with gcs_bucket.blob(dir_prefix+\"stats/diffs_stddev_by_level.nc\").open(\"rb\") as f:\n", + " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", + "with gcs_bucket.blob(dir_prefix+\"stats/mean_by_level.nc\").open(\"rb\") as f:\n", + " mean_by_level = xarray.load_dataset(f).compute()\n", + "with gcs_bucket.blob(dir_prefix+\"stats/stddev_by_level.nc\").open(\"rb\") as f:\n", + " stddev_by_level = xarray.load_dataset(f).compute()\n", + "with gcs_bucket.blob(dir_prefix+\"stats/min_by_level.nc\").open(\"rb\") as f:\n", + " min_by_level = xarray.load_dataset(f).compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ke2zQyuT_sMA" + }, + "outputs": [], + "source": [ + "# @title Build jitted functions, and possibly initialize random weights\n", + "\n", + "\n", + "def construct_wrapped_gencast():\n", + " \"\"\"Constructs and wraps the GenCast Predictor.\"\"\"\n", + " predictor = gencast.GenCast(\n", + " sampler_config=sampler_config,\n", + " task_config=task_config,\n", + " denoiser_architecture_config=denoiser_architecture_config,\n", + " noise_config=noise_config,\n", + " noise_encoder_config=noise_encoder_config,\n", + " )\n", + "\n", + " predictor = normalization.InputsAndResiduals(\n", + " predictor,\n", + " diffs_stddev_by_level=diffs_stddev_by_level,\n", + " mean_by_level=mean_by_level,\n", + " stddev_by_level=stddev_by_level,\n", + " )\n", + "\n", + " predictor = nan_cleaning.NaNCleaner(\n", + " predictor=predictor,\n", + " reintroduce_nans=True,\n", + " fill_value=min_by_level,\n", + " var_to_clean='sea_surface_temperature',\n", + " )\n", + "\n", + " return predictor\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def run_forward(inputs, targets_template, forcings):\n", + " predictor = construct_wrapped_gencast()\n", + " return predictor(inputs, targets_template=targets_template, forcings=forcings)\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def loss_fn(inputs, targets, forcings):\n", + " predictor = construct_wrapped_gencast()\n", + " loss, diagnostics = predictor.loss(inputs, targets, forcings)\n", + " return xarray_tree.map_structure(\n", + " lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),\n", + " (loss, diagnostics),\n", + " )\n", + "\n", + "\n", + "def grads_fn(params, state, inputs, targets, forcings):\n", + " def _aux(params, state, i, t, f):\n", + " (loss, diagnostics), next_state = loss_fn.apply(\n", + " params, state, jax.random.PRNGKey(0), i, t, f\n", + " )\n", + " return loss, (diagnostics, next_state)\n", + "\n", + " (loss, (diagnostics, next_state)), grads = jax.value_and_grad(\n", + " _aux, has_aux=True\n", + " )(params, state, inputs, targets, forcings)\n", + " return loss, diagnostics, next_state, grads\n", + "\n", + "\n", + "if params is None:\n", + " init_jitted = jax.jit(loss_fn.init)\n", + " params, state = init_jitted(\n", + " rng=jax.random.PRNGKey(0),\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings,\n", + " )\n", + "\n", + "\n", + "loss_fn_jitted = jax.jit(\n", + " lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]\n", + ")\n", + "grads_fn_jitted = jax.jit(grads_fn)\n", + "run_forward_jitted = jax.jit(\n", + " lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VBNutliiCyqA" + }, + "source": [ + "# Run the model\n", + "\n", + "The `chunked_prediction_generator_multiple_runs` iterates over forecast steps, where the 1 step forecast is jitted and samples are pmapped across the chips.\n", + "This allows us to make efficient use of all devices and parallelise generating an ensemble across them. We then combine the chunks at the end to form our final forecast.\n", + "\n", + "Note that the cell will take longer than the standard inference time to run when executed for the first time, as this will include code compilation time. This cost does not increase with the number of devices, it is a fixed-cost one time operation whose result can be reused across any number of devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7obeY9i9oTtD" + }, + "outputs": [], + "source": [ + "# @title Autoregressive rollout (loop in python)\n", + "\n", + "print(\"Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Targets: \", eval_targets.dims.mapping)\n", + "print(\"Forcings:\", eval_forcings.dims.mapping)\n", + "\n", + "num_ensemble_members = 8 # @param int\n", + "rng = jax.random.PRNGKey(0)\n", + "# We fold-in the ensemble member, this way the first N members should always\n", + "# match across different runs which use take the same inputs, regardless of\n", + "# total ensemble size.\n", + "rngs = np.stack(\n", + " [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)\n", + "\n", + "chunks = []\n", + "for chunk in rollout.chunked_prediction_generator_multiple_runs(\n", + " xarray_jax.pmap(run_forward_jitted, dim=\"sample\"),\n", + " rngs=rngs,\n", + " inputs=eval_inputs,\n", + " targets_template=eval_targets * np.nan,\n", + " forcings=eval_forcings,\n", + " num_steps_per_chunk = 1,\n", + " num_samples = num_ensemble_members,\n", + " pmap_devices=jax.local_devices()\n", + " ):\n", + " chunks.append(chunk)\n", + "predictions = xarray.combine_by_coords(chunks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "nIoUfBxBAwqm" + }, + "outputs": [], + "source": [ + "# @title Choose predictions to plot\n", + "\n", + "plot_pred_variable = widgets.Dropdown(\n", + " options=predictions.data_vars.keys(),\n", + " value=\"2m_temperature\",\n", + " description=\"Variable\")\n", + "plot_pred_level = widgets.Dropdown(\n", + " options=predictions.coords[\"level\"].values,\n", + " value=500,\n", + " description=\"Level\")\n", + "plot_pred_robust = widgets.Checkbox(value=True, description=\"Robust\")\n", + "plot_pred_max_steps = widgets.IntSlider(\n", + " min=1,\n", + " max=predictions.dims[\"time\"],\n", + " value=predictions.dims[\"time\"],\n", + " description=\"Max steps\")\n", + "plot_pred_samples = widgets.IntSlider(\n", + " min=1,\n", + " max=num_ensemble_members,\n", + " value=num_ensemble_members,\n", + " description=\"Samples\")\n", + "\n", + "widgets.VBox([\n", + " plot_pred_variable,\n", + " plot_pred_level,\n", + " plot_pred_robust,\n", + " plot_pred_max_steps,\n", + " plot_pred_samples,\n", + " widgets.Label(value=\"Run the next cell to plot the predictions. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "wn7dccXO5R7C" + }, + "outputs": [], + "source": [ + "# @title Plot prediction samples and diffs\n", + "\n", + "plot_size = 5\n", + "plot_max_steps = min(predictions.dims[\"time\"], plot_pred_max_steps.value)\n", + "\n", + "fig_title = plot_pred_variable.value\n", + "if \"level\" in predictions[plot_pred_variable.value].coords:\n", + " fig_title += f\" at {plot_pred_level.value} hPa\"\n", + "\n", + "for sample_idx in range(plot_pred_samples.value):\n", + " data = {\n", + " \"Targets\": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Predictions\": scale(select(predictions.isel(sample=sample_idx), plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Diff\": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -\n", + " select(predictions.isel(sample=sample_idx), plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),\n", + " robust=plot_pred_robust.value, center=0),\n", + " }\n", + " display.display(plot_data(data, fig_title + f\", Sample {sample_idx}\", plot_size, plot_pred_robust.value))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X3m9lW5fN4oL" + }, + "outputs": [], + "source": [ + "# @title Plot ensemble mean and CRPS\n", + "\n", + "def crps(targets, predictions, bias_corrected = True):\n", + " if predictions.sizes.get(\"sample\", 1) \u003c 2:\n", + " raise ValueError(\n", + " \"predictions must have dim 'sample' with size at least 2.\")\n", + " sum_dims = [\"sample\", \"sample2\"]\n", + " preds2 = predictions.rename({\"sample\": \"sample2\"})\n", + " num_samps = predictions.sizes[\"sample\"]\n", + " num_samps2 = (num_samps - 1) if bias_corrected else num_samps\n", + " mean_abs_diff = np.abs(\n", + " predictions - preds2).sum(\n", + " dim=sum_dims, skipna=False) / (num_samps * num_samps2)\n", + " mean_abs_err = np.abs(targets - predictions).sum(dim=\"sample\", skipna=False) / num_samps\n", + " return mean_abs_err - 0.5 * mean_abs_diff\n", + "\n", + "\n", + "plot_size = 5\n", + "plot_max_steps = min(predictions.dims[\"time\"], plot_pred_max_steps.value)\n", + "\n", + "fig_title = plot_pred_variable.value\n", + "if \"level\" in predictions[plot_pred_variable.value].coords:\n", + " fig_title += f\" at {plot_pred_level.value} hPa\"\n", + "\n", + "data = {\n", + " \"Targets\": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Ensemble Mean\": scale(select(predictions.mean(dim=[\"sample\"]), plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Ensemble CRPS\": scale(crps((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),\n", + " select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),\n", + " robust=plot_pred_robust.value, center=0),\n", + "}\n", + "display.display(plot_data(data, fig_title, plot_size, plot_pred_robust.value))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O6ZhRFBPD0kq" + }, + "source": [ + "# Train the model\n", + "\n", + "The following operations requires larger amounts of memory than running inference.\n", + "\n", + "The first time executing the cell takes more time, as it includes the time to jit the function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nv-u3dAP7IRZ" + }, + "outputs": [], + "source": [ + "# @title Loss computation\n", + "loss, diagnostics = loss_fn_jitted(\n", + " jax.random.PRNGKey(0),\n", + " train_inputs,\n", + " train_targets,\n", + " train_forcings)\n", + "print(\"Loss:\", float(loss))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mBNFq1IGZNLz" + }, + "outputs": [], + "source": [ + "# @title Gradient computation\n", + "loss, diagnostics, next_state, grads = grads_fn_jitted(\n", + " params=params,\n", + " state=state,\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings)\n", + "mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", + "print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "z2AqgxUgiALy" + ], + "name": "GenCast Mini Demo", + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/graphcast/autoregressive.py b/graphcast/autoregressive.py index 1cf13247..624c8a9d 100644 --- a/graphcast/autoregressive.py +++ b/graphcast/autoregressive.py @@ -246,7 +246,7 @@ def add_noise(x): return x + self._noise_level * jax.random.normal( hk.next_rng_key(), shape=x.shape) # Add noise to time-dependent variables of the inputs. - inputs = jax.tree_map(add_noise, inputs) + inputs = jax.tree.map(add_noise, inputs) # The per-timestep targets passed by scan to one_step_loss below will have # no leading time axis. We need a treedef without the time axis to use diff --git a/graphcast/casting.py b/graphcast/casting.py index cb0fcdc2..9678f91c 100644 --- a/graphcast/casting.py +++ b/graphcast/casting.py @@ -140,7 +140,7 @@ def _all_inputs_to_bfloat16( xarray.Dataset, xarray.Dataset]: return (inputs.astype(jnp.bfloat16), - jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets), + jax.tree.map(lambda x: x.astype(jnp.bfloat16), targets), forcings.astype(jnp.bfloat16)) @@ -149,7 +149,7 @@ def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype, def cast_fn(x): if x.dtype == input_dtype: return x.astype(output_dtype) - return jax.tree_map(cast_fn, inputs) + return jax.tree.map(cast_fn, inputs) @contextlib.contextmanager diff --git a/graphcast/deep_typed_graph_net.py b/graphcast/deep_typed_graph_net.py index 93a1bd38..27a7a670 100644 --- a/graphcast/deep_typed_graph_net.py +++ b/graphcast/deep_typed_graph_net.py @@ -34,8 +34,11 @@ } """ -from typing import Mapping, Optional +import functools +from typing import Callable, List, Mapping, Optional, Tuple +import chex +from graphcast import mlp as mlp_builder from graphcast import typed_graph from graphcast import typed_graph_net import haiku as hk @@ -44,6 +47,9 @@ import jraph +GraphToGraphNetwork = Callable[[typed_graph.TypedGraph], typed_graph.TypedGraph] + + class DeepTypedGraphNet(hk.Module): """Deep Graph Neural Network. @@ -87,6 +93,7 @@ def __init__(self, edge_output_size: Optional[Mapping[str, int]] = None, include_sent_messages_in_node_update: bool = False, use_layer_norm: bool = True, + use_norm_conditioning: bool = False, activation: str = "relu", f32_aggregation: bool = False, aggregate_edges_for_nodes_fn: str = "segment_sum", @@ -114,6 +121,18 @@ def __init__(self, include_sent_messages_in_node_update: Whether to include pooled sent messages from each node in the node update. use_layer_norm: Whether it uses layer norm or not. + use_norm_conditioning: If True, the latent feaures outputted by the + activation normalization that follows the MLPs (e.g. LayerNorm), rather + than being scaled/offset by learned parameters of the normalization + module, will be scaled/offset by offsets/biases produced by a linear + layer (with different weights for each MLP), which takes an extra + argument "global_norm_conditioning". This argument is used to condition + the normalization of all nodes and all edges (hence global), and would + usually only have a batch and feature axis. This is typically used to + condition diffusion models on the "diffusion time". Will raise an error + if this is set to True but the "global_norm_conditioning" is not passed + to the __call__ method, as well as if this is set to False, but + "global_norm_conditioning" is passed to the call method. activation: name of activation function. f32_aggregation: Use float32 in the edge aggregation. aggregate_edges_for_nodes_fn: function used to aggregate messages to each @@ -141,9 +160,14 @@ def __init__(self, self._edge_output_size = edge_output_size self._include_sent_messages_in_node_update = ( include_sent_messages_in_node_update) + if use_norm_conditioning and not use_layer_norm: + raise ValueError( + "`norm_conditioning` can only be used when " + "`use_layer_norm` is true." + ) self._use_layer_norm = use_layer_norm + self._use_norm_conditioning = use_norm_conditioning self._activation = _get_activation_fn(activation) - self._initialized = False self._f32_aggregation = f32_aggregation self._aggregate_edges_for_nodes_fn = _get_aggregate_edges_for_nodes_fn( aggregate_edges_for_nodes_fn) @@ -154,24 +178,31 @@ def __init__(self, assert aggregate_edges_for_nodes_fn == "segment_sum" def __call__(self, - input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + input_graph: typed_graph.TypedGraph, + global_norm_conditioning: Optional[chex.Array] = None + ) -> typed_graph.TypedGraph: """Forward pass of the learnable dynamics model.""" - self._networks_builder(input_graph) + embedder_network, processor_networks, decoder_network = ( + self._networks_builder(input_graph, global_norm_conditioning) + ) # Embed input features (if applicable). - latent_graph_0 = self._embed(input_graph) + latent_graph_0 = self._embed(input_graph, embedder_network) # Do `m` message passing steps in the latent graphs. - latent_graph_m = self._process(latent_graph_0) + latent_graph_m = self._process(latent_graph_0, processor_networks) # Compute outputs from the last latent graph (if applicable). - return self._output(latent_graph_m) - - def _networks_builder(self, graph_template): - if self._initialized: - return - self._initialized = True - + return self._output(latent_graph_m, decoder_network) + + def _networks_builder( + self, + graph_template: typed_graph.TypedGraph, + global_norm_conditioning: Optional[chex.Array] = None, + ) -> Tuple[ + GraphToGraphNetwork, List[GraphToGraphNetwork], GraphToGraphNetwork + ]: + # TODO(aelkadi): move to mlp_builder. def build_mlp(name, output_size): mlp = hk.nets.MLP( output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [ @@ -180,11 +211,40 @@ def build_mlp(name, output_size): def build_mlp_with_maybe_layer_norm(name, output_size): network = build_mlp(name, output_size) + stages = [network] + if self._use_norm_conditioning: + if global_norm_conditioning is None: + raise ValueError( + "When using norm conditioning, `global_norm_conditioning` must" + "be passed to the call method.") + # If using norm conditioning, it is no longer the responsibility of the + # LayerNorm module itself to learn its scale and offset. These will be + # learned for the module by the norm conditioning layer instead. + create_scale = create_offset = False + else: + if global_norm_conditioning is not None: + raise ValueError( + "`globa_norm_conditioning` was passed, but `norm_conditioning`" + " is not enabled.") + create_scale = create_offset = True + if self._use_layer_norm: layer_norm = hk.LayerNorm( - axis=-1, create_scale=True, create_offset=True, + axis=-1, create_scale=create_scale, create_offset=create_offset, name=name + "_layer_norm") - network = hk.Sequential([network, layer_norm]) + stages.append(layer_norm) + + if self._use_norm_conditioning: + norm_conditioning_layer = mlp_builder.LinearNormConditioning( + name=name + "_norm_conditioning") + norm_conditioning_layer = functools.partial( + norm_conditioning_layer, + # Broadcast to the node/edge axis. + norm_conditioning=global_norm_conditioning[None], + ) + stages.append(norm_conditioning_layer) + + network = hk.Sequential(stages) return jraph.concatenated_args(network) # The embedder graph network independently embeds edge and node features. @@ -208,7 +268,7 @@ def build_mlp_with_maybe_layer_norm(name, output_size): embed_edge_fn=embed_edge_fn, embed_node_fn=embed_node_fn, ) - self._embedder_network = typed_graph_net.GraphMapFeatures( + embedder_network = typed_graph_net.GraphMapFeatures( **embedder_kwargs) if self._f32_aggregation: @@ -232,9 +292,9 @@ def aggregate_fn(data, *args, **kwargs): # that update the node and edge latent features. # Note that we can use `modules.InteractionNetwork` because # it also outputs the messages as updated edge latent features. - self._processor_networks = [] + processor_networks = [] for step_i in range(self._num_message_passing_steps): - self._processor_networks.append( + processor_networks.append( typed_graph_net.InteractionNetwork( update_edge_fn=_build_update_fns_for_edge_types( build_mlp_with_maybe_layer_norm, @@ -259,11 +319,15 @@ def aggregate_fn(data, *args, **kwargs): embed_node_fn=_build_update_fns_for_node_types( build_mlp, graph_template, "decoder_nodes_", self._node_output_size) if self._node_output_size else None,) - self._output_network = typed_graph_net.GraphMapFeatures( + output_network = typed_graph_net.GraphMapFeatures( **output_kwargs) + return embedder_network, processor_networks, output_network def _embed( - self, input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + self, + input_graph: typed_graph.TypedGraph, + embedder_network: GraphToGraphNetwork, + ) -> typed_graph.TypedGraph: """Embeds the input graph features into a latent graph.""" # Copy the context to all of the node types, if applicable. @@ -286,11 +350,14 @@ def _embed( context=input_graph.context._replace(features=())) # Embeds the node and edge features. - latent_graph_0 = self._embedder_network(input_graph) + latent_graph_0 = embedder_network(input_graph) return latent_graph_0 def _process( - self, latent_graph_0: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + self, + latent_graph_0: typed_graph.TypedGraph, + processor_networks: List[GraphToGraphNetwork], + ) -> typed_graph.TypedGraph: """Processes the latent graph with several steps of message passing.""" # Do `num_message_passing_steps` with each of the `self._processor_networks` @@ -298,7 +365,7 @@ def _process( # times. latent_graph = latent_graph_0 for unused_repetition_i in range(self._num_processor_repetitions): - for processor_network in self._processor_networks: + for processor_network in processor_networks: latent_graph = self._process_step(processor_network, latent_graph) return latent_graph @@ -326,10 +393,13 @@ def _process_step( nodes=nodes_with_residuals, edges=edges_with_residuals) return latent_graph_k - def _output(self, - latent_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + def _output( + self, + latent_graph: typed_graph.TypedGraph, + output_network: GraphToGraphNetwork, + ) -> typed_graph.TypedGraph: """Produces the output from the latent graph.""" - return self._output_network(latent_graph) + return output_network(latent_graph) def _build_update_fns_for_node_types( diff --git a/graphcast/denoiser.py b/graphcast/denoiser.py new file mode 100644 index 00000000..1091964c --- /dev/null +++ b/graphcast/denoiser.py @@ -0,0 +1,851 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Support for wrapping a general Predictor to act as a Denoiser.""" + +import dataclasses +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple + +import chex +from graphcast import deep_typed_graph_net +from graphcast import denoisers_base as base +from graphcast import grid_mesh_connectivity +from graphcast import icosahedral_mesh +from graphcast import model_utils +from graphcast import sparse_transformer +from graphcast import transformer +from graphcast import typed_graph +from graphcast import xarray_jax +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from scipy import sparse +import xarray + + +Kwargs = Mapping[str, Any] +NoiseLevelEncoder = Callable[[jnp.ndarray], jnp.ndarray] + + +class FourierFeaturesMLP(hk.Module): + """A simple MLP applied to Fourier features of values or their logarithms.""" + + def __init__(self, + base_period: float, + num_frequencies: int, + output_sizes: Sequence[int], + apply_log_first: bool = False, + w_init: ... = None, + activation: ... = jax.nn.gelu, + **mlp_kwargs + ): + """Initializes the module. + + Args: + base_period: + See model_utils.fourier_features. Note this would apply to log inputs if + apply_log_first is used. + num_frequencies: + See model_utils.fourier_features. + output_sizes: + Layer sizes for the MLP. + apply_log_first: + Whether to take the log of the inputs before computing Fourier features. + w_init: + Weights initializer for the MLP, default setting aims to produce + approx unit-variance outputs given the input sin/cos features. + activation: + **mlp_kwargs: + Further settings for the MLP. + """ + super().__init__() + self._base_period = base_period + self._num_frequencies = num_frequencies + self._apply_log_first = apply_log_first + if w_init is None: + # Scale of 2 is appropriate for input layer as sin/cos fourier features + # have variance 0.5 for random inputs. Also reasonable to use for later + # layers as relu activation cuts variance in half for inputs to later + # layers and gelu something close enough too. + w_init = hk.initializers.VarianceScaling( + 2.0, mode="fan_in", distribution="uniform" + ) + self._mlp = hk.nets.MLP( + output_sizes=output_sizes, + w_init=w_init, + activation=activation, + **mlp_kwargs) + + def __call__(self, values: jnp.ndarray) -> jnp.ndarray: + if self._apply_log_first: + values = jnp.log(values) + + features = model_utils.fourier_features( + values, self._base_period, self._num_frequencies) + + return self._mlp(features) + + +@chex.dataclass(frozen=True, eq=True) +class NoiseEncoderConfig: + """Configures the noise level encoding. + + Properties: + apply_log_first: Whether to take the log of the inputs before computing + Fourier features. + base_period: The base period to use. This should be greater or equal to the + range of the values, or to the period if the values have periodic + semantics (e.g. 2pi if they represent angles). Frequencies used will be + integer multiples of 1/base_period. + num_frequencies: The number of frequencies to use, we will use integer + multiples of 1/base_period from 1 up to num_frequencies inclusive. (We + don't include a zero frequency as this would just give constant features + which are redundant if a bias term is present). + output_sizes: Layer sizes for the MLP. + """ + apply_log_first: bool = True + base_period: float = 16.0 + num_frequencies: int = 32 + # 2-layer MLP applied to Fourier features + output_sizes: tuple[int, int] = (32, 16) + + +@chex.dataclass(frozen=True, eq=True) +class SparseTransformerConfig: + """Sparse Transformer config.""" + # Neighbours to attend to. + attention_k_hop: int + # Primary width, the number of channels on the carrier path. + d_model: int + # Depth, or num transformer blocks. One 'layer' is attn + ffw. + num_layers: int = 16 + # Number of heads for self-attention. + num_heads: int = 4 + # Attention type. + attention_type: str = "splash_mha" + # mask type if splash attention being used. + mask_type: str = "lazy" + block_q: int = 1024 + block_kv: int = 512 + block_kv_compute: int = 256 + block_q_dkv: int = 512 + block_kv_dkv: int = 1024 + block_kv_dkv_compute: int = 1024 + # Init scale for final ffw layer (divided by depth) + ffw_winit_final_mult: float = 0.0 + # Init scale for mha w (divided by depth). + attn_winit_final_mult: float = 0.0 + # Number of hidden units in the MLP blocks. Defaults to 4 * d_model. + ffw_hidden: int = 2048 + # Name for haiku module. + name: Optional[str] = None + + +@chex.dataclass(eq=True) +class DenoiserArchitectureConfig: + """Defines the GenCast architecture. + + Properties: + sparse_transformer_config: Config for the mesh transformer. + mesh_size: How many refinements to do on the multi-mesh. + latent_size: How many latent features to include in the various MLPs. + hidden_layers: How many hidden layers for each MLP. + radius_query_fraction_edge_length: Scalar that will be multiplied by the + length of the longest edge of the finest mesh to define the radius of + connectivity to use in the Grid2Mesh graph. Reasonable values are + between 0.6 and 1. 0.6 reduces the number of grid points feeding into + multiple mesh nodes and therefore reduces edge count and memory use, but + 1 gives better predictions. + norm_conditioning_features: List of feature names which will be used to + condition the GNN via norm_conditioning, rather than as regular + features. If this is provided, the GNN has to support the + `global_norm_conditioning` argument. For now it only supports global + norm conditioning (e.g. the same vector conditions all edges and nodes + normalization), which means features passed here must not have "lat" or + "lon" axes. In the future it may support node level norm conditioning + too. + grid2mesh_aggregate_normalization: Optional constant to normalize the output + of aggregate_edges_for_nodes_fn in the mesh2grid GNN. This can be used to + reduce the shock the model undergoes when switching resolution, which + increases the number of edges connected to a node. + node_output_size: Size of the output node representations for + each node type. For node types not specified here, the latent node + representation from the output of the processor will be returned. + """ + + sparse_transformer_config: SparseTransformerConfig + mesh_size: int + latent_size: int = 512 + hidden_layers: int = 1 + radius_query_fraction_edge_length: float = 0.6 + norm_conditioning_features: tuple[str, ...] = ("noise_level_encodings",) + grid2mesh_aggregate_normalization: Optional[float] = None + node_output_size: Optional[int] = None + + +class Denoiser(base.Denoiser): + """Wraps a general deterministic Predictor to act as a Denoiser. + + This passes an encoding of the noise level as an additional input to the + Predictor as an additional input 'noise_level_encodings' with shape + ('batch', 'noise_level_encoding_channels'). It passes the noisy_targets as + additional forcings (since they are also per-target-timestep data that the + predictor needs to condition on) with the same names as the original target + variables. + """ + + def __init__( + self, + noise_encoder_config: Optional[NoiseEncoderConfig], + denoiser_architecture_config: DenoiserArchitectureConfig, + ): + self._predictor = _DenoiserArchitecture( + denoiser_architecture_config=denoiser_architecture_config, + ) + # Use default values if not specified. + if noise_encoder_config is None: + noise_encoder_config = NoiseEncoderConfig() + self._noise_level_encoder = FourierFeaturesMLP(**noise_encoder_config) + + def __call__( + self, + inputs: xarray.Dataset, + noisy_targets: xarray.Dataset, + noise_levels: xarray.DataArray, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + if forcings is None: forcings = xarray.Dataset() + forcings = forcings.assign(noisy_targets) + + if noise_levels.dims != ("batch",): + raise ValueError("noise_levels expected to be shape (batch,).") + noise_level_encodings = self._noise_level_encoder( + xarray_jax.unwrap_data(noise_levels) + ) + noise_level_encodings = xarray_jax.Variable( + ("batch", "noise_level_encoding_channels"), noise_level_encodings + ) + inputs = inputs.assign(noise_level_encodings=noise_level_encodings) + + return self._predictor( + inputs=inputs, + targets_template=noisy_targets, + forcings=forcings, + **kwargs) + + +class _DenoiserArchitecture: + """GenCast Predictor. + + The model works on graphs that take into account: + * Mesh nodes: nodes for the vertices of the mesh. + * Grid nodes: nodes for the points of the grid. + * Nodes: When referring to just "nodes", this means the joint set of + both mesh nodes, concatenated with grid nodes. + + The model works with 3 graphs: + * Grid2Mesh graph: Graph that contains all nodes. This graph is strictly + bipartite with edges going from grid nodes to mesh nodes using a + fixed radius query. The grid2mesh_gnn will operate in this graph. The output + of this stage will be a latent representation for the mesh nodes, and a + latent representation for the grid nodes. + * Mesh graph: Graph that contains mesh nodes only. The mesh_gnn will + operate in this graph. It will update the latent state of the mesh nodes + only. + * Mesh2Grid graph: Graph that contains all nodes. This graph is strictly + bipartite with edges going from mesh nodes to grid nodes such that each grid + node is connected to 3 nodes of the mesh triangular face that contains + the grid points. The mesh2grid_gnn will operate in this graph. It will + process the updated latent state of the mesh nodes, and the latent state + of the grid nodes, to produce the final output for the grid nodes. + + The model is built on top of `TypedGraph`s so the different types of nodes and + edges can be stored and treated separately. + """ + + def __init__( + self, + denoiser_architecture_config: DenoiserArchitectureConfig, + ): + """Initializes the predictor.""" + self._spatial_features_kwargs = dict( + add_node_positions=False, + add_node_latitude=True, + add_node_longitude=True, + add_relative_positions=True, + relative_longitude_local_coordinates=True, + relative_latitude_local_coordinates=True, + ) + + # Construct the mesh. + mesh = icosahedral_mesh.get_last_triangular_mesh_for_sphere( + splits=denoiser_architecture_config.mesh_size + ) + # Permute the mesh to a banded structure so we can run sparse attention + # operations. + self._mesh = _permute_mesh_to_banded(mesh=mesh) + + # Encoder, which moves data from the grid to the mesh with a single message + # passing step. + self._grid2mesh_gnn = ( + deep_typed_graph_net.DeepTypedGraphNet( + activation="swish", + aggregate_normalization=( + denoiser_architecture_config.grid2mesh_aggregate_normalization + ), + edge_latent_size=dict( + grid2mesh=denoiser_architecture_config.latent_size + ), + embed_edges=True, + embed_nodes=True, + f32_aggregation=True, + include_sent_messages_in_node_update=False, + mlp_hidden_size=denoiser_architecture_config.latent_size, + mlp_num_hidden_layers=denoiser_architecture_config.hidden_layers, + name="grid2mesh_gnn", + node_latent_size=dict( + grid_nodes=denoiser_architecture_config.latent_size, + mesh_nodes=denoiser_architecture_config.latent_size + ), + node_output_size=None, + num_message_passing_steps=1, + use_layer_norm=True, + use_norm_conditioning=True, + ) + ) + + # Processor - performs multiple rounds of message passing on the mesh. + self._mesh_gnn = transformer.MeshTransformer( + name="mesh_transformer", + transformer_ctor=sparse_transformer.Transformer, + transformer_kwargs=dataclasses.asdict( + denoiser_architecture_config.sparse_transformer_config + ), + ) + + # Decoder, which moves data from the mesh back into the grid with a single + # message passing step. + self._mesh2grid_gnn = ( + deep_typed_graph_net.DeepTypedGraphNet( + activation="swish", + edge_latent_size=dict( + mesh2grid=denoiser_architecture_config.latent_size + ), + embed_nodes=False, + f32_aggregation=False, + include_sent_messages_in_node_update=False, + mlp_hidden_size=denoiser_architecture_config.latent_size, + mlp_num_hidden_layers=denoiser_architecture_config.hidden_layers, + name="mesh2grid_gnn", + node_latent_size=dict( + grid_nodes=denoiser_architecture_config.latent_size, + mesh_nodes=denoiser_architecture_config.latent_size, + ), + node_output_size={ + "grid_nodes": denoiser_architecture_config.node_output_size + }, + num_message_passing_steps=1, + use_layer_norm=True, + use_norm_conditioning=True, + ) + ) + + self._norm_conditioning_features = ( + denoiser_architecture_config.norm_conditioning_features + ) + # Obtain the query radius in absolute units for the unit-sphere for the + # grid2mesh model, by rescaling the `radius_query_fraction_edge_length`. + self._query_radius = ( + _get_max_edge_distance(self._mesh) + * denoiser_architecture_config.radius_query_fraction_edge_length + ) + + # Other initialization is delayed until the first call (`_maybe_init`) + # when we get some sample data so we know the lat/lon values. + self._initialized = False + + # A "_init_mesh_properties": + # This one could be initialized at init but we delay it for consistency too. + self._num_mesh_nodes = None # num_mesh_nodes + self._mesh_nodes_lat = None # [num_mesh_nodes] + self._mesh_nodes_lon = None # [num_mesh_nodes] + + # A "_init_grid_properties": + self._grid_lat = None # [num_lat_points] + self._grid_lon = None # [num_lon_points] + self._num_grid_nodes = None # num_lat_points * num_lon_points + self._grid_nodes_lat = None # [num_grid_nodes] + self._grid_nodes_lon = None # [num_grid_nodes] + + # A "_init_{grid2mesh,processor,mesh2grid}_graph" + self._grid2mesh_graph_structure = None + self._mesh_graph_structure = None + self._mesh2grid_graph_structure = None + + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + ) -> xarray.Dataset: + self._maybe_init(inputs) + + # Convert all input data into flat vectors for each of the grid nodes. + # xarray (batch, time, lat, lon, level, multiple vars, forcings) + # -> [num_grid_nodes, batch, num_channels] + grid_node_features, global_norm_conditioning = ( + self._inputs_to_grid_node_features_and_norm_conditioning( + inputs, forcings + ) + ) + + # [num_mesh_nodes, batch, latent_size], [num_grid_nodes, batch, latent_size] + (latent_mesh_nodes, latent_grid_nodes) = self._run_grid2mesh_gnn( + grid_node_features, global_norm_conditioning + ) + + # Run message passing in the multimesh. + # [num_mesh_nodes, batch, latent_size] + updated_latent_mesh_nodes = self._run_mesh_gnn( + latent_mesh_nodes, global_norm_conditioning + ) + + # Transfer data from the mesh to the grid. + # [num_grid_nodes, batch, output_size] + output_grid_nodes = self._run_mesh2grid_gnn( + updated_latent_mesh_nodes, latent_grid_nodes, global_norm_conditioning + ) + + # Convert output flat vectors for the grid nodes to the format of the + # output. [num_grid_nodes, batch, output_size] -> xarray (batch, one time + # step, lat, lon, level, multiple vars) + return self._grid_node_outputs_to_prediction( + output_grid_nodes, targets_template + ) + + def _maybe_init(self, sample_inputs: xarray.Dataset): + """Inits everything that has a dependency on the input coordinates.""" + if not self._initialized: + self._init_mesh_properties() + self._init_grid_properties( + grid_lat=sample_inputs.lat, grid_lon=sample_inputs.lon) + self._grid2mesh_graph_structure = self._init_grid2mesh_graph() + self._mesh_graph_structure = self._init_mesh_graph() + self._mesh2grid_graph_structure = self._init_mesh2grid_graph() + + self._initialized = True + + def _init_mesh_properties(self): + """Inits static properties that have to do with mesh nodes.""" + self._num_mesh_nodes = self._mesh.vertices.shape[0] + mesh_phi, mesh_theta = model_utils.cartesian_to_spherical( + self._mesh.vertices[:, 0], + self._mesh.vertices[:, 1], + self._mesh.vertices[:, 2]) + ( + mesh_nodes_lat, + mesh_nodes_lon, + ) = model_utils.spherical_to_lat_lon( + phi=mesh_phi, theta=mesh_theta) + # Convert to f32 to ensure the lat/lon features aren't in f64. + self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32) + self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32) + + def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray): + """Inits static properties that have to do with grid nodes.""" + self._grid_lat = grid_lat.astype(np.float32) + self._grid_lon = grid_lon.astype(np.float32) + # Initialized the counters. + self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0] + + # Initialize lat and lon for the grid. + grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat) + self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32) + self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32) + + def _init_grid2mesh_graph(self) -> typed_graph.TypedGraph: + """Build Grid2Mesh graph.""" + + # Create some edges according to distance between mesh and grid nodes. + assert self._grid_lat is not None and self._grid_lon is not None + (grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices( + grid_latitude=self._grid_lat, + grid_longitude=self._grid_lon, + mesh=self._mesh, + radius=self._query_radius) + + # Edges sending info from grid to mesh. + senders = grid_indices + receivers = mesh_indices + + # Precompute structural node and edge features according to config options. + # Structural features are those that depend on the fixed values of the + # latitude and longitudes of the nodes. + (senders_node_features, receivers_node_features, + edge_features) = model_utils.get_bipartite_graph_spatial_features( + senders_node_lat=self._grid_nodes_lat, + senders_node_lon=self._grid_nodes_lon, + receivers_node_lat=self._mesh_nodes_lat, + receivers_node_lon=self._mesh_nodes_lon, + senders=senders, + receivers=receivers, + edge_normalization_factor=None, + **self._spatial_features_kwargs, + ) + + n_grid_node = np.array([self._num_grid_nodes]) + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([mesh_indices.shape[0]]) + grid_node_set = typed_graph.NodeSet( + n_node=n_grid_node, features=senders_node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=receivers_node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")): + edge_set + } + grid2mesh_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + return grid2mesh_graph + + def _init_mesh_graph(self) -> typed_graph.TypedGraph: + """Build Mesh graph.""" + # Work simply on the mesh edges. + # N.B.To make sure ordering is preserved, any changes to faces_to_edges here + # should be reflected in the other 2 calls to faces_to_edges in this file. + senders, receivers = icosahedral_mesh.faces_to_edges(self._mesh.faces) + + # Precompute structural node and edge features according to config options. + # Structural features are those that depend on the fixed values of the + # latitude and longitudes of the nodes. + assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None + node_features, edge_features = model_utils.get_graph_spatial_features( + node_lat=self._mesh_nodes_lat, + node_lon=self._mesh_nodes_lon, + senders=senders, + receivers=receivers, + **self._spatial_features_kwargs, + ) + + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([senders.shape[0]]) + assert n_mesh_node == len(node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set + } + mesh_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + + return mesh_graph + + def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph: + """Build Mesh2Grid graph.""" + + # Create some edges according to how the grid nodes are contained by + # mesh triangles. + (grid_indices, + mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices( + grid_latitude=self._grid_lat, + grid_longitude=self._grid_lon, + mesh=self._mesh) + + # Edges sending info from mesh to grid. + senders = mesh_indices + receivers = grid_indices + + # Precompute structural node and edge features according to config options. + assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None + (senders_node_features, receivers_node_features, + edge_features) = model_utils.get_bipartite_graph_spatial_features( + senders_node_lat=self._mesh_nodes_lat, + senders_node_lon=self._mesh_nodes_lon, + receivers_node_lat=self._grid_nodes_lat, + receivers_node_lon=self._grid_nodes_lon, + senders=senders, + receivers=receivers, + edge_normalization_factor=None, + **self._spatial_features_kwargs, + ) + + n_grid_node = np.array([self._num_grid_nodes]) + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([senders.shape[0]]) + grid_node_set = typed_graph.NodeSet( + n_node=n_grid_node, features=receivers_node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=senders_node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")): + edge_set + } + mesh2grid_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + return mesh2grid_graph + + def _run_grid2mesh_gnn(self, grid_node_features: chex.Array, + global_norm_conditioning: Optional[chex.Array] = None, + ) -> tuple[chex.Array, chex.Array]: + """Runs the grid2mesh_gnn, extracting latent mesh and grid nodes.""" + + # Concatenate node structural features with input features. + batch_size = grid_node_features.shape[1] + + grid2mesh_graph = self._grid2mesh_graph_structure + assert grid2mesh_graph is not None + grid_nodes = grid2mesh_graph.nodes["grid_nodes"] + mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"] + new_grid_nodes = grid_nodes._replace( + features=jnp.concatenate([ + grid_node_features, + _add_batch_second_axis( + grid_nodes.features.astype(grid_node_features.dtype), + batch_size) + ], + axis=-1)) + + # To make sure capacity of the embedded is identical for the grid nodes and + # the mesh nodes, we also append some dummy zero input features for the + # mesh nodes. + dummy_mesh_node_features = jnp.zeros( + (self._num_mesh_nodes,) + grid_node_features.shape[1:], + dtype=grid_node_features.dtype) + new_mesh_nodes = mesh_nodes._replace( + features=jnp.concatenate([ + dummy_mesh_node_features, + _add_batch_second_axis( + mesh_nodes.features.astype(dummy_mesh_node_features.dtype), + batch_size) + ], + axis=-1)) + + # Broadcast edge structural features to the required batch size. + grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh") + edges = grid2mesh_graph.edges[grid2mesh_edges_key] + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(dummy_mesh_node_features.dtype), batch_size)) + + input_graph = self._grid2mesh_graph_structure._replace( + edges={grid2mesh_edges_key: new_edges}, + nodes={ + "grid_nodes": new_grid_nodes, + "mesh_nodes": new_mesh_nodes + }) + + # Run the GNN. + grid2mesh_out = self._grid2mesh_gnn(input_graph, global_norm_conditioning) + latent_mesh_nodes = grid2mesh_out.nodes["mesh_nodes"].features + latent_grid_nodes = grid2mesh_out.nodes["grid_nodes"].features + return latent_mesh_nodes, latent_grid_nodes + + def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array, + global_norm_conditioning: Optional[chex.Array] = None + ) -> chex.Array: + """Runs the mesh_gnn, extracting updated latent mesh nodes.""" + + # Add the structural edge features of this graph. Note we don't need + # to add the structural node features, because these are already part of + # the latent state, via the original Grid2Mesh gnn, however, we need + # the edge ones, because it is the first time we are seeing this particular + # set of edges. + batch_size = latent_mesh_nodes.shape[1] + + mesh_graph = self._mesh_graph_structure + assert mesh_graph is not None + mesh_edges_key = mesh_graph.edge_key_by_name("mesh") + edges = mesh_graph.edges[mesh_edges_key] + + # We are assuming here that the mesh gnn uses a single set of edge keys + # named "mesh" for the edges and that it uses a single set of nodes named + # "mesh_nodes" + msg = ("The setup currently requires to only have one kind of edge in the" + " mesh GNN.") + assert len(mesh_graph.edges) == 1, msg + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(latent_mesh_nodes.dtype), batch_size)) + + nodes = mesh_graph.nodes["mesh_nodes"] + nodes = nodes._replace(features=latent_mesh_nodes) + + input_graph = mesh_graph._replace( + edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes}) + + # Run the GNN. + return self._mesh_gnn(input_graph, + global_norm_conditioning=global_norm_conditioning + ).nodes["mesh_nodes"].features + + def _run_mesh2grid_gnn(self, + updated_latent_mesh_nodes: chex.Array, + latent_grid_nodes: chex.Array, + global_norm_conditioning: Optional[chex.Array] = None, + ) -> chex.Array: + """Runs the mesh2grid_gnn, extracting the output grid nodes.""" + + # Add the structural edge features of this graph. Note we don't need + # to add the structural node features, because these are already part of + # the latent state, via the original Grid2Mesh gnn, however, we need + # the edge ones, because it is the first time we are seeing this particular + # set of edges. + batch_size = updated_latent_mesh_nodes.shape[1] + + mesh2grid_graph = self._mesh2grid_graph_structure + assert mesh2grid_graph is not None + mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"] + grid_nodes = mesh2grid_graph.nodes["grid_nodes"] + new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes) + new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes) + mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid") + edges = mesh2grid_graph.edges[mesh2grid_key] + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(latent_grid_nodes.dtype), batch_size)) + + input_graph = mesh2grid_graph._replace( + edges={mesh2grid_key: new_edges}, + nodes={ + "mesh_nodes": new_mesh_nodes, + "grid_nodes": new_grid_nodes + }) + + # Run the GNN. + output_graph = self._mesh2grid_gnn(input_graph, global_norm_conditioning) + output_grid_nodes = output_graph.nodes["grid_nodes"].features + + return output_grid_nodes + + def _inputs_to_grid_node_features_and_norm_conditioning( + self, + inputs: xarray.Dataset, + forcings: xarray.Dataset, + ) -> Tuple[chex.Array, Optional[chex.Array]]: + """xarray ->[n_grid_nodes, batch, n_channels], [batch, n_cond channels].""" + + if self._norm_conditioning_features: + norm_conditioning_inputs = inputs[list(self._norm_conditioning_features)] + inputs = inputs.drop_vars(list(self._norm_conditioning_features)) + + if "lat" in norm_conditioning_inputs or "lon" in norm_conditioning_inputs: + raise ValueError("Features with lat or lon dims are not currently " + "supported for norm conditioning.") + global_norm_conditioning = xarray_jax.unwrap_data( + model_utils.dataset_to_stacked(norm_conditioning_inputs, + preserved_dims=("batch",), + ).transpose("batch", ...)) + + else: + global_norm_conditioning = None + + # xarray `Dataset` (batch, time, lat, lon, level, multiple vars) + # to xarray `DataArray` (batch, lat, lon, channels) + stacked_inputs = model_utils.dataset_to_stacked(inputs) + stacked_forcings = model_utils.dataset_to_stacked(forcings) + stacked_inputs = xarray.concat( + [stacked_inputs, stacked_forcings], dim="channels") + + # xarray `DataArray` (batch, lat, lon, channels) + # to single numpy array with shape [lat_lon_node, batch, channels] + grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes( + stacked_inputs) + # ["node", "batch", "features"] + grid_node_features = xarray_jax.unwrap( + grid_xarray_lat_lon_leading.data + ).reshape((-1,) + grid_xarray_lat_lon_leading.data.shape[2:]) + return grid_node_features, global_norm_conditioning + + def _grid_node_outputs_to_prediction( + self, + grid_node_outputs: chex.Array, + targets_template: xarray.Dataset, + ) -> xarray.Dataset: + """[num_grid_nodes, batch, num_outputs] -> xarray.""" + + # numpy array with shape [lat_lon_node, batch, channels] + assert self._grid_lat is not None and self._grid_lon is not None + grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0]) + grid_outputs_lat_lon_leading = grid_node_outputs.reshape( + grid_shape + grid_node_outputs.shape[1:]) + dims = ("lat", "lon", "batch", "channels") + grid_xarray_lat_lon_leading = xarray_jax.DataArray( + data=grid_outputs_lat_lon_leading, + dims=dims) + grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading) + + # xarray `DataArray` (batch, lat, lon, channels) + # to xarray `Dataset` (batch, one time step, lat, lon, level, multiple vars) + return model_utils.stacked_to_dataset( + grid_xarray.variable, targets_template) + + +def _add_batch_second_axis(data, batch_size): + # data [leading_dim, trailing_dim] + assert data.ndim == 2 + ones = jnp.ones([batch_size, 1], dtype=data.dtype) + return data[:, None] * ones # [leading_dim, batch, trailing_dim] + + +def _get_max_edge_distance(mesh): + # N.B.To make sure ordering is preserved, any changes to faces_to_edges here + # should be reflected in the other 2 calls to faces_to_edges in this file. + senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces) + edge_distances = np.linalg.norm( + mesh.vertices[senders] - mesh.vertices[receivers], axis=-1) + return edge_distances.max() + + +def _permute_mesh_to_banded(mesh): + """Permutes the mesh nodes such that adjacency matrix has banded structure.""" + # Build adjacency matrix. + # N.B.To make sure ordering is preserved, any changes to faces_to_edges here + # should be reflected in the other 2 calls to faces_to_edges in this file. + senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces) + num_mesh_nodes = mesh.vertices.shape[0] + adj_mat = sparse.csr_matrix((num_mesh_nodes, num_mesh_nodes)) + adj_mat[senders, receivers] = 1 + # Permutation to banded (this algorithm is deterministic, a given sparse + # adjacency matrix will yield the same permutation every time this is run). + mesh_permutation = sparse.csgraph.reverse_cuthill_mckee( + adj_mat, symmetric_mode=True + ) + vertex_permutation_map = {j: i for i, j in enumerate(mesh_permutation)} + permute_func = np.vectorize(lambda x: vertex_permutation_map[x]) + return icosahedral_mesh.TriangularMesh( + vertices=mesh.vertices[mesh_permutation], faces=permute_func(mesh.faces) + ) diff --git a/graphcast/denoisers_base.py b/graphcast/denoisers_base.py new file mode 100644 index 00000000..b30b39bd --- /dev/null +++ b/graphcast/denoisers_base.py @@ -0,0 +1,53 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class for Denoisers used in diffusion Predictors. + +Denoisers are a bit like deterministic Predictors, except: +* Their __call__ method also conditions on noisy_targets and the noise_levels + of those noisy targets +* They don't have an overrideable loss function (the loss is assumed to be some + form of MSE and is implemented outside the Denoiser itself) +""" + +from typing import Optional, Protocol + +import xarray + + +class Denoiser(Protocol): + """A denoising model that conditions on inputs as well as noise level.""" + + def __call__( + self, + inputs: xarray.Dataset, + noisy_targets: xarray.Dataset, + noise_levels: xarray.DataArray, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + """Computes denoised targets from noisy targets. + + Args: + inputs: Inputs to condition on, as for Predictor.__call__. + noisy_targets: Targets which have had i.i.d. zero-mean Gaussian noise + added to them (where the noise level used may vary along the 'batch' + dimension). + noise_levels: A DataArray with dimensions ('batch',) specifying the noise + levels that were used for each example in the batch. + forcings: Optional additional per-target-timestep forcings to condition + on, as for Predictor.__call__. + **kwargs: Any additional custom kwargs. + + Returns: + Denoised predictions with the same shape as noisy_targets. + """ diff --git a/graphcast/docs/GenCast_0p25deg_accelerator_scorecard.png b/graphcast/docs/GenCast_0p25deg_accelerator_scorecard.png new file mode 100644 index 00000000..d9173acb Binary files /dev/null and b/graphcast/docs/GenCast_0p25deg_accelerator_scorecard.png differ diff --git a/graphcast/docs/GenCast_1p0deg_Mini_ENS_scorecard.png b/graphcast/docs/GenCast_1p0deg_Mini_ENS_scorecard.png new file mode 100644 index 00000000..f7c94ca7 Binary files /dev/null and b/graphcast/docs/GenCast_1p0deg_Mini_ENS_scorecard.png differ diff --git a/graphcast/docs/cloud_vm_setup.md b/graphcast/docs/cloud_vm_setup.md new file mode 100644 index 00000000..62602397 --- /dev/null +++ b/graphcast/docs/cloud_vm_setup.md @@ -0,0 +1,242 @@ +This document describes how to run `gencast_demo_cloud_vm.ipynb` through [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/gencast_demo_cloud_vm.ipynb) using Google Cloud compute. + +## TPU Cost and Availability + +- There are 2 ways to access a Cloud TPU VM: + - "Spot" + - As per https://cloud.google.com/tpu/docs/spot, "Spot VMs + make unused capacity available at highly-discounted rates. Spot VMs can be + preempted (shut down) at any time… You can't restart TPU Spot VMs, and you + must recreate them after preemption." + - "On Demand" + - As per https://cloud.google.com/tpu/docs/quota#tpu_quota, + "On-demand resources won't be preempted, but on-demand quota does not + guarantee there will be enough available Cloud TPU resources to satisfy + your request." +- N.B. you may come across references to "preemptible" TPUs, these have been + [deprecated](https://cloud.google.com/tpu/docs/preemptible) +- As per https://cloud.google.com/compute/docs/instances/spot#pricing, + "Spot prices give you 60-91% discounts compared to the standard price for most + machine types and GPUs." +- As per https://cloud.google.com/tpu/pricing, at the time of + writing, "On Demand" prices (price pre-discount applied above) are: + - v4p: ~$3.20 per chip hour + - v5e: ~$1.20 per chip hour + - v5p: ~$5.20 per chip hour +- Reference pricing, estimating ~10 minutes for TPU set up/data transfer + - v5e: 4 sample, 30 step rollout of 1deg GenCast takes ~5 minutes, including + tracing and compilation, this will cost ~ 15 (minutes) * 1.20 ($/chip hour) * + 1/60 (minutes to hour) * 4 (chips) * 0.09-0.4 (discount) = $0.11 - $0.48 + - Once the TPU is set up and compilation/tracing is done, the same rollout + takes ~3 minutes = $0.07 - $0.29 + - v5p: 8 sample, 30 step rollout of 0.25deg GenCast takes ~30 minutes, + including tracing and compilation, this will cost ~ 40 (minutes) * 5.20 + ($/chip hour) * 1/60 (minutes to hour) * 8 (chips) * 0.09-0.4 (discount) = + $2.50 - $11.10 + - Once the TPU is set up and compilation/tracing is done, the same rollout + takes ~8 minutes = $0.50 - $2.22 + +## GenCast Memory Requirements + +### Running inference on TPU + +- This requires: + - 0.25deg GenCast: ~250GB of Host RAM (System Memory) and ~32GB of HBM (vRAM) + - 1deg GenCast: ~21GB of Host RAM (System Memory) and ~8GB of HBM (vRAM) +- Host RAM/System Memory is required for compilation. One way of accessing + sufficient host memory is to request a multi-device host, such as a 2x2x1 + configuration of TPUs +- HBM/vRAM is required for running inference. Each TPU in a 2x2x1 TPUv4p+ + configuration has enough HBM to run GenCast inference. As such, we make + efficient use of all devices and parallelise generating an ensemble across + them +- Note that compilation cost does not increase with the number of devices, it is + a fixed-cost one time operation whose result can be reused across any number + of devices + +### Running Inference on GPU + +- Inference can also be run on GPU, however, splash attention is currently not + supported for GPU in JAX +- Instead, in the `SparseTransformerConfig`, set + ``` + attention_type = "triblockdiag_mha" + mask_type = "full" + ``` +- We tried running the model on a H100 with this attention mechanism and found that + while the performance is comparable, there is a small degradation (on average ~0.3% on unbiased Ensemble Mean RMSE and ~0.4% on unbiased CRPS). We suspect that + this originates from the attention mechanisms being algebraically equivalent, but + not numerically equivalent. For reference, a scorecard comparing GenCast + forecasts produced on a TPUv4 with `splash attention` vs. on a H100 with `triblockdiag_mha` + can be found in [docs/](https://github.com/google-deepmind/graphcast/docs/GenCast_0p25deg_accelerator_scorecard.png). + Note that this scorecard differs from those found in the GenCast paper + in a number of ways: + - 8 member ensembles (vs. 50 in the paper) + - 30 hour initialisation strides starting from 01-01-2019T00, i.e. a comparison + of 292 initialisations (vs. 730 in the paper) + - Colorbar limits of +-3% (vs. +-20% in the paper) +- `triblockdiag_mha` also requires more memory, as such running inference on GPU + requires: + - 0.25deg GenCast: ~300GB of System Memory and ~60GB of vRAM + - 1deg GenCast: ~24GB of System Memory and ~16GB vRAM + +## Prerequisites + +### Create a Google Cloud account + +- This can be done via https://cloud.google.com/ +- Creating an account will prompt you to enter credit card information + - While this is necessary, please note that Google Cloud provides $300 worth + of free credits to start with, for 90 days. As per Cloud: "You won't be + auto-charged once your trial ends. **However, your resources are marked for + deletion and may be lost immediately**. Activate your full account anytime + during your trial to avoid losing your resources. After activating your full + account, you'll pay only for what you use." +- This will create a project by default + +### Set up `gcloud` on your personal device + +- Install `python` https://www.python.org/downloads +- Install `gcloud` CLI https://cloud.google.com/sdk/docs/install + and follow instructions to initialise + - Don't forget to restart your terminal once you have completed this to enable + `gcloud` + +## Provisioning a Cloud VM TPU + +- In what follows, set fields: + - Name: as preferred, we will assume they are the same and refer to them as + `` from here on + - Zone: at the time of writing, for each TPU type, availability can be found + in + - v4p: None ☹️ + - v5e: `us-south1-a` or `asia-southeast1-b` + - v5p: `us-east5-a` or `europe-west4-b` + - Note that availability is variable and subject to change - we will refer to this + as `` from here on + - TPU type: based on the GenCast model being run, number of chips N desired and quota limits. + - The "GenCast Memory Requirements" section above details the accelerator recommendation for each model, where, in the TPU type dropdown: + - v5e → `v5litepod-N` + - v5p → `v5p-N` + - e.g. ![tpu types](tpu_types.png) + - The number of chips `N` will correspond to the number of forecasts samples that can be generated in parallel + - Some zones have limits on number of chips that can be requested, e.g. at the time of writing `us-south1-a` limits to `v5litepod-4` + - Some TPU types have restrictions on the minimum number of chips that can be request, e.g. at the time of writing v5p appears to come in a minimum 8 chip configuration: `v5p-8` + - TPU software version: based on the TPU type + - v5e: `v2-tpuv5-litepod` + - v5p: `v2-alpha-tpuv5` +- Enabling "queueing" puts the creation request into a queue until resources become available. If you do not select this then if capacity is not instantaneously available at the time of the creation request, it will result in a `Stockout` error +- After creating the instance, you will find your request in the Console on "Compute Engine>TPUs>Queued Resources" tab + - Its status should change to "provisioning" and eventually "active" + +**NOTE**: if you see a failed job in Queued Resources, you may have to delete it before launching again. This is because it is considered as quota and may cause you to hit the quota ceiling. Fortunately, this will not consume credits. + +Summary: + +| Model | Recommended Accelerator | Cells | Software Version | | +|:----------------------:|:-----------------------:|:---------------------------------:|:----------------:|:-:| +| GenCast 0p25deg (Oper) | v5p-N | us-east5-a or europe-west4-b | v2-alpha-tpuv5 | | +| GenCast 1deg (Mini) | v5litepod-N | us-south1-a or asia-southeast1-b | v2-tpuv5-litepod | | + +**OPTION 1: Via the command line (required if requesting "Spot" TPUs)**: + +- As per https://cloud.google.com/tpu/docs/spot#create-tpu-spot-vms +- If enabling queueing, set fields as above, and set + - `QUEUED_RESOURCE_ID` = `` + - `node_id` = `` + - `accelerator-type` = desired TPU type e.g. `v5litepod-4` + - `runtime-version` = desired TPU software version e.g. `v2-tpuv5-litepod` +- If not enabling queuing, set fields as above, and set + - `TPU_NAME` = `` + - `accelerator-type` = desired TPU type e.g. `v5litepod-4` + - `version` = desired TPU software version e.g. `v2-tpuv5-litepod` +- To request a spot device, append `--spot` to the command +- E.g. + ``` + gcloud compute tpus queued-resources create node-1 --node-id=node-1 --zone=us-south1-a --accelerator-type=v5litepod-4 --runtime-version=v2-tpuv5-litepod --spot + ``` + +**OPTION 2: Via the Cloud Console**: + +- In the Cloud console sidebar, navigate to "Create a VM>Compute Engine" and then to "TPUs" in the menu on the left + - On the first visit this will prompt you to "Enable" Compute Engine API" +- Click "Create TPU" +- Set fields as above + - N.B. The UI presents the option "Management>Preemptibility>Turn on pre-emptibility for this TPU node" + - This is actually a deprecated field and will result in a `Preemptibility is not available for the selected zone and type` error + - To request "Spot" TPUs see the command line method above + - Enable queuing by toggling the provided "Enable queuing" button +- E.g. ![Provisioning a TPU](provision_tpu.png) + + +## Preparing an "active" Cloud VM TPU + +- SSH into the TPU VM and port forward to your personal device + ``` + gcloud compute tpus tpu-vm ssh --zone --project -- -L 8081:localhost:8081 + ``` + - Where `` can be found in the Cloud Console, e.g. ![Cloud project](project.png) +- Transfer relevant files onto the VM (this is required because as per https://github.com/googlecolab/colabtools/issues/2533, mounting Google Drive or Cloud Buckets is not possible in local runtimes), e.g. to run a 30 step rollout of the 1 degree model + ``` + gcloud storage cp gs://dm_graphcast/gencast/dataset/source-era5_date-2019-03-29_res-1.0_levels-13_steps-30.nc . + gcloud storage cp "gs://dm_graphcast/gencast/params/GenCast 1p0deg <2019.npz" . + gcloud storage cp --recursive gs://dm_graphcast/gencast/stats/ . + ``` +- Install + - JAX, to ensure the proper runtime libraries are installed to access the TPUs + ``` + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + ``` + - Jupyter, to launch a local runtime server + ``` + pip install jupyter + ``` + +### (Optional) Set up a bucket to store predictions + +- Now that the zone of the TPU is known, create a bucket, where + - `` is as preferred + - `` is the zone used for the TPU, using the names listed as per https://cloud.google.com/storage/docs/locations#available-locations, + - E,g, `us-south1-a` → `US-SOUTH-1` + ``` + gcloud storage buckets create gs:// --location + ``` +- Create a Cloud TPU service account + ``` + gcloud beta services identity create --service tpu.googleapis.com --project + ``` + - The above command returns a Cloud TPU Service Account with following format (we refer to this as ``): + ``` + service-@cloud-tpu.iam.gserviceaccount.com + ``` +- Grant writing permissions + ``` + gcloud storage buckets add-iam-policy-binding gs:// --member=serviceAccount: --role=roles/storage.objectCreator + ``` +- You can view the bucket in the Console UI by visiting Storage>Buckets in the sidebar + +## Connecting Notebook to a prepared Cloud VM TPU + +- While still SSHd to the TPU VM, start the Jupyter server + ``` + python3 -m notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8081 --NotebookApp.port_retries=0 --no-browser + ``` + - This will generate a http://localhost… URL, e.g. http://localhost:8081/tree?token=bdb95985daa075099fa8103688120cca0a477326df24164f + ![local runtime url](local_runtime_url.png) +- Connect to the local runtime in gencast_demo_cloud_vm.ipynb with this URL + ![local runtime popup 1](local_runtime_popup_1.png) +- Happy forecasting! (Don't forget to delete the TPU VM once done) + ![local runtime popup 2](local_runtime_popup_2.png) + +## (Optional) Transferring data off of the TPU + +- Run the "Save the predictions" cell in the notebook to save the predictions to the TPU filesystem + ``` + # @title (Optional) Save the predictions. + predictions.to_zarr("predictions.zarr") + ``` +- In the command line copy these predictions to the bucket created as per the instructions above + ``` + gcloud storage cp --recursive predictions.zarr gs:/// + ``` + diff --git a/graphcast/docs/local_runtime_popup_1.png b/graphcast/docs/local_runtime_popup_1.png new file mode 100644 index 00000000..50a3a7b8 Binary files /dev/null and b/graphcast/docs/local_runtime_popup_1.png differ diff --git a/graphcast/docs/local_runtime_popup_2.png b/graphcast/docs/local_runtime_popup_2.png new file mode 100644 index 00000000..36a299fb Binary files /dev/null and b/graphcast/docs/local_runtime_popup_2.png differ diff --git a/graphcast/docs/local_runtime_url.png b/graphcast/docs/local_runtime_url.png new file mode 100644 index 00000000..5d5ffef7 Binary files /dev/null and b/graphcast/docs/local_runtime_url.png differ diff --git a/graphcast/docs/project.png b/graphcast/docs/project.png new file mode 100644 index 00000000..99ca5051 Binary files /dev/null and b/graphcast/docs/project.png differ diff --git a/graphcast/docs/provision_tpu.png b/graphcast/docs/provision_tpu.png new file mode 100644 index 00000000..cc825c40 Binary files /dev/null and b/graphcast/docs/provision_tpu.png differ diff --git a/graphcast/docs/tpu_types.png b/graphcast/docs/tpu_types.png new file mode 100644 index 00000000..954c20fd Binary files /dev/null and b/graphcast/docs/tpu_types.png differ diff --git a/graphcast/dpm_solver_plus_plus_2s.py b/graphcast/dpm_solver_plus_plus_2s.py new file mode 100644 index 00000000..b3685618 --- /dev/null +++ b/graphcast/dpm_solver_plus_plus_2s.py @@ -0,0 +1,187 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DPM-Solver++ 2S sampler from https://arxiv.org/abs/2211.01095.""" + +from typing import Optional + +from graphcast import casting +from graphcast import denoisers_base +from graphcast import samplers_base as base +from graphcast import samplers_utils as utils +from graphcast import xarray_jax +import haiku as hk +import jax.numpy as jnp +import xarray + + +class Sampler(base.Sampler): + """Sampling using DPM-Solver++ 2S from [1]. + + This is combined with optional stochastic churn as described in [2]. + + The '2S' terminology from [1] means that this is a second-order (2), + single-step (S) solver. Here 'single-step' here distinguishes it from + 'multi-step' methods where the results of function evaluations from previous + steps are reused in computing updates for subsequent steps. The solver still + uses multiple steps though. + + [1] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic + Models, https://arxiv.org/abs/2211.01095 + [2] Elucidating the Design Space of Diffusion-Based Generative Models, + https://arxiv.org/abs/2206.00364 + """ + + def __init__(self, + denoiser: denoisers_base.Denoiser, + max_noise_level: float, + min_noise_level: float, + num_noise_levels: int, + rho: float, + stochastic_churn_rate: float, + churn_min_noise_level: float, + churn_max_noise_level: float, + noise_level_inflation_factor: float + ): + """Initializes the sampler. + + Args: + denoiser: A Denoiser which predicts noise-free targets. + max_noise_level: The highest noise level used at the start of the + sequence of reverse diffusion steps. + min_noise_level: The lowest noise level used at the end of the sequence of + reverse diffusion steps. + num_noise_levels: Determines the number of noise levels used and hence the + number of reverse diffusion steps performed. + rho: Parameter affecting the spacing of noise steps. Higher values will + concentrate noise steps more around zero. + stochastic_churn_rate: S_churn from the paper. This controls the rate + at which noise is re-injected/'churned' during the sampling algorithm. + If this is set to zero then we are performing deterministic sampling + as described in Algorithm 1. + churn_min_noise_level: Minimum noise level at which stochastic churn + occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. + churn_max_noise_level: Maximum noise level at which stochastic churn + occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. + noise_level_inflation_factor: This can be used to set the actual amount of + noise injected higher than what the denoiser is told has been added. + The motivation is to compensate for a tendency of L2-trained denoisers + to remove slightly too much noise / blur too much. S_noise from the + paper. Only used if stochastic_churn_rate > 0. + """ + super().__init__(denoiser) + self._noise_levels = utils.noise_schedule( + max_noise_level, min_noise_level, num_noise_levels, rho) + self._stochastic_churn = stochastic_churn_rate > 0 + self._per_step_churn_rates = utils.stochastic_churn_rate_schedule( + self._noise_levels, stochastic_churn_rate, churn_min_noise_level, + churn_max_noise_level) + self._noise_level_inflation_factor = noise_level_inflation_factor + + def __call__( + self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + + dtype = casting.infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types + noise_levels = jnp.array(self._noise_levels).astype(dtype) + per_step_churn_rates = jnp.array(self._per_step_churn_rates).astype(dtype) + + def denoiser(noise_level: jnp.ndarray, x: xarray.Dataset) -> xarray.Dataset: + """Computes D(x, sigma, y).""" + bcast_noise_level = xarray_jax.DataArray( + jnp.tile(noise_level, x.sizes['batch']), dims=('batch',)) + # Estimate the expectation of the fully-denoised target x0, conditional on + # inputs/forcings, noisy targets and their noise level: + return self._denoiser( + inputs=inputs, + noisy_targets=x, + noise_levels=bcast_noise_level, + forcings=forcings) + + def body_fn(i: jnp.ndarray, x: xarray.Dataset) -> xarray.Dataset: + """One iteration of the sampling algorithm. + + Args: + i: Sampling iteration. + x: Noisy targets at iteration i, these will have noise level + self._noise_levels[i]. + + Returns: + Noisy targets at the next lowest noise level self._noise_levels[i+1]. + """ + def init_noise(template): + return noise_levels[0] * utils.spherical_white_noise_like(template) + + # Initialise the inputs if i == 0. + # This is done here to ensure both noise sampler calls can use the same + # spherical harmonic basis functions. While there may be a small compute + # cost the memory savings can be significant. + # TODO(dominicmasters): Figure out if we can merge the two noise sampler + # calls into one to avoid this hack. + maybe_init_noise = (i == 0).astype(noise_levels[0].dtype) + x = x + init_noise(x) * maybe_init_noise + + noise_level = noise_levels[i] + + if self._stochastic_churn: + # We increase the noise level of x a bit before taking it down again: + x, noise_level = utils.apply_stochastic_churn( + x, noise_level, + stochastic_churn_rate=per_step_churn_rates[i], + noise_level_inflation_factor=self._noise_level_inflation_factor) + + # Apply one step of the ODE solver to take x down to the next lowest + # noise level. + + # Note that the Elucidating paper's choice of sigma(t)=t and s(t)=1 + # (corresponding to alpha(t)=1 in the DPM paper) as well as the standard + # choice of r=1/2 (corresponding to a geometric mean for the s_i + # midpoints) greatly simplifies the update from the DPM-Solver++ paper. + # You need to do a bit of algebraic fiddling to arrive at the below after + # substituting these choices into DPMSolver++'s Algorithm 1. The simpler + # update we arrive at helps with intuition too. + + next_noise_level = noise_levels[i + 1] + # This is s_{i+1} from the paper. They don't explain how the s_i are + # chosen, but the default choice seems to be a geometric mean, which is + # equivalent to setting all the r_i = 1/2. + mid_noise_level = jnp.sqrt(noise_level * next_noise_level) + + mid_over_current = mid_noise_level / noise_level + x_denoised = denoiser(noise_level, x) + # This turns out to be a convex combination of current and denoised x, + # which isn't entirely apparent from the paper formulae: + x_mid = mid_over_current * x + (1 - mid_over_current) * x_denoised + + next_over_current = next_noise_level / noise_level + x_mid_denoised = denoiser(mid_noise_level, x_mid) # pytype: disable=wrong-arg-types + x_next = next_over_current * x + (1 - next_over_current) * x_mid_denoised + + # For the final step to noise level 0, we do an Euler update which + # corresponds to just returning the denoiser's prediction directly. + # + # In fact the behaviour above when next_noise_level == 0 is almost + # equivalent, except that it runs the denoiser a second time to denoise + # from noise level 0. The denoiser should just be the identity function in + # this case, but it hasn't necessarily been trained at noise level 0 so + # we avoid relying on this. + return utils.tree_where(next_noise_level == 0, x_denoised, x_next) + + # Init with zeros but apply additional noise at step 0 to initialise the + # state. + noise_init = xarray.zeros_like(targets_template) + return hk.fori_loop( + 0, len(noise_levels) - 1, body_fun=body_fn, init_val=noise_init) diff --git a/graphcast/gencast.py b/graphcast/gencast.py new file mode 100644 index 00000000..1454de76 --- /dev/null +++ b/graphcast/gencast.py @@ -0,0 +1,284 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Denoising diffusion models based on the framework of [1]. + +Throughout we will refer to notation and equations from [1]. + + [1] Elucidating the Design Space of Diffusion-Based Generative Models + Karras, Aittala, Aila and Laine, 2022 + https://arxiv.org/abs/2206.00364 +""" + +from typing import Any, Optional, Tuple + +import chex +from graphcast import casting +from graphcast import denoiser +from graphcast import dpm_solver_plus_plus_2s +from graphcast import graphcast +from graphcast import losses +from graphcast import predictor_base +from graphcast import samplers_utils +from graphcast import xarray_jax +import haiku as hk +import jax +import xarray + + +TARGET_SURFACE_VARS = ( + '2m_temperature', + 'mean_sea_level_pressure', + '10m_v_component_of_wind', + '10m_u_component_of_wind', # GenCast predicts in 12hr timesteps. + 'total_precipitation_12hr', + 'sea_surface_temperature', +) + +TARGET_SURFACE_NO_PRECIP_VARS = ( + '2m_temperature', + 'mean_sea_level_pressure', + '10m_v_component_of_wind', + '10m_u_component_of_wind', + 'sea_surface_temperature', +) + + +TASK = graphcast.TaskConfig( + input_variables=( + # GenCast doesn't take precipitation as input. + TARGET_SURFACE_NO_PRECIP_VARS + + graphcast.TARGET_ATMOSPHERIC_VARS + + graphcast.GENERATED_FORCING_VARS + + graphcast.STATIC_VARS + ), + target_variables=TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS, + # GenCast doesn't take incident solar radiation as a forcing. + forcing_variables=graphcast.GENERATED_FORCING_VARS, + pressure_levels=graphcast.PRESSURE_LEVELS_WEATHERBENCH_13, + # GenCast takes the current frame and the frame 12 hours prior. + input_duration='24h', +) + + +@chex.dataclass(frozen=True, eq=True) +class SamplerConfig: + """Configures the sampler used to draw samples from GenCast. + + max_noise_level: The highest noise level used at the start of the + sequence of reverse diffusion steps. + min_noise_level: The lowest noise level used at the end of the sequence of + reverse diffusion steps. + num_noise_levels: Determines the number of noise levels used and hence the + number of reverse diffusion steps performed. + rho: Parameter affecting the spacing of noise steps. Higher values will + concentrate noise steps more around zero. + stochastic_churn_rate: S_churn from the paper. This controls the rate + at which noise is re-injected/'churned' during the sampling algorithm. + If this is set to zero then we are performing deterministic sampling + as described in Algorithm 1. + churn_max_noise_level: Maximum noise level at which stochastic churn + occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. + churn_min_noise_level: Minimum noise level at which stochastic churn + occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. + noise_level_inflation_factor: This can be used to set the actual amount of + noise injected higher than what the denoiser is told has been added. + The motivation is to compensate for a tendency of L2-trained denoisers + to remove slightly too much noise / blur too much. S_noise from the + paper. Only used if stochastic_churn_rate > 0. + """ + max_noise_level: float = 80. + min_noise_level: float = 0.03 + num_noise_levels: int = 20 + rho: float = 7. + # Stochastic sampler settings. + stochastic_churn_rate: float = 2.5 + churn_min_noise_level: float = 0.75 + churn_max_noise_level: float = float('inf') + noise_level_inflation_factor: float = 1.05 + + +@chex.dataclass(frozen=True, eq=True) +class NoiseConfig: + training_noise_level_rho: float = 7.0 + training_max_noise_level: float = 88.0 + training_min_noise_level: float = 0.02 + + +@chex.dataclass(frozen=True, eq=True) +class CheckPoint: + description: str + license: str + params: dict[str, Any] + task_config: graphcast.TaskConfig + denoiser_architecture_config: denoiser.DenoiserArchitectureConfig + sampler_config: SamplerConfig + noise_config: NoiseConfig + noise_encoder_config: denoiser.NoiseEncoderConfig + + +class GenCast(predictor_base.Predictor): + """Predictor for a denoising diffusion model following the framework of [1]. + + [1] Elucidating the Design Space of Diffusion-Based Generative Models + Karras, Aittala, Aila and Laine, 2022 + https://arxiv.org/abs/2206.00364 + + Unlike the paper, we have a conditional model and our denoising function + conditions on previous timesteps. + + As the paper demonstrates, the sampling algorithm can be varied independently + of the denoising model and its training procedure, and it is separately + configurable here. + """ + + def __init__( + self, + task_config: graphcast.TaskConfig, + denoiser_architecture_config: denoiser.DenoiserArchitectureConfig, + sampler_config: Optional[SamplerConfig] = None, + noise_config: Optional[NoiseConfig] = None, + noise_encoder_config: Optional[denoiser.NoiseEncoderConfig] = None, + ): + """Constructs GenCast.""" + # Output size depends on number of variables being predicted. + num_surface_vars = len( + set(task_config.target_variables) + - set(graphcast.ALL_ATMOSPHERIC_VARS) + ) + num_atmospheric_vars = len( + set(task_config.target_variables) + & set(graphcast.ALL_ATMOSPHERIC_VARS) + ) + num_outputs = ( + num_surface_vars + + len(task_config.pressure_levels) * num_atmospheric_vars + ) + denoiser_architecture_config.node_output_size = num_outputs + self._denoiser = denoiser.Denoiser( + noise_encoder_config, + denoiser_architecture_config, + ) + self._sampler_config = sampler_config + # Singleton to avoid re-initializing the sampler for each inference call. + self._sampler = None + self._noise_config = noise_config + + def _c_in(self, noise_scale: xarray.DataArray) -> xarray.DataArray: + """Scaling applied to the noisy targets input to the underlying network.""" + return (noise_scale**2 + 1)**-0.5 + + def _c_out(self, noise_scale: xarray.DataArray) -> xarray.DataArray: + """Scaling applied to the underlying network's raw outputs.""" + return noise_scale * (noise_scale**2 + 1)**-0.5 + + def _c_skip(self, noise_scale: xarray.DataArray) -> xarray.DataArray: + """Scaling applied to the skip connection.""" + return 1 / (noise_scale**2 + 1) + + def _loss_weighting(self, noise_scale: xarray.DataArray) -> xarray.DataArray: + r"""The loss weighting \lambda(\sigma) from the paper.""" + return self._c_out(noise_scale) ** -2 + + def _preconditioned_denoiser( + self, + inputs: xarray.Dataset, + noisy_targets: xarray.Dataset, + noise_levels: xarray.DataArray, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + """The preconditioned denoising function D from the paper (Eqn 7).""" + raw_predictions = self._denoiser( + inputs=inputs, + noisy_targets=noisy_targets * self._c_in(noise_levels), + noise_levels=noise_levels, + forcings=forcings, + **kwargs) + return (raw_predictions * self._c_out(noise_levels) + + noisy_targets * self._c_skip(noise_levels)) + + def loss_and_predictions( + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + ) -> Tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]: + return self.loss(inputs, targets, forcings), self(inputs, targets, forcings) + + def loss(self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + ) -> predictor_base.LossAndDiagnostics: + + if self._noise_config is None: + raise ValueError('Noise config must be specified to train GenCast.') + + # Sample noise levels: + dtype = casting.infer_floating_dtype(targets) # pytype: disable=wrong-arg-types + key = hk.next_rng_key() + batch_size = inputs.sizes['batch'] + noise_levels = xarray_jax.DataArray( + data=samplers_utils.rho_inverse_cdf( + min_value=self._noise_config.training_min_noise_level, + max_value=self._noise_config.training_max_noise_level, + rho=self._noise_config.training_noise_level_rho, + cdf=jax.random.uniform(key, shape=(batch_size,), dtype=dtype)), + dims=('batch',)) + + # Sample noise and apply it to targets: + noise = ( + samplers_utils.spherical_white_noise_like(targets) * noise_levels + ) + noisy_targets = targets + noise + + denoised_predictions = self._preconditioned_denoiser( + inputs, noisy_targets, noise_levels, forcings) + + loss, diagnostics = losses.weighted_mse_per_level( + denoised_predictions, + targets, + # Weights are same as we used for GraphCast. + per_variable_weights={ + # Any variables not specified here are weighted as 1.0. + # A single-level variable, but an important headline variable + # and also one which we have struggled to get good performance + # on at short lead times, so leaving it weighted at 1.0, equal + # to the multi-level variables: + '2m_temperature': 1.0, + # New single-level variables, which we don't weight too highly + # to avoid hurting performance on other variables. + '10m_u_component_of_wind': 0.1, + '10m_v_component_of_wind': 0.1, + 'mean_sea_level_pressure': 0.1, + 'sea_surface_temperature': 0.1, + 'total_precipitation_12hr': 0.1 + }, + ) + loss *= self._loss_weighting(noise_levels) + return loss, diagnostics + + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + if self._sampler_config is None: + raise ValueError( + 'Sampler config must be specified to run inference on GenCast.' + ) + if self._sampler is None: + self._sampler = dpm_solver_plus_plus_2s.Sampler( + self._preconditioned_denoiser, **self._sampler_config + ) + return self._sampler(inputs, targets_template, forcings, **kwargs) diff --git a/graphcast/icosahedral_mesh.py b/graphcast/icosahedral_mesh.py index 4c43642b..62837834 100644 --- a/graphcast/icosahedral_mesh.py +++ b/graphcast/icosahedral_mesh.py @@ -279,3 +279,7 @@ def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]]) receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]]) return senders, receivers + + +def get_last_triangular_mesh_for_sphere(splits: int) -> TriangularMesh: + return get_hierarchy_of_triangular_meshes_for_sphere(splits=splits)[-1] diff --git a/graphcast/mlp.py b/graphcast/mlp.py new file mode 100644 index 00000000..a60eb661 --- /dev/null +++ b/graphcast/mlp.py @@ -0,0 +1,45 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constructors for MLPs.""" + +import haiku as hk +import jax +import jax.numpy as jnp + + +# TODO(aelkadi): Move the mlp factory here from `deep_typed_graph_net.py`. + + +class LinearNormConditioning(hk.Module): + """Module for norm conditioning. + + Conditions the normalization of "inputs" by applying a linear layer to the + "norm_conditioning" which produces the scale and variance which are applied to + each channel (across the last dim) of "inputs". + """ + + def __init__(self, name="norm_conditioning"): + super().__init__(name=name) + + def __call__(self, inputs: jax.Array, norm_conditioning: jax.Array): + + feature_size = inputs.shape[-1] + conditional_linear_layer = hk.Linear( + output_size=2 * feature_size, + w_init=hk.initializers.TruncatedNormal(stddev=1e-8), + ) + conditional_scale_offset = conditional_linear_layer(norm_conditioning) + scale_minus_one, offset = jnp.split(conditional_scale_offset, 2, axis=-1) + scale = scale_minus_one + 1. + return inputs * scale + offset diff --git a/graphcast/model_utils.py b/graphcast/model_utils.py index 949088c0..11209f11 100644 --- a/graphcast/model_utils.py +++ b/graphcast/model_utils.py @@ -15,6 +15,7 @@ from typing import Mapping, Optional, Tuple +import jax.numpy as jnp import numpy as np from scipy.spatial import transform import xarray @@ -722,3 +723,36 @@ def stacked_to_dataset( name=template_var.name, ) return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count + + +def fourier_features( + values: jnp.ndarray, + base_period: float, + num_frequencies: int, + ) -> jnp.ndarray: + """Maps values to sin/cos features for a range of frequencies. + + Args: + values: Values to compute Fourier features for. + base_period: The base period to use. This should be greater or equal to the + range of the values, or to the period if the values have periodic + semantics (e.g. 2pi if they represent angles). Frequencies used will be + integer multiples of 1/base_period. + num_frequencies: The number of frequencies to use, we will use integer + multiples of 1/base_period from 1 up to num_frequencies inclusive. (We + don't include a zero frequency as this would just give constant features + which are redundant if a bias term is present). + + Returns: + Array with same shape as values except with an extra trailing dimension + of size 2*num_frequencies, which contains a sin and a cos feature for each + frequency. + """ + frequencies = np.arange(1, num_frequencies + 1) / base_period + angular_frequencies = jnp.array(2 * np.pi * frequencies, dtype=values.dtype) + values_times_angular_freqs = values[..., None] * angular_frequencies + return jnp.concatenate( + [jnp.cos(values_times_angular_freqs), + jnp.sin(values_times_angular_freqs)], + axis=-1) + diff --git a/graphcast/nan_cleaning.py b/graphcast/nan_cleaning.py new file mode 100644 index 00000000..ced9e9fc --- /dev/null +++ b/graphcast/nan_cleaning.py @@ -0,0 +1,125 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrappers for Predictors which allow them to work with data cleaned of NaNs. + +The Predictor which is wrapped sees inputs and targets without NaNs, and makes +NaNless predictions. +""" + +from typing import Optional, Tuple + +from graphcast import predictor_base as base +import numpy as np +import xarray + + +class NaNCleaner(base.Predictor): + """A predictor wrapper than removes NaNs from ingested data. + + The Predictor which is wrapped sees inputs and targets without NaNs. + """ + + def __init__( + self, + predictor: base.Predictor, + var_to_clean: str, + fill_value: xarray.Dataset, + reintroduce_nans: bool = False, + ): + """Initializes the NaNCleaner.""" + self._predictor = predictor + self._fill_value = fill_value[var_to_clean] + self._var_to_clean = var_to_clean + self._reintroduce_nans = reintroduce_nans + + def _clean(self, dataset: xarray.Dataset) -> xarray.Dataset: + """Cleans the dataset of NaNs.""" + data_array = dataset[self._var_to_clean] + dataset = dataset.assign( + {self._var_to_clean: data_array.fillna(self._fill_value)} + ) + return dataset + + def _maybe_reintroduce_nans( + self, stale_inputs: xarray.Dataset, predictions: xarray.Dataset + ) -> xarray.Dataset: + # NaN positions don't change between input frames, if they do then + # we should be more careful about re-introducing them. + if self._var_to_clean in predictions.keys(): + nan_mask = np.isnan(stale_inputs[self._var_to_clean]).any(dim='time') + with_nan_values = predictions[self._var_to_clean].where(~nan_mask, np.nan) + predictions = predictions.assign({self._var_to_clean: with_nan_values}) + return predictions + + def __call__( + self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs, + ) -> xarray.Dataset: + if self._reintroduce_nans: + # Copy inputs before cleaning so that we can reintroduce NaNs later. + original_inputs = inputs.copy() + if self._var_to_clean in inputs.keys(): + inputs = self._clean(inputs) + if forcings and self._var_to_clean in forcings.keys(): + forcings = self._clean(forcings) + predictions = self._predictor( + inputs, targets_template, forcings, **kwargs + ) + if self._reintroduce_nans: + predictions = self._maybe_reintroduce_nans(original_inputs, predictions) + return predictions + + def loss( + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs, + ) -> base.LossAndDiagnostics: + if self._var_to_clean in inputs.keys(): + inputs = self._clean(inputs) + if self._var_to_clean in targets.keys(): + targets = self._clean(targets) + if forcings and self._var_to_clean in forcings.keys(): + forcings = self._clean(forcings) + return self._predictor.loss( + inputs, targets, forcings, **kwargs + ) + + def loss_and_predictions( + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs, + ) -> Tuple[base.LossAndDiagnostics, xarray.Dataset]: + if self._reintroduce_nans: + # Copy inputs before cleaning so that we can reintroduce NaNs later. + original_inputs = inputs.copy() + if self._var_to_clean in inputs.keys(): + inputs = self._clean(inputs) + if self._var_to_clean in targets.keys(): + targets = self._clean(targets) + if forcings and self._var_to_clean in forcings.keys(): + forcings = self._clean(forcings) + + loss, predictions = self._predictor.loss_and_predictions( + inputs, targets, forcings, **kwargs + ) + if self._reintroduce_nans: + predictions = self._maybe_reintroduce_nans(original_inputs, predictions) + return loss, predictions diff --git a/graphcast/rollout.py b/graphcast/rollout.py index b243d0fe..d10bc1c0 100644 --- a/graphcast/rollout.py +++ b/graphcast/rollout.py @@ -13,11 +13,12 @@ # limitations under the License. """Utils for rolling out models.""" -from typing import Iterator +from typing import Iterator, Optional, Sequence from absl import logging import chex import dask.array +from graphcast import xarray_jax from graphcast import xarray_tree import jax import numpy as np @@ -37,6 +38,170 @@ def __call__( ... +def _replicate_dataset( + data: xarray.Dataset, replica_dim: str, + replicate_to_device: bool, + devices: Sequence[jax.Device], + ) -> xarray.Dataset: + """Used to prepare for xarray_jax.pmap.""" + + def replicate_variable(variable: xarray.Variable) -> xarray.Variable: + if replica_dim in variable.dims: + # TODO(pricei): call device_put_replicated when replicate_to_device==True + return variable.transpose(replica_dim, ...) + else: + data = len(devices) * [variable.data] + if replicate_to_device: + assert devices is not None + # TODO(pricei): Refactor code to use "device_put_replicated" instead of + # device_put_sharded. + data = jax.device_put_sharded(data, devices) + else: + data = np.stack(data, axis=0) + return xarray_jax.Variable( + data=data, dims=(replica_dim,) + variable.dims, attrs=variable.attrs + ) + + def replicate_dataset(dataset: xarray.Dataset) -> xarray.Dataset: + if dataset is None: + return None + data_variables = { + name: replicate_variable(var) + for name, var in dataset.data_vars.variables.items() + } + coords = {name: coord.variable for name, coord in dataset.coords.items()} + return xarray.Dataset(data_variables, coords=coords, attrs=dataset.attrs) + + return replicate_dataset(data) + + +def chunked_prediction_generator_multiple_runs( + predictor_fn: PredictorFn, + rngs: chex.PRNGKey, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset], + num_samples: Optional[int], + pmap_devices: Optional[Sequence[jax.Device]] = None, + **chunked_prediction_kwargs, +) -> Iterator[xarray.Dataset]: + """Outputs a trajectory of multiple samples by yielding chunked predictions. + + Args: + predictor_fn: Function to use to make predictions for each chunk. + rngs: RNG sequence to be used for each ensemble member. + inputs: Inputs for the model. + targets_template: Template for the target prediction, requires targets + equispaced in time. + forcings: Optional forcing for the model. + num_samples: The number of runs / samples to rollout. + pmap_devices: List of devices over which predictor_fn is pmapped, or None if + it is not pmapped. + **chunked_prediction_kwargs: + See chunked_prediction, some of these are required arguments. + + Yields: + The predictions for each chunked step of the chunked rollout, such that + if all predictions are concatenated in time and sample dimension squeezed, + this would match the targets template in structure. + + """ + if pmap_devices is not None: + assert ( + num_samples % jax.device_count() == 0 + ), "num_samples must be a multiple of jax.device_count()" + + def predictor_fn_pmap_named_args(rng, inputs, targets_template, forcings): + targets_template = _replicate_dataset( + targets_template, + replica_dim="sample", + replicate_to_device=True, + devices=pmap_devices, + ) + return predictor_fn(rng, inputs, targets_template, forcings) + + for i in range(0, num_samples, jax.device_count()): + sample_idx = slice(i, i + jax.device_count()) + logging.info("Samples %s out of %s", sample_idx, num_samples) + logging.flush() + sample_group_rngs = rngs[sample_idx] + + if "sample" not in inputs.dims: + sample_inputs = inputs + else: + sample_inputs = inputs.isel(sample=sample_idx, drop=True) + + sample_inputs = _replicate_dataset( + sample_inputs, + replica_dim="sample", + replicate_to_device=True, + devices=pmap_devices, + ) + + if forcings is not None: + if "sample" not in forcings.dims: + sample_forcings = forcings + else: + sample_forcings = forcings.isel(sample=sample_idx, drop=True) + + # TODO(pricei): We are replicating the full forcings for all rollout + # timesteps here, rather than inside `predictor_fn_pmap_named_args` like + # the targets_template above, because the forcings are concatenated with + # the inputs which will already be replicated. We should refactor this + # so that chunked prediction is aware of whether it is being run with + # pmap, and if so do the replication and device_put only of the + # necessary timesteps, as part of the chunked prediction function. + sample_forcings = _replicate_dataset( + sample_forcings, + replica_dim="sample", + replicate_to_device=False, + devices=pmap_devices, + ) + else: + sample_forcings = None + + for prediction_chunk in chunked_prediction_generator( + predictor_fn=predictor_fn_pmap_named_args, + rng=sample_group_rngs, + inputs=sample_inputs, + targets_template=targets_template, + forcings=sample_forcings, + pmap_devices=pmap_devices, + **chunked_prediction_kwargs, + ): + prediction_chunk.coords["sample"] = np.arange( + sample_idx.start, sample_idx.stop, sample_idx.step + ) + yield prediction_chunk + del prediction_chunk + else: + for i in range(num_samples): + logging.info("Sample %d/%d", i, num_samples) + logging.flush() + this_sample_rng = rngs[i] + + if "sample" in inputs.dims: + sample_inputs = inputs.isel(sample=i, drop=True) + else: + sample_inputs = inputs + + sample_forcings = forcings + if sample_forcings is not None: + if "sample" in sample_forcings.dims: + sample_forcings = sample_forcings.isel(sample=i, drop=True) + + for prediction_chunk in chunked_prediction_generator( + predictor_fn=predictor_fn, + rng=this_sample_rng, + inputs=sample_inputs, + targets_template=targets_template, + forcings=sample_forcings, + **chunked_prediction_kwargs): + prediction_chunk.coords["sample"] = i + yield prediction_chunk + del prediction_chunk + + def chunked_prediction( predictor_fn: PredictorFn, rng: chex.PRNGKey, @@ -85,6 +250,7 @@ def chunked_prediction_generator( forcings: xarray.Dataset, num_steps_per_chunk: int = 1, verbose: bool = False, + pmap_devices: Optional[Sequence[jax.Device]] = None ) -> Iterator[xarray.Dataset]: """Outputs a long trajectory by yielding chunked predictions. @@ -99,6 +265,8 @@ def chunked_prediction_generator( at each call of `predictor_fn`. It must evenly divide the number of steps in `targets_template`. verbose: Whether to log the current chunk being predicted. + pmap_devices: List of devices over which predictor_fn is pmapped, or None if + it is not pmapped. Yields: The predictions for each chunked step of the chunked rollout, such as @@ -140,6 +308,19 @@ def chunked_prediction_generator( time=slice(0, num_steps_per_chunk)) current_inputs = inputs + + def split_rng_fn(rng): + # Note, this is *not* equivalent to `return jax.random.split(rng)`, because + # by assigning to a tuple, the single numpy array returned by + # `jax.random.split` actually gets split into two arrays, so when calling + # the function with pmap the output is Tuple[Array, Array], where the + # leading axis of each array is `num devices`. + rng1, rng2 = jax.random.split(rng) + return rng1, rng2 + + if pmap_devices is not None: + split_rng_fn = jax.pmap(split_rng_fn, devices=pmap_devices) + for chunk_index in range(num_chunks): if verbose: logging.info("Chunk %d/%d", chunk_index, num_chunks) @@ -160,13 +341,24 @@ def chunked_prediction_generator( current_forcings = current_forcings.assign_coords(time=targets_chunk_time) current_forcings = current_forcings.compute() # Make predictions for the chunk. - rng, this_rng = jax.random.split(rng) + rng, this_rng = split_rng_fn(rng) predictions = predictor_fn( rng=this_rng, inputs=current_inputs, targets_template=current_targets_template, forcings=current_forcings) + # In the pmapped case, profiling reveals that the predictions, forcings and + # inputs are all copied onto a single TPU, causing OOM. To avoid this + # we pull all of the input/output data off the devices. This will have + # some performance impact, but maximise the memory efficiency. + # TODO(aelkadi): Pmap `_get_next_inputs` when running under pmap, and + # remove the device_get. + if pmap_devices is not None: + predictions = jax.device_get(predictions) + current_forcings = jax.device_get(current_forcings) + current_inputs = jax.device_get(current_inputs) + next_frame = xarray.merge([predictions, current_forcings]) next_inputs = _get_next_inputs(current_inputs, next_frame) diff --git a/graphcast/samplers_base.py b/graphcast/samplers_base.py new file mode 100644 index 00000000..6d60aa83 --- /dev/null +++ b/graphcast/samplers_base.py @@ -0,0 +1,47 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class for diffusion samplers.""" + +import abc +from typing import Optional + +from graphcast import denoisers_base +import xarray + + +class Sampler(abc.ABC): + """A sampling algorithm for a denoising diffusion model. + + This is constructed with a denoising function, and uses it to draw samples. + """ + + _denoiser: denoisers_base.Denoiser + + def __init__(self, denoiser: denoisers_base.Denoiser): + """Constructs Sampler. + + Args: + denoiser: A Denoiser which has been trained with an MSE loss to predict + the noise-free targets. + """ + self._denoiser = denoiser + + @abc.abstractmethod + def __call__( + self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + """Draws a sample using self._denoiser. Contract like Predictor.__call__.""" diff --git a/graphcast/samplers_utils.py b/graphcast/samplers_utils.py new file mode 100644 index 00000000..7b0125d1 --- /dev/null +++ b/graphcast/samplers_utils.py @@ -0,0 +1,431 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for diffusion samplers. Makes use of dinosaur.spherical_harmonic.""" + +import dataclasses +import functools +from typing import Any, cast, Optional, Tuple + +import chex +from dinosaur import spherical_harmonic +from graphcast import xarray_jax +from graphcast import xarray_tree +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import xarray + +# Some useful constants useful when dealing with Earth's geometry. +# The earth isn't really a sphere so these are only approximate, this is the +# average radius according to https://en.wikipedia.org/wiki/Earth_radius, +# with the actual value varying from 6378 to 6357km. +EARTH_RADIUS_KM = 6371. +# And this is also approximate, but we've chosen to make it consistent with the +# radius above when modelling the earth as a sphere. This gives a value of +# around 40030; the actual value varies from 40008 to 40075. +EARTH_CIRCUMFERENCE_KM = EARTH_RADIUS_KM * 2 * np.pi + + +@dataclasses.dataclass(frozen=True) +class _ArrayGrid: + """A class that performs operations and transformations in the spectral basis. + + Attributes: + longitude_wavenumbers: num of longitude wavenumbers in the spectral basis. + total_wavenumbers: number of total wavenumbers in the spectral basis. + longitude_nodes: number of quadrature nodes along the lon direction. + latitude_nodes: number of quadrature nodes along the lat direction. + latitude_spacing: either 'gauss' or 'equiangular'. This determines the + spacing of nodal grid points in the latitudinal (north-south) direction. + """ + longitude_wavenumbers: int + total_wavenumbers: int + longitude_nodes: int + latitude_nodes: int + latitude_spacing: str + + @classmethod + def with_lat_lon( + cls, + lat: np.ndarray, + lon: np.ndarray, + ) -> '_ArrayGrid': + """_ArrayGrid for use with data in specified lat/lon grid (in degrees).""" + + latitude_nodes = lat.shape[0] + longitude_nodes = lon.shape[0] + latitude_spacing = _infer_latitude_spacing(lat) + if latitude_spacing in ['equiangular', 'gauss']: + if longitude_nodes != 2 * latitude_nodes: + # Technically not a requirement but useful to ensure `max_wavenumber` + # makes sense. + raise ValueError( + 'Unexpected number of longitude nodes. ' + f'Expected {2 * latitude_nodes}, got {longitude_nodes}') + elif latitude_spacing == 'equiangular_with_poles': + if longitude_nodes != 2 * (latitude_nodes - 1): + # Technically not a requirement but useful to ensure `max_wavenumber` + # makes s + raise ValueError( + 'Unexpected number of longitude nodes. ' + f'Expected {2 * (latitude_nodes - 1)}, got {longitude_nodes}') + else: + raise ValueError(f'Unexpected latitude_spacing={latitude_spacing}') + max_wavenumber = int(longitude_nodes // 2) - 1 + grid = cls( + longitude_wavenumbers=max_wavenumber+1, + # total_wavenumbers should be one larger than max_wavenumber as the + # wavenumbers go from 0 to max_wavenumber inclusive. + total_wavenumbers=max_wavenumber+1, + longitude_nodes=longitude_nodes, + latitude_nodes=latitude_nodes, + latitude_spacing=latitude_spacing, + ) + _verify_nodal_axes(lat, lon, grid.nodal_axes) + return grid + + @functools.cached_property + def _grid(self) -> spherical_harmonic.Grid: + return spherical_harmonic.Grid( + spherical_harmonics_impl=spherical_harmonic.RealSphericalHarmonics, + **dataclasses.asdict(self), + ) + + @functools.cached_property + def nodal_axes(self) -> Tuple[np.ndarray, np.ndarray]: + """Longitude and sin(latitude) coordinates of the nodal basis.""" + return self._grid.nodal_axes + + @functools.cached_property + def modal_axes(self) -> Tuple[np.ndarray, np.ndarray]: + """Longitudinal and total wavenumbers (m, l) of the modal basis.""" + return self._grid.modal_axes + + def to_nodal(self, x: chex.Array) -> chex.Array: + """Maps `x` from a modal to nodal representation.""" + return self._grid.to_nodal(x) + + +def _infer_latitude_spacing(lat: np.ndarray) -> str: + """Infers the type of latitude spacing given the latitude.""" + if not np.all(np.diff(lat) > 0.): + raise ValueError('Latitude values are expected to be sorted.') + + if np.allclose(np.diff(lat), lat[1] - lat[0]): + if np.isclose(max(lat), 90.): + spacing = 'equiangular_with_poles' + else: + spacing = 'equiangular' + else: + spacing = 'gauss' + return spacing + + +def _verify_nodal_axes(lat_coords: np.ndarray, lon_coords: np.ndarray, + nodal_axes: Tuple[np.ndarray, np.ndarray]): + nodal_axes_lon, nodal_axes_sin_lat = nodal_axes + if not np.allclose(nodal_axes_sin_lat, np.sin(np.deg2rad(lat_coords))): + raise ValueError( + "Latitude coords don't match those used by " + "spherical_harmonic.SphericalHarmonicBasis.") + if not np.allclose(nodal_axes_lon, np.deg2rad(lon_coords)): + raise ValueError( + "Longitude coords don't match those used by " + "spherical_harmonic.SphericalHarmonicBasis.") + + +class Grid: + """xarray wrapper around _ArrayGrid.""" + + @classmethod + def for_nodal_data( + cls, + nodal_data: xarray.DataArray, + ) -> 'Grid': + """A Grid for use with a given shape of nodal (lat/lon grid) data. + + This uses the maximum number of spherical harmonics that the grid is able + to resolve. + + This class supports data arrays with latitude spacings as defined by + "dinosaur.spherical_harmonic". In summary: + * 'equiangular': equally spaced (by `d_lat`) values between -90 + d_lat and + 90 - d_lat / 2. In our case, longitude must also be spaced by `d_lat`. + * 'equiangular_with_poles': equally spaced (by `d_lat`) values between -90 + and 90. In our case, longitude must also be spaced by `d_lat`. + * 'gauss': Gauss-Legendre nodes. + + Args: + nodal_data: An xarray with 'lat' and 'lon' dimensions and coordinates in + degrees. + + Returns: + A grid with the specified latitude_nodes, with + longitude_nodes=2*latitude_nodes and max_wavenumber=latitude_nodes-1. + """ + + grid = _ArrayGrid.with_lat_lon( + nodal_data.coords['lat'].data, + nodal_data.coords['lon'].data) + return cls(grid, + nodal_data.coords['lat'].data, + nodal_data.coords['lon'].data) + + def __init__(self, + grid: _ArrayGrid, + lat_coords: np.ndarray, + lon_coords: np.ndarray): + _verify_nodal_axes(lat_coords, lon_coords, grid.nodal_axes) + self._underlying = grid + # Record the exact original lat/lon coords so we can return them exactly + # from an inverse transform, avoiding any xarray merge issues if coordinates + # are off by a rounding error. + self._lat_coords = lat_coords + self._lon_coords = lon_coords + self._longitude_wavenumber_coords, self._total_wavenumber_coords = ( + grid.modal_axes) + + @property + def total_wavenumber_coords(self) -> xarray.DataArray: + """Coords that must be used for 'total_wavenumber' dimension.""" + return xarray.DataArray( + data=self._total_wavenumber_coords, + dims=('total_wavenumber',), + coords={'total_wavenumber': self._total_wavenumber_coords}) + + @property + def longitude_wavenumber_coords(self) -> xarray.DataArray: + """Coords that must be used for 'longitude_wavenumber' dimension.""" + return xarray.DataArray( + data=self._longitude_wavenumber_coords, + dims=('longitude_wavenumber',), + coords={'longitude_wavenumber': self._longitude_wavenumber_coords}) + + def to_nodal( + self, modal_data: xarray.DataArray) -> xarray.DataArray: + """Applies the inverse spherical harmonic transform. + + Args: + modal_data: A tree of xarray.DataArray with 'longitude_wavenumber' and + 'total_wavenumber' dimensions with coords + `self.longitude_wavenumber_coords` and `self.total_wavenumber_coords` + respectively, and with the same sparsity pattern described under + `to_modal`. + + Returns: + Corresponding tree where the 'longitude_wavenumber' and + 'total_wavenumber' dimensions are replaced by 'lat', 'lon' dimensions. + """ + def inverse_transform(modal: xarray.DataArray) -> xarray.DataArray: + if (not np.all(modal.coords['longitude_wavenumber'] == + self._longitude_wavenumber_coords) or + not np.all(modal.coords['total_wavenumber'] == + self._total_wavenumber_coords)): + raise ValueError('Wavenumber coords don\'t follow required convention.') + + return xarray_jax.apply_ufunc( + self._underlying.to_nodal, modal, + input_core_dims=[['longitude_wavenumber', 'total_wavenumber']], + output_core_dims=[['lon', 'lat']], + ).assign_coords( + lon=self._lon_coords, + lat=self._lat_coords, + ) + + return xarray_tree.map_structure(inverse_transform, modal_data) + + +def sample( + key: jnp.ndarray, + power_spectrum: xarray.DataArray, + template: xarray.DataArray, + grid: Optional[Grid] = None, + ) -> xarray.DataArray: + """Samples Gaussian Process noise on a sphere, with a given power spectrum. + + This means the noise will have the given power spectrum *in expectation*; the + power spectrum of individual samples may vary. + + The noise will be isotropic, meaning the distribution is invariant to + rotations of the sphere. + + The marginal variance of the returned values will be equal to the total power, + i.e. the sum of power_spectrum. So if you want unit marginal variance, just + make sure to normalize the power_spectrum to sum to 1. + + Args: + key: JAX rng key. + power_spectrum: An array with shape (total_wavenumber,) giving the power + which is desired at each total wavenumber (corresponding to a wavelength + EARTH_CIRCUMFERENCE/total_wavenumber) for total wavenumbers 0 up to some + maximum. This is in squared units of the quantity being sampled. + template: An array with the shape that you want the samples in, containing + 'lat' and 'lon' dimensions. If other dimensions are present, we draw + multiple independent samples along these other dimensions. + grid: spherical_harmonic.Grid on which to sample the noise. If not specified + a grid will be created based on `template`, however note you may save some + RAM and compute by re-using a single Grid instance across multiple calls. + + Returns: + DataArray with the same shape as template. + """ + if grid is None: + grid = Grid.for_nodal_data(template) + dims = [d for d in template.dims if d not in ('lat', 'lon')] + shape = [template.sizes[d] for d in dims] + coords = {name: coord for name, coord in template.coords.items() + if name not in ('lat', 'lon')} + dims.extend(('total_wavenumber', 'longitude_wavenumber')) + shape.extend((len(grid.total_wavenumber_coords), + len(grid.longitude_wavenumber_coords))) + coords.update({'total_wavenumber': grid.total_wavenumber_coords, + 'longitude_wavenumber': grid.longitude_wavenumber_coords}) + coeffs = xarray_jax.DataArray( + data=jax.random.normal(key, shape), dims=dims, coords=coords) + # Mask out coefficients which are out of range. This broadcasts to a + # triangular mask with shape (total_wavenumber, longitude_wavenumber): + mask = ( + abs(coeffs.longitude_wavenumber) <= coeffs.total_wavenumber + ).astype(np.float32) + # For total_wavenumber t, there will be 2t+1 non-zero coefficients at + # different longitude_wavenumbers. We must normalize the coefficients so that + # summing their squares at each total_wavenumber, sums to the corresponding + # value in the power spectrum: + multiplier = mask * np.sqrt(power_spectrum / mask.sum( + 'longitude_wavenumber', skipna=False)) + # And a standard normalization factor used in this implementation of the + # spherical harmonic transform: + multiplier *= np.sqrt(4 * np.pi) + # Only finally multiply by coeffs to avoid too many broadcasting + # multiplications: + coeffs *= multiplier + result = cast(xarray.DataArray, grid.to_nodal(coeffs)) + result = result.astype(template.dtype) + return result.transpose(*template.dims) + + +def spherical_white_noise_like(template: xarray.Dataset) -> xarray.Dataset: + """Samples isotropic mean 0 variance 1 white noise on the sphere.""" + def spherical_white_noise_like_dataarray(data_array: xarray.DataArray + ) -> xarray.DataArray: + num_wavenumbers = data_array.lon.shape[0] // 2 + key = hk.next_rng_key() + return sample( + key=key, + power_spectrum=xarray_jax.DataArray( + data=np.array([1/num_wavenumbers for _ in range(num_wavenumbers)]), + dims=['total_wavenumber']), + template=data_array) + return template.map(spherical_white_noise_like_dataarray) + + +def rho_inverse_cdf( + min_value: float, + max_value: float, + rho: float, + cdf: Any) -> Any: + """Quantiles of rho distribution used for noise levels at sampling time. + + This is parameterised by rho as in Eqn 5 from the Elucidating paper + (but with max/min flipped so that quantiles are given in ascending not + descending order). It's equivalent to a Beta[rho, 1] distribution rescaled to + [min_value, max_value]. + + At sampling time we use noise levels at fixed quantiles of this distribution. + Unlike in the paper, we also use the same distribution for noise levels at + training time (albeit potentially with different parameters, and sampling from + it at random). + + Args: + min_value: + max_value: + Define the support of the distribution. + rho: + Shape parameter. + cdf: + Value or values between 0 and 1 indicating which quantile you want. Can + be a numpy or jax array. + + Returns: + Quantiles of the distribution, with same shape/type as `cdf`. + """ + return ( + min_value**(1 / rho) + cdf * + (max_value**(1 / rho) - min_value**(1 / rho)) + )**rho + + +def tree_where( + cond: jnp.ndarray, + xs: Any, + ys: Any + ) -> Any: + """Like jnp.where but works with trees for xs and ys (but not for cond).""" + return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), xs, ys) + + +def noise_schedule( + max_noise_level: float = 80., + min_noise_level: float = 0.002, + num_noise_levels: int = 30, + rho: float = 7., +) -> np.ndarray: + """Computes a descending noise schedule for sampling, ending with zero.""" + noise_levels = rho_inverse_cdf( + min_value=min_noise_level, + max_value=max_noise_level, + rho=rho, + # We want the noise levels in descending order, so ask for quantiles + # 1 down to 0: + cdf=np.linspace(1, 0, num_noise_levels)) + # The final zero noise level is somewhat special-cased. We don't actually + # denoise from this noise level but appending it here is convenient for + # sampling loop implementations. + return np.append(noise_levels, 0.) + + +def stochastic_churn_rate_schedule( + noise_levels: np.ndarray, + stochastic_churn_rate: float = 0., + churn_min_noise_level: float = 0.05, + churn_max_noise_level: float = 50.0, +) -> np.ndarray: + """Computes a stochastic churn rate for each noise level.""" + num_noise_levels = len(noise_levels)-1 # Exclude final zero noise level. + # As in the Elucidated Diffusion paper, clamp this so it doesn't increase the + # variance by a factor of more than 2, no matter how few noise levels are + # used: + per_step_churn_rate = min(stochastic_churn_rate / num_noise_levels, + np.sqrt(2) - 1) + return ( + (churn_min_noise_level <= noise_levels[:-1]) & + (noise_levels[:-1] <= churn_max_noise_level) + ) * per_step_churn_rate + + +def apply_stochastic_churn( + x: Any, + noise_level: jax.typing.ArrayLike, + stochastic_churn_rate: jax.typing.ArrayLike, + noise_level_inflation_factor: jax.typing.ArrayLike, +) -> tuple[Any, jax.typing.ArrayLike]: + """Returns x at higher noise level, and the higher noise level itself.""" + # We increase the noise level of x a bit before taking it down again: + new_noise_level = noise_level * (1.0 + stochastic_churn_rate) + extra_noise_stddev = (jnp.sqrt(new_noise_level**2 - noise_level**2) + * noise_level_inflation_factor) + updated_x = x + spherical_white_noise_like(x) * extra_noise_stddev + return updated_x, new_noise_level + diff --git a/graphcast/sparse_transformer.py b/graphcast/sparse_transformer.py new file mode 100644 index 00000000..502bd840 --- /dev/null +++ b/graphcast/sparse_transformer.py @@ -0,0 +1,577 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformer with either dense or sparse attention. + +The sparse attention implemented here is for nodes to attend only to themselves +and their neighbours on the graph). It assumes that the adjacency matrix has a +banded structure, and is implemented with dense operations computing with only +the diagonal, super diagonal, and subdiagonal blocks of the tri-block-diagonal +attention matrix. + +The basic model structure of the transformer and some functions were adapted +from xlm's transformer_simple.py. +""" + +import dataclasses +import logging +from typing import Any, Callable, Literal, Optional, Tuple + +from graphcast import mlp as mlp_builder +from graphcast import sparse_transformer_utils as utils +import haiku as hk +import jax +from jax.experimental.pallas.ops.tpu import splash_attention +import jax.numpy as jnp +import numpy as np +import scipy as sp + + +@dataclasses.dataclass +class _ModelConfig: + """Transformer config.""" + # Depth, or num transformer blocks. One 'layer' is attn + ffw. + num_layers: int + # Primary width, the number of channels on the carrier path. + d_model: int + # Number of heads for self-attention. + num_heads: int + # Mask block size. + mask_block_size: int + # Attention type - 'mha' or 'triblockdiag_mha' + attention_type: str = 'triblockdiag_mha' + block_q: Optional[int] = None + block_kv: Optional[int] = None + block_kv_compute: Optional[int] = None + block_q_dkv: Optional[int] = None + block_kv_dkv: Optional[int] = None + block_kv_dkv_compute: Optional[int] = None + # mask type if splash attention being used - 'full' or 'lazy' + mask_type: Optional[str] = 'full' + # Number of channels per-head for self-attn QK computation. + key_size: Optional[int] = None + # Number of channels per-head for self-attn V computation. + value_size: Optional[int] = None + # Activation to use, any in jax.nn. + activation: str = 'gelu' + # Init scale for ffw layers (divided by num_layers) + ffw_winit_mult: float = 2.0 + # Init scale for final ffw layer (divided by depth) + ffw_winit_final_mult: float = 2.0 + # Init scale for mha proj (divided by depth). + attn_winit_mult: float = 2.0 + # Init scale for mha w (divided by depth). + attn_winit_final_mult: float = 2.0 + # Number of hidden units in the MLP blocks. Defaults to 4 * d_model. + ffw_hidden: Optional[int] = None + + def __post_init__(self): + if self.ffw_hidden is None: + self.ffw_hidden = 4 * self.d_model + # Compute key_size and value_size from d_model // num_heads. + if self.key_size is None: + if self.d_model % self.num_heads != 0: + raise ValueError('num_heads has to divide d_model exactly') + self.key_size = self.d_model // self.num_heads + if self.value_size is None: + if self.d_model % self.num_heads != 0: + raise ValueError('num_heads has to divide d_model exactly') + self.value_size = self.d_model // self.num_heads + + +def get_mask_block_size(mask: sp.sparse.csr_matrix) -> int: + """Get blocksize of the adjacency matrix (attn mask) for the permuted mesh.""" + # sub-diagonal bandwidth + lbandwidth = ( + np.arange(mask.shape[0]) - (mask != 0).argmax(axis=0) + 1).max() + # super-diagonal bandwidth + ubandwidth = ( + (mask.shape[0]-1) - np.argmax(mask[::-1,:] != 0, axis=0 + ) - np.arange(mask.shape[0]) + 1).max() + block_size = np.maximum(lbandwidth, ubandwidth) + return block_size + + +def ffw(x: jnp.ndarray, cfg: _ModelConfig) -> jnp.ndarray: + """Feed-forward block.""" + ffw_winit = hk.initializers.VarianceScaling(cfg.ffw_winit_mult / + cfg.num_layers) + ffw_winit_final = hk.initializers.VarianceScaling(cfg.ffw_winit_final_mult / + cfg.num_layers) + x = hk.Linear(cfg.ffw_hidden, name='ffw_up', w_init=ffw_winit)(x) + x = getattr(jax.nn, cfg.activation)(x) + return hk.Linear(cfg.d_model, name='ffw_down', w_init=ffw_winit_final)(x) + + +def triblockdiag_softmax(logits: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Softmax given the diag, upper diag, and lower diag logit blocks.""" + + logits_d, logits_u, logits_l = logits + + m = jnp.max(jnp.stack([ + jax.lax.stop_gradient(logits_d.max(-1, keepdims=True)), + jax.lax.stop_gradient(logits_u.max(-1, keepdims=True)), + jax.lax.stop_gradient(logits_l.max(-1, keepdims=True))]), axis=0) + + unnormalized_d = jnp.exp(logits_d - m) + unnormalized_u = jnp.exp(logits_u - m) + unnormalized_l = jnp.exp(logits_l - m) + + denom = ( + unnormalized_d.sum(-1, keepdims=True) + + unnormalized_u.sum(-1, keepdims=True) + + unnormalized_l.sum(-1, keepdims=True) + ) + + logits_d = unnormalized_d / denom + logits_u = unnormalized_u / denom + logits_l = unnormalized_l / denom + + return (logits_d, logits_u, logits_l) + + +def triblockdiag_mha(q_input: jnp.ndarray, kv_input: jnp.ndarray, + mask: jnp.ndarray, cfg: _ModelConfig, + ) -> jnp.ndarray: + """Triblockdiag multihead attention.""" + + # q_inputs, kv_input: (batch, num_blocks, block_size, num_heads, d_model) + q = multihead_linear(q_input, 'q', cfg) + k = multihead_linear(kv_input, 'k', cfg) + v = multihead_linear(kv_input, 'v', cfg) + + k = jnp.pad(k, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0))) + v = jnp.pad(v, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0))) + + def qk_prod(queries, keys): + return jnp.einsum('bnqhd,bnkhd->bnhqk', queries, keys) + + # q shape is (batch, num_blocks, block_size, num_heads, qk_dim) + # k shape is (batch, num_blocks + 2, block_size, num_heads, qk_dim) + logits_d = qk_prod(q, k[:, 1:-1, ...]) * cfg.key_size**-0.5 + logits_u = qk_prod(q, k[:, 2:, ...]) * cfg.key_size**-0.5 + logits_l = qk_prod(q, k[:, :-2, ...]) * cfg.key_size**-0.5 + + # apply mask + logits_d = jnp.where(mask[:, 0, ...], logits_d, -1e30) + logits_u = jnp.where(mask[:, 1, ...], logits_u, -1e30) + logits_l = jnp.where(mask[:, 2, ...], logits_l, -1e30) + + logits_d, logits_u, logits_l = utils.wrap_fn_for_upcast_downcast( + (logits_d, logits_u, logits_l), + triblockdiag_softmax + ) + + def av_prod(attn_weights, values): + return jnp.einsum('bnhqk,bnkhd->bnqhd', attn_weights, values) + + out_d = av_prod(logits_d, v[:, 1:-1, ...]) + out_u = av_prod(logits_u, v[:, 2:, ...]) + out_l = av_prod(logits_l, v[:, :-2, ...]) + # x shape is (batch, num_blocks, block_size, num_heads, d_model) + x = out_d + out_u + out_l + + x = jnp.reshape(x, x.shape[:-2] + (cfg.num_heads * cfg.value_size,)) + attn_winit_final = hk.initializers.VarianceScaling( + cfg.attn_winit_final_mult / cfg.num_layers) + x = hk.Linear(cfg.d_model, name='mha_final', w_init=attn_winit_final)(x) + return x + + +def multihead_linear( + x: jnp.ndarray, qkv: str, cfg: _ModelConfig +) -> jnp.ndarray: + """Linearly project `x` to have `head_size` dimensions per head.""" + head_size = cfg.value_size if qkv == 'v' else cfg.key_size + attn_winit = hk.initializers.VarianceScaling(cfg.attn_winit_mult / + cfg.num_layers) + out = hk.Linear( + cfg.num_heads * head_size, + w_init=attn_winit, + name='mha_proj_' + qkv, + with_bias=False, + )(x) + shape = out.shape[:-1] + (cfg.num_heads, head_size) + return jnp.reshape(out, shape) + + +def mha(q_input: jnp.ndarray, kv_input: jnp.ndarray, + mask: jnp.ndarray, cfg: _ModelConfig, + normalize_logits: bool = True, + ) -> jnp.ndarray: + """Multi head attention.""" + + q = multihead_linear(q_input, 'q', cfg) + k = multihead_linear(kv_input, 'k', cfg) + v = multihead_linear(kv_input, 'v', cfg) + + logits = jnp.einsum('bthd, bThd->bhtT', q, k) + if normalize_logits: + logits *= cfg.key_size**-0.5 + if mask is not None: + def apply_mask(m, l): + return jnp.where(m, l, -1e30) + logits = jax.vmap(jax.vmap( + apply_mask, in_axes=[None, 0]), in_axes=[None, 0])(mask, logits) + + # Wrap softmax weights for upcasting & downcasting in case of BF16 activations + weights = utils.wrap_fn_for_upcast_downcast(logits, jax.nn.softmax) + + # Note: our mask never has all 0 rows, since nodes always have self edges, + # so no need to account for that possibility explicitly. + + x = jnp.einsum('bhtT,bThd->bthd', weights, v) + x = jnp.reshape(x, x.shape[:-2] + (cfg.num_heads * cfg.value_size,)) + + attn_winit_final = hk.initializers.VarianceScaling( + cfg.attn_winit_final_mult / cfg.num_layers) + + x = hk.Linear(cfg.d_model, name='mha_final', w_init=attn_winit_final)(x) + return x + + +def _make_splash_mha( + mask, + mask_type: str, + num_heads: int, + block_q: Optional[int] = None, + block_kv: Optional[int] = None, + block_kv_compute: Optional[int] = None, + block_q_dkv: Optional[int] = None, + block_kv_dkv: Optional[int] = None, + block_kv_dkv_compute: Optional[int] = None, + tanh_soft_cap: Optional[float] = None, +) -> Callable[..., jnp.ndarray]: + """Construct attention kernel.""" + if mask_type == 'full': + mask = np.broadcast_to(mask[None], + (num_heads, *mask.shape)).astype(np.bool_) + + block_sizes = splash_attention.BlockSizes( + block_q=block_q, + block_kv=block_kv, + block_kv_compute=block_kv_compute, + block_q_dkv=block_q_dkv, + block_kv_dkv=block_kv_dkv, + block_kv_dkv_compute=block_kv_dkv_compute, + use_fused_bwd_kernel=True, + ) + attn = splash_attention.make_splash_mha(mask, block_sizes=block_sizes, + head_shards=1, + q_seq_shards=1, + attn_logits_soft_cap=tanh_soft_cap, + ) + return attn + + +def splash_mha(q_input: jnp.ndarray, kv_input: jnp.ndarray, + mask: jnp.ndarray | splash_attention.splash_attention_mask.Mask, + cfg: _ModelConfig, + tanh_soft_cap: Optional[float] = None, + normalize_q: bool = True) -> jnp.ndarray: + """Splash attention.""" + + q = multihead_linear(q_input, 'q', cfg) + k = multihead_linear(kv_input, 'k', cfg) + v = multihead_linear(kv_input, 'v', cfg) + + _, _, num_heads, head_dim = q.shape + + assert head_dim % 128 == 0 # splash attention kernel requires this + + attn = _make_splash_mha( + mask=mask, + mask_type=cfg.mask_type, + num_heads=num_heads, + block_q=cfg.block_q, + block_kv=cfg.block_kv, + block_kv_compute=cfg.block_kv_compute, + block_q_dkv=cfg.block_q_dkv, + block_kv_dkv=cfg.block_kv_dkv, + block_kv_dkv_compute=cfg.block_kv_dkv_compute, + tanh_soft_cap=tanh_soft_cap, + ) + attn = jax.vmap(attn) # Add batch axis. + + if normalize_q: + q *= cfg.key_size**-0.5 + + # (batch, nodes, num_heads, head_dim) -> (batch, num_heads, nodes, head_dim) + reformat = lambda y: y.transpose(0, 2, 1, 3) + x = attn(q=reformat(q), k=reformat(k), v=reformat(v)) + x = x.transpose(0, 2, 1, 3) + + x = jnp.reshape(x, x.shape[:-2] + (cfg.num_heads * cfg.value_size,)) + + attn_winit_final = hk.initializers.VarianceScaling( + cfg.attn_winit_final_mult / cfg.num_layers) + + x = hk.Linear(cfg.d_model, name='mha_final', w_init=attn_winit_final)(x) + return x + + +def layernorm( + x: jnp.ndarray, create_scale: bool, create_offset: bool +) -> jnp.ndarray: + return hk.LayerNorm( + axis=-1, create_scale=create_scale, create_offset=create_offset, + name='norm')(x) + + +def mask_block_diags(mask: sp.sparse.csr_matrix, + num_padding_nodes: int, + block_size: int) -> jnp.ndarray: + """Pad and reshape mask diag, super-siag and sub-diag blocks.""" + # add zero padding to mask + mask_padding_rows = sp.sparse.csr_matrix( + (num_padding_nodes, mask.shape[1]), dtype=jnp.int32) + mask = sp.sparse.vstack([mask, mask_padding_rows]) + mask_padding_cols = sp.sparse.csr_matrix( + (mask.shape[0], num_padding_nodes), dtype=jnp.int32) + mask = sp.sparse.hstack([mask, mask_padding_cols]) + + assert (mask.shape[-1] % block_size) == 0 + mask_daig_blocks = jnp.stack( + [jnp.array(mask[i * block_size : (i + 1) * block_size, + i * block_size : (i + 1) * block_size, + ].toarray()) + for i in range(mask.shape[0] // block_size)]) + mask_upper_diag_blocks = jnp.stack( + [jnp.array(mask[i * block_size : (i + 1) * block_size, + (i + 1) * block_size : (i + 2) * block_size, + ].toarray()) + for i in range(mask.shape[0] // block_size - 1)] + + [jnp.zeros((block_size, block_size), dtype=mask.dtype)]) + mask_lower_diag_blocks = jnp.stack( + [jnp.zeros((block_size, block_size), dtype=mask.dtype)] + + [jnp.array(mask[(i + 1) * block_size : (i + 2) * block_size, + i * block_size : (i + 1) * block_size, + ].toarray()) + for i in range(mask.shape[0] // block_size - 1)]) + mask = jnp.stack( + [mask_daig_blocks, mask_upper_diag_blocks, mask_lower_diag_blocks] + ) + mask = jnp.expand_dims(mask, (0, 3)) + return mask + + +def _pad_mask(mask, num_padding_nodes: Tuple[int, int]) -> jnp.ndarray: + q_padding, kv_padding = num_padding_nodes + mask_padding_rows = sp.sparse.csr_matrix( + (q_padding, mask.shape[1]), dtype=np.bool_) + mask = sp.sparse.vstack([mask, mask_padding_rows]) + mask_padding_cols = sp.sparse.csr_matrix( + (mask.shape[0], kv_padding), dtype=np.bool_) + mask = sp.sparse.hstack([mask, mask_padding_cols]) + return mask + + +class WeatherMeshMask(splash_attention.splash_attention_mask.Mask): + """Lazy local mask, prevent attention to embeddings outside window. + + Attributes: + mask: + """ + + _shape: Tuple[int, int] + mask: sp.sparse.spmatrix + + def __init__( + self, + mask: Any + ): + self._shape = mask.shape + self.mask = mask + + @property + def shape(self) -> Tuple[int, int]: + return self._shape + + def __getitem__(self, idx) -> np.ndarray: + if len(idx) != 2: + raise NotImplementedError(f'Unsupported slice: {idx}') + q_slice, kv_slice = idx + if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice): + raise NotImplementedError(f'Unsupported slice: {idx}') + + return self.mask[q_slice, kv_slice].toarray() + + +class Block(hk.Module): + """Transformer block (mha and ffw).""" + + def __init__(self, cfg, mask, num_nodes, num_padding_nodes, name=None): + super().__init__(name=name) + self._cfg = cfg + self.mask = mask + self.num_nodes = num_nodes + self.num_padding_nodes = num_padding_nodes + + def __call__(self, x, global_norm_conditioning=jax.Array): + # x shape is (batch, num_nodes, feature_dim) + def attn(x): + if self._cfg.attention_type == 'triblockdiag_mha': + # We pad -> reshape -> compute attn -> reshape -> select at each block + # so as to avoid complications involved in making the norm layers and + # ffw blocks account for the padding. However, this might be decreasing + # efficiency. + + # Add padding so that number of nodes is divisible into blocks + x = jnp.pad(x, ((0, 0), (0, self.num_padding_nodes), (0, 0))) + x = x.reshape(x.shape[0], + x.shape[1]//self._cfg.mask_block_size, + self._cfg.mask_block_size, + x.shape[-1]) + x = triblockdiag_mha(x, x, mask=self.mask, cfg=self._cfg) + x = x.reshape(x.shape[0], + self.num_nodes + self.num_padding_nodes, + x.shape[-1]) + return x[:,:self.num_nodes, :] + + elif self._cfg.attention_type == 'mha': + return mha(x, x, mask=self.mask, cfg=self._cfg) + + elif self._cfg.attention_type == 'splash_mha': + # We pad -> reshape -> compute attn -> reshape -> select at each block + # so as to avoid complications involved in making the norm layers and + # ffw blocks account for the padding. However, this might be decreasing + # efficiency. + + # Add padding so that number of nodes is divisible by block sizes. + x = jnp.pad(x, ((0, 0), (0, self.num_padding_nodes[0]), (0, 0))) + x = splash_mha(x, x, mask=self.mask, cfg=self._cfg) + return x[:,:self.num_nodes, :] + + else: + raise NotImplementedError() + + def norm_conditioning_layer(x): + return mlp_builder.LinearNormConditioning( + name=self.name+'_norm_conditioning')( + x, + norm_conditioning=jnp.expand_dims(global_norm_conditioning, 1) + ) + + x = x + attn( + norm_conditioning_layer( + layernorm(x, create_scale=False, create_offset=False) + ) + ) + x = x + ffw( + norm_conditioning_layer( + layernorm(x, create_scale=False, create_offset=False) + ), + self._cfg, + ) + return x + + +class Transformer(hk.Module): + """Main transformer module that processes embeddings. + + All but the very first and very last layer of a 'classic' Transformer: + Receives already embedded inputs instead of discrete tokens. + Outputs an embedding for each 'node'/'position' rather than logits. + """ + + def __init__(self, + adj_mat: sp.sparse.csr_matrix, + attention_k_hop: int, + attention_type: Literal['splash_mha', 'triblockdiag_mha', 'mha'], + mask_type: Literal['full', 'lazy'], + num_heads=1, + name=None, + block_q: Optional[int] = None, + block_kv: Optional[int] = None, + block_kv_compute: Optional[int] = None, + block_q_dkv: Optional[int] = None, + block_kv_dkv: Optional[int] = None, + block_kv_dkv_compute: Optional[int] = None, + **kwargs): + super().__init__(name=name) + + # Construct mask and deduce block size. + mask = adj_mat ** attention_k_hop + mask_block_size = get_mask_block_size(mask) + logging.info('mask_block_size: %s.', mask_block_size) + + if attention_type == 'triblockdiag_mha': + # we will stack the nodes in blocks of 'block_size' nodes, so we need to + # pad the input such that (num_nodes + num_padding_nodes) % block_size = 0 + self.num_padding_nodes = int(np.ceil( + mask.shape[0]/mask_block_size)*mask_block_size + - mask.shape[0]) + self.mask = mask_block_diags( + mask, self.num_padding_nodes, mask_block_size) + elif attention_type == 'splash_mha': + max_q_block_size = np.maximum(block_q, block_q_dkv) + max_kv_block_size = np.maximum(block_kv, block_kv_dkv) + q_padding = int(np.ceil( + mask.shape[0]/max_q_block_size)*max_q_block_size - mask.shape[0]) + kv_padding = int(np.ceil( + mask.shape[1]/max_kv_block_size)*max_kv_block_size - mask.shape[1]) + self.num_padding_nodes = (q_padding, kv_padding) + mask = _pad_mask(mask, self.num_padding_nodes) + if mask_type == 'lazy': + splash_mask = [ + WeatherMeshMask(mask) + for _ in range(num_heads) + ] + self.mask = splash_attention.splash_attention_mask.MultiHeadMask( + splash_mask) + elif mask_type == 'full': + self.mask = mask.toarray() + elif attention_type == 'mha': + self.mask = jnp.array(mask.toarray()) + self.num_padding_nodes = 0 + else: + raise ValueError( + 'Unsupported attention type: %s' % attention_type + ) + + # Construct config for use within class. + self._cfg = _ModelConfig( + mask_block_size=mask_block_size, + attention_type=attention_type, + mask_type=mask_type, + num_heads=num_heads, + block_q=block_q, + block_kv=block_kv, + block_kv_compute=block_kv_compute, + block_q_dkv=block_q_dkv, + block_kv_dkv=block_kv_dkv, + block_kv_dkv_compute=block_kv_dkv_compute, + **kwargs) + + def __call__(self, node_features, global_norm_conditioning: jax.Array): + # node_features expected to have shape (batch, num_nodes, d) + x = node_features + for i_layer in range(self._cfg.num_layers): + x = Block(cfg=self._cfg, mask=self.mask, + num_nodes=node_features.shape[1], + num_padding_nodes=self.num_padding_nodes, + name='block_%02d' % i_layer + )(x, global_norm_conditioning=global_norm_conditioning) + + def norm_conditioning_layer(x): + return mlp_builder.LinearNormConditioning( + name=self.name+'_final_norm_conditioning')( + x, + norm_conditioning=jnp.expand_dims(global_norm_conditioning, 1) + ) + x = norm_conditioning_layer( + layernorm(x, create_scale=False, create_offset=False) + ) + + return x diff --git a/graphcast/sparse_transformer_utils.py b/graphcast/sparse_transformer_utils.py new file mode 100644 index 00000000..4af0f33f --- /dev/null +++ b/graphcast/sparse_transformer_utils.py @@ -0,0 +1,76 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for training models in low precision.""" + +import functools +from typing import Callable, Tuple, Union + +import jax +import jax.numpy as jnp + + +# Wrappers for jax.lax.reduce_precision which is non-differentiable. +@functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2)) +def reduce_precision(x, exponent_bits, mantissa_bits): + return jax.tree_util.tree_map( + lambda y: jax.lax.reduce_precision(y, exponent_bits, mantissa_bits), x) + + +def reduce_precision_fwd(x, exponent_bits, mantissa_bits): + return reduce_precision(x, exponent_bits, mantissa_bits), None + + +def reduce_precision_bwd(exponent_bits, mantissa_bits, res, dout): + del res # Unused. + return reduce_precision(dout, exponent_bits, mantissa_bits), + + +reduce_precision.defvjp(reduce_precision_fwd, reduce_precision_bwd) + + +def wrap_fn_for_upcast_downcast(inputs: Union[jnp.ndarray, + Tuple[jnp.ndarray, ...]], + fn: Callable[[Union[jnp.ndarray, + Tuple[jnp.ndarray, ...]]], + Union[jnp.ndarray, + Tuple[jnp.ndarray, ...]]], + f32_upcast: bool = True, + guard_against_excess_precision: bool = True + ) -> Union[jnp.ndarray, + Tuple[jnp.ndarray, ...]]: + """Wraps `fn` to upcast to float32 and then downcast, for use with BF16.""" + # Do not upcast if the inputs are already in float32. + # This removes a no-op `jax.lax.reduce_precision` which is unsupported + # in jax2tf at the moment. + if isinstance(inputs, Tuple): + f32_upcast = f32_upcast and inputs[0].dtype != jnp.float32 + orig_dtype = inputs[0].dtype + else: + f32_upcast = f32_upcast and inputs.dtype != jnp.float32 + orig_dtype = inputs.dtype + + if f32_upcast: + inputs = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), inputs) + + if guard_against_excess_precision: + # This is evil magic to guard against differences in precision in the QK + # calculation between the forward pass and backwards pass. This is like + # --xla_allow_excess_precision=false but scoped here. + finfo = jnp.finfo(orig_dtype) # jnp important! + inputs = reduce_precision(inputs, finfo.nexp, finfo.nmant) + + output = fn(inputs) + if f32_upcast: + output = jax.tree_util.tree_map(lambda x: x.astype(orig_dtype), output) + return output diff --git a/graphcast/transformer.py b/graphcast/transformer.py new file mode 100644 index 00000000..6788368e --- /dev/null +++ b/graphcast/transformer.py @@ -0,0 +1,124 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Transformer model for weather predictions. + +This model wraps the a transformer model and swaps the leading two axes of the +nodes in the input graph prior to evaluating the model to make it compatible +with a [nodes, batch, ...] ordering of the inputs. +""" + +from typing import Any, Mapping, Optional + +from graphcast import typed_graph +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from scipy import sparse + + +Kwargs = Mapping[str, Any] + + +def _get_adj_matrix_for_edge_set( + graph: typed_graph.TypedGraph, + edge_set_name: str, + add_self_edges: bool, +): + """Returns the adjacency matrix for the given graph and edge set.""" + # Get nodes and edges of the graph. + edge_set_key = graph.edge_key_by_name(edge_set_name) + sender_node_set, receiver_node_set = edge_set_key.node_sets + + # Compute number of sender and receiver nodes. + sender_n_node = graph.nodes[sender_node_set].n_node[0] + receiver_n_node = graph.nodes[receiver_node_set].n_node[0] + + # Build adjacency matrix. + adj_mat = sparse.csr_matrix((sender_n_node, receiver_n_node), dtype=np.bool_) + edge_set = graph.edges[edge_set_key] + s, r = edge_set.indices + adj_mat[s, r] = True + if add_self_edges: + # Should only do this if we are certain the adjacency matrix is square. + assert sender_node_set == receiver_node_set + adj_mat[np.arange(sender_n_node), np.arange(receiver_n_node)] = True + return adj_mat + + +class MeshTransformer(hk.Module): + """A Transformer for inputs with ordering [nodes, batch, ...].""" + + def __init__(self, + transformer_ctor, + transformer_kwargs: Kwargs, + name: Optional[str] = None): + """Initialises the Transformer model. + + Args: + transformer_ctor: Constructor for transformer. + transformer_kwargs: Kwargs to pass to the transformer module. + name: Optional name for haiku module. + """ + super().__init__(name=name) + # We defer the transformer initialisation to the first call to __call__, + # where we can build the mask senders and receivers of the TypedGraph + self._batch_first_transformer = None + self._transformer_ctor = transformer_ctor + self._transformer_kwargs = transformer_kwargs + + @hk.name_like('__init__') + def _maybe_init_batch_first_transformer(self, x: typed_graph.TypedGraph): + if self._batch_first_transformer is not None: + return + self._batch_first_transformer = self._transformer_ctor( + adj_mat=_get_adj_matrix_for_edge_set( + graph=x, + edge_set_name='mesh', + add_self_edges=True, + ), + **self._transformer_kwargs, + ) + + def __call__( + self, x: typed_graph.TypedGraph, + global_norm_conditioning: jax.Array + ) -> typed_graph.TypedGraph: + """Applies the model to the input graph and returns graph of same shape.""" + + if set(x.nodes.keys()) != {'mesh_nodes'}: + raise ValueError( + f'Expected x.nodes to have key `mesh_nodes`, got {x.nodes.keys()}.' + ) + features = x.nodes['mesh_nodes'].features + if features.ndim != 3: # pytype: disable=attribute-error # jax-ndarray + raise ValueError( + 'Expected `x.nodes["mesh_nodes"].features` to be 3, got' + f' {features.ndim}.' + ) # pytype: disable=attribute-error # jax-ndarray + + # Initialise transformer and mask. + self._maybe_init_batch_first_transformer(x) + + y = jnp.transpose(features, axes=[1, 0, 2]) + y = self._batch_first_transformer(y, global_norm_conditioning) + y = jnp.transpose(y, axes=[1, 0, 2]) + x = x._replace( + nodes={ + 'mesh_nodes': x.nodes['mesh_nodes']._replace( + features=y.astype(features.dtype) + ) + } + ) + return x diff --git a/graphcast_demo.ipynb b/graphcast_demo.ipynb index 65ce6a1a..ece15fec 100644 --- a/graphcast_demo.ipynb +++ b/graphcast_demo.ipynb @@ -114,7 +114,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "4wagX1TL_f15" }, "outputs": [], @@ -122,14 +121,14 @@ "# @title Authenticate with Google Cloud Storage\n", "\n", "gcs_client = storage.Client.create_anonymous_client()\n", - "gcs_bucket = gcs_client.get_bucket(\"dm_graphcast\")" + "gcs_bucket = gcs_client.get_bucket(\"dm_graphcast\")\n", + "dir_prefix = \"graphcast/\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "5JUymx84dI2m" }, "outputs": [], @@ -258,7 +257,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "KGaJ6V9MdI2n" }, "outputs": [], @@ -266,8 +264,8 @@ "# @title Choose the model\n", "\n", "params_file_options = [\n", - " name for blob in gcs_bucket.list_blobs(prefix=\"params/\")\n", - " if (name := blob.name.removeprefix(\"params/\"))] # Drop empty string.\n", + " name for blob in gcs_bucket.list_blobs(prefix=dir_prefix+\"params/\")\n", + " if (name := blob.name.removeprefix(dir_prefix+\"params/\"))] # Drop empty string.\n", "\n", "random_mesh_size = widgets.IntSlider(\n", " value=4, min=4, max=6, description=\"Mesh size:\")\n", @@ -305,7 +303,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "lYQgrPgPdI2n" }, "outputs": [], @@ -333,7 +330,7 @@ " )\n", "else:\n", " assert source == \"Checkpoint\"\n", - " with gcs_bucket.blob(f\"params/{params_file.value}\").open(\"rb\") as f:\n", + " with gcs_bucket.blob(f\"{dir_prefix}params/{params_file.value}\").open(\"rb\") as f:\n", " ckpt = checkpoint.load(f, graphcast.CheckPoint)\n", " params = ckpt.params\n", " state = {}\n", @@ -376,7 +373,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "-DJzie5me2-H" }, "outputs": [], @@ -384,8 +380,8 @@ "# @title Get and filter the list of available example datasets\n", "\n", "dataset_file_options = [\n", - " name for blob in gcs_bucket.list_blobs(prefix=\"dataset/\")\n", - " if (name := blob.name.removeprefix(\"dataset/\"))] # Drop empty string.\n", + " name for blob in gcs_bucket.list_blobs(prefix=dir_prefix+\"dataset/\")\n", + " if (name := blob.name.removeprefix(dir_prefix+\"dataset/\"))] # Drop empty string.\n", "\n", "def data_valid_for_model(\n", " file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):\n", @@ -420,7 +416,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "Yz-ekISoJxeZ" }, "outputs": [], @@ -431,7 +426,7 @@ " raise ValueError(\n", " \"Invalid dataset file, rerun the cell above and choose a valid dataset file.\")\n", "\n", - "with gcs_bucket.blob(f\"dataset/{dataset_file.value}\").open(\"rb\") as f:\n", + "with gcs_bucket.blob(f\"{dir_prefix}dataset/{dataset_file.value}\").open(\"rb\") as f:\n", " example_batch = xarray.load_dataset(f).compute()\n", "\n", "assert example_batch.dims[\"time\"] \u003e= 3 # 2 for input, \u003e=1 for targets\n", @@ -552,18 +547,17 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "Q--ZRhpTdI2o" }, "outputs": [], "source": [ "# @title Load normalization data\n", "\n", - "with gcs_bucket.blob(\"stats/diffs_stddev_by_level.nc\").open(\"rb\") as f:\n", + "with gcs_bucket.blob(dir_prefix+\"stats/diffs_stddev_by_level.nc\").open(\"rb\") as f:\n", " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", - "with gcs_bucket.blob(\"stats/mean_by_level.nc\").open(\"rb\") as f:\n", + "with gcs_bucket.blob(dir_prefix+\"stats/mean_by_level.nc\").open(\"rb\") as f:\n", " mean_by_level = xarray.load_dataset(f).compute()\n", - "with gcs_bucket.blob(\"stats/stddev_by_level.nc\").open(\"rb\") as f:\n", + "with gcs_bucket.blob(dir_prefix+\"stats/stddev_by_level.nc\").open(\"rb\") as f:\n", " stddev_by_level = xarray.load_dataset(f).compute()" ] }, diff --git a/setup.py b/setup.py index c7ce01c0..14976ce7 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ setup( name="graphcast", - version="0.1.1", + version="0.2.0.dev", description=description, long_description=description, author="DeepMind", @@ -34,6 +34,7 @@ "chex", "colabtools", "dask", + "dinosaur-dycore", "dm-haiku", "dm-tree", "jax", @@ -46,6 +47,7 @@ "trimesh", "typing_extensions", "xarray", + "xarray_tensorstore" ], classifiers=[ "Development Status :: 3 - Alpha",