Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding PaliGemma2 to KerasHub #1998

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions keras_hub/src/models/pali_gemma/pali_gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,40 @@ class PaliGemmaBackbone(Backbone):
a two-layer feedforward network for each transformer decoder block.
head_dim: int. The size of each attention head in the mixed decoder.
vit_patch_size: int. The size of each square patch in the input image.
vit_num_heads: int. The number of attention heads for the vision(image)
vit_num_heads: int. The number of attention heads for the vision (image)
transformer encoder.
vit_hidden_dim: int. The size of the transformer hidden state at the end
of each vision transformer layer.
vit_num_layers: int. The number of vision transformer layers.
vit_intermediate_dim: int. The output dimension of the first Dense layer
in a two-layer feedforward network for vision transformer.
vit_pooling: string. The encoded vision embeddings are pooled using the
specified polling setting. The accepted values are `"map"`, `"gap"`,
`"0"` or `"none"`. Defaults to `"none"`.
in a two-layer feedforward network for vision transformer. Defaults
to `4304`.
vit_pooling: `None` or string. The encoded vision embeddings are pooled
using the specified polling setting. The accepted values are
`"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`.
vit_classifier_activation: activation function. The activation that
is used for final output classification in the vision transformer.
Defaults to `None`.
vit_name: string. The name used for vision transformer layers.
query_head_dim_normalize: boolean. If `True` normalize the query before
attention with `head_dim`. If `False`, normalize the query with
`hidden_dim / num_query_heads`. Defaults to `True`.
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
block. Defaults to `False`.
use_post_attention_norm: boolean. Whether to normalize after the attention
block. Defaults to `False`.
attention_logit_soft_cap: `None` or int. Soft cap for the attention
logits. Defaults to `None`.
final_logit_soft_cap: `None` or int. Soft cap for the final logits.
Defaults to `None`.
use_sliding_window_attention: boolean. Whether to use sliding local
window attention. Defaults to `False`.
sliding_window_size: int. Size of the sliding local window. Defaults to
`4096`.
layer_norm_epsilon: float. The epsilon value user for every layer norm
in all transformer blocks.
in all transformer blocks. Defaults to `1e-6`.
dropout: float. Dropout probability for the Transformer decoder blocks.
Defaults to `0`.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for the models computations and weights. Note that some
computations, such as softmax and layer normalization will always
Expand Down Expand Up @@ -119,6 +137,13 @@ def __init__(
vit_pooling=None,
vit_classifier_activation=None,
vit_name=None,
query_head_dim_normalize=True,
use_post_ffw_norm=False,
use_post_attention_norm=False,
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
layer_norm_epsilon=1e-6,
dropout=0,
dtype=None,
Expand All @@ -136,6 +161,7 @@ def __init__(
seed=None,
),
dtype=dtype,
logit_soft_cap=final_logit_soft_cap,
name="token_embedding",
)
# TODO Remove this. Work around for previous serialization bug.
Expand All @@ -155,12 +181,19 @@ def __init__(
)
self.transformer_layers = []
for i in range(num_layers):
sliding_window = use_sliding_window_attention and (i % 2 == 0)
layer = PaliGemmaDecoderBlock(
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
num_query_heads=num_query_heads,
head_dim=head_dim,
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
query_head_dim_normalize=query_head_dim_normalize,
use_post_ffw_norm=use_post_ffw_norm,
use_post_attention_norm=use_post_attention_norm,
logit_soft_cap=attention_logit_soft_cap,
use_sliding_window_attention=sliding_window,
sliding_window_size=sliding_window_size,
dropout=dropout,
dtype=dtype,
name=f"decoder_block_{i}",
Expand All @@ -173,7 +206,9 @@ def __init__(
)

# === Functional Model ===
image_input = self.vit_encoder.inputs[0]
image_input = keras.Input(
shape=(image_size, image_size, 3), name="images"
)
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
Expand Down Expand Up @@ -219,7 +254,15 @@ def __init__(
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
# VIT Params
# Gemma2 params
self.query_head_dim_normalize = query_head_dim_normalize
self.use_post_ffw_norm = use_post_ffw_norm
self.use_post_attention_norm = use_post_attention_norm
self.attention_logit_soft_cap = attention_logit_soft_cap
self.final_logit_soft_cap = final_logit_soft_cap
self.sliding_window_size = sliding_window_size
self.use_sliding_window_attention = use_sliding_window_attention
# ViT params
self.vit_patch_size = vit_patch_size
self.vit_num_heads = vit_num_heads
self.vit_hidden_dim = vit_hidden_dim
Expand All @@ -243,8 +286,6 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"head_dim": self.head_dim,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"vit_patch_size": self.vit_patch_size,
"vit_num_heads": self.vit_num_heads,
"vit_hidden_dim": self.vit_hidden_dim,
Expand All @@ -253,6 +294,15 @@ def get_config(self):
"vit_pooling": self.vit_pooling,
"vit_classifier_activation": self.vit_classifier_activation,
"vit_name": self.vit_name,
"query_head_dim_normalize": self.query_head_dim_normalize,
"use_post_ffw_norm": self.use_post_ffw_norm,
"use_post_attention_norm": self.use_post_attention_norm,
"final_logit_soft_cap": self.final_logit_soft_cap,
"attention_logit_soft_cap": self.attention_logit_soft_cap,
"sliding_window_size": self.sliding_window_size,
"use_sliding_window_attention": self.use_sliding_window_attention,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
}
)
return config
73 changes: 72 additions & 1 deletion keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def test_backbone_basics(self):
8,
),
variable_length_data=[self.input_data],
run_mixed_precision_check=False, # TODO: Set to `True`
)

@pytest.mark.large
Expand Down Expand Up @@ -98,3 +97,75 @@ def test_all_presets(self):
preset=preset,
input_data=self.input_data,
)


class PaliGemma2BackboneTest(TestCase):
def setUp(self):
self.batch_size = 2
self.vocabulary_size = 256
self.text_sequence_length = 64
self.image_size = 16
self.image_sequence_length = int((self.image_size / 4) ** 2)
self.init_kwargs = {
"vocabulary_size": self.vocabulary_size,
"image_size": self.image_size,
"num_layers": 2,
"num_query_heads": 2,
"num_key_value_heads": 1,
"hidden_dim": 8,
"intermediate_dim": 16,
"head_dim": 4,
"vit_patch_size": 4,
"vit_num_layers": 2,
"vit_num_heads": 2,
"vit_hidden_dim": 8,
"vit_intermediate_dim": 16,
# Gemma2
"query_head_dim_normalize": True,
"use_post_ffw_norm": True,
"use_post_attention_norm": True,
"final_logit_soft_cap": 30,
"attention_logit_soft_cap": 50,
"use_sliding_window_attention": True,
"sliding_window_size": 4096,
}

dummy_images = np.random.rand(
self.batch_size, self.image_size, self.image_size, 3
)
dummy_text_token_ids = np.random.rand(
self.batch_size, self.text_sequence_length
)
self.input_data = {
"token_ids": dummy_text_token_ids,
"images": dummy_images,
"padding_mask": np.ones(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
"response_mask": np.zeros(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=PaliGemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(
self.batch_size,
self.text_sequence_length + self.image_sequence_length,
8,
),
variable_length_data=[self.input_data],
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=PaliGemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
44 changes: 21 additions & 23 deletions keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
the attention layer.
num_key_value_heads: int. The number of heads for the key and value
projections in the attention layer.
query_head_dim_normalize: boolean. If `True` normalize the query before
attention with `head_dim`. If `False`, normalize the query with
`hidden_dim / num_query_heads`. Defaults to `True`.
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
block. Defaults to `False`.
use_post_attention_norm: boolean. Whether to normalize after the
attention block. Defaults to `False`.
logit_soft_cap: `None` or int. Soft cap for the attention logits.
Defaults to `None`.
use_sliding_window_attention: boolean. Whether to use sliding local
window attention. Defaults to `False`.
sliding_window_size: int. Size of the sliding local window. Defaults to
`4096`.
layer_norm_epsilon: float. The epsilon hyperparameter used for layer
normalization.
normalization. Defaults to `1e-6`.
dropout: float. The dropout rate for the transformer attention layer.
Defaults to `0`.
"""

def __init__(
self,
hidden_dim,
intermediate_dim,
head_dim,
num_query_heads,
num_key_value_heads,
layer_norm_epsilon=1e-6,
dropout=0,
**kwargs,
):
super().__init__(
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
head_dim=head_dim,
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
layer_norm_epsilon=layer_norm_epsilon,
dropout=dropout,
**kwargs,
)

def call(
self,
x,
Expand All @@ -83,6 +75,9 @@ def call(
attention_mask=attention_mask,
)

if self.use_post_attention_norm:
attention = self.post_attention_norm(attention)

if self.dropout:
attention = self.attention_dropout(attention)

Expand All @@ -94,6 +89,9 @@ def call(
x = keras.activations.gelu(x1, approximate=True) * x2
x = self.ffw_linear(x)

if self.use_post_ffw_norm:
x = self.post_ffw_norm(x)

x = x + attention_x

if cache is not None:
Expand Down
Loading
Loading