Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First training attempt #119

Closed
dkokron opened this issue Dec 30, 2024 · 9 comments
Closed

First training attempt #119

dkokron opened this issue Dec 30, 2024 · 9 comments

Comments

@dkokron
Copy link

dkokron commented Dec 30, 2024

I am working with the gencast_mini_demo.ipynb demo.
Source data is source-era5_date-2019-03-29_res-1.0_levels-13_steps-01.nc.

I added the optimizer steps and the loop. Inferencing does not work on CPUs (#113) so I'll ask here if those Loss and Mean values look reasonable?

params = ckpt.params
loss, diagnostics, next_state, grads = grads_fn_jitted(
    params=params,
    state=state,
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

optimizer = optax.adamw(0.0004)
opt_state = optimizer.init(params)
for i in range(5):
   updates, opt_state = optimizer.update(grads, opt_state, params)
   params = optax.apply_updates(params, updates)
   loss, diagnostics, next_state, grads = grads_fn_jitted(
       params=params,
       state=state,
       inputs=train_inputs,
       targets=train_targets,
       forcings=train_forcings)
   mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
   print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

Loss: 4.0923, Mean |grad|: 0.006474
Loss: 83.6939, Mean |grad|: 0.272431
Loss: 20.1864, Mean |grad|: 0.044424
Loss: 15.5214, Mean |grad|: 0.022913
Loss: 10.1925, Mean |grad|: 0.008500
Loss: 8.7714, Mean |grad|: 0.005301
@dkokron dkokron changed the title First training attmept First training attempt Dec 30, 2024
@dkokron
Copy link
Author

dkokron commented Dec 30, 2024

This version follows the paper (https://arxiv.org/pdf/2312.15796) more closely.

page 27
Table D1 | Diffusion model training hyperparameters.

params = ckpt.params
loss, diagnostics, next_state, grads = grads_fn_jitted(
    params=params,
    state=state,
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

lr_schedule = optax.schedules.warmup_cosine_decay_schedule(
    init_value = .00001,
    peak_value = .003,
    warmup_steps = 1000,
    decay_steps = 1000000,
    end_value = 0.0,
    exponent = 0.1,
)
optimizer = optax.adamw(lr_schedule)
opt_state = optimizer.init(params)
for i in range(5):
   updates, opt_state = optimizer.update(grads, opt_state, params)
   params = optax.apply_updates(params, updates)
   loss, diagnostics, next_state, grads = grads_fn_jitted(
       params=params,
       state=state,
       inputs=train_inputs,
       targets=train_targets,
       forcings=train_forcings)
   mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
   print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

Loss: 4.0923, Mean |grad|: 0.006474
Loss: 5.3288, Mean |grad|: 0.073276
Loss: 4.4514, Mean |grad|: 0.045505
Loss: 4.5843, Mean |grad|: 0.025861
Loss: 4.5447, Mean |grad|: 0.023869
Loss: 4.3157, Mean |grad|: 0.014323

@dkokron
Copy link
Author

dkokron commented Jan 1, 2025

Setting the lr_schedule init_value to zero and training on source-era5_date-2019-03-29_res-1.0_levels-13_steps-01.nc, I get zero change in the params. That seems odd to me. There ought to be some numerical differences from adding new data. Can someone tell me what I'm doing wrong?

params = ckpt.params
loss, diagnostics, next_state, grads = grads_fn_jitted(
    params=params,
    state=state,
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.6f}, Mean |grad|: {mean_grad:.6f}")

lr_schedule = optax.schedules.warmup_cosine_decay_schedule(
    init_value = 0.0,
    peak_value = .003,
    warmup_steps = 1000,
    decay_steps = 1000000,
    end_value = 0.0,
    exponent = 0.1,
)
optimizer = optax.adamw(lr_schedule)
opt_state = optimizer.init(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

print(np.array_equal(ckpt.params["fourier_features_mlp/~/mlp/~/linear_0"]["b"], params["fourier_features_mlp/~/mlp/~/linear_0"]["b"]))
print(np.array_equal(ckpt.params["fourier_features_mlp/~/mlp/~/linear_0"]["w"], params["fourier_features_mlp/~/mlp/~/linear_0"]["w"]))
print(np.array_equal(ckpt.params["fourier_features_mlp/~/mlp/~/linear_1"]["b"], params["fourier_features_mlp/~/mlp/~/linear_1"]["b"]))
print(np.array_equal(ckpt.params["fourier_features_mlp/~/mlp/~/linear_1"]["w"], params["fourier_features_mlp/~/mlp/~/linear_1"]["w"]))

Loss: 4.092268, Mean |grad|: 0.006474
True
True
True
True

@andrewlkd
Copy link
Collaborator

An initial learning rate of 0 means that the first step of optimisation will multiply all gradients by 0. This will effectively result in the identity being applied to the parameters in the update.

I suspect running a second iteration or using a non-zero initial rate will show a different outcome.

@dkokron
Copy link
Author

dkokron commented Jan 6, 2025

Both of your suspicions are true. I was following the training schedule from figure 7 of https://arxiv.org/pdf/2212.12794 where it appears they used an initial learning rate of 0.0. Table D1 from https://arxiv.org/pdf/2312.15796 doesn't mention the value used in training Gencast. Can you share the value that was use to train Gencast?

@andrewlkd
Copy link
Collaborator

GenCast also uses an initial learning rate of 0. The other hyperparameters are as listed in Table D1.

@dkokron
Copy link
Author

dkokron commented Jan 6, 2025

If your initial learning rate was zero, how many grad-optimize-apply_updates iterations were used to train Gencast?

@dkokron
Copy link
Author

dkokron commented Jan 6, 2025

What values did you use for the learning rate schedule when fine-tuning a model (from a checkpoint) with newer data (HRES)?

@dkokron
Copy link
Author

dkokron commented Jan 13, 2025

Is a response to this latest question forthcoming or should I just close this issue?

@andrewlkd
Copy link
Collaborator

As the GenCast paper states in Table D1, the model is trained in two stages. In the first, 2 million updates are applied, in the second 64 thousand are applied.

Finetuning with newer HRES data is done by repeating Stage 2. I.e. Operational GenCast has undergone Stage 1, Stage 2 and Stage 2 again but with 0.25deg HRES data.

@dkokron dkokron closed this as completed Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants