This is the official repository for our paper "Audio Mamba: Selective State Spaces for Self-Supervised Audio Representations", set to appear in Proc. INTERSPEECH 2024.
- Pre-trained weights for the default Self-Supervised Audio Mamba (SSAM) configurations
- Our local copy of hear-eval-kit for easy downstream reproducibility. Original can be found here
- Feature extraction API compatible with the hear-eval-kit format for extracting features.
- Code used to train the SSAM models.
- Helper code to extract features and run downstream experiments on provided pre-trained models
- Required:
cuda 11.x
or newer,cudnn 8.2
or newer. - Create a new conda environment with
python 3.10
or later. - Requires
torch 2.1.2
or newer.
Follow these steps
conda create -n mamba-env python=3.10 -y
conda activate mamba-env
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
# install hear-eval-kit specific requirements
pip install -r external_sources/hear-eval-kit/requirements.txt
# install hear-eval-kit, WITHOUT AUTO DEPS
cd external_sources/hear-eval-kit && pip install --no-deps . && cd -
# install causal-conv1d
pip install git+https://github.com/Dao-AILab/[email protected]
# install mamba-ssm
pip install git+https://github.com/state-spaces/[email protected]
- Follow https://hearbenchmark.com/hear-tasks.html to get data. By default, data on HEAR's zenodo page is 48000 Hz.
- We recommend downloading data directly from HEAR's GCS bucket, where you can find preprocessed 16000 Hz data.
- Extract all the files to a folder
$TASKS_DIR
- Pre-trained weights can be downloaded from Google Drive
- Download the entire folder and export that folder as
$PT_MAMBA_MODEL_DIR
export PT_MAMBA_MODEL_DIR=/path/to/pretrained_weights
./extract_features.sh $TASKS_DIR $OUTPUT_DIR
where TASKS_DIR is the directory where you extracted tasks from HEAR-2021 to, and OUTPUT_DIR is the base directory where output features will be stored. The given script will extract features from SSAST and SSAM Tiny configurations, you can change it as you need.
This also prepares a todo_audioset
directory in OUTPUT_DIR, which is setting up for downstream classification on 10 seeds.
After extracting features, to run downstream experiment on a specific config, use the following command:
./downstream_experiments.sh ssam_tiny_200_16x4 $OUTPUT_DIR/todo_audioset
This will run downstream experiments on all the extracted features for the tiny SSAM configuration on 10 random seeds.
Finally, you can run the following script to get results of downstream experiments of the two models
python stats_aggregation_v2.py --base_dir ${OUTPUT_DIR}/todo_audioset --output_dir ${OUTPUT_DIR}/parsed_results
The hear_api can be used to extract features from your own audio files.
import torchaudio
from hear_api import RuntimeSSAST
from importlib import import_module
config = import_module("configs.ssam_tiny_200_16x4").get_config()
ssam = RuntimeSSAST(config, "path/to/pretrained_dir").cuda()
# alternatively just use the following if you have the paths setup right
# ssam = import_module("hear_configs.ssam_tiny_200_16x4").load_model().cuda()
x, sr = torchaudio.load("path/to/audio.wav")
x = x.cuda()
o = ssam.get_scene_embeddings(x)
Pretraining code is included in the release. Any model configuration (for instance, ssam_tiny_200_16x4
) was trained with the following command:
torchrun --standalone --nnodes=1 --nproc-per-node=4 train.py --config configs.ssam_tiny_200_16x4 --workdir $EXP_DIR/ssam_tiny_200_16x4_4x256_fp16_r1 --precision float16 --print_freq 50 --num_workers 16 --no_wandb
We use a torchdata
based datapipe for data loading, operating on precomputed log melspectrogram features stored in webdataset archive(s). You can adapt the data loading for your own use case.