"""Defines a pre-trained HiFi-GAN vocoder model.
This vocoder can be used with TTS models that output mel spectrograms to
synthesize audio.
.. code-block:: python
from pretrained.vocoder import pretrained_vocoder
vocoder = pretrained_vocoder("hifigan")
"""
import argparse
import logging
import math
from typing import Literal, cast, get_args
import numpy as np
import safetensors.torch as st
import torch
import torch.nn.functional as F
import torchaudio
from ml.models.modules import StreamingConv1d, StreamingConvTranspose1d, streaming_add
from ml.utils.checkpoint import ensure_downloaded
from ml.utils.device.auto import detect_device
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer
from torch import Tensor, nn
from torch.nn.utils import remove_weight_norm, weight_norm
logger = logging.getLogger(__name__)
PretrainedHiFiGANType = Literal["16000hz", "22050hz"]
[docs]def cast_pretrained_hifigan_type(s: str) -> PretrainedHiFiGANType:
if s not in get_args(PretrainedHiFiGANType):
raise KeyError(f"Invalid HiFi-GAN type: {s} Expected one of: {get_args(PretrainedHiFiGANType)}")
return cast(PretrainedHiFiGANType, s)
get = lambda x, i: None if x is None else x[i] # noqa: E731
[docs]def init_hifigan_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, StreamingConv1d, StreamingConvTranspose1d)):
m.weight.data.normal_(mean, std)
StreamingConvState = tuple[Tensor, int]
StreamingAddState = tuple[Tensor, Tensor]
ResBlockState = list[tuple[StreamingConvState, StreamingConvState, StreamingAddState]]
HiFiGANState = tuple[
StreamingConvState,
list[StreamingConvState],
list[list[ResBlockState]],
list[list[StreamingAddState]],
StreamingConvState,
]
[docs]class ResBlock(nn.Module):
__constants__ = ["lrelu_slope"]
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: tuple[int, int, int] = (1, 3, 5),
lrelu_slope: float = 0.1,
) -> None:
super().__init__()
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return (kernel_size * dilation - dilation) // 2
self.convs1 = nn.ModuleList(
[
weight_norm(
StreamingConv1d(
channels,
channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
StreamingConv1d(
channels,
channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
StreamingConv1d(
channels,
channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_hifigan_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
StreamingConv1d(
channels,
channels,
kernel_size=kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
StreamingConv1d(
channels,
channels,
kernel_size=kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
StreamingConv1d(
channels,
channels,
kernel_size=kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_hifigan_weights)
self.lrelu_slope = lrelu_slope
[docs] def forward(self, x: Tensor, state: ResBlockState | None) -> tuple[Tensor, ResBlockState]:
state_out: ResBlockState = []
for i, (c1, c2) in enumerate(zip(self.convs1, self.convs2)):
state_in_i = get(state, i)
xt = F.leaky_relu(x, self.lrelu_slope)
xt, s1 = c1(xt, get(state_in_i, 0))
xt = F.leaky_relu(xt, self.lrelu_slope)
xt, s2 = c2(xt, get(state_in_i, 1))
x, sa = streaming_add(xt, x, get(state_in_i, 2))
state_out.append((s1, s2, sa))
return x, state_out
[docs] def remove_weight_norm(self) -> None:
for layer in self.convs1:
remove_weight_norm(layer)
for layer in self.convs2:
remove_weight_norm(layer)
[docs]class HiFiGAN(nn.Module):
"""Defines a HiFi-GAN model.
Parameters:
sampling_rate: The sampling rate of the model.
model_in_dim: The input dimension of the model.
upsample_kernel_sizes: The kernel sizes of the upsampling layers.
upsample_rates: The upsample rates of each layer.
resblock_kernel_sizes: The kernel sizes of the ResBlocks.
resblock_dilation_sizes: The dilation sizes of the ResBlocks.
upsample_initial_channel: The initial channel of the upsampling layers.
lrelu_slope: The slope of the leaky ReLU.
"""
def __init__(
self,
sampling_rate: int,
model_in_dim: int,
upsample_kernel_sizes: list[int],
upsample_rates: list[int],
resblock_kernel_sizes: list[int] = [3, 7, 11],
resblock_dilation_sizes: list[tuple[int, int, int]] = [(1, 3, 5), (1, 3, 5), (1, 3, 5)],
upsample_initial_channel: int = 512,
lrelu_slope: float = 0.1,
) -> None:
super().__init__()
self.model_in_dim = model_in_dim
self.sampling_rate = sampling_rate
self.hop_size = math.prod(upsample_rates)
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.lrelu_slope = lrelu_slope
conv_pre = StreamingConv1d(model_in_dim, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
self.conv_pre = weight_norm(conv_pre)
assert len(upsample_rates) == len(upsample_kernel_sizes)
self.ups = cast(list[StreamingConvTranspose1d], nn.ModuleList())
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
module = StreamingConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
kernel_size=k,
stride=u,
# padding=(k - u) // 2,
)
self.ups.append(weight_norm(module))
self.resblocks = cast(list[ResBlock], nn.ModuleList())
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(ResBlock(ch, k, d, lrelu_slope))
self.conv_post = weight_norm(StreamingConv1d(ch, 1, 7, 1, padding=3))
cast(nn.ModuleList, self.ups).apply(init_hifigan_weights)
self.conv_post.apply(init_hifigan_weights)
[docs] def forward(self, x: Tensor, state: HiFiGANState | None = None) -> tuple[Tensor, HiFiGANState]:
x, pre_s = self.conv_pre(x, get(state, 0))
up_s_in = get(state, 1)
up_s_out: list[StreamingConvState] = []
down_s_in = get(state, 2)
down_s_out: list[list[ResBlockState]] = []
sa_in = get(state, 3)
sa_out: list[list[StreamingAddState]] = []
for i, up in enumerate(self.ups):
x = F.leaky_relu(x, self.lrelu_slope)
x, up_s = up(x, get(up_s_in, i))
up_s_out.append(up_s)
xs = None
down_s_in_i = get(down_s_in, i)
down_s_out_i: list[ResBlockState] = []
sa_in_i = get(sa_in, i)
sa_out_i: list[StreamingAddState] = []
for j in range(self.num_kernels):
down_s_in_ij = get(down_s_in_i, j)
sa_in_ij = get(sa_in_i, j - 1)
xs_i, down_s_out_ij = self.resblocks[i * self.num_kernels + j](x, down_s_in_ij)
if xs is None:
xs = xs_i
else:
xs, sa_i = streaming_add(xs, xs_i, sa_in_ij)
sa_out_i.append(sa_i)
down_s_out_i.append(down_s_out_ij)
down_s_out.append(down_s_out_i)
sa_out.append(sa_out_i)
assert xs is not None
x = xs / self.num_kernels
x = F.leaky_relu(x)
x, post_s = self.conv_post(x, get(state, 4))
x = torch.tanh(x)
return x, (pre_s, up_s_out, down_s_out, sa_out, post_s)
[docs] def infer(self, x: Tensor) -> Tensor:
y, _ = self(x)
return y
[docs] def remove_weight_norm(self) -> None:
for layer in self.ups:
remove_weight_norm(layer)
for layer in self.resblocks:
layer.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
[docs] def audio_to_mels(self) -> "AudioToHifiGanMels":
return AudioToHifiGanMels(
sampling_rate=self.sampling_rate,
num_mels=self.model_in_dim,
n_fft=1024,
win_size=1024,
hop_size=self.hop_size,
fmin=0,
fmax=8000,
)
def _load_hifigan_weights(
key: PretrainedHiFiGANType,
model: HiFiGAN,
url: str,
sha256: str,
load_weights: bool = True,
device: torch.device | None = None,
) -> HiFiGAN:
if not load_weights:
return model
with Timer("downloading checkpoint"):
model_path = ensure_downloaded(url, "hifigan", f"{key}.bin", sha256=sha256)
with Timer("loading checkpoint", spinner=True):
if device is None:
device = torch.device("cpu")
ckpt = st.load_file(model_path)
model.to(device)
model.load_state_dict(ckpt)
return model
[docs]def pretrained_hifigan(
key: str | PretrainedHiFiGANType,
*,
pretrained: bool = True,
keep_weight_norm: bool = False,
) -> HiFiGAN:
"""Loads the pretrained HiFi-GAN model.
Args:
key: The key of the pretrained model.
pretrained: Whether to load the pretrained weights.
keep_weight_norm: Whether to keep the weight norm.
Returns:
The pretrained HiFi-GAN model.
"""
key = cast_pretrained_hifigan_type(key)
with Timer("initializing model", spinner=True):
match key:
case "16000hz":
model = HiFiGAN(
sampling_rate=16000,
model_in_dim=128,
upsample_kernel_sizes=[20, 8, 4, 4],
upsample_rates=[10, 4, 2, 2],
)
url = "https://huggingface.co/codekansas/hifigan/resolve/main/hifigan_16000hz.bin"
sha256 = "4693bd59cb1653635d902c8a34064c7628d9472637c71a71898911c59a06aa51"
case "22050hz":
model = HiFiGAN(
sampling_rate=22050,
model_in_dim=80,
upsample_kernel_sizes=[16, 16, 4, 4],
upsample_rates=[8, 8, 2, 2],
)
url = "https://huggingface.co/codekansas/hifigan/resolve/main/hifigan_22050hz.bin"
sha256 = "79cbede45d1be8e5700f0326a3c796c311ee7b04cf1fd8994a35418eecddf941"
case _:
raise ValueError(f"Invalid HiFi-GAN type: {key}")
model = _load_hifigan_weights(key, model, url, sha256, pretrained)
if not keep_weight_norm:
model.remove_weight_norm()
return model
[docs]class AudioToHifiGanMels(nn.Module):
"""Defines a module to convert from a waveform to the mels used by HiFi-GAN.
The default parameters should be kept the same for pre-trained models.
Parameters:
sampling_rate: The sampling rate of the audio.
num_mels: The number of mel bins.
n_fft: The number of FFT bins.
win_size: The window size.
fmin: The minimum frequency.
fmax: The maximum frequency.
"""
__constants__ = ["sampling_rate", "num_mels", "n_fft", "win_size", "hop_size", "fmin", "fmax"]
def __init__(
self,
sampling_rate: int,
num_mels: int,
n_fft: int,
win_size: int,
hop_size: int,
fmin: int = 0,
fmax: int = 8000,
) -> None:
super().__init__()
self.sampling_rate = sampling_rate
self.num_mels = num_mels
self.n_fft = n_fft
self.win_size = win_size
self.hop_size = hop_size
self.fmin = fmin
self.fmax = fmax
# try:
# from librosa.filters import mel as librosa_mel_fn
# except ImportError:
# raise ImportError("Please install librosa to use AudioToHifiGanMels")
# mel_librosa = librosa_mel_fn(
# sr=sampling_rate,
# n_fft=n_fft,
# n_mels=num_mels,
# fmin=fmin,
# fmax=fmax,
# )
# mel = torch.from_numpy(mel_librosa).float().T
mel = torchaudio.functional.melscale_fbanks(
n_freqs=n_fft // 2 + 1,
f_min=fmin,
f_max=fmax,
n_mels=num_mels,
sample_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
self.register_buffer("mel_basis", mel)
self.register_buffer("hann_window", torch.hann_window(win_size))
def _dynamic_range_compression(self, x: np.ndarray, c: float = 1.0, clip_val: float = 1e-5) -> np.ndarray:
return np.log(np.clip(x, a_min=clip_val, a_max=None) * c)
def _dynamic_range_decompression(self, x: np.ndarray, c: float = 1.0) -> np.ndarray:
return np.exp(x) / c
def _dynamic_range_compression_torch(self, x: Tensor, c: float = 1.0, clip_val: float = 1e-5) -> Tensor:
return torch.log(torch.clamp(x, min=clip_val) * c)
def _dynamic_range_decompression_torch(self, x: Tensor, c: float = 1.0) -> Tensor:
return torch.exp(x) / c
def _spectral_normalize_torch(self, magnitudes: Tensor) -> Tensor:
output = self._dynamic_range_compression_torch(magnitudes)
return output
def _spectral_de_normalize_torch(self, magnitudes: Tensor) -> Tensor:
output = self._dynamic_range_decompression_torch(magnitudes)
return output
mel_basis: Tensor
hann_window: Tensor
[docs] def wav_to_mels(self, y: Tensor, center: bool = False) -> Tensor:
ymin, ymax = torch.min(y), torch.max(y)
if ymin < -1.0:
logger.warning("min value is %.2g", ymin)
if ymax > 1.0:
logger.warning("max value is %.2g", ymax)
pad = int((self.n_fft - self.hop_size) / 2)
y = torch.nn.functional.pad(y.unsqueeze(1), (pad, pad), mode="reflect")
y = y.squeeze(1)
spec = torch.stft(
y,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
spec = torch.einsum("bct,cm->bmt", spec, self.mel_basis)
spec = self._spectral_normalize_torch(spec)
return spec
[docs] def forward(self, y: Tensor, center: bool = False) -> Tensor:
return self.wav_to_mels(y, center)
[docs]def test_mel_to_audio_adhoc() -> None:
configure_logging()
parser = argparse.ArgumentParser(description="Runs adhoc test of mel to audio conversion")
parser.add_argument("key", choices=get_args(PretrainedHiFiGANType), help="The key of the pretrained model")
parser.add_argument("input_file", type=str, help="Path to input audio file")
parser.add_argument("output_file", type=str, help="Path to output audio file")
args = parser.parse_args()
dev = detect_device()
# Loads the HiFi-GAN model.
model = pretrained_hifigan(args.key, pretrained=True)
dev.module_to(model)
# Loads the audio file.
audio, sr = torchaudio.load(args.input_file)
audio = audio[:1]
audio = audio[:, : sr * 10]
if sr != model.sampling_rate:
audio = torchaudio.functional.resample(audio, sr, model.sampling_rate)
# Note: This normalizes the audio to the range [-1, 1], which may increase
# the volume of the audio if it is quiet.
audio = audio / audio.abs().max() * 0.999
audio = dev.tensor_to(audio)
# Converts the audio to mels.
audio_to_mels = model.audio_to_mels()
dev.module_to(audio_to_mels)
mels = audio_to_mels.wav_to_mels(audio)
# Converts the mels back to audio.
audio = model.infer(mels).squeeze(0)
# Saves the audio.
torchaudio.save(args.output_file, audio.cpu(), model.sampling_rate)
logger.info("Saved %s", args.output_file)
if __name__ == "__main__":
# python -m pretrained.vocoder.hifigan
test_mel_to_audio_adhoc()