AlignMark / base.py
2026anonymous
Initial commit
b9deace
Raw
History Blame Contribute Delete
2.92 kB
import os
import torch
from types import SimpleNamespace
from torch.nn.parallel import DistributedDataParallel as DDP
from models import WatermarkModel, AudioFusionModel
from speechtokenizer import SpeechTokenizer
class ModelWrapper:
def __init__(self, model):
self.model = model
def __getattr__(self, name):
if hasattr(self.model, name):
return getattr(self.model, name)
elif hasattr(self.model, "module") and hasattr(self.model.module, name):
return getattr(self.model.module, name)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
class WatermarkBase:
def __init__(self, cfg: SimpleNamespace):
self.cfg = cfg
self.local_rank = getattr(cfg, "local_rank", None)
self.device = getattr(cfg, "device", "cpu")
self.sample_rate = getattr(cfg, "sample_rate", 16000)
self.nbits = getattr(cfg, "nbits", 16)
# Load SpeechTokenizer (used as VAE here)
config_path = os.path.join(
"speechtokenizer", "pretrained_model", "speechtokenizer_hubert_avg_config.json"
)
ckpt_path = os.path.join("speechtokenizer", "pretrained_model", "SpeechTokenizer.pt")
self.vae = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path).to(self.device)
for _, param in self.vae.named_parameters():
param.requires_grad = False
# Watermark model and fusion model
self.model = WatermarkModel(cfg).to(self.device)
self.fusion_model = AudioFusionModel(
n_fft=256, hop_length=64, win_length=256, hidden_dim=64, nbits=self.nbits
).to(self.device)
if self.local_rank is not None:
self.model = DDP(self.model, device_ids=[self.local_rank], find_unused_parameters=False)
self.fusion_model = DDP(self.fusion_model, device_ids=[self.local_rank], find_unused_parameters=False)
self.model = ModelWrapper(self.model)
self.fusion_model = ModelWrapper(self.fusion_model)
def load_model(self, checkpoint_path: str):
if not os.path.exists(checkpoint_path):
print(f"Checkpoint not found: {checkpoint_path}")
return
print(f"Loading checkpoint from {checkpoint_path} ...")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
# load watermark model
new_wm_dict = {k.replace("module.", ""): v for k, v in checkpoint["model_state_dict"].items()}
self.model.load_state_dict(new_wm_dict, strict=True)
# load fusion model if present
new_fusion_dict = {k.replace("module.", ""): v for k, v in checkpoint["fusion_state_dict"].items()}
self.fusion_model.load_state_dict(new_fusion_dict, strict=True)
print("Model weights loaded.")