# mypy: disable-error-code="import-not-found,import-untyped"
"""Defines a pre-trained Tacotron2 model.
This combines a Tacotron2 model with a HiFiGAN vocoder to produce an
end-to-end TTS model, adapted to be fine-tunable.
.. highlight:: python
.. code-block:: python
from pretrained.tacotron2 import pretrained_tacotron2_tts
tts = pretrained_tacotron2_tts()
audio, states = tts.generate("Hello, world!")
write_audio([audio])
You can also interact with this model directly through the command line:
.. highlight:: python
.. code-block:: python
python -m pretrained.tacotron2 'Hello, world!'
The two parts of the model can be trained separately, including using LoRA
fine-tuning.
Using this model requires the following additional dependencies:
- ``inflect``
- ``ftfy``
Additionally, to generate STFTs for training the model, you will need
to install ``librosa``. If you want to play audio for the demo, you should
also install ``sounddevice``.
"""
import argparse
import functools
import html
import logging
import re
from dataclasses import dataclass
from math import sqrt
from numbers import Number
from pathlib import Path
from typing import Callable, Iterable, NamedTuple, cast
import numpy as np
import torch
import torch.nn.functional as F
from ml.core.config import conf_field
from ml.models.base import BaseModel, BaseModelConfig
from ml.models.lora import freeze_non_lora_, maybe_lora, reset_lora_weights_
from ml.utils.audio import write_audio
from ml.utils.checkpoint import ensure_downloaded
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from ml.utils.large_models import init_empty_weights, meta_to_empty_func
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from pretrained.vocoder import Vocoder, VocoderType, pretrained_vocoder
logger = logging.getLogger(__name__)
TACOTRON_CKPT_URL = "https://drive.google.com/open?id=1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA"
[docs]class Normalizer:
def __init__(self) -> None:
super().__init__()
try:
import inflect
except ImportError:
raise ImportError("Number normalization requires the inflect package; pip install inflect")
self.inflect_engine = inflect.engine()
self.comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
self.decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
self.pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
self.dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
self.ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
self.number_re = re.compile(r"[0-9]+")
self.abbr_re = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
self.whitespace_re = re.compile(r"\s+")
def _remove_commas(self, m: re.Match) -> str:
return m.group(1).replace(",", "")
def _expand_decimal_point(self, m: re.Match) -> str:
return m.group(1).replace(".", " point ")
def _expand_dollars(self, m: re.Match) -> str:
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(self, m: re.Match) -> str:
return cast(str, self.inflect_engine.number_to_words(m.group(0)))
def _expand_number(self, m: re.Match) -> str:
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + cast(str, self.inflect_engine.number_to_words(cast(Number, num % 100)))
elif num % 100 == 0:
return cast(str, self.inflect_engine.number_to_words(cast(Number, num // 100))) + " hundred"
else:
out = cast(str, self.inflect_engine.number_to_words(cast(Number, num), andword="", zero="oh", group=2))
return out.replace(", ", " ")
return cast(str, self.inflect_engine.number_to_words(cast(Number, num), andword=""))
def __call__(self, text: str) -> str:
text = re.sub(self.comma_number_re, self._remove_commas, text)
text = re.sub(self.pounds_re, r"\1 pounds", text)
text = re.sub(self.dollars_re, self._expand_dollars, text)
text = re.sub(self.decimal_number_re, self._expand_decimal_point, text)
text = re.sub(self.ordinal_re, self._expand_ordinal, text)
text = re.sub(self.number_re, self._expand_number, text)
for regex, replacement in self.abbr_re:
text = re.sub(regex, replacement, text)
text = re.sub(self.whitespace_re, " ", text)
return text
[docs]@functools.lru_cache()
def text_clean_func(lower: bool = True) -> Callable[[str], str]:
try:
import ftfy
ftfy_fix: Callable[[str], str] = ftfy.fix_text
except ImportError:
logger.warning("Please install ftfy: pip install ftfy")
def ftfy_fix(x: str) -> str:
return x
try:
normalizer: Callable[[str], str] = Normalizer()
except ImportError:
logger.warning("Please install inflect and make sure it can be imported: pip install inflect")
def normalizer(x: str) -> str:
return x
def _clean(text: str) -> str:
text = ftfy_fix(text)
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text)
text = text.strip()
if lower:
text = text.lower()
text = normalizer(text)
return text
return _clean
[docs]def get_mask_from_lengths(lengths: Tensor) -> Tensor:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, dtype=torch.int32, device=lengths.device)
mask = (ids < lengths.unsqueeze(1)).byte()
mask = mask <= 0
return mask
[docs]class LinearNorm(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
w_init_gain: str = "linear",
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
) -> None:
super().__init__()
linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
nn.init.xavier_uniform_(linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
self.linear_layer = maybe_lora(linear_layer, r=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
[docs] def forward(self, x: Tensor) -> Tensor:
return self.linear_layer(x)
[docs]class ConvNorm(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 1,
stride: int = 1,
padding: int | None = None,
dilation: int = 1,
bias: bool = True,
w_init_gain: str = "linear",
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
) -> None:
super().__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
self.conv = maybe_lora(conv, r=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
[docs] def forward(self, signal: Tensor) -> Tensor:
return self.conv(signal)
[docs]class LocationLayer(nn.Module):
def __init__(
self,
attention_n_filters: int,
attention_kernel_size: int,
attention_dim: int,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
) -> None:
super().__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(
2,
attention_n_filters,
kernel_size=attention_kernel_size,
padding=padding,
bias=False,
stride=1,
dilation=1,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.location_dense = LinearNorm(
attention_n_filters,
attention_dim,
bias=False,
w_init_gain="tanh",
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
[docs] def forward(self, attention_weights_cat: Tensor) -> Tensor:
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose(1, 2)
processed_attention = self.location_dense(processed_attention)
return processed_attention
[docs]class Attention(nn.Module):
def __init__(
self,
attention_rnn_dim: int,
embedding_dim: int,
attention_dim: int,
attention_location_n_filters: int,
attention_location_kernel_size: int,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
) -> None:
super().__init__()
self.query_layer = LinearNorm(
attention_rnn_dim,
attention_dim,
bias=False,
w_init_gain="tanh",
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.memory_layer = LinearNorm(
embedding_dim,
attention_dim,
bias=False,
w_init_gain="tanh",
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.v = LinearNorm(
attention_dim,
1,
bias=False,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.location_layer = LocationLayer(
attention_location_n_filters,
attention_location_kernel_size,
attention_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.score_mask_value = -float("inf")
[docs] def get_alignment_energies(self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor) -> Tensor:
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory))
energies = energies.squeeze(-1)
return energies
[docs] def forward(
self,
attn_hid_state: Tensor,
memory: Tensor,
proc_memory: Tensor,
attn_weights_cat: Tensor,
mask: Tensor | None,
) -> tuple[Tensor, Tensor]:
alignment = self.get_alignment_energies(attn_hid_state, proc_memory, attn_weights_cat)
if mask is not None:
alignment = alignment.masked_fill(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
[docs]class Prenet(nn.Module):
__constants__ = ["dropout", "dropout_always_on"]
def __init__(
self,
in_dim: int = 80,
sizes: list[int] = [256, 256],
dropout: float = 0.5,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
dropout_always_on: bool = True,
) -> None:
super().__init__()
in_sizes = [in_dim] + sizes[:-1]
layers = [
LinearNorm(
in_size,
out_size,
bias=False,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
for (in_size, out_size) in zip(in_sizes, sizes)
]
self.layers = nn.ModuleList(layers)
self.dropout = dropout
self.dropout_always_on = dropout_always_on
[docs] def forward(self, x: Tensor) -> Tensor:
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=self.dropout, training=self.training or self.dropout_always_on)
return x
[docs]@dataclass
class PostnetConfig:
n_mel_channels: int = conf_field(80, help="Number of mel channels")
emb_dim: int = conf_field(512, help="Postnet embedding dimension")
kernel_size: int = conf_field(5, help="Postnet kernel size")
n_convolutions: int = conf_field(5, help="Number of postnet convolutions")
lora_rank: int | None = conf_field(None, help="LoRA rank")
lora_alpha: float = conf_field(1.0, help="LoRA alpha")
lora_dropout: float = conf_field(0.0, help="LoRA dropout")
[docs]class Postnet(nn.Module):
def __init__(self, config: PostnetConfig) -> None:
super().__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
nn.Sequential(
ConvNorm(
config.n_mel_channels,
config.emb_dim,
kernel_size=config.kernel_size,
stride=1,
padding=(config.kernel_size - 1) // 2,
dilation=1,
w_init_gain="tanh",
lora_rank=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
),
nn.BatchNorm1d(config.emb_dim),
)
)
for _ in range(1, config.n_convolutions - 1):
self.convolutions.append(
nn.Sequential(
ConvNorm(
config.emb_dim,
config.emb_dim,
kernel_size=config.kernel_size,
stride=1,
padding=int((config.kernel_size - 1) / 2),
dilation=1,
w_init_gain="tanh",
lora_rank=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
),
nn.BatchNorm1d(config.emb_dim),
)
)
self.convolutions.append(
nn.Sequential(
ConvNorm(
config.emb_dim,
config.n_mel_channels,
kernel_size=config.kernel_size,
stride=1,
padding=int((config.kernel_size - 1) / 2),
dilation=1,
w_init_gain="linear",
lora_rank=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
),
nn.BatchNorm1d(config.n_mel_channels),
)
)
[docs] def forward(self, x: Tensor) -> Tensor:
for i, conv in enumerate(self.convolutions):
if i < len(self.convolutions) - 1:
x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training)
else:
x = F.dropout(conv(x), 0.5, training=self.training)
return x
[docs]@dataclass
class EncoderConfig:
emb_dim: int = conf_field(512, help="Encoder embedding dimension")
kernel_size: int = conf_field(5, help="Encoder kernel size")
n_convolutions: int = conf_field(3, help="Number of encoder convolutions")
lora_rank: int | None = conf_field(None, help="LoRA rank")
lora_alpha: float = conf_field(1.0, help="LoRA alpha")
lora_dropout: float = conf_field(0.0, help="LoRA dropout")
freeze_bn: bool = conf_field(False, help="Freeze batch normalization")
speaker_emb_dim: int | None = conf_field(None, help="Speaker embedding dimension")
[docs]class Encoder(nn.Module):
def __init__(self, config: EncoderConfig) -> None:
super().__init__()
convolutions = []
for _ in range(config.n_convolutions):
conv_norm = ConvNorm(
config.emb_dim,
config.emb_dim,
kernel_size=config.kernel_size,
stride=1,
padding=int((config.kernel_size - 1) / 2),
dilation=1,
w_init_gain="relu",
lora_rank=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
)
batch_norm = nn.BatchNorm1d(config.emb_dim)
if config.freeze_bn:
batch_norm.requires_grad_(False)
conv_layer = nn.Sequential(conv_norm, batch_norm)
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
lstm = nn.LSTM(
config.emb_dim,
int(config.emb_dim / 2),
1,
batch_first=True,
bidirectional=True,
)
self.lstm = maybe_lora(lstm, r=config.lora_rank, alpha=config.lora_alpha, dropout=config.lora_dropout)
[docs] def forward(self, x: Tensor, input_lengths: Tensor, speaker_emb: Tensor | None = None) -> Tensor:
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
if speaker_emb is not None:
outputs += speaker_emb.unsqueeze(1)
return outputs
[docs] def infer(self, x: Tensor, input_lengths: Tensor, speaker_emb: Tensor | None = None) -> Tensor:
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True)
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
if speaker_emb is not None:
outputs += speaker_emb.unsqueeze(1)
return outputs
[docs]@dataclass
class DecoderConfig:
n_mel_channels: int = conf_field(80, help="Number of mel channels")
n_frames_per_step: int = conf_field(1, help="Number of frames processed per step")
encoder_emb_dim: int = conf_field(512, help="Encoder embedding dimension")
attention_dim: int = conf_field(128, help="Attention dimension")
attention_location_n_filters: int = conf_field(32, help="Number of filters for location-sensitive attention")
attention_location_kernel_size: int = conf_field(31, help="Kernel size for location-sensitive attention")
attention_rnn_dim: int = conf_field(1024, help="Attention RNN dimension")
decoder_rnn_dim: int = conf_field(1024, help="Decoder RNN dimension")
prenet_dim: int = conf_field(256, help="Prenet dimension")
prenet_dropout: bool = conf_field(0.5, help="Whether to use dropout in prenet layers")
max_decoder_steps: int = conf_field(1000, help="Maximum decoder steps during inference")
gate_threshold: float = conf_field(0.5, help="Probability threshold for stop token")
p_attention_dropout: float = conf_field(0.1, help="Dropout probability for attention LSTM")
p_decoder_dropout: float = conf_field(0.1, help="Dropout probability for decoder LSTM")
prenet_dropout_always_on: bool = conf_field(True, help="If set, prenet dropout is always on")
lora_rank: int | None = conf_field(None, help="LoRA rank")
lora_alpha: float = conf_field(1.0, help="LoRA alpha")
lora_dropout: float = conf_field(0.0, help="LoRA dropout")
[docs]class DecoderStates(NamedTuple):
attn_h: Tensor
attn_c: Tensor
dec_h: Tensor
dec_c: Tensor
attn_weights: Tensor
attn_weights_cum: Tensor
attn_ctx: Tensor
memory: Tensor
processed_memory: Tensor
mask: Tensor | None
[docs]class Decoder(nn.Module):
def __init__(self, config: DecoderConfig) -> None:
super().__init__()
self.n_mel_channels = config.n_mel_channels
self.n_frames_per_step = config.n_frames_per_step
self.encoder_embedding_dim = config.encoder_emb_dim
self.attention_rnn_dim = config.attention_rnn_dim
self.decoder_rnn_dim = config.decoder_rnn_dim
self.prenet_dim = config.prenet_dim
self.max_decoder_steps = config.max_decoder_steps
self.gate_threshold = config.gate_threshold
self.p_attention_dropout = config.p_attention_dropout
self.p_decoder_dropout = config.p_decoder_dropout
self.prenet = Prenet(
config.n_mel_channels * config.n_frames_per_step,
[config.prenet_dim, config.prenet_dim],
config.prenet_dropout,
lora_rank=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
dropout_always_on=config.prenet_dropout_always_on,
)
attention_rnn = nn.LSTMCell(
config.prenet_dim + config.encoder_emb_dim,
config.attention_rnn_dim,
)
self.attention_rnn = maybe_lora(
attention_rnn,
r=config.lora_rank,
alpha=config.lora_alpha,
dropout=config.lora_dropout,
)
self.attention_layer = Attention(
config.attention_rnn_dim,
config.encoder_emb_dim,
config.attention_dim,
config.attention_location_n_filters,
config.attention_location_kernel_size,
lora_rank=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
)
decoder_rnn = nn.LSTMCell(
config.attention_rnn_dim + config.encoder_emb_dim,
config.decoder_rnn_dim,
bias=True,
)
self.decoder_rnn = maybe_lora(
decoder_rnn,
r=config.lora_rank,
alpha=config.lora_alpha,
dropout=config.lora_dropout,
)
self.linear_projection = LinearNorm(
config.decoder_rnn_dim + config.encoder_emb_dim,
config.n_mel_channels * config.n_frames_per_step,
lora_rank=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
)
self.gate_layer = LinearNorm(
config.decoder_rnn_dim + config.encoder_emb_dim,
1,
bias=True,
w_init_gain="sigmoid",
lora_rank=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
)
[docs] def get_go_frame(self, memory: Tensor) -> Tensor:
return memory.new_zeros(memory.shape[0], self.n_mel_channels * self.n_frames_per_step)
[docs] def initialize_decoder_states(self, memory: Tensor, mask: Tensor | None) -> DecoderStates:
bsz, max_tsz, *_ = memory.size()
attn_h = memory.new_zeros(bsz, self.attention_rnn_dim)
attn_c = memory.new_zeros(bsz, self.attention_rnn_dim)
dec_hid = memory.new_zeros(bsz, self.decoder_rnn_dim)
dec_cell = memory.new_zeros(bsz, self.decoder_rnn_dim)
attn_weights = memory.new_zeros(bsz, max_tsz)
attn_weights_cum = memory.new_zeros(bsz, max_tsz)
attn_ctx = memory.new_zeros(bsz, self.encoder_embedding_dim)
processed_memory = self.attention_layer.memory_layer(memory)
return DecoderStates(
attn_h=attn_h,
attn_c=attn_c,
dec_h=dec_hid,
dec_c=dec_cell,
attn_weights=attn_weights,
attn_weights_cum=attn_weights_cum,
attn_ctx=attn_ctx,
memory=memory,
processed_memory=processed_memory,
mask=mask,
)
[docs] def parse_decoder_outputs(
self,
mel_outputs: list[Tensor],
gate_outputs: list[Tensor],
alignments: list[Tensor],
states: DecoderStates,
) -> tuple[Tensor, Tensor, Tensor, DecoderStates]:
alignments = torch.stack(alignments, dim=1)
gate_outputs = torch.stack(gate_outputs, dim=1)
mel_outputs = torch.stack(mel_outputs, dim=-1)
return mel_outputs, gate_outputs, alignments, states
[docs] def decode(self, decoder_input: Tensor, states: DecoderStates) -> tuple[Tensor, Tensor, Tensor, DecoderStates]:
attn_h, attn_c, dec_h, dec_c, attn_weights, attn_weights_cum, attn_ctx, memory, processed_memory, mask = states
cell_input = torch.cat((decoder_input, attn_ctx), -1)
attn_h, attn_c = self.attention_rnn(cell_input, (attn_h, attn_c))
attn_h = F.dropout(attn_h, self.p_attention_dropout, self.training)
attn_weights_cat = torch.cat((attn_weights.unsqueeze(1), attn_weights_cum.unsqueeze(1)), dim=1)
attn_ctx, attn_weights = self.attention_layer(attn_h, memory, processed_memory, attn_weights_cat, mask)
attn_weights_cum = attn_weights_cum + attn_weights
decoder_input = torch.cat((attn_h, attn_ctx), -1)
dec_h, dec_c = self.decoder_rnn(decoder_input, (dec_h, dec_c))
dec_h = F.dropout(dec_h, self.p_decoder_dropout, self.training)
dec_h_attn_ctx = torch.cat((dec_h, attn_ctx), dim=1)
dec_out = self.linear_projection(dec_h_attn_ctx)
gate_pred = self.gate_layer(dec_h_attn_ctx)
new_states = DecoderStates(
attn_h=attn_h,
attn_c=attn_c,
dec_h=dec_h,
dec_c=dec_c,
attn_weights=attn_weights,
attn_weights_cum=attn_weights_cum,
attn_ctx=attn_ctx,
memory=memory,
processed_memory=processed_memory,
mask=mask,
)
return dec_out, gate_pred, attn_weights, new_states
[docs] def forward(
self,
memory: Tensor,
dec_ins: Tensor,
memory_lengths: Tensor,
states: DecoderStates | None = None,
) -> tuple[Tensor, Tensor, Tensor, DecoderStates]:
dec_in = self.get_go_frame(memory).unsqueeze(0)
dec_ins = self.parse_decoder_inputs(dec_ins)
dec_ins = torch.cat((dec_in, dec_ins), dim=0)
prenet_ins = self.prenet(dec_ins)
states = self.initialize_decoder_states(memory, mask=get_mask_from_lengths(memory_lengths))
mel_outs: list[Tensor] = []
gate_outs: list[Tensor] = []
alignments: list[Tensor] = []
while len(mel_outs) < prenet_ins.size(0) - 1:
prenet_in = prenet_ins[len(mel_outs)]
mel_out, gate_out, attn_weights, states = self.decode(prenet_in, states)
mel_outs += [mel_out.squeeze(1)]
gate_outs += [gate_out.squeeze(1)]
alignments += [attn_weights]
return self.parse_decoder_outputs(mel_outs, gate_outs, alignments, states)
[docs] def infer(
self,
memory: Tensor,
memory_lengths: Tensor,
states: DecoderStates | None = None,
) -> tuple[Tensor, Tensor, Tensor, DecoderStates]:
dec_in = self.get_go_frame(memory)
if states is None:
states = self.initialize_decoder_states(memory, mask=get_mask_from_lengths(memory_lengths))
mel_outs: list[Tensor] = []
gate_outs: list[Tensor] = []
alignments: list[Tensor] = []
while True:
prenet_in = self.prenet(dec_in)
mel_out, gate_out, alignment, states = self.decode(prenet_in, states)
mel_outs += [mel_out.squeeze(1)]
gate_outs += [gate_out]
alignments += [alignment]
if (torch.sigmoid(gate_out) > self.gate_threshold).all():
break
elif len(mel_outs) == self.max_decoder_steps:
logger.warning("Warning! Reached max decoder steps %d", self.max_decoder_steps)
break
dec_in = mel_out
return self.parse_decoder_outputs(mel_outs, gate_outs, alignments, states)
[docs]def window_sumsquare(
window: str | float,
n_frames: int,
hop_length: int = 200,
win_length: int = 800,
n_fft: int = 800,
dtype: type = np.float32,
norm: float | None = None,
) -> np.ndarray:
try:
from scipy.signal import get_window
except ImportError:
raise ImportError("Please install scipy to use this module: pip install scipy")
try:
import librosa.util
except ImportError:
raise ImportError("Please install librosa to use this module: pip install librosa")
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x: np.ndarray = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa.util.normalize(win_sq, norm=norm) ** 2
win_sq = librosa.util.pad_center(win_sq, size=n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
return x
[docs]def griffin_lim(magnitudes: Tensor, stft_fn: "STFT", n_iters: int = 30) -> Tensor:
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
angles = angles.astype(np.float32)
angles = torch.from_numpy(angles)
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
for _ in range(n_iters):
_, angles_tensor = stft_fn.transform(signal)
signal = stft_fn.inverse(magnitudes, angles_tensor).squeeze(1)
return signal
[docs]def dynamic_range_compression(x: Tensor, c: int | float = 1, clip_val: float = 1e-5) -> Tensor:
return torch.log(torch.clamp(x, min=clip_val) * c)
[docs]def dynamic_range_decompression(x: Tensor, c: int | float = 1) -> Tensor:
return torch.exp(x) / c
[docs]class STFT(nn.Module):
forward_basis: Tensor
inverse_basis: Tensor
def __init__(
self,
filter_length: int = 800,
hop_length: int = 200,
win_length: int = 800,
window: str = "hann",
) -> None:
try:
from scipy.signal import get_window
except ImportError:
raise ImportError("Please install scipy to use this module: pip install scipy")
try:
import librosa.util
except ImportError:
raise ImportError("Please install librosa to use this module: pip install librosa")
super().__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
scale = self.filter_length / self.hop_length
with Timer("getting fourier basis"):
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
with Timer("getting forward and inverse basis"):
forward_basis: Tensor = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis: Tensor = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
with Timer("applying window"):
assert filter_length >= win_length
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = librosa.util.pad_center(fft_window, size=filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window[None, None, :]
inverse_basis *= fft_window[None, None, :]
self.register_buffer("forward_basis", forward_basis.float())
self.register_buffer("inverse_basis", inverse_basis.float())
[docs] def inverse(self, magnitude: Tensor, phase: Tensor) -> Tensor:
try:
import librosa.util
except ImportError:
raise ImportError("Please install librosa to use this module: pip install librosa")
recombine_magnitude_phase = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)
inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase,
self.inverse_basis,
stride=self.hop_length,
padding=0,
)
if self.window is not None:
window_sum_np = window_sumsquare(
self.window,
magnitude.size(-1),
hop_length=self.hop_length,
win_length=self.win_length,
n_fft=self.filter_length,
dtype=np.float32,
)
# Remove modulation effects.
approx_nonzero_indices = torch.from_numpy(np.where(window_sum_np > librosa.util.tiny(window_sum_np))[0])
window_sum = torch.from_numpy(window_sum_np)
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# Scale by hop ratio.
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
return inverse_transform
[docs] def forward(self, input_data: Tensor) -> Tensor:
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
[docs]class TacotronSTFT(nn.Module):
mel_basis: Tensor
def __init__(
self,
filter_length: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
n_mel_channels: int = 80,
sampling_rate: int = 16000,
mel_fmin: float = 0.0,
mel_fmax: float = 8000.0,
) -> None:
super().__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
try:
from librosa.filters import mel as librosa_mel_fn
except ImportError:
raise ImportError("Please install librosa: pip install librosa")
mel_basis_np = librosa_mel_fn(
sr=sampling_rate,
n_fft=filter_length,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax,
)
mel_basis = torch.from_numpy(mel_basis_np).float()
self.register_buffer("mel_basis", mel_basis)
[docs] @torch.no_grad()
def spectral_normalize(self, magnitudes: Tensor) -> Tensor:
output = dynamic_range_compression(magnitudes)
return output
[docs] @torch.no_grad()
def spectral_de_normalize(self, magnitudes: Tensor) -> Tensor:
output = dynamic_range_decompression(magnitudes)
return output
[docs] @torch.no_grad()
def mel_spectrogram(self, y: Tensor) -> Tensor:
assert torch.min(y.data) >= -1
assert torch.max(y.data) <= 1
magnitudes, _ = self.stft_fn.transform(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
return mel_output
[docs]@dataclass
class TacotronConfig(BaseModelConfig):
mask_padding: bool = conf_field(False, help="Mask padding in loss computation")
n_mel_channels: int = conf_field(80, help="Number of bins in mel-spectrograms")
n_symbols: int = conf_field(148, help="Number of symbols in dictionary")
symbols_emb_dim: int = conf_field(512, help="Input embedding dimension")
n_frames_per_step: int = conf_field(1, help="Number of frames processed per step")
symbols_emb_dropout: float = conf_field(0.1, help="Dropout rate for symbol embeddings")
encoder: EncoderConfig = conf_field(EncoderConfig(), help="Encoder configuration")
decoder: DecoderConfig = conf_field(DecoderConfig(), help="Decoder configuration")
postnet: PostnetConfig = conf_field(PostnetConfig(), help="Postnet configuration")
[docs]class Tacotron(BaseModel):
def __init__(self, config: TacotronConfig) -> None:
super().__init__(config)
self.mask_padding = config.mask_padding
self.n_mel_channels = config.n_mel_channels
self.n_frames_per_step = config.n_frames_per_step
self.embedding = nn.Embedding(config.n_symbols, config.symbols_emb_dim)
std = sqrt(2.0 / (config.n_symbols + config.symbols_emb_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(config.encoder)
self.decoder = Decoder(config.decoder)
self.postnet = Postnet(config.postnet)
[docs] def parse_output(
self,
outputs: tuple[Tensor, Tensor, Tensor, Tensor, DecoderStates],
output_lengths: Tensor | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, DecoderStates]:
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths)
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
mask = mask.permute(1, 0, 2)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[1].data.masked_fill_(mask, 0.0)
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
return outputs
[docs] def forward(
self,
inputs: tuple[Tensor, Tensor, Tensor, Tensor],
states: DecoderStates | None = None,
speaker_emb: Tensor | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, DecoderStates]:
text_inputs, text_lengths, mels, output_lengths = inputs
embedded_inputs = self.embedding.forward(text_inputs).transpose(1, 2)
encoder_outputs = self.encoder.forward(embedded_inputs, text_lengths, speaker_emb)
mel_outputs, gate_outputs, alignments, states = self.decoder(encoder_outputs, mels, text_lengths, states)
mel_outputs_postnet = mel_outputs + self.postnet(mel_outputs)
return self.parse_output((mel_outputs, mel_outputs_postnet, gate_outputs, alignments, states), output_lengths)
[docs] def infer(
self,
inputs: Tensor,
input_lengths: Tensor,
states: DecoderStates | None = None,
speaker_emb: Tensor | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, DecoderStates]:
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder.infer(embedded_inputs, input_lengths, speaker_emb)
mel_outputs, gate_outputs, alignments, states = self.decoder.infer(encoder_outputs, input_lengths, states)
mel_outputs_postnet = mel_outputs + self.postnet(mel_outputs)
return self.parse_output((mel_outputs, mel_outputs_postnet, gate_outputs, alignments, states))
[docs]class Tokenizer:
def __init__(self) -> None:
super().__init__()
pad = "_"
punctuation = "!'(),.:;? "
specials = "-"
letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
# Prepend "@" to ARPAbet symbols to ensure uniqueness.
valid_symbols = (
"AA:AA0:AA1:AA2:AE:AE0:AE1:AE2:AH:AH0:AH1:AH2:AO:AO0:AO1:AO2:AW"
":AW0:AW1:AW2:AY:AY0:AY1:AY2:B:CH:D:DH:EH:EH0:EH1:EH2:ER:ER0:ER1"
":ER2:EY:EY0:EY1:EY2:F:G:HH:IH:IH0:IH1:IH2:IY:IY0:IY1:IY2:JH:K:L"
":M:N:NG:OW:OW0:OW1:OW2:OY:OY0:OY1:OY2:P:R:S:SH:T:TH:UH:UH0:UH1"
":UH2:UW:UW0:UW1:UW2:V:W:Y:Z:ZH"
)
arpabet = ["@" + s for s in valid_symbols.split(":")]
# Gets the symbol conversion dictionary.
self.symbols = [pad] + list(specials) + list(punctuation) + list(letters) + arpabet
self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
self.id_to_symbol = {v: k for k, v in self.symbol_to_id.items()}
self.curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
def __call__(self, text: str) -> Tensor:
clean_func = text_clean_func()
sequence: list[int] = []
def _should_keep_symbol(s: str) -> bool:
return s in self.symbol_to_id and s != "_" and s != "~"
def _symbols_to_sequence(symbols: str | list[str]) -> list[int]:
return [self.symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text: str) -> list[int]:
return _symbols_to_sequence(["@" + s for s in text.split()])
while len(text):
m = self.curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(clean_func(text))
break
sequence += _symbols_to_sequence(clean_func(m.group(1)))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
return torch.tensor(sequence, dtype=torch.int32)
[docs]def ensure_tacotron_downloaded() -> Path:
with Timer("downloading checkpoint"):
return ensure_downloaded(TACOTRON_CKPT_URL, "tacotron2", "weights_tacotron.pth")
[docs]def pretrained_tacotron2(
*,
pretrained: bool = True,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
lora_encoder: bool = True,
lora_decoder: bool = True,
lora_postnet: bool = True,
device: torch.device | None = None,
prenet_dropout: bool = True,
num_tokens: int | None = None,
) -> Tacotron:
"""Loads the pretrained Tacotron2 model.
Args:
pretrained: Whether to load the pretrained weights.
lora_rank: The LoRA rank to use, if LoRA is desired.
lora_alpha: The LoRA alpha to use, if LoRA is desired.
lora_dropout: The LoRA dropout to use, if LoRA is desired.
lora_encoder: Whether to use LoRA in the encoder.
lora_decoder: Whether to use LoRA in the decoder.
lora_postnet: Whether to use LoRA in the postnet.
device: The device to load the weights onto.
prenet_dropout: Whether to use always apply dropout in the PreNet.
num_tokens: The number of tokens in the vocabulary.
Returns:
The pretrained Tacotron model.
"""
config = TacotronConfig()
if lora_encoder:
config.encoder.lora_rank = lora_rank
config.encoder.lora_alpha = lora_alpha
config.encoder.lora_dropout = lora_dropout
if lora_decoder:
config.decoder.lora_rank = lora_rank
config.decoder.lora_alpha = lora_alpha
config.decoder.lora_dropout = lora_dropout
if lora_postnet:
config.postnet.lora_rank = lora_rank
config.postnet.lora_alpha = lora_alpha
config.postnet.lora_dropout = lora_dropout
config.decoder.prenet_dropout_always_on = prenet_dropout
if num_tokens is not None:
config.n_symbols = num_tokens
if not pretrained:
return Tacotron(config)
with Timer("initializing model", spinner=True), init_empty_weights():
model = Tacotron(config)
with Timer("downloading checkpoint", spinner=True):
filepath = ensure_tacotron_downloaded()
with Timer("loading checkpoint", spinner=True):
if device is None:
device = torch.device("cpu")
ckpt = torch.load(filepath, map_location=device)
model._apply(meta_to_empty_func(device))
model.load_state_dict({k: v for k, v in ckpt["state_dict"].items()})
reset_lora_weights_(model)
with Timer("freezing weights", spinner=True):
if lora_rank is not None:
freeze_non_lora_(model)
if num_tokens is not None:
model.embedding.requires_grad_(True)
return model
[docs]def tacotron_stft(
filter_length: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
n_mel_channels: int = 80,
sampling_rate: int = 16000,
mel_fmin: float = 0.0,
mel_fmax: float = 8000.0,
) -> TacotronSTFT:
"""Returns an STFT module for training the Tacotron model.
Args:
filter_length: The length of the filters used for the STFT.
hop_length: The hop length of the STFT.
win_length: The window length of the STFT.
n_mel_channels: The number of mel channels.
sampling_rate: The sampling rate of the audio.
mel_fmin: The minimum frequency of the mel filterbank.
mel_fmax: The maximum frequency of the mel filterbank.
Returns:
The STFT module.
"""
return TacotronSTFT(
filter_length=filter_length,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
sampling_rate=sampling_rate,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax,
)
[docs]def tacotron_tokenizer() -> Tokenizer:
return Tokenizer()
[docs]class TTS:
def __init__(
self,
tacotron: Tacotron,
vocoder: Vocoder,
*,
device: base_device | None = None,
) -> None:
"""Provides an API for doing text-to-speech.
Note that this module is not an `nn.Module`, so you can use it in your
module without worrying aobut storing all the weights on accident.
Args:
tacotron: The Tacotron model.
vocoder: The vocoder model.
device: The device to load the weights onto.
"""
super().__init__()
self.device = detect_device() if device is None else device
self.tacotron = tacotron.eval()
self.vocoder = vocoder.eval()
self.vocoder.remove_weight_norm()
self.sampling_rate = self.vocoder.sampling_rate
self.device.module_to(self.tacotron)
self.device.module_to(self.vocoder)
self.tokenizer = Tokenizer()
[docs] @torch.inference_mode()
def generate_mels(
self,
text: str | list[str],
postnet: bool = True,
states: DecoderStates | None = None,
) -> tuple[Tensor, DecoderStates]:
if isinstance(text, str):
tokens = self.tokenizer(text).unsqueeze(0)
token_lengths = tokens.new_full((1,), tokens.shape[1], dtype=torch.int32)
else:
token_list = [self.tokenizer(t) for t in text]
tokens = pad_sequence(token_list, batch_first=True, padding_value=0)
token_lengths = tokens.new_empty((tokens.shape[0],), dtype=torch.int32)
for i, t in enumerate(token_list):
token_lengths[i] = t.shape[0]
tokens, token_lengths = self.device.tensor_to(tokens), self.device.tensor_to(token_lengths)
mel_outputs, mel_outputs_postnet, _, _, states = self.tacotron.infer(tokens, token_lengths, states)
return mel_outputs_postnet if postnet else mel_outputs, states
[docs] @torch.inference_mode()
def generate_wave(self, mels: Tensor) -> Tensor:
return self.vocoder.infer(mels)
[docs] @torch.inference_mode()
def generate(
self,
text: str | list[str],
postnet: bool = True,
states: DecoderStates | None = None,
) -> tuple[Tensor, DecoderStates]:
mels, states = self.generate_mels(text, postnet=postnet, states=states)
audio = self.generate_wave(mels).squeeze(0)
return audio, states
[docs]def pretrained_tacotron2_tts(vocoder_type: VocoderType = "hifigan", *, device: base_device | None = None) -> TTS:
tacotron = pretrained_tacotron2()
vocoder = pretrained_vocoder(vocoder_type)
tts = TTS(tacotron, vocoder, device=device)
return tts
[docs]def test_tacotron_adhoc() -> None:
configure_logging()
parser = argparse.ArgumentParser()
parser.add_argument("text", type=str, nargs="?", help="The text to synthesize.")
parser.add_argument("-o", "--out-file", type=str, default=None, help="The output file.")
parser.add_argument("-p", "--prompt", default="Text: ", help="The prompt to use.")
args = parser.parse_args()
tts = pretrained_tacotron2_tts()
def generate_for_text(texts: Iterable[str]) -> None:
for text in texts:
logger.info("Generating audio for '%s'", text)
audio, _ = tts.generate(text, postnet=True)
if args.out_file is None:
try:
import sounddevice as sd
except ImportError:
raise ImportError("Please install sounddevice to use this module: pip install sounddevice")
audio = audio.cpu().numpy().T
sd.play(audio, tts.sampling_rate, blocking=True)
else:
out_path = Path(args.out_file)
out_path.parent.mkdir(exist_ok=True)
write_audio(iter([audio]), out_path, tts.sampling_rate)
if args.text:
generate_for_text([args.text])
else:
def gen_texts() -> Iterable[str]:
try:
text = input(args.prompt)
while text:
yield text
text = input(args.prompt)
except EOFError:
pass
generate_for_text(gen_texts())
if __name__ == "__main__":
# python -m pretrained.tacotron2
test_tacotron_adhoc()