Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The GraphCast generates different predictions with the same inputs. How to fix it? #123

Open
LQscience opened this issue Jan 12, 2025 · 2 comments

Comments

@LQscience
Copy link

No description provided.

@LQscience
Copy link
Author

LQscience commented Jan 12, 2025

Thank you for your reply. I'm sorry I didn't find time to organize the relevant materials a few days ago.

I've attached screenshots of the main code and related output below.
Furthermore, I've shared the complete data, Jupyter Notebook, and the YAML file for the conda environment on Google Drive for your convenience in reproducing the results.

Here are the comparison between two predictions with the same inputs, forcings and output_template:
截屏2025-01-15 16 25 13

Below are the main codes:

import time
import random
import pathlib
import datetime
import pandas as pd

import jax
import haiku
import xarray
import functools
import numpy as np

from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout

class GraphCast:
    """
    Just wrap the graphcast model code from Google with Python Class Object
    """
    def __init__(self, model_path, dir_path_stats):
        """
        Initialize the GraphCast
        
        params:
            model_path: file path for the model parameters
            dir_path_stats: file dir for the inputs statistics
        """
        with open(model_path, 'rb') as model:
            ckpt = checkpoint.load(model, graphcast.CheckPoint)
            self.params = ckpt.params
            self.state = {}
            self.model_config = ckpt.model_config
            self.task_config = ckpt.task_config

        with open(f"{dir_path_stats}/diffs_stddev_by_level.nc", "rb") as f:
            self.diffs_stddev_by_level = xarray.load_dataset(f).compute()
        with open(f"{dir_path_stats}/mean_by_level.nc", "rb") as f:
            self.mean_by_level = xarray.load_dataset(f).compute()
        with open(f"{dir_path_stats}/stddev_by_level.nc", "rb") as f:
            self.stddev_by_level = xarray.load_dataset(f).compute()

        self.run_forward_jitted = self.drop_state(
            self.with_params(
                jax.jit(
                    self.with_configs(
                        self.run_forward.apply))))

    def predict(self, 
                inputs:xarray.Dataset, 
                targets:xarray.Dataset, 
                forcings:xarray.Dataset) -> xarray.Dataset:
        """
        Make weather predictions
        params: xarray.Dataset
            inputs: input data
            targets: predictions template
            forcings: forcing data
        return: xarray.Dataset
            the weather predistions
        
        """
        assert self.model_config.resolution in (0, 360. / inputs.sizes["lon"]), (
            "Model resolution doesn't match the data resolution. You likely want to "
            "re-filter the dataset list, and download the correct data.")
        predictions = rollout.chunked_prediction(
            self.run_forward_jitted, 
            rng=jax.random.PRNGKey(0), 
            inputs=inputs, 
            targets_template=targets * np.nan, 
            forcings=forcings)

        return predictions

    @staticmethod
    def construct_wrapped_graphcast(
        model_config:graphcast.ModelConfig, 
        task_config:graphcast.TaskConfig, 
        diffs_stddev_by_level:xarray.Dataset,
        mean_by_level:xarray.Dataset,
        stddev_by_level:xarray.Dataset):
        """Constructs and wraps the GraphCast Predictor."""
        # Deeper one-step predictor.   
        predictor = graphcast.GraphCast(model_config, task_config)
        # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
        # from/to float32 to/from BFloat16.
        predictor = casting.Bfloat16Cast(predictor)
        # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
        # BFloat16 happens after applying normalization to the inputs/targets.
        predictor = normalization.InputsAndResiduals(
            predictor, 
            diffs_stddev_by_level=diffs_stddev_by_level, 
            mean_by_level=mean_by_level, 
            stddev_by_level=stddev_by_level)
        # Wraps everything so the one-step model can produce trajectories.
        predictor = autoregressive.Predictor(predictor, gradient_checkpointing = True)
        
        return predictor

    # 前向运算
    @haiku.transform_with_state
    def run_forward(model_config, task_config, 
                    diffs_stddev_by_level, mean_by_level, stddev_by_level, 
                    inputs, targets_template, forcings):
        predictor = GraphCast.construct_wrapped_graphcast(model_config, task_config,
                                                         diffs_stddev_by_level,
                                                         mean_by_level, stddev_by_level
                                                         )
        return predictor(inputs, targets_template=targets_template, forcings=forcings)


    # Jax doesn't seem to like passing configs as args through the jit. Passing it
    # in via partial (instead of capture by closure) forces jax to invalidate the
    # jit cache if you change configs.
    def with_configs(self, fn):
        return functools.partial(fn, model_config=self.model_config, task_config=self.task_config,
                                 diffs_stddev_by_level=self.diffs_stddev_by_level,
                                 mean_by_level=self.mean_by_level, stddev_by_level=self.stddev_by_level)

    # Always pass params and state, so the usage below are simpler
    def with_params(self, fn):
        return functools.partial(fn, params=self.params, state=self.state)

    # Our models aren't stateful, so the state is always empty, so just return the
    # predictions. This is requiredy by our rollout code, and generally simpler.
    def drop_state(self, fn):
        return lambda **kw: fn(**kw)[0]

############ The running settings ############
initial_time = datetime.datetime(2023, 7, 29, 0)
predictions_steps = 40
gap = 6

lookback_datetimes = pd.to_datetime([initial_time-datetime.timedelta(hours=6), initial_time])
forecast_datetimes = pd.to_datetime(
    [initial_time+datetime.timedelta(hours=step*gap) for step in range(1,predictions_steps+1)])

############## The first running ##############
seed = 42
random.seed(seed)
np.random.seed(seed)
jax.random.PRNGKey(seed)

stats_dir = "model/stats"
params_dir = "model/params"

model_path = params_dir+'/GraphCast_operational.npz'

GraphCastModel = GraphCast(model_path, stats_dir)

inputs = xarray.open_dataset("test_inputs.nc")
forcings = xarray.open_dataset("test_forcings.nc")

output_template = inputs.assign(total_precipitation_6hr=inputs["2m_temperature"])[
    list(GraphCastModel.task_config.target_variables)].reindex(time=forecast_datetimes.values)
output_template["total_precipitation_6hr"] = output_template["total_precipitation_6hr"].assign_attrs(
    units="m", long_name="Total precipitation")

prediction_2 = GraphCastModelq.predict(inputs, output_template, forcings)

############## The second running ##############
seed = 42
random.seed(seed)
np.random.seed(seed)
jax.random.PRNGKey(seed)

stats_dir = "model/stats"
params_dir = "model/params"

model_path = params_dir+'/GraphCast_operational.npz'

GraphCastModel = GraphCast(model_path, stats_dir)

inputs = xarray.open_dataset("test_inputs.nc")
forcings = xarray.open_dataset("test_forcings.nc")

output_template = inputs.assign(total_precipitation_6hr=inputs["2m_temperature"])[
    list(GraphCastModel.task_config.target_variables)].reindex(time=forecast_datetimes.values)
output_template["total_precipitation_6hr"] = output_template["total_precipitation_6hr"].assign_attrs(
    units="m", long_name="Total precipitation")

prediction_2 = GraphCastModelq.predict(inputs, output_template, forcings)

############## Compare predictions of these two runnings ##############

for var in prediction_1.data_vars:
    print(var, (prediction_1[var]!=prediction_2[var]).sum().data)

for var in prediction_1.data_vars:
    print(var, np.abs((prediction_1[var]-prediction_2[var]).data).sum())

@LQscience LQscience changed the title The GraphCast generates different predictions with the same inputs The GraphCast generates different predictions with the same inputs. How to fix it? Jan 12, 2025
@andrewlkd
Copy link
Collaborator

Hey, can you elaborate on what you've found/are trying? Some reproducer code would be useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants