File size: 2,922 Bytes
b9deace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import torch
from types import SimpleNamespace
from torch.nn.parallel import DistributedDataParallel as DDP

from models import WatermarkModel, AudioFusionModel
from speechtokenizer import SpeechTokenizer


class ModelWrapper:
    def __init__(self, model):
        self.model = model

    def __getattr__(self, name):
        if hasattr(self.model, name):
            return getattr(self.model, name)
        elif hasattr(self.model, "module") and hasattr(self.model.module, name):
            return getattr(self.model.module, name)
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)


class WatermarkBase:
    def __init__(self, cfg: SimpleNamespace):
        self.cfg = cfg
        self.local_rank = getattr(cfg, "local_rank", None)
        self.device = getattr(cfg, "device", "cpu")
        self.sample_rate = getattr(cfg, "sample_rate", 16000)
        self.nbits = getattr(cfg, "nbits", 16)

        # Load SpeechTokenizer (used as VAE here)
        config_path = os.path.join(
            "speechtokenizer", "pretrained_model", "speechtokenizer_hubert_avg_config.json"
        )
        ckpt_path = os.path.join("speechtokenizer", "pretrained_model", "SpeechTokenizer.pt")
        self.vae = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path).to(self.device)
        for _, param in self.vae.named_parameters():
            param.requires_grad = False

        # Watermark model and fusion model
        self.model = WatermarkModel(cfg).to(self.device)
        self.fusion_model = AudioFusionModel(
            n_fft=256, hop_length=64, win_length=256, hidden_dim=64, nbits=self.nbits
        ).to(self.device)

        if self.local_rank is not None:
            self.model = DDP(self.model, device_ids=[self.local_rank], find_unused_parameters=False)
            self.fusion_model = DDP(self.fusion_model, device_ids=[self.local_rank], find_unused_parameters=False)

        self.model = ModelWrapper(self.model)
        self.fusion_model = ModelWrapper(self.fusion_model)

    def load_model(self, checkpoint_path: str):
        if not os.path.exists(checkpoint_path):
            print(f"Checkpoint not found: {checkpoint_path}")
            return

        print(f"Loading checkpoint from {checkpoint_path} ...")
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)

        # load watermark model
        new_wm_dict = {k.replace("module.", ""): v for k, v in checkpoint["model_state_dict"].items()}
        self.model.load_state_dict(new_wm_dict, strict=True)

        # load fusion model if present
        new_fusion_dict = {k.replace("module.", ""): v for k, v in checkpoint["fusion_state_dict"].items()}
        self.fusion_model.load_state_dict(new_fusion_dict, strict=True)

        print("Model weights loaded.")