Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BERT-style masking function #55

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
{
"cSpell.words": [
// List of words to be added to the spell-checker dictionary.
// In vscode, use the "Add <word> to workspace settings" in the quick-fix menu to add words to this list.
"bionemo"
"allclose",
"bionemo",
"dtype",
"nemo",
"pretraining",
"rampup",
"resamplers",
"singlecell",
"uniref"
],
"editor.rulers": [
120
Expand Down
2 changes: 1 addition & 1 deletion scripts/singlecell/geneformer/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def main(
train_dataset_path=train_data_path,
val_dataset_path=val_data_path,
test_dataset_path=test_data_path,
random_token_prob=0.1, # this is the incorrect setting we originally used.
random_token_prob=0.02, # changed to represent the incorrect setting we originally used.
pstjohn marked this conversation as resolved.
Show resolved Hide resolved
median_dict=median_dict,
micro_batch_size=micro_batch_size,
global_batch_size=micro_batch_size * int(num_nodes * devices / pipeline_model_parallel_size),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ def random_numpy_context(seed: int = 42) -> Iterator[None]:
yield
finally:
np.random.set_state(state)


def get_seed_from_rng(rng: np.random.Generator) -> int:
"""Generates a deterministic random seed from an existing random generator.

Used to seed a torch random generator from a numpy random generator.
"""
return rng.integers(np.iinfo(np.int64).max)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: IS the seed int64 so we can get more range?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rng.integers returns an int64 by default, we're just asking for any random int64.

Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.


import functools
from pathlib import Path
from typing import List, Optional, Sequence

import numpy as np
import pytorch_lightning as pl
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging
Expand All @@ -25,7 +27,10 @@
from torch.utils.data import DataLoader

from bionemo.core.data.resamplers import PRNGDatasetShuffler
from bionemo.core.utils import random_utils
from bionemo.geneformer.data.singlecell.dataset import SingleCellDataset
from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer
from bionemo.llm.data import collate


__all__: Sequence[str] = ("SingleCellDataModule",)
Expand Down Expand Up @@ -90,6 +95,8 @@ def __init__( # noqa: D107
self.persistent_workers = persistent_workers
self.pin_memory = pin_memory
self.index_mapping_dir = index_mapping_dir or str(Path(self.data_path_train).parent)

rng = np.random.default_rng(seed)
self._train_dataset_ori = SingleCellDataset(
self.data_path_train,
self.tokenizer,
Expand All @@ -98,6 +105,7 @@ def __init__( # noqa: D107
mask_prob=self.mask_prob,
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
jstjohn marked this conversation as resolved.
Show resolved Hide resolved
)
self._val_dataset_ori = SingleCellDataset(
self.data_path_val,
Expand All @@ -107,6 +115,7 @@ def __init__( # noqa: D107
mask_prob=self.mask_prob,
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
)
self._test_dataset_ori = SingleCellDataset(
self.data_path_test,
Expand All @@ -116,6 +125,7 @@ def __init__( # noqa: D107
mask_prob=self.mask_prob,
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
)

# This is needed here, or you need to specify it in the megatron adapter thing TODO name?
Expand Down Expand Up @@ -169,7 +179,12 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
# collate_fn=dataset.collate_fn, No special work happens in this dataloader outside of getitem
collate_fn=functools.partial(
collate.bert_padding_collate_fn,
padding_value=self.tokenizer.token_to_id(GeneTokenizer.pad_token),
min_length=None,
max_length=self.max_len,
Copy link
Collaborator

@jstjohn jstjohn Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add min_length=None, here to kind of self-document that we're allowing for fewer than max_len tokens if all elements of a batch are shorter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

),
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,25 @@

import json
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple, TypedDict
from typing import Any, Dict, Optional, Sequence, Tuple

import numpy as np
import torch
from nemo.utils import logging
from torch.utils.data import Dataset

from bionemo.geneformer.data.singlecell.utils import sample_or_truncate_plus_pad
from bionemo.core.utils import random_utils
from bionemo.geneformer.data.singlecell.utils import sample_or_truncate
from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer
from bionemo.llm.data import masking, types


__all__: Sequence[str] = (
"SingleCellDataset",
"Item",
"process_item",
)


class Item(TypedDict): # noqa: D101
text: np.ndarray
types: np.ndarray
padding_mask: np.ndarray
labels: np.ndarray
loss_mask: np.ndarray
is_random: np.ndarray


class SingleCellDataset(Dataset):
"""A dataset class for single-cell pre-training. These can be generated using the sc_memmap.py script. Future
updates will contain more comprehensive workflows for generating a Sparse Memmap from scRNA-seq.
Expand Down Expand Up @@ -94,6 +87,7 @@ def __init__( # noqa: D107
random_token_prob: float = 0.1,
prepend_cls_token: bool = True,
assert_increasing_columns: bool = True,
seed: int = np.random.SeedSequence().entropy, # type: ignore
):
super().__init__()
self.data_path = data_path
Expand All @@ -102,6 +96,7 @@ def __init__( # noqa: D107
self.mask_token_prob = mask_token_prob
self.mask_prob = mask_prob
self.prepend_cls_token = prepend_cls_token
self._seed = seed
# check if column indices are increasing for looking up genes. This is a way of spotting if the sc_memmap.py
# script produced properly strctured sparse files.
self.assert_increasing_columns = assert_increasing_columns
Expand Down Expand Up @@ -198,15 +193,18 @@ def lookup_cell_by_idx(self, idx) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
)
return gene_data, col_idxs, feature_ids

def __getitem__(self, idx: int) -> Item:
"""Performs a lookup and the required transformation for the model""" # noqa: D415
def __getitem__(self, idx: int) -> types.BertSample: # noqa: D105
rng = np.random.default_rng([self._seed, idx])

"""Performs a lookup and the required transformation for the model"""
gene_data, col_idxs, feature_ids = self.lookup_cell_by_idx(idx)
return process_item(
gene_data,
col_idxs,
feature_ids,
self.tokenizer,
gene_median=self.gene_medians,
rng=rng,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this seed fixed? Comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rng is defined on line 197 -- rng = np.random.default_rng([self._seed, idx])

it's deterministic as a function of the class seed and item index. Megatron's sampler assumes that datasets are deterministic, so that when you call getitem(i) on each GPU rank, they all get the same data. I don't think we were obeying that constraint previously

max_len=self.max_len,
mask_token_prob=self.mask_token_prob,
mask_prob=self.mask_prob,
Expand All @@ -221,14 +219,15 @@ def process_item( # noqa: D417
feature_ids: np.ndarray,
tokenizer: GeneTokenizer,
gene_median: dict,
rng: np.random.Generator,
max_len: int = 1024,
mask_prob: float = 0.15,
mask_token_prob: float = 0.8,
random_token_prob: float = 0.1,
target_sum: int = 10000,
normalize: bool = True,
prepend_cls_token: bool = True,
) -> Item:
) -> types.BertSample:
"""Process a single item in the dataset.

Optionally performs median normalization and rank ordering. The tokenizers CLS token is added to the beginning
Expand All @@ -240,6 +239,7 @@ def process_item( # noqa: D417
feature_ids (list): Feature ids for the full dataset.
tokenizer (Tokenizer): Tokenizer object.
gene_median (optional(dict)): Dictionary of gene medians. Defaults to None. Expects ensembl IDs to be keys.
rng: Random number generator to ensure deterministic results.
max_len (int): Maximum length of the item. Defaults to 1024. Applies padding to any sequence shorter than max_len and truncates any sequence longer than max_len.
mask_prob (float): Probability of masking a token. Defaults to 0.15.
target_sum (int): Target sum for normalization. Defaults to 10000.
Expand All @@ -259,14 +259,6 @@ def process_item( # noqa: D417
if max_len < 1:
raise ValueError(f"max_len must be greater than 1, {max_len=}")

if random_token_prob + mask_token_prob > 1.0:
raise ValueError(
"Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0, identity_token_prob is any remainder less than 1.0."
)

identity_token_prob = 1.0 - (random_token_prob + mask_token_prob)
assert identity_token_prob >= 0.0

if gene_median is None:
raise ValueError("gene_median must be provided for this tokenizer")

Expand Down Expand Up @@ -296,62 +288,34 @@ def process_item( # noqa: D417
token_ids = token_ids[idxs]

# - select max_len subset, set sample to false so it doesnt permute the already rank ordered expression values.
token_ids = sample_or_truncate_plus_pad(
token_ids, max_len, tokenizer.token_to_id(tokenizer.pad_token), sample=False
token_ids = sample_or_truncate(token_ids, max_len, sample=False)

masked_tokens, labels, loss_mask = masking.apply_bert_pretraining_mask(
tokenized_sequence=torch.from_numpy(token_ids),
random_seed=random_utils.get_seed_from_rng(rng),
mask_config=masking.BertMaskConfig(
mask_token=tokenizer.token_to_id(tokenizer.mask_token),
random_tokens=range(5, len(tokenizer.vocab)),
mask_prob=mask_prob,
mask_token_prob=mask_token_prob,
random_token_prob=random_token_prob,
),
)

mask = None
mask_tokens_positions = None
random_tokens_positions = None

# - masked tokens
if mask_prob > 0.0:
probs = np.full(token_ids.shape[0], mask_prob)
probs[token_ids == tokenizer.token_to_id(tokenizer.pad_token)] = 0.0
mask = np.random.binomial(1, probs).astype(bool)
mask_tokens_positions = mask & np.random.binomial(1, mask_token_prob, mask.shape).astype(bool)
random_tokens_positions = (
mask & np.random.binomial(1, random_token_prob, mask.shape).astype(bool) & (~mask_tokens_positions)
)
# - ensure [CLS] token is masked from the loss. Note that we're dealing with 1d arrays so flattening isn't a problem here.
if prepend_cls_token:
mask = np.insert(mask, 0, False)
mask_tokens_positions = np.insert(mask_tokens_positions, 0, False)
random_tokens_positions = np.insert(random_tokens_positions, 0, False)

# - add [CLS] token, note that token_ids is a 1d array so flattening isn't a problem here.
if prepend_cls_token:
token_ids = np.insert(token_ids, 0, tokenizer.token_to_id(tokenizer.cls_token))
attention_mask = token_ids != tokenizer.token_to_id(tokenizer.pad_token)

labels = np.ones(len(token_ids)) * -1

if mask is None:
# If prob is set to zero, we get None for our mask, which could have unintended side effects.
# We abuse the scenario where mask == None
labels[mask] = token_ids[mask]
mask = np.zeros(shape=token_ids.shape, dtype=bool)
else:
mask[~attention_mask] = False # make sure that we aren't doing MLM on [PAD] tokens
labels[mask] = token_ids[mask]
if mask_tokens_positions is None:
mask_tokens_positions = np.zeros_like(mask)
if random_tokens_positions is None:
random_tokens_positions = np.zeros_like(mask)
# identity_tokens = mask & (~mask_tokens_positions) & (~random_tokens_positions), not needed because
token_ids[mask_tokens_positions] = tokenizer.token_to_id(tokenizer.mask_token)
# There are 5 special tokens in the tokenizer, so we start from 5. TODO make this a parameter of the tokenizer.
if random_tokens_positions.sum() > 0:
token_ids[random_tokens_positions] = np.random.randint(5, len(tokenizer.vocab), random_tokens_positions.sum())
masked_tokens, labels, loss_mask = masking.add_cls_and_eos_tokens(
sequence=masked_tokens,
labels=labels,
loss_mask=loss_mask,
cls_token=tokenizer.token_to_id(tokenizer.cls_token),
)

# NeMo megatron assumes this return structure.
item = {
"text": token_ids.astype(np.int64),
"types": np.zeros_like(token_ids).astype(np.int64),
"attention_mask": attention_mask.astype(np.int64),
"labels": labels.astype(np.int64),
"loss_mask": mask,
"is_random": np.zeros_like(token_ids).astype(np.int64),
return {
"text": masked_tokens,
"types": torch.zeros_like(masked_tokens, dtype=torch.int64),
"attention_mask": torch.ones_like(masked_tokens, dtype=torch.int64),
"labels": labels,
"loss_mask": loss_mask,
"is_random": torch.zeros_like(masked_tokens, dtype=torch.int64),
}

return item
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence

import numpy as np


__all__: Sequence[str] = ("sample_or_truncate_plus_pad",)


def sample_or_truncate_plus_pad(
gene_ids: np.array,
def sample_or_truncate(
gene_ids: np.ndarray,
max_length: int,
pad_token_id: int,
sample: bool = True,
) -> np.array:
) -> np.ndarray:
"""Truncate and pad samples.

Args:
gene_ids (np.ndarray): Array of gene IDs.
max_length (int): Maximum length of the samples.
pad_token_id (int): ID of the padding token.
sample (bool, optional): Whether to sample or truncate the samples. Defaults to True.

Returns:
np.array: Tuple containing the truncated or padded gene IDs.
"""
if len(gene_ids) == max_length:
if len(gene_ids) <= max_length:
return gene_ids

if len(gene_ids) > max_length: # - sample or truncate
if sample:
indices = np.random.permutation(len(gene_ids))[:max_length]
return gene_ids[indices]
else:
return gene_ids[:max_length]
else: # - pad
pad_tokens = np.full((max_length - len(gene_ids)), pad_token_id, dtype=np.int32)
gene_ids = np.concatenate([gene_ids, pad_tokens])
return gene_ids
if sample:
indices = np.random.permutation(len(gene_ids))[:max_length]
return gene_ids[indices]
else:
return gene_ids[:max_length]
Loading