-
Notifications
You must be signed in to change notification settings - Fork 686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The GraphCast generates different predictions with the same inputs. How to fix it? #123
Comments
Thank you for your reply. I'm sorry I didn't find time to organize the relevant materials a few days ago. I've attached screenshots of the main code and related output below. Here are the comparison between two predictions with the same inputs, forcings and output_template: Below are the main codes: import time
import random
import pathlib
import datetime
import pandas as pd
import jax
import haiku
import xarray
import functools
import numpy as np
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
class GraphCast:
"""
Just wrap the graphcast model code from Google with Python Class Object
"""
def __init__(self, model_path, dir_path_stats):
"""
Initialize the GraphCast
params:
model_path: file path for the model parameters
dir_path_stats: file dir for the inputs statistics
"""
with open(model_path, 'rb') as model:
ckpt = checkpoint.load(model, graphcast.CheckPoint)
self.params = ckpt.params
self.state = {}
self.model_config = ckpt.model_config
self.task_config = ckpt.task_config
with open(f"{dir_path_stats}/diffs_stddev_by_level.nc", "rb") as f:
self.diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(f"{dir_path_stats}/mean_by_level.nc", "rb") as f:
self.mean_by_level = xarray.load_dataset(f).compute()
with open(f"{dir_path_stats}/stddev_by_level.nc", "rb") as f:
self.stddev_by_level = xarray.load_dataset(f).compute()
self.run_forward_jitted = self.drop_state(
self.with_params(
jax.jit(
self.with_configs(
self.run_forward.apply))))
def predict(self,
inputs:xarray.Dataset,
targets:xarray.Dataset,
forcings:xarray.Dataset) -> xarray.Dataset:
"""
Make weather predictions
params: xarray.Dataset
inputs: input data
targets: predictions template
forcings: forcing data
return: xarray.Dataset
the weather predistions
"""
assert self.model_config.resolution in (0, 360. / inputs.sizes["lon"]), (
"Model resolution doesn't match the data resolution. You likely want to "
"re-filter the dataset list, and download the correct data.")
predictions = rollout.chunked_prediction(
self.run_forward_jitted,
rng=jax.random.PRNGKey(0),
inputs=inputs,
targets_template=targets * np.nan,
forcings=forcings)
return predictions
@staticmethod
def construct_wrapped_graphcast(
model_config:graphcast.ModelConfig,
task_config:graphcast.TaskConfig,
diffs_stddev_by_level:xarray.Dataset,
mean_by_level:xarray.Dataset,
stddev_by_level:xarray.Dataset):
"""Constructs and wraps the GraphCast Predictor."""
# Deeper one-step predictor.
predictor = graphcast.GraphCast(model_config, task_config)
# Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
# from/to float32 to/from BFloat16.
predictor = casting.Bfloat16Cast(predictor)
# Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
# BFloat16 happens after applying normalization to the inputs/targets.
predictor = normalization.InputsAndResiduals(
predictor,
diffs_stddev_by_level=diffs_stddev_by_level,
mean_by_level=mean_by_level,
stddev_by_level=stddev_by_level)
# Wraps everything so the one-step model can produce trajectories.
predictor = autoregressive.Predictor(predictor, gradient_checkpointing = True)
return predictor
# 前向运算
@haiku.transform_with_state
def run_forward(model_config, task_config,
diffs_stddev_by_level, mean_by_level, stddev_by_level,
inputs, targets_template, forcings):
predictor = GraphCast.construct_wrapped_graphcast(model_config, task_config,
diffs_stddev_by_level,
mean_by_level, stddev_by_level
)
return predictor(inputs, targets_template=targets_template, forcings=forcings)
# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(self, fn):
return functools.partial(fn, model_config=self.model_config, task_config=self.task_config,
diffs_stddev_by_level=self.diffs_stddev_by_level,
mean_by_level=self.mean_by_level, stddev_by_level=self.stddev_by_level)
# Always pass params and state, so the usage below are simpler
def with_params(self, fn):
return functools.partial(fn, params=self.params, state=self.state)
# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(self, fn):
return lambda **kw: fn(**kw)[0]
############ The running settings ############
initial_time = datetime.datetime(2023, 7, 29, 0)
predictions_steps = 40
gap = 6
lookback_datetimes = pd.to_datetime([initial_time-datetime.timedelta(hours=6), initial_time])
forecast_datetimes = pd.to_datetime(
[initial_time+datetime.timedelta(hours=step*gap) for step in range(1,predictions_steps+1)])
############## The first running ##############
seed = 42
random.seed(seed)
np.random.seed(seed)
jax.random.PRNGKey(seed)
stats_dir = "model/stats"
params_dir = "model/params"
model_path = params_dir+'/GraphCast_operational.npz'
GraphCastModel = GraphCast(model_path, stats_dir)
inputs = xarray.open_dataset("test_inputs.nc")
forcings = xarray.open_dataset("test_forcings.nc")
output_template = inputs.assign(total_precipitation_6hr=inputs["2m_temperature"])[
list(GraphCastModel.task_config.target_variables)].reindex(time=forecast_datetimes.values)
output_template["total_precipitation_6hr"] = output_template["total_precipitation_6hr"].assign_attrs(
units="m", long_name="Total precipitation")
prediction_2 = GraphCastModelq.predict(inputs, output_template, forcings)
############## The second running ##############
seed = 42
random.seed(seed)
np.random.seed(seed)
jax.random.PRNGKey(seed)
stats_dir = "model/stats"
params_dir = "model/params"
model_path = params_dir+'/GraphCast_operational.npz'
GraphCastModel = GraphCast(model_path, stats_dir)
inputs = xarray.open_dataset("test_inputs.nc")
forcings = xarray.open_dataset("test_forcings.nc")
output_template = inputs.assign(total_precipitation_6hr=inputs["2m_temperature"])[
list(GraphCastModel.task_config.target_variables)].reindex(time=forecast_datetimes.values)
output_template["total_precipitation_6hr"] = output_template["total_precipitation_6hr"].assign_attrs(
units="m", long_name="Total precipitation")
prediction_2 = GraphCastModelq.predict(inputs, output_template, forcings)
############## Compare predictions of these two runnings ##############
for var in prediction_1.data_vars:
print(var, (prediction_1[var]!=prediction_2[var]).sum().data)
for var in prediction_1.data_vars:
print(var, np.abs((prediction_1[var]-prediction_2[var]).data).sum()) |
Hey, can you elaborate on what you've found/are trying? Some reproducer code would be useful. |
No description provided.
The text was updated successfully, but these errors were encountered: