Skip to content

Commit

Permalink
Adding GenCast support.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewlkd committed Dec 4, 2024
1 parent 97d1ad5 commit bf9785f
Show file tree
Hide file tree
Showing 31 changed files with 5,110 additions and 111 deletions.
220 changes: 158 additions & 62 deletions README.md

Large diffs are not rendered by default.

690 changes: 690 additions & 0 deletions gencast_demo_cloud_vm.ipynb

Large diffs are not rendered by default.

875 changes: 875 additions & 0 deletions gencast_mini_demo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion graphcast/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def add_noise(x):
return x + self._noise_level * jax.random.normal(
hk.next_rng_key(), shape=x.shape)
# Add noise to time-dependent variables of the inputs.
inputs = jax.tree_map(add_noise, inputs)
inputs = jax.tree.map(add_noise, inputs)

# The per-timestep targets passed by scan to one_step_loss below will have
# no leading time axis. We need a treedef without the time axis to use
Expand Down
4 changes: 2 additions & 2 deletions graphcast/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _all_inputs_to_bfloat16(
xarray.Dataset,
xarray.Dataset]:
return (inputs.astype(jnp.bfloat16),
jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
jax.tree.map(lambda x: x.astype(jnp.bfloat16), targets),
forcings.astype(jnp.bfloat16))


Expand All @@ -149,7 +149,7 @@ def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype,
def cast_fn(x):
if x.dtype == input_dtype:
return x.astype(output_dtype)
return jax.tree_map(cast_fn, inputs)
return jax.tree.map(cast_fn, inputs)


