Skip to content

Commit

Permalink
Add universal masking and padding functions for bert-style models
Browse files Browse the repository at this point in the history
Also corrects a number of geneformer scripts that relied on a bug in the
previous masking function that assigned the wrong number of random
tokens. With the new function that assigns the correct number of random
tokens, we set the random token mask percentage lower to match previous
results.

Signed-off-by: Peter St. John <[email protected]>
  • Loading branch information
pstjohn committed Aug 9, 2024
1 parent f47e508 commit e0af43c
Show file tree
Hide file tree
Showing 13 changed files with 842 additions and 111 deletions.
12 changes: 9 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
{
"cSpell.words": [
// List of words to be added to the spell-checker dictionary.
// In vscode, use the "Add <word> to workspace settings" in the quick-fix menu to add words to this list.
"bionemo"
"allclose",
"bionemo",
"dtype",
"nemo",
"pretraining",
"rampup",
"resamplers",
"singlecell",
"uniref"
],
"editor.rulers": [
120
Expand Down
2 changes: 1 addition & 1 deletion scripts/singlecell/geneformer/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def main(
train_dataset_path=train_data_path,
val_dataset_path=val_data_path,
test_dataset_path=test_data_path,
random_token_prob=0.1, # this is the incorrect setting we originally used.
random_token_prob=0.02, # changed to represent the incorrect setting we originally used.
median_dict=median_dict,
micro_batch_size=micro_batch_size,
global_batch_size=micro_batch_size * int(num_nodes * devices / pipeline_model_parallel_size),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ def random_numpy_context(seed: int = 42) -> Iterator[None]:
yield
finally:
np.random.set_state(state)


def get_seed_from_rng(rng: np.random.Generator) -> int:
"""Generates a deterministic random seed from an existing random generator.
Used to seed a torch random generator from a numpy random generator.
"""
return rng.integers(np.iinfo(np.int64).max)
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.


import functools
from pathlib import Path
from typing import List, Optional, Sequence

import numpy as np
import pytorch_lightning as pl
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging
Expand All @@ -25,7 +27,10 @@
from torch.utils.data import DataLoader

from bionemo.core.data.resamplers import PRNGDatasetShuffler
from bionemo.core.utils import random_utils
from bionemo.geneformer.data.singlecell.dataset import SingleCellDataset
from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer
from bionemo.llm.data import collate


__all__: Sequence[str] = ("SingleCellDataModule",)
Expand Down Expand Up @@ -90,6 +95,8 @@ def __init__( # noqa: D107
self.persistent_workers = persistent_workers
self.pin_memory = pin_memory
self.index_mapping_dir = index_mapping_dir or str(Path(self.data_path_train).parent)

rng = np.random.default_rng(seed)
self._train_dataset_ori = SingleCellDataset(
self.data_path_train,
self.tokenizer,
Expand All @@ -98,6 +105,7 @@ def __init__( # noqa: D107
mask_prob=self.mask_prob,
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
)
self._val_dataset_ori = SingleCellDataset(
self.data_path_val,
Expand All @@ -107,6 +115,7 @@ def __init__( # noqa: D107
mask_prob=self.mask_prob,
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
)
self._test_dataset_ori = SingleCellDataset(
self.data_path_test,
Expand All @@ -116,6 +125,7 @@ def __init__( # noqa: D107
mask_prob=self.mask_prob,
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
)

# This is needed here, or you need to specify it in the megatron adapter thing TODO name?
Expand Down Expand Up @@ -169,7 +179,12 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
# collate_fn=dataset.collate_fn, No special work happens in this dataloader outside of getitem
collate_fn=functools.partial(
collate.bert_padding_collate_fn,
padding_value=self.tokenizer.token_to_id(GeneTokenizer.pad_token),
min_length=None,
max_length=self.max_len,
),
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,25 @@

import json
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple, TypedDict
from typing import Any, Dict, Optional, Sequence, Tuple

import numpy as np
import torch
from nemo.utils import logging
from torch.utils.data import Dataset

from bionemo.geneformer.data.singlecell.utils import sample_or_truncate_plus_pad
from bionemo.core.utils import random_utils
from bionemo.geneformer.data.singlecell.utils import sample_or_truncate
from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer
from bionemo.llm.data import masking, types


__all__: Sequence[str] = (
"SingleCellDataset",
"Item",
"process_item",
)


class Item(TypedDict): # noqa: D101
text: np.ndarray
types: np.ndarray
padding_mask: np.ndarray
labels: np.ndarray
loss_mask: np.ndarray
is_random: np.ndarray


class SingleCellDataset(Dataset):
"""A dataset class for single-cell pre-training. These can be generated using the sc_memmap.py script. Future
updates will contain more comprehensive workflows for generating a Sparse Memmap from scRNA-seq.
Expand Down Expand Up @@ -94,6 +87,7 @@ def __init__( # noqa: D107
random_token_prob: float = 0.1,
prepend_cls_token: bool = True,
assert_increasing_columns: bool = True,
seed: int = np.random.SeedSequence().entropy, # type: ignore
):
super().__init__()
self.data_path = data_path
Expand All @@ -102,6 +96,7 @@ def __init__( # noqa: D107
self.mask_token_prob = mask_token_prob
self.mask_prob = mask_prob
self.prepend_cls_token = prepend_cls_token
self._seed = seed
# check if column indices are increasing for looking up genes. This is a way of spotting if the sc_memmap.py
# script produced properly strctured sparse files.
self.assert_increasing_columns = assert_increasing_columns
Expand Down Expand Up @@ -198,15 +193,18 @@ def lookup_cell_by_idx(self, idx) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
)
return gene_data, col_idxs, feature_ids

def __getitem__(self, idx: int) -> Item:
"""Performs a lookup and the required transformation for the model""" # noqa: D415
def __getitem__(self, idx: int) -> types.BertSample: # noqa: D105
rng = np.random.default_rng([self._seed, idx])

"""Performs a lookup and the required transformation for the model"""
gene_data, col_idxs, feature_ids = self.lookup_cell_by_idx(idx)
return process_item(
gene_data,
col_idxs,
feature_ids,
self.tokenizer,
gene_median=self.gene_medians,
rng=rng,
max_len=self.max_len,
mask_token_prob=self.mask_token_prob,
mask_prob=self.mask_prob,
Expand All @@ -221,14 +219,15 @@ def process_item( # noqa: D417
feature_ids: np.ndarray,
tokenizer: GeneTokenizer,
gene_median: dict,
rng: np.random.Generator,
max_len: int = 1024,
mask_prob: float = 0.15,
mask_token_prob: float = 0.8,
random_token_prob: float = 0.1,
target_sum: int = 10000,
normalize: bool = True,
prepend_cls_token: bool = True,
) -> Item:
) -> types.BertSample:
"""Process a single item in the dataset.
Optionally performs median normalization and rank ordering. The tokenizers CLS token is added to the beginning
Expand All @@ -240,6 +239,7 @@ def process_item( # noqa: D417
feature_ids (list): Feature ids for the full dataset.
tokenizer (Tokenizer): Tokenizer object.
gene_median (optional(dict)): Dictionary of gene medians. Defaults to None. Expects ensembl IDs to be keys.
rng: Random number generator to ensure deterministic results.
max_len (int): Maximum length of the item. Defaults to 1024. Applies padding to any sequence shorter than max_len and truncates any sequence longer than max_len.
mask_prob (float): Probability of masking a token. Defaults to 0.15.
target_sum (int): Target sum for normalization. Defaults to 10000.
Expand All @@ -259,14 +259,6 @@ def process_item( # noqa: D417
if max_len < 1:
raise ValueError(f"max_len must be greater than 1, {max_len=}")

if random_token_prob + mask_token_prob > 1.0:
raise ValueError(
"Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0, identity_token_prob is any remainder less than 1.0."
)

identity_token_prob = 1.0 - (random_token_prob + mask_token_prob)
assert identity_token_prob >= 0.0

if gene_median is None:
raise ValueError("gene_median must be provided for this tokenizer")

Expand Down Expand Up @@ -296,62 +288,34 @@ def process_item( # noqa: D417
token_ids = token_ids[idxs]

# - select max_len subset, set sample to false so it doesnt permute the already rank ordered expression values.
token_ids = sample_or_truncate_plus_pad(
token_ids, max_len, tokenizer.token_to_id(tokenizer.pad_token), sample=False
token_ids = sample_or_truncate(token_ids, max_len, sample=False)

masked_tokens, labels, loss_mask = masking.apply_bert_pretraining_mask(
tokenized_sequence=torch.from_numpy(token_ids),
random_seed=random_utils.get_seed_from_rng(rng),
mask_config=masking.BertMaskConfig(
mask_token=tokenizer.token_to_id(tokenizer.mask_token),
random_tokens=range(5, len(tokenizer.vocab)),
mask_prob=mask_prob,
mask_token_prob=mask_token_prob,
random_token_prob=random_token_prob,
),
)

mask = None
mask_tokens_positions = None
random_tokens_positions = None

# - masked tokens
if mask_prob > 0.0:
probs = np.full(token_ids.shape[0], mask_prob)
probs[token_ids == tokenizer.token_to_id(tokenizer.pad_token)] = 0.0
mask = np.random.binomial(1, probs).astype(bool)
mask_tokens_positions = mask & np.random.binomial(1, mask_token_prob, mask.shape).astype(bool)
random_tokens_positions = (
mask & np.random.binomial(1, random_token_prob, mask.shape).astype(bool) & (~mask_tokens_positions)
)
# - ensure [CLS] token is masked from the loss. Note that we're dealing with 1d arrays so flattening isn't a problem here.
if prepend_cls_token:
mask = np.insert(mask, 0, False)
mask_tokens_positions = np.insert(mask_tokens_positions, 0, False)
random_tokens_positions = np.insert(random_tokens_positions, 0, False)

# - add [CLS] token, note that token_ids is a 1d array so flattening isn't a problem here.
if prepend_cls_token:
token_ids = np.insert(token_ids, 0, tokenizer.token_to_id(tokenizer.cls_token))
attention_mask = token_ids != tokenizer.token_to_id(tokenizer.pad_token)

labels = np.ones(len(token_ids)) * -1

if mask is None:
# If prob is set to zero, we get None for our mask, which could have unintended side effects.
# We abuse the scenario where mask == None
labels[mask] = token_ids[mask]
mask = np.zeros(shape=token_ids.shape, dtype=bool)
else:
mask[~attention_mask] = False # make sure that we aren't doing MLM on [PAD] tokens
labels[mask] = token_ids[mask]
if mask_tokens_positions is None:
mask_tokens_positions = np.zeros_like(mask)
if random_tokens_positions is None:
random_tokens_positions = np.zeros_like(mask)
# identity_tokens = mask & (~mask_tokens_positions) & (~random_tokens_positions), not needed because
token_ids[mask_tokens_positions] = tokenizer.token_to_id(tokenizer.mask_token)
# There are 5 special tokens in the tokenizer, so we start from 5. TODO make this a parameter of the tokenizer.
if random_tokens_positions.sum() > 0:
token_ids[random_tokens_positions] = np.random.randint(5, len(tokenizer.vocab), random_tokens_positions.sum())
masked_tokens, labels, loss_mask = masking.add_cls_and_eos_tokens(
sequence=masked_tokens,
labels=labels,
loss_mask=loss_mask,
cls_token=tokenizer.token_to_id(tokenizer.cls_token),
)

# NeMo megatron assumes this return structure.
item = {
"text": token_ids.astype(np.int64),
"types": np.zeros_like(token_ids).astype(np.int64),
"attention_mask": attention_mask.astype(np.int64),
"labels": labels.astype(np.int64),
"loss_mask": mask,
"is_random": np.zeros_like(token_ids).astype(np.int64),
return {
"text": masked_tokens,
"types": torch.zeros_like(masked_tokens, dtype=torch.int64),
"attention_mask": torch.ones_like(masked_tokens, dtype=torch.int64),
"labels": labels,
"loss_mask": loss_mask,
"is_random": torch.zeros_like(masked_tokens, dtype=torch.int64),
}

return item
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence

import numpy as np


__all__: Sequence[str] = ("sample_or_truncate_plus_pad",)


def sample_or_truncate_plus_pad(
gene_ids: np.array,
def sample_or_truncate(
gene_ids: np.ndarray,
max_length: int,
pad_token_id: int,
sample: bool = True,
) -> np.array:
) -> np.ndarray:
"""Truncate and pad samples.
Args:
gene_ids (np.ndarray): Array of gene IDs.
max_length (int): Maximum length of the samples.
pad_token_id (int): ID of the padding token.
sample (bool, optional): Whether to sample or truncate the samples. Defaults to True.
Returns:
np.array: Tuple containing the truncated or padded gene IDs.
"""
if len(gene_ids) == max_length:
if len(gene_ids) <= max_length:
return gene_ids

if len(gene_ids) > max_length: # - sample or truncate
if sample:
indices = np.random.permutation(len(gene_ids))[:max_length]
return gene_ids[indices]
else:
return gene_ids[:max_length]
else: # - pad
pad_tokens = np.full((max_length - len(gene_ids)), pad_token_id, dtype=np.int32)
gene_ids = np.concatenate([gene_ids, pad_tokens])
return gene_ids
if sample:
indices = np.random.permutation(len(gene_ids))[:max_length]
return gene_ids[indices]
else:
return gene_ids[:max_length]
Loading

0 comments on commit e0af43c

Please sign in to comment.