rubyburger3 / models /mmllama3.py
emplitude's picture
Upload folder using huggingface_hub
e7aa5c6 verified
Raw
History Blame Contribute Delete
5.03 kB
from typing import List
import warnings
import torch
from torch import nn, Tensor
from torchvision import transforms
from torchtune.models.llama3 import lora_llama3_8b, llama3_8b
from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear
from torchtune.modules import TransformerDecoder
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from imagebind.models import imagebind_model
from models.imagebind_wrapper import get_imagebind_v2, V2_PATH
from models.imagebind_wrapper import ImageBind
IMAGEBIND_DIM = 1024
CLIP_DIM = 768
class MMEmbedding(nn.Embedding):
def __init__(self, e, perception_tokens=1, use_clip=False):
super().__init__(
num_embeddings=e.num_embeddings,
embedding_dim=e.embedding_dim,
padding_idx=e.padding_idx,
max_norm=e.max_norm,
norm_type=e.norm_type,
scale_grad_by_freq=e.scale_grad_by_freq,
sparse=e.sparse,
)
self._perception_tokens = perception_tokens
self._context = []
self._use_clip = use_clip
dim_in = IMAGEBIND_DIM + (CLIP_DIM if use_clip else 0)
dim_out = e.embedding_dim * perception_tokens
self.proj_to_llama = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.LayerNorm(dim_out),
nn.Linear(dim_out, dim_out),
)
def set_context(self, context):
self._context = context
def forward(self, input: Tensor) -> Tensor:
r = super().forward(input)
# self._context is first indexed by batch idx
for b, context_dict in enumerate(self._context):
# then by sequence idx
for s, embed in context_dict.items():
# and then must be transformed from imagebind dim -> llama3 dim
if self._use_clip:
llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"], embed["clip_embed"]]))
else:
llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"]]))
r[b, s:s+self._perception_tokens] = llama_embed.view(self._perception_tokens, -1)
return r
class MMLinear(nn.Linear):
def __init__(self, o):
super().__init__(
in_features=o.in_features,
out_features=o.out_features,
bias=(o.bias != None)
)
self._context = []
dim_out = CLIP_DIM
dim_in = o.in_features
self.proj_from_llama = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.LayerNorm(dim_out),
nn.Linear(dim_out, dim_out),
)
def set_context(self, context):
self._context = context
def forward(self, input_bsd: Tensor) -> Tensor:
# self._context has the indexes of image llama tokens: process these with proj_from_llama
self._clip_projections = []
# # self._context is first indexed by batch idx
# for b, context_dict in enumerate(self._context):
# # then by sequence idx
# for s, embed in context_dict.items():
# # and then must be transformed from llama3 dim -> clip dim
# self._clip_projections.append((
# self.proj_from_llama(input_bsd[b, s]),
# (embed["clip_embed"] if "clip_embed" in embed else None) # terrible
# ))
r = super().forward(input_bsd)
return r
def lora_mmllama3_8b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
quantize_base: bool = False,
perception_tokens: int = 2,
use_clip: bool = False
) -> TransformerDecoder:
llama3 = lora_llama3_8b(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
lora_rank,
lora_alpha,
quantize_base,
)
llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
llama3.output = MMLinear(llama3.output)
return llama3
def mmllama3_8b(
perception_tokens: int = 2,
use_clip: bool = False
) -> TransformerDecoder:
llama3 = llama3_8b()
llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
llama3.output = MMLinear(llama3.output)
return llama3
def imagebind_huge(use_v2: bool=True):
if use_v2:
imagebind = ImageBind(v2=True)
else:
imagebind = imagebind_model.imagebind_huge(pretrained=True)
imagebind.transform_from_pil = transforms.Compose([
transforms.Resize(
224, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
])
return imagebind