@contextlib.contextmanager
Expand Down
122 changes: 96 additions & 26 deletions graphcast/deep_typed_graph_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
}
"""

from typing import Mapping, Optional
import functools
from typing import Callable, List, Mapping, Optional, Tuple

import chex
from graphcast import mlp as mlp_builder
from graphcast import typed_graph
from graphcast import typed_graph_net
import haiku as hk
Expand All @@ -44,6 +47,9 @@
import jraph


GraphToGraphNetwork = Callable[[typed_graph.TypedGraph], typed_graph.TypedGraph]


class DeepTypedGraphNet(hk.Module):
"""Deep Graph Neural Network.
Expand Down Expand Up @@ -87,6 +93,7 @@ def __init__(self,
edge_output_size: Optional[Mapping[str, int]] = None,
include_sent_messages_in_node_update: bool = False,
use_layer_norm: bool = True,
use_norm_conditioning: bool = False,
activation: str = "relu",
f32_aggregation: bool = False,
aggregate_edges_for_nodes_fn: str = "segment_sum",
Expand Down Expand Up @@ -114,6 +121,18 @@ def __init__(self,
include_sent_messages_in_node_update: Whether to include pooled sent
messages from each node in the node update.
use_layer_norm: Whether it uses layer norm or not.
use_norm_conditioning: If True, the latent feaures outputted by the
activation normalization that follows the MLPs (e.g. LayerNorm), rather
than being scaled/offset by learned parameters of the normalization
module, will be scaled/offset by offsets/biases produced by a linear
layer (with different weights for each MLP), which takes an extra
argument "global_norm_conditioning". This argument is used to condition
the normalization of all nodes and all edges (hence global), and would
usually only have a batch and feature axis. This is typically used to
condition diffusion models on the "diffusion time". Will raise an error
if this is set to True but the "global_norm_conditioning" is not passed
to the __call__ method, as well as if this is set to False, but
"global_norm_conditioning" is passed to the call method.
activation: name of activation function.
f32_aggregation: Use float32 in the edge aggregation.
aggregate_edges_for_nodes_fn: function used to aggregate messages to each
Expand Down Expand Up @@ -141,9 +160,14 @@ def __init__(self,
self._edge_output_size = edge_output_size
self._include_sent_messages_in_node_update = (
include_sent_messages_in_node_update)
if use_norm_conditioning and not use_layer_norm:
raise ValueError(
"`norm_conditioning` can only be used when "
"`use_layer_norm` is true."
)
self._use_layer_norm = use_layer_norm
self._use_norm_conditioning = use_norm_conditioning
self._activation = _get_activation_fn(activation)
self._initialized = False
self._f32_aggregation = f32_aggregation
self._aggregate_edges_for_nodes_fn = _get_aggregate_edges_for_nodes_fn(
aggregate_edges_for_nodes_fn)
Expand All @@ -154,24 +178,31 @@ def __init__(self,
assert aggregate_edges_for_nodes_fn == "segment_sum"

def __call__(self,
input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
input_graph: typed_graph.TypedGraph,
global_norm_conditioning: Optional[chex.Array] = None
) -> typed_graph.TypedGraph:
"""Forward pass of the learnable dynamics model."""
self._networks_builder(input_graph)
embedder_network, processor_networks, decoder_network = (
self._networks_builder(input_graph, global_norm_conditioning)
)

# Embed input features (if applicable).
latent_graph_0 = self._embed(input_graph)
latent_graph_0 = self._embed(input_graph, embedder_network)

# Do `m` message passing steps in the latent graphs.
latent_graph_m = self._process(latent_graph_0)
latent_graph_m = self._process(latent_graph_0, processor_networks)

# Compute outputs from the last latent graph (if applicable).
return self._output(latent_graph_m)

def _networks_builder(self, graph_template):
if self._initialized:
return
self._initialized = True

return self._output(latent_graph_m, decoder_network)

def _networks_builder(
self,
graph_template: typed_graph.TypedGraph,
global_norm_conditioning: Optional[chex.Array] = None,
) -> Tuple[
GraphToGraphNetwork, List[GraphToGraphNetwork], GraphToGraphNetwork
]:
# TODO(aelkadi): move to mlp_builder.
def build_mlp(name, output_size):
mlp = hk.nets.MLP(
output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [
Expand All @@ -180,11 +211,40 @@ def build_mlp(name, output_size):

def build_mlp_with_maybe_layer_norm(name, output_size):
network = build_mlp(name, output_size)
stages = [network]
if self._use_norm_conditioning:
if global_norm_conditioning is None:
raise ValueError(
"When using norm conditioning, `global_norm_conditioning` must"
"be passed to the call method.")
# If using norm conditioning, it is no longer the responsibility of the
# LayerNorm module itself to learn its scale and offset. These will be
# learned for the module by the norm conditioning layer instead.
create_scale = create_offset = False
else:
if global_norm_conditioning is not None:
raise ValueError(
"`globa_norm_conditioning` was passed, but `norm_conditioning`"
" is not enabled.")
create_scale = create_offset = True

if self._use_layer_norm:
layer_norm = hk.LayerNorm(
axis=-1, create_scale=True, create_offset=True,
axis=-1, create_scale=create_scale, create_offset=create_offset,
name=name + "_layer_norm")
network = hk.Sequential([network, layer_norm])
stages.append(layer_norm)

if self._use_norm_conditioning:
norm_conditioning_layer = mlp_builder.LinearNormConditioning(
name=name + "_norm_conditioning")
norm_conditioning_layer = functools.partial(
norm_conditioning_layer,
# Broadcast to the node/edge axis.
norm_conditioning=global_norm_conditioning[None],
)
stages.append(norm_conditioning_layer)

network = hk.Sequential(stages)
return jraph.concatenated_args(network)

# The embedder graph network independently embeds edge and node features.
Expand All @@ -208,7 +268,7 @@ def build_mlp_with_maybe_layer_norm(name, output_size):
embed_edge_fn=embed_edge_fn,
embed_node_fn=embed_node_fn,
)
self._embedder_network = typed_graph_net.GraphMapFeatures(
embedder_network = typed_graph_net.GraphMapFeatures(
**embedder_kwargs)

if self._f32_aggregation:
Expand All @@ -232,9 +292,9 @@ def aggregate_fn(data, *args, **kwargs):
# that update the node and edge latent features.
# Note that we can use `modules.InteractionNetwork` because
# it also outputs the messages as updated edge latent features.
self._processor_networks = []
processor_networks = []
for step_i in range(self._num_message_passing_steps):
self._processor_networks.append(
processor_networks.append(
typed_graph_net.InteractionNetwork(
update_edge_fn=_build_update_fns_for_edge_types(
build_mlp_with_maybe_layer_norm,
Expand All @@ -259,11 +319,15 @@ def aggregate_fn(data, *args, **kwargs):
embed_node_fn=_build_update_fns_for_node_types(
build_mlp, graph_template, "decoder_nodes_", self._node_output_size)
if self._node_output_size else None,)
self._output_network = typed_graph_net.GraphMapFeatures(
output_network = typed_graph_net.GraphMapFeatures(
**output_kwargs)
return embedder_network, processor_networks, output_network

def _embed(
self, input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
self,
input_graph: typed_graph.TypedGraph,
embedder_network: GraphToGraphNetwork,
) -> typed_graph.TypedGraph:
"""Embeds the input graph features into a latent graph."""

# Copy the context to all of the node types, if applicable.
Expand All @@ -286,19 +350,22 @@ def _embed(
context=input_graph.context._replace(features=()))

# Embeds the node and edge features.
latent_graph_0 = self._embedder_network(input_graph)
latent_graph_0 = embedder_network(input_graph)
return latent_graph_0

def _process(
self, latent_graph_0: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
self,
latent_graph_0: typed_graph.TypedGraph,
processor_networks: List[GraphToGraphNetwork],
) -> typed_graph.TypedGraph:
"""Processes the latent graph with several steps of message passing."""

# Do `num_message_passing_steps` with each of the `self._processor_networks`
# with unshared weights, and repeat that `self._num_processor_repetitions`
# times.
latent_graph = latent_graph_0
for unused_repetition_i in range(self._num_processor_repetitions):
for processor_network in self._processor_networks:
for processor_network in processor_networks:
latent_graph = self._process_step(processor_network, latent_graph)

return latent_graph
Expand Down Expand Up @@ -326,10 +393,13 @@ def _process_step(
nodes=nodes_with_residuals, edges=edges_with_residuals)
return latent_graph_k

def _output(self,
latent_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
def _output(
self,
latent_graph: typed_graph.TypedGraph,
output_network: GraphToGraphNetwork,
) -> typed_graph.TypedGraph:
"""Produces the output from the latent graph."""
return self._output_network(latent_graph)
return output_network(latent_graph)


def _build_update_fns_for_node_types(
Expand Down
Loading

0 comments on commit bf9785f

Please sign in to comment.