| import argparse |
|
|
| import hydra |
| import soundfile |
| import torch |
| from omegaconf import OmegaConf |
|
|
|
|
| class SpecScaler(torch.nn.Module): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.log(x.clamp_(1e-9, 1e9)) |
|
|
|
|
| def _parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Run inference using GigaAM checkpoint" |
| ) |
| parser.add_argument("--encoder_config", help="Path to GigaAM config file (.yaml)") |
| parser.add_argument( |
| "--model_weights", help="Path to GigaAM checkpoint file (.ckpt)" |
| ) |
| parser.add_argument("--audio_path", help="Path to audio signal") |
| parser.add_argument("--device", help="Device: cpu / cuda") |
| return parser.parse_args() |
|
|
|
|
| def main(encoder_config: str, model_weights: str, device: str, audio_path: str): |
| conf = OmegaConf.load(encoder_config) |
|
|
| encoder = hydra.utils.instantiate(conf.encoder) |
| ckpt = torch.load(model_weights, map_location="cpu") |
| encoder.load_state_dict(ckpt, strict=True) |
| encoder.to(device) |
|
|
| feature_extractor = hydra.utils.instantiate(conf.feature_extractor) |
|
|
| audio_signal, _ = soundfile.read(audio_path, dtype="float32") |
| features = feature_extractor(torch.tensor(audio_signal).float()) |
| features = features.to(device) |
|
|
| encoded, _ = encoder.forward( |
| audio_signal=features.unsqueeze(0), |
| length=torch.tensor([features.shape[-1]]).to(device), |
| ) |
| print(f"encoded signal shape: {encoded.shape}") |
|
|
|
|
| if __name__ == "__main__": |
| args = _parse_args() |
| main( |
| encoder_config=args.encoder_config, |
| model_weights=args.model_weights, |
| device=args.device, |
| audio_path=args.audio_path, |
| ) |
|
|