Source code for pretrained.demucs

"""Implementation of the Demucs model architecture.

From the paper `Real Time Speech Enhancement in the Waveform Domain
<https://arxiv.org/abs/2006.12847>`_. The paper has a project page
`here <https://github.com/facebookresearch/denoiser>`_.

This model is a relatively straight-forward autoencoder, similar to a UNet but
with an RNN in between. The original model was trained to do denoising, which
makes sense for this particular model since it simply requires removing some
part of the input waveform.
"""

import functools
import math
import time
from typing import cast

import ml.api as ml
import torch
import torch.nn.functional as F
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from torch import Tensor, nn


[docs]def sinc(t: Tensor) -> Tensor: return torch.where(t == 0, torch.tensor(1.0, device=t.device, dtype=t.dtype), torch.sin(t) / t)
[docs]@functools.lru_cache() def kernel_upsample2(device: torch.device, dtype: torch.dtype, zeros: int = 56) -> Tensor: win = torch.hann_window(4 * zeros + 1, periodic=False) winodd = win[1::2] t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) t *= math.pi kernel = (sinc(t) * winodd).view(1, 1, -1) return kernel.to(device, dtype)
[docs]def upsample2(x: Tensor, zeros: int = 56) -> Tensor: *other, time = x.shape kernel = kernel_upsample2(x.device, x.dtype, zeros) out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time) y = torch.stack([x, out], dim=-1) return y.view(*other, -1)
[docs]@functools.lru_cache() def kernel_downsample2(device: torch.device, dtype: torch.dtype, zeros: int = 56) -> Tensor: win = torch.hann_window(4 * zeros + 1, periodic=False) winodd = win[1::2] t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) t.mul_(math.pi) kernel = (sinc(t) * winodd).view(1, 1, -1) return kernel.to(device, dtype)
[docs]def downsample2(x: Tensor, zeros: int = 56) -> Tensor: if x.shape[-1] % 2 != 0: x = F.pad(x, (0, 1)) xeven = x[..., ::2] xodd = x[..., 1::2] *other, time = xodd.shape kernel = kernel_downsample2(x.device, x.dtype, zeros) out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view(*other, time) return out.view(*other, -1).mul(0.5)
[docs]def fast_conv(conv: nn.Conv1d | nn.ConvTranspose1d, x: Tensor) -> Tensor: batch, in_channels, length = x.shape weight, bias = conv.weight, conv.bias out_channels, in_channels, kernel = weight.shape assert batch == 1 if bias is None: out = conv(x) elif kernel == 1: x = x.view(in_channels, length) out = torch.addmm(bias.view(-1, 1), weight.view(out_channels, in_channels), x) elif length == kernel: x = x.view(in_channels * kernel, 1) out = torch.addmm(bias.view(-1, 1), weight.view(out_channels, in_channels * kernel), x) else: out = conv(x) return out.view(batch, out_channels, -1)
[docs]def rescale_conv(conv: nn.Conv1d | nn.ConvTranspose1d, reference: float) -> None: std = conv.weight.std().detach() scale = (std / reference) ** 0.5 conv.weight.data /= scale if conv.bias is not None: conv.bias.data /= scale
[docs]def rescale_module(module: nn.Module, reference: float) -> None: for sub in module.modules(): if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): rescale_conv(sub, reference)
[docs]class RNN(nn.Module): def __init__(self, dim: int, layers: int = 2, bi: bool = True) -> None: super().__init__() self.lstm = nn.LSTM(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim) self.linear = None if bi: self.linear = nn.Linear(2 * dim, dim)
[docs] def forward(self, x: Tensor, hidden: Tensor | None = None) -> tuple[Tensor, Tensor]: x, hidden = self.lstm(x, hidden) if self.linear: x = self.linear(x) return x, hidden
[docs]class Encoder(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, act: ml.ActivationType = "relu", ) -> None: super().__init__() self.conv_a = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride) self.conv_b = nn.Conv1d(out_channels, out_channels * 2, 1) self.act = ml.get_activation(act) self.glu = nn.GLU(dim=1)
[docs] def forward(self, x: Tensor) -> Tensor: x = self.conv_a(x) x = self.act(x) x = self.conv_b(x) x = self.glu(x) return x
[docs]class Decoder(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, act: ml.ActivationType = "relu", ) -> None: super().__init__() self.conv_a = nn.Conv1d(in_channels, in_channels * 2, 1, stride=1) self.conv_b = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride) self.act = ml.get_activation(act) self.glu = nn.GLU(dim=1)
[docs] def forward(self, x: Tensor) -> Tensor: x = self.glu(self.conv_a(x)) x = self.act(self.conv_b(x)) return x
[docs]class Demucs(nn.Module): def __init__( self, in_channels: int, out_channels: int, hidden: int = 48, depth: int = 5, kernel_size: int = 8, stride: int = 4, causal: bool = True, resample: int = 4, growth: float = 2, max_hidden: int = 10_000, normalize: bool = True, rescale: float = 0.1, floor: float = 1e-3, sample_rate: int = 16_000, ) -> None: """Demucs speech enhancement model. Args: in_channels: Number of input channels. out_channels: Number of output channels. hidden: Number of initial hidden channels. depth: Number of layers. kernel_size: Kernel size for each layer. stride: Stride for each layer. causal: If false, uses BiLSTM instead of LSTM. resample: Amount of resampling to apply to the input/output. Can be one of 1, 2 or 4. growth: Number of channels is multiplied by this for every layer. max_hidden: Maximum number of channels. Can be useful to control the size/speed of the model. normalize: If true, normalize the input. rescale: Controls custom weight initialization. floor: Floor value for normalization. sample_rate: Sample rate used for training the model. """ super().__init__() if resample not in [1, 2, 4]: raise ValueError("Resample must be one of 1, 2 or 4") # Model parameters. self.in_channels = in_channels self.out_channels = out_channels self.hidden = hidden self.depth = depth self.kernel_size = kernel_size self.stride = stride self.causal = causal self.growth = growth self.max_hidden = max_hidden self.rescale = rescale # Used during training. self.resample = resample self.normalize = normalize self.floor = floor self.sample_rate = sample_rate encoders: list[Encoder] = [] decoders: list[Decoder] = [] for index in range(depth): encoder = Encoder(in_channels, hidden, kernel_size, stride) decoder = Decoder(hidden, out_channels, kernel_size, stride, act="relu" if index > 0 else "no_act") encoders.append(encoder) decoders.append(decoder) in_channels = hidden out_channels = hidden hidden = min(int(growth * hidden), max_hidden) self.lstm = RNN(in_channels, bi=not causal) if rescale: rescale_module(self, reference=rescale) self.encoders = cast(list[Encoder], nn.ModuleList(encoders)) self.decoders = cast(list[Decoder], nn.ModuleList(decoders[::-1]))
[docs] def valid_length(self, length: int) -> int: """Returns the nearest valid length to use with the model. Return the nearest valid length to use with the model so that there is no time steps left over in a convolutions, e.g. for all layers, size of the input - kernel_size % stride = 0. If the mixture has a valid length, the estimated sources will have exactly the same length. Args: length: Length of the input. Returns: The nearest valid length. """ length = math.ceil(length * self.resample) for _ in range(self.depth): length = math.ceil((length - self.kernel_size) / self.stride) + 1 length = max(length, 1) for _ in range(self.depth): length = (length - 1) * self.stride + self.kernel_size length = int(math.ceil(length / self.resample)) return int(length)
@property def total_stride(self) -> int: return self.stride**self.depth // self.resample
[docs] def forward(self, mix: Tensor) -> Tensor: if mix.dim() == 2: mix = mix.unsqueeze(1) std: Tensor | None = None if self.normalize: mono = mix.mean(dim=1, keepdim=True) std = mono.std(dim=-1, keepdim=True) mix = mix / (self.floor + std) length = mix.shape[-1] x = mix x = F.pad(x, (0, self.valid_length(length) - length)) if self.resample == 2: x = upsample2(x) elif self.resample == 4: x = upsample2(x) x = upsample2(x) skips = [] for encode in self.encoders: x = encode(x) skips.append(x) x = x.permute(2, 0, 1) x, _ = self.lstm.forward(x) x = x.permute(1, 2, 0) for decode in self.decoders: skip = skips.pop(-1) x = x + skip[..., : x.shape[-1]] x = decode(x) if self.resample == 2: x = downsample2(x) elif self.resample == 4: x = downsample2(x) x = downsample2(x) x = x[..., :length] if std is not None: x = x * std return x
[docs] def streamer( self, *, dry: float = 0.0, num_frames: int = 1, resample_lookahead: int = 64, resample_buffer: int = 256, device: base_device | None = None, ) -> "DemucsStreamer": """Gets a streamer for the current model. Args: dry: Percentage of the unaltered signal to preserve (0 to 1). num_frames: Number of frames to process at once. Higher values will increase overall latency but improve the real time factor. resample_lookahead: Extra lookahead used for the resampling. resample_buffer: Size of the buffer of previous inputs/outputs kept for resampling. device: The device to use for predictions. If `None`, will use the device returned by detect_device(). Returns: A streamer for streaming from the current model. """ return DemucsStreamer( self, dry=dry, num_frames=num_frames, resample_lookahead=resample_lookahead, resample_buffer=resample_buffer, device=device, )
[docs]class DemucsStreamer: def __init__( self, demucs: Demucs, dry: float = 0.0, num_frames: int = 1, resample_lookahead: int = 64, resample_buffer: int = 256, device: base_device | None = None, ) -> None: self.device = detect_device() if device is None else device self.demucs = demucs self.device.module_to(self.demucs) self.lstm_state: Tensor | None = None self.conv_state: list[Tensor] | None = None self.dry = dry self.resample_lookahead = resample_lookahead resample_buffer = min(demucs.total_stride, resample_buffer) self.resample_buffer = resample_buffer self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1) self.total_length = self.frame_length + self.resample_lookahead self.stride = demucs.total_stride * num_frames self.resample_in = self.device.tensor_to(torch.zeros(demucs.in_channels, resample_buffer)) self.resample_out = self.device.tensor_to(torch.zeros(demucs.in_channels, resample_buffer)) self.frames = 0 self.total_time = 0.0 self.variance = 0.0 self.pending = self.device.tensor_to(torch.zeros(demucs.in_channels, 0))
[docs] def reset_time_per_frame(self) -> None: self.total_time = 0 self.frames = 0
@property def time_per_frame(self) -> float: return self.total_time / self.frames
[docs] def flush(self) -> Tensor: self.lstm_state = None self.conv_state = None pending_length = self.pending.shape[1] padding = torch.zeros(self.demucs.in_channels, self.total_length, device=self.pending.device) out = self.feed(padding) return out[:, :pending_length]
[docs] def feed(self, wav: Tensor) -> Tensor: begin = time.time() if wav.dim() != 2: raise ValueError("input wav should be two dimensional.") in_channels, _ = wav.shape if in_channels != self.demucs.in_channels: raise ValueError(f"Expected {self.demucs.in_channels} channels, got {in_channels}") self.pending = torch.cat([self.pending, wav], dim=1) outs: list[Tensor] = [] while self.pending.shape[1] >= self.total_length: self.frames += 1 frame = self.pending[:, : self.total_length] dry_signal = frame[:, : self.stride] if self.demucs.normalize: mono = frame.mean(0) variance = (mono**2).mean() self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance frame = frame / (self.demucs.floor + math.sqrt(self.variance)) padded_frame = torch.cat([self.resample_in, frame], dim=-1) self.resample_in[:] = frame[:, self.stride - self.resample_buffer : self.stride] frame = padded_frame if self.demucs.resample == 4: frame = upsample2(upsample2(frame)) elif self.demucs.resample == 2: frame = upsample2(frame) frame = frame[:, self.demucs.resample * self.resample_buffer :] frame = frame[:, : self.demucs.resample * self.frame_length] out, extra = self._separate_frame(frame) padded_out = torch.cat([self.resample_out, out, extra], 1) self.resample_out[:] = out[:, -self.resample_buffer :] if self.demucs.resample == 4: out = downsample2(downsample2(padded_out)) elif self.demucs.resample == 2: out = downsample2(padded_out) else: out = padded_out out = out[:, self.resample_buffer // self.demucs.resample :] out = out[:, : self.stride] if self.demucs.normalize: out *= math.sqrt(self.variance) out = self.dry * dry_signal + (1 - self.dry) * out outs.append(out) self.pending = self.pending[:, self.stride :] self.total_time += time.time() - begin if outs: out = torch.cat(outs, 1) else: out = torch.zeros(in_channels, 0, device=wav.device) return out
def _separate_frame(self, frame: Tensor) -> tuple[Tensor, Tensor]: skips: list[Tensor] = [] next_state: list[Tensor] = [] stride = self.stride * self.demucs.resample x = frame[None] for idx, encode in enumerate(self.demucs.encoders): stride //= self.demucs.stride length = x.shape[2] if idx == self.demucs.depth - 1: x = fast_conv(encode.conv_a, x) x = encode.act(x) x = fast_conv(encode.conv_b, x) x = encode.glu(x) else: if not_first := self.conv_state is not None: prev = self.conv_state.pop(0) prev = prev[..., stride:] tgt = (length - self.demucs.kernel_size) // self.demucs.stride + 1 missing = tgt - prev.shape[-1] offset = length - self.demucs.kernel_size - self.demucs.stride * (missing - 1) x = x[..., offset:] x = encode.act(encode.conv_a(x)) x = fast_conv(encode.conv_b, x) x = encode.glu(x) if not_first: x = torch.cat([prev, x], -1) next_state.append(x) skips.append(x) x = x.permute(2, 0, 1) x, self.lstm_state = self.demucs.lstm.forward(x, self.lstm_state) x = x.permute(1, 2, 0) # In the following, x contains only correct samples, i.e. the one for # which each time position is covered by two window of the upper layer. # extra contains extra samples to the right, and is used only as a # better padding for the online resampling. extra: Tensor | None = None for idx, decode in enumerate(self.demucs.decoders): skip = skips.pop(-1) x += skip[..., : x.shape[-1]] x = fast_conv(decode.conv_a, x) x = decode.glu(x) if extra is not None: skip = skip[..., x.shape[-1] :] extra += skip[..., : extra.shape[-1]] extra = decode.conv_b(decode.glu(decode.conv_a(extra))) x = decode.conv_b(x) next_state.append(x[..., -self.demucs.stride :] - cast(Tensor, decode.conv_b.bias).view(-1, 1)) if extra is None: extra = x[..., -self.demucs.stride :] else: extra[..., : self.demucs.stride] += next_state[-1] x = x[..., : -self.demucs.stride] if self.conv_state is not None: prev = self.conv_state.pop(0) x[..., : self.demucs.stride] += prev if idx != self.demucs.depth - 1: x = decode.act(x) extra = decode.act(extra) self.conv_state = next_state assert extra is not None, "Extra is None!" return x[0], extra[0]