avista-base / decoder.py
yubo0306's picture
Upload AVHubertForConditionalGeneration
425850c verified
Raw
History Blame Contribute Delete
47.7 kB
from typing import Callable, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn as nn
from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.hubert.configuration_hubert import HubertConfig
from transformers.models.hubert.modeling_hubert import (
HubertAttnAdapterLayer,
HubertFeedForward,
is_deepspeed_zero3_enabled,
)
from transformers.utils import is_torchdynamo_compiling, logging
from typing_extensions import Unpack
logger = logging.get_logger(__name__)
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py#L428
class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.
Attributes:
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for query state.
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""
cumulative_seqlens_q: Optional[torch.LongTensor]
cumulative_seqlens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]
class SinusoidalPositionalEmbedding(nn.Module):
def __init__(self, config) -> None:
super().__init__()
weight = torch.empty(
(
config.max_position_embeddings,
config.hidden_size,
),
requires_grad=False,
)
self._init_sinusoidal_embedding(weight)
self.register_buffer("position_embeddings", weight)
def _init_sinusoidal_embedding(self, embeddings: torch.Tensor) -> None:
T, D = embeddings.size()
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / D) for j in range(D)] for pos in range(T)])
embeddings[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
embeddings[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
def forward(
self,
inputs: torch.Tensor,
past_key_values_length: int = 0, # Offset
position_ids: torch.LongTensor | None = None,
) -> torch.Tensor:
if position_ids is None:
bsz, seq_len = inputs.shape[:2]
position_ids = torch.arange(
past_key_values_length,
past_key_values_length + seq_len,
dtype=torch.long,
device=self.position_embeddings.device,
).expand(bsz, -1)
else:
position_ids = position_ids.unsqueeze(0)
return self.position_embeddings[position_ids]
# Copied from https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/bart/modeling_bart.py#L116
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
# self.offset = 2
# super().__init__(num_embeddings + self.offset, embedding_dim)
super().__init__(num_embeddings, embedding_dim)
def forward(
self,
input_ids: torch.Tensor,
past_key_values_length: int = 0,
position_ids: torch.LongTensor = None,
):
"""`input_ids' shape is expected to be [bsz x seqlen]."""
if position_ids is None:
bsz, seq_len = input_ids.shape[:2]
position_ids = torch.arange(
past_key_values_length,
past_key_values_length + seq_len,
dtype=torch.long,
device=self.weight.device,
).expand(bsz, -1)
else:
position_ids = position_ids.unsqueeze(0)
# return super().forward(positions + self.offset)
return super().forward(position_ids)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
if scaling is None:
scaling = query.size(-1) ** -0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class AVHubertAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[HubertConfig] = None,
layer_idx: Optional[int] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
curr_past_key_value = past_key_value.self_attention_cache
else:
curr_past_key_value = past_key_value
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
key_states = key_states.view(*kv_input_shape).transpose(1, 2)
value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = curr_past_key_value.update(
key_states,
value_states,
self.layer_idx,
{"cache_position": cache_position},
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
attention_interface: Callable = eager_attention_forward
# TODO: attn implementation other than eager attention
# if self.config._attn_implementation != "eager":
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
output_attentions=output_attentions,
head_mask=layer_head_mask,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
class AVHubertDecoderLayer(nn.Module):
def __init__(self, config: HubertConfig, layer_idx: Optional[int] = None):
super().__init__()
self.attention = AVHubertAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
layer_idx=layer_idx,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.encoder_attn = AVHubertAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
layer_idx=layer_idx,
)
self.encoder_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = HubertFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = HubertAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states, self_attn_weights, past_key_value = self.attention(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states, cross_attn_weights, _ = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + residual
hidden_states = self.encoder_layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (past_key_value,)
return outputs
class AVHubertDecoderLayerStableLayerNorm(nn.Module):
def __init__(self, config: HubertConfig, layer_idx: Optional[int] = None):
super().__init__()
self.attention = AVHubertAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
layer_idx=layer_idx,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.encoder_attn = AVHubertAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
layer_idx=layer_idx,
)
self.encoder_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = HubertFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = HubertAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, self_attn_weights, past_key_value = self.attention(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_layer_norm(hidden_states)
hidden_states, cross_attn_weights, _ = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + residual
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (past_key_value,)
return outputs
class AVHubertDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
if config.learned_pos:
self.pos_embed = LearnedPositionalEmbedding(
num_embeddings=config.max_position_embeddings,
embedding_dim=config.hidden_size,
)
else:
self.pos_embed = SinusoidalPositionalEmbedding(config=config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([AVHubertDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
head_mask: torch.Tensor | None = None,
cross_attn_head_mask: torch.Tensor | None = None,
past_key_values: EncoderDecoderCache | None = None,
use_cache: bool | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
cache_position: torch.LongTensor | None = None,
):
input_shape = inputs_embeds.shape[:-1]
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache.from_legacy_cache()
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
batch_size, seq_length = inputs_embeds.size()[:-1]
if cache_position is None:
cache_position = torch.arange(
past_key_values_length,
past_key_values_length + seq_length,
device=inputs_embeds.device,
)
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
self_attn_cache = (
past_key_values.self_attention_cache
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, self_attn_cache)
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds
)
# embed positions
position_embeddings = self.pos_embed(inputs_embeds, past_key_values_length, position_ids=cache_position)
hidden_states = inputs_embeds + position_embeddings
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
raise NotImplementedError("Currently, gradient checkpointing is not supported.")
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None, None, None)
if use_cache:
next_decoder_cache = layer_outputs[3 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
def _update_causal_mask(
self,
attention_mask: Optional[torch.Tensor],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
):
if self.config._attn_implementation == "flex_attention":
raise NotImplementedError
if self.config._attn_implementation == "flash_attention_2":
raise NotImplementedError
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = True if isinstance(past_key_values, StaticCache) else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=cache_position.device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
raise NotImplementedError
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
return encoder_attention_mask
class AVHubertDecoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
if config.learned_pos:
self.pos_embed = LearnedPositionalEmbedding(
num_embeddings=config.max_position_embeddings,
embedding_dim=config.hidden_size,
)
else:
self.pos_embed = SinusoidalPositionalEmbedding(config=config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList(
[
AVHubertDecoderLayerStableLayerNorm(config, layer_idx=layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
head_mask: torch.Tensor | None = None,
cross_attn_head_mask: torch.Tensor | None = None,
past_key_values: EncoderDecoderCache | None = None,
use_cache: bool | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
cache_position: torch.LongTensor | None = None,
):
input_shape = inputs_embeds.shape[:-1]
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache.from_legacy_cache()
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
batch_size, seq_length = inputs_embeds.size()[:-1]
if cache_position is None:
cache_position = torch.arange(
past_key_values_length,
past_key_values_length + seq_length,
device=inputs_embeds.device,
)
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
self_attn_cache = (
past_key_values.self_attention_cache
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, self_attn_cache)
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds
)
# embed positions
position_embeddings = self.pos_embed(inputs_embeds, past_key_values_length, position_ids=cache_position)
hidden_states = inputs_embeds + position_embeddings
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
raise NotImplementedError("Currently, gradient checkpointing is not supported.")
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None, None, None)
if use_cache:
next_decoder_cache = layer_outputs[3 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
def _update_causal_mask(
self,
attention_mask: Optional[torch.Tensor],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
):
if self.config._attn_implementation == "flex_attention":
raise NotImplementedError
if self.config._attn_implementation == "flash_attention_2":
raise NotImplementedError
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = True if isinstance(past_key_values, StaticCache) else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=cache_position.device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
raise NotImplementedError
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
return encoder_attention_mask