Source code for pretrained.vocoder.waveglow

"""Defines a pre-trained WaveGlow vocoder model.

This vocoder can be used with TTS models that output mel spectrograms to
synthesize audio.

.. code-block:: python

    from pretrained.vocoder import pretrained_vocoder

    vocoder = pretrained_vocoder("waveglow")
"""

from dataclasses import dataclass
from typing import cast

import torch
import torch.nn.functional as F
from ml.core.config import conf_field
from ml.models.lora import maybe_lora
from ml.utils.checkpoint import ensure_downloaded, get_state_dict_prefix
from ml.utils.timer import Timer
from torch import Tensor, nn

WAVEGLOW_CKPT_FP16 = "https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp/versions/19.09.0/files/nvidia_waveglowpyt_fp16_20190427"
WAVEGLOW_CKPT_FP32 = "https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_fp32/versions/19.09.0/files/nvidia_waveglowpyt_fp32_20190427"


@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a: Tensor, input_b: Tensor, n_channels: int) -> Tensor:
    n_channels_int = n_channels
    in_act = input_a + input_b
    t_act = torch.tanh(in_act[:, :n_channels_int, :])
    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
    acts = t_act * s_act
    return acts


[docs]class WaveGlowLoss(nn.Module): def __init__(self, sigma: float = 1.0) -> None: super().__init__() self.sigma = sigma
[docs] def forward(self, model_output: tuple[Tensor, list[Tensor], list[Tensor]]) -> Tensor: z, log_s_list, log_det_w_list = model_output for i, log_s in enumerate(log_s_list): if i == 0: log_s_total = torch.sum(log_s) log_det_w_total = log_det_w_list[i] else: log_s_total = log_s_total + torch.sum(log_s) log_det_w_total += log_det_w_list[i] loss = torch.sum(z * z) / (2 * self.sigma * self.sigma) - log_s_total - log_det_w_total return loss / (z.size(0) * z.size(1) * z.size(2))
[docs]class Invertible1x1Conv(nn.Module): weight_inv: Tensor def __init__(self, c: int) -> None: super().__init__() self.conv = nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, bias=False) # Sample a random orthonormal matrix to initialize weights weight, _ = torch.linalg.qr(torch.randn(c, c), "reduced") # Ensure determinant is 1.0 not -1.0 if torch.det(weight) < 0: weight[:, 0] = -1 * weight[:, 0] weight = weight.view(c, c, 1) self.conv.weight.data = weight self.register_buffer("weight_inv", torch.zeros_like(weight), persistent=False)
[docs] def forward(self, z: Tensor) -> tuple[Tensor, Tensor]: batch_size, _, n_of_groups = z.size() weight = self.conv.weight.squeeze() # Forward computation. log_det_w = batch_size * n_of_groups * torch.logdet(weight) z = self.conv(z) return z, log_det_w
[docs] def infer(self, z: Tensor) -> Tensor: self._invert() return F.conv1d(z, self.weight_inv, bias=None, stride=1, padding=0)
def _invert(self) -> None: weight = self.conv.weight.squeeze() self.weight_inv.copy_(weight.float().inverse().unsqueeze(-1).to(self.weight_inv))
[docs]@dataclass class WaveNetConfig: n_layers: int = conf_field(8, help="Number of layers") kernel_size: int = conf_field(3, help="Kernel size") n_channels: int = conf_field(512, help="Number of channels")
[docs]class WaveNet(nn.Module): def __init__( self, n_in_channels: int, n_mel_channels: int, config: WaveNetConfig, lora_rank: int | None = None, ) -> None: super().__init__() assert config.kernel_size % 2 == 1 assert config.n_channels % 2 == 0 self.n_layers = config.n_layers self.n_channels = config.n_channels self.in_layers = nn.ModuleList() self.res_skip_layers = nn.ModuleList() self.cond_layers = nn.ModuleList() start = nn.Conv1d(n_in_channels, config.n_channels, 1) start = nn.utils.weight_norm(start, name="weight") self.start = maybe_lora(start, lora_rank) # Initializing last layer to 0 makes the affine coupling layers # do nothing at first. This helps with training stability end = nn.Conv1d(config.n_channels, 2 * n_in_channels, 1) end.weight.data.zero_() if end.bias is not None: end.bias.data.zero_() self.end = maybe_lora(end, lora_rank) for i in range(config.n_layers): dilation = 2**i padding = int((config.kernel_size * dilation - dilation) / 2) in_layer = nn.Conv1d( config.n_channels, 2 * config.n_channels, config.kernel_size, dilation=dilation, padding=padding ) in_layer = nn.utils.weight_norm(in_layer, name="weight") in_layer = maybe_lora(in_layer, lora_rank) self.in_layers.append(in_layer) cond_layer = nn.Conv1d(n_mel_channels, 2 * config.n_channels, 1) cond_layer = nn.utils.weight_norm(cond_layer, name="weight") cond_layer = maybe_lora(cond_layer, lora_rank) self.cond_layers.append(cond_layer) # last one is not necessary if i < config.n_layers - 1: res_skip_channels = 2 * config.n_channels else: res_skip_channels = config.n_channels res_skip_layer = nn.Conv1d(config.n_channels, res_skip_channels, 1) res_skip_layer = nn.utils.weight_norm(res_skip_layer, name="weight") res_skip_layer = maybe_lora(res_skip_layer, lora_rank) self.res_skip_layers.append(res_skip_layer)
[docs] def forward(self, audio: Tensor, spect: Tensor) -> Tensor: audio = self.start(audio) output = 0 layers = zip(self.in_layers, self.cond_layers, self.res_skip_layers) for i, (in_layer, cond_layer, res_skip_layer) in enumerate(layers): acts = fused_add_tanh_sigmoid_multiply(in_layer(audio), cond_layer(spect), self.n_channels) res_skip_acts = res_skip_layer(acts) if i < self.n_layers - 1: audio = res_skip_acts[:, : self.n_channels, :] + audio skip_acts = res_skip_acts[:, self.n_channels :, :] else: skip_acts = res_skip_acts output += skip_acts return self.end(output)
[docs]@dataclass class WaveGlowConfig: n_mel_channels: int = conf_field(80, help="Number of mel channels") n_flows: int = conf_field(12, help="Number of flows") n_group: int = conf_field(8, help="Number of groups in a flow") n_early_every: int = conf_field(4, help="Number of layers between early layers") n_early_size: int = conf_field(2, help="Number of channels in early layers") sampling_rate: int = conf_field(22050, help="Sampling rate of model.") wavenet: WaveNetConfig = conf_field(WaveNetConfig(), help="WaveNet configuration") lora_rank: int | None = conf_field(None, help="LoRA rank")
[docs]class WaveGlow(nn.Module): def __init__(self, config: WaveGlowConfig) -> None: super().__init__() self.sampling_rate = config.sampling_rate self.upsample = nn.ConvTranspose1d(config.n_mel_channels, config.n_mel_channels, 1024, stride=256) assert config.n_group % 2 == 0 self.n_flows = config.n_flows self.n_group = config.n_group self.n_early_every = config.n_early_every self.n_early_size = config.n_early_size self.WN = cast(list[WaveNet], nn.ModuleList()) self.convinv = nn.ModuleList() n_half = config.n_group // 2 # Set up layers with the right sizes based on how many dimensions # have been output already n_remaining_channels = config.n_group for k in range(config.n_flows): if k % self.n_early_every == 0 and k > 0: n_half = n_half - int(self.n_early_size / 2) n_remaining_channels = n_remaining_channels - self.n_early_size self.convinv.append(Invertible1x1Conv(n_remaining_channels)) self.WN.append(WaveNet(n_half, config.n_mel_channels * config.n_group, config.wavenet, config.lora_rank)) self.n_remaining_channels = n_remaining_channels
[docs] def forward(self, forward_input: tuple[Tensor, Tensor]) -> tuple[Tensor, list[Tensor], list[Tensor]]: spect, audio = forward_input # Upsample spectrogram to size of audio spect = self.upsample(spect) assert spect.size(2) >= audio.size(1) if spect.size(2) > audio.size(1): spect = spect[:, :, : audio.size(1)] spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) spect = spect.contiguous().view(spect.size(0), spect.size(1), -1) spect = spect.permute(0, 2, 1) audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) output_audio = [] log_s_list = [] log_det_w_list = [] for k in range(self.n_flows): if k % self.n_early_every == 0 and k > 0: output_audio.append(audio[:, : self.n_early_size, :]) audio = audio[:, self.n_early_size :, :] audio, log_det_w = self.convinv[k](audio) log_det_w_list.append(log_det_w) n_half = int(audio.size(1) // 2) audio_0 = audio[:, :n_half, :] audio_1 = audio[:, n_half:, :] output = self.WN[k](audio_0, spect) log_s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = torch.exp(log_s) * audio_1 + b log_s_list.append(log_s) audio = torch.cat([audio_0, audio_1], 1) output_audio.append(audio) return torch.cat(output_audio, 1), log_s_list, log_det_w_list
[docs] def infer(self, spect: Tensor, sigma: float = 1.0) -> Tensor: spect = self.upsample(spect) time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] spect = spect[:, :, :-time_cutoff] spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) spect = spect.contiguous().view(spect.size(0), spect.size(1), -1) spect = spect.permute(0, 2, 1) audio = spect.new_empty(spect.size(0), self.n_remaining_channels, spect.size(2)).normal_(std=sigma) for k in reversed(range(self.n_flows)): n_half = int(audio.size(1) / 2) audio_0 = audio[:, :n_half, :] audio_1 = audio[:, n_half:, :] output = self.WN[k](audio_0, spect) s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = (audio_1 - b) / torch.exp(s) audio = torch.cat([audio_0, audio_1], 1) audio = self.convinv[k].infer(audio) if k % self.n_early_every == 0 and k > 0: z = torch.randn(spect.size(0), self.n_early_size, spect.size(2), device=spect.device).to(spect.dtype) audio = torch.cat((sigma * z, audio), 1) audio = audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1).data return audio
[docs] def remove_weight_norm(self) -> None: """Removes weight normalization module from all of the WaveGlow modules.""" def remove(conv_list: nn.ModuleList) -> nn.ModuleList: new_conv_list = nn.ModuleList() for old_conv in conv_list: old_conv = nn.utils.remove_weight_norm(old_conv) new_conv_list.append(old_conv) return new_conv_list for wave_net in self.WN: wave_net.start = nn.utils.remove_weight_norm(wave_net.start) wave_net.in_layers = remove(wave_net.in_layers) wave_net.cond_layers = remove(wave_net.cond_layers) wave_net.res_skip_layers = remove(wave_net.res_skip_layers)
[docs]def pretrained_waveglow( *, fp16: bool = True, pretrained: bool = True, lora_rank: int | None = None, ) -> WaveGlow: """Loads the pretrained WaveGlow model. Reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/waveglow/entrypoints.py Args: fp16: When True, returns a model with half precision float16 weights pretrained: When True, returns a model pre-trained on LJ Speech dataset lora_rank: The LoRA rank to use, if LoRA is desired. Returns: The WaveGlow model """ config = WaveGlowConfig(lora_rank=lora_rank) model = WaveGlow(config) if pretrained: weights_name = f"weights_fp{16 if fp16 else 32}.pth" with Timer("downloading checkpoint"): fpath = ensure_downloaded(WAVEGLOW_CKPT_FP16 if fp16 else WAVEGLOW_CKPT_FP32, "waveglow", weights_name) ckpt = torch.load(fpath, map_location="cpu") model.load_state_dict(get_state_dict_prefix(ckpt["state_dict"], "module.")) return model