This repository contains code and resources for training and inferring gravity wave flux using machine learning models. The project is structured for easy setup and execution, supporting both single-node and multi-node GPU training.
-
Clone the repository with submodules:
git clone --recurse-submodules [email protected]:NASA-IMPACT/gravity-wave-finetuning.git gravity_wave_finetuning cd gravity_wave_finetuning
-
Create and activate a Conda environment for the project:
conda env create -f environment.yml conda activate pt24
The dataset and model checkpoints are hosted on Hugging Face. Cloning these files requires Git LFS. If not already installed please install it via Conda or directly from the git-lfs page.
-
Install Git Large File Storage (Git LFS):
conda install anaconda::git-lfs git lfs install
-
Clone the Hugging Face repository to get the model checkpoints:
git clone --no-checkout [email protected]:Prithvi-WxC/Gravity_wave_Parameterization checkpoint cd checkpoint git lfs pull
-
Clone the Hugging Face repository to get the dataset and extract it:
git clone --no-checkout [email protected]:datasets/Prithvi-WxC/Gravity_wave_Parameterization dataset cd dataset git lfs pull
For detailed information about the dataset, refer to the Hugging Face dataset page: Gravity Wave Parameterization.
To configure the training process, update the paths for the dataset and checkpoints in the config.py
file.
To run the training on a single node and a single GPU, execute the following command:
torchrun \
--nproc_per_node=1 \
--nnodes=1 \
--rdzv_backend=c10d \
finetune_gravity_wave.py
--split uvtp122
For multi-node training, refer to the scripts/train.pbs
script, which is provided for running on a PBS-managed cluster. Customize this script according to your system’s configuration.
After training, you can run inferences using the following command. Make sure to specify the correct paths for the checkpoint, data, and where the results should be saved:
torchrun \
--standalone \
--nnodes=1 \
--nproc_per_node=1 \
--rdzv_backend=c10d \
inference.py \
--split=uvtp122 \
--ckpt_path=/path/to/checkpoint \
--data_path=/path/to/data \
--results_dir=/path/to/results_dir
If you use this work, consider citing our paper
@misc{schmude2024prithviwxcfoundationmodel,
title={Prithvi WxC: Foundation Model for Weather and Climate},
author={Johannes Schmude and Sujit Roy and Will Trojak and Johannes Jakubik and Daniel Salles Civitarese and Shraddha Singh and Julian Kuehnert and Kumar Ankur and Aman Gupta and Christopher E Phillips and Romeo Kienzler and Daniela Szwarcman and Vishal Gaur and Rajat Shinde and Rohit Lal and Arlindo Da Silva and Jorge Luis Guevara Diaz and Anne Jones and Simon Pfreundschuh and Amy Lin and Aditi Sheshadri and Udaysankar Nair and Valentine Anantharaj and Hendrik Hamann and Campbell Watson and Manil Maskey and Tsengdar J Lee and Juan Bernabe Moreno and Rahul Ramachandran},
year={2024},
eprint={2409.13598},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2409.13598},
}
@article{gupta2024machine,
title={Machine learning global simulation of nonlocal gravity wave propagation},
author={Gupta, Aman and Sheshadri, Aditi and Roy, Sujit and Gaur, Vishal and Maskey, Manil and Ramachandran, Rahul},
journal={arXiv preprint arXiv:2406.14775},
year={2024}
}