| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .configuration_bert import FlexBertConfig |
| from .activation import get_act_fn |
| from .normalization import get_norm_layer |
| from .initialization import ModuleType, init_weights |
|
|
|
|
| class BertResidualGLU(nn.Module): |
| """Applies the FFN at the end of each Mosaic BERT layer. |
| |
| Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but |
| introduces Gated Linear Units. |
| |
| Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a |
| standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with |
| `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed |
| with the `config.intermediate_size=3072`. |
| However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased |
| parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`. |
| """ |
|
|
| def __init__( |
| self, |
| config, |
| ): |
| super().__init__() |
| self.config = config |
| self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False) |
| self.act = get_act_fn(config.hidden_act) |
| self.wo = nn.Linear(config.intermediate_size, config.hidden_size) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.layernorm = get_norm_layer(config) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """Compute new hidden states from current hidden states. |
| |
| Args: |
| hidden_states (torch.Tensor): The (unpadded) hidden states from |
| the attention layer [nnz, dim]. |
| """ |
| residual_connection = hidden_states |
| |
| hidden_states = self.gated_layers(hidden_states) |
| gated = hidden_states[:, : self.config.intermediate_size] |
| non_gated = hidden_states[:, self.config.intermediate_size :] |
| hidden_states = self.act(gated) * non_gated |
| hidden_states = self.dropout(hidden_states) |
| |
| hidden_states = self.wo(hidden_states) |
| |
| hidden_states = self.layernorm(hidden_states + residual_connection) |
| return hidden_states |
|
|
|
|
| class FlexBertMLPBase(nn.Module): |
| """A FlexBERT MLP base class for type hints.""" |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_id = layer_id |
|
|
| def _init_weights(self, reset_params: bool = False): |
| raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
| def reset_parameters(self): |
| self._init_weights(reset_params=True) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
|
|
| class FlexBertMLP(FlexBertMLPBase): |
| """Applies the MLP at the end of each FlexBERT layer. |
| |
| Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| self.Wi = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_in_bias) |
| self.act = get_act_fn(config.hidden_act) |
| self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
| self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wi, |
| layer_dim=self.config.hidden_size, |
| layer_id=None, |
| type_of_module=ModuleType.in_module, |
| ) |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.intermediate_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """Compute new hidden states from current hidden states. |
| |
| Args: |
| hidden_states (torch.Tensor): The (unpadded) hidden states from |
| the attention layer [nnz, dim]. |
| """ |
| return self.Wo(self.drop(self.act(self.Wi(hidden_states)))) |
|
|
|
|
| class FlexBertGLU(FlexBertMLPBase): |
| """Applies the GLU at the end of each FlexBERT layer. |
| |
| Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_in_bias) |
| self.act = get_act_fn(config.hidden_act) |
| self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
| self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wi, |
| layer_dim=self.config.hidden_size, |
| layer_id=None, |
| type_of_module=ModuleType.in_module, |
| ) |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.intermediate_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input, gate = self.Wi(hidden_states).chunk(2, dim=-1) |
| return self.Wo(self.drop(self.act(input) * gate)) |
|
|
|
|
| class FlexBertParallelGLU(FlexBertMLPBase): |
| """Applies the GLU at the end of each FlexBERT layer using intermediate_ff computed in parallel of the attention. |
| |
| Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| self.act = get_act_fn(config.hidden_act) |
| self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
| self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.intermediate_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor: |
| input, gate = intermediate_ff.chunk(2, dim=-1) |
| return self.Wo(self.drop(self.act(input) * gate)) |
|
|
|
|
| MLP2CLS = { |
| "mlp": FlexBertMLP, |
| "glu": FlexBertGLU, |
| "parallel_glu": FlexBertParallelGLU, |
| } |
|
|
|
|
| def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertMLPBase: |
| try: |
| mlp_layer = ( |
| config.initial_mlp_layer |
| if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None |
| else config.mlp_layer |
| ) |
| return MLP2CLS[mlp_layer](config, layer_id=layer_id) |
| except KeyError as e: |
| if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None: |
| raise ValueError( |
| f"Invalid MLP layer type: {config.initial_mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}" |
| ) |
| else: |
| raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}") |
|
|