"""Defines a simple API for using Meta's pretrained Hubert model.
.. highlight:: python
.. code-block:: python
from pretrained.hubert import pretrained_hubert
model = pretrained_hubert("base")
predictor = model.predictor()
# Gets HuBERT embeddings for a waveform.
predictor.predict(torch.randn(1, 16_000), output_layer=None)
# Gets HuBERT embeddings for a long waveform, in batches.
predictor.predict_in_chunks(torch.randn(1, 160_000), 16_000, output_layer=None)
In order to get HuBERT clusters, you can use:
.. highlight:: python
.. code-block:: python
from pretrained.hubert import pretrained_hubert_with_kmeans
model, kmeans = pretrained_hubert_with_kmeans("base-l7-c100")
predictor = model.predictor(kmeans)
# Get the HuBERT tokens for a waveform.
predictor.predict(torch.randn(1, 16_000))
The choices for the model key are:
- ``"base"`` - 12 layers, 768 hidden size, 12 attention heads.
- ``"large"`` - 24 layers, 1024 hidden size, 16 attention heads.
- ``"extra_large"`` - 48 layers, 1280 hidden size, 16 attention heads.
"""
import argparse
from pathlib import Path
from typing import Literal, cast, get_args
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.sox_effects as ta_sox
from ml.models.activations import ActivationType, get_activation
from ml.models.kmeans import KMeans
from ml.models.norms import ConvLayerNorm
from ml.utils.audio import get_audio_props, read_audio
from ml.utils.checkpoint import ensure_downloaded
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer
from torch import Tensor, nn
from torch.nn.utils.parametrizations import weight_norm
PretrainedHubertSize = Literal["base", "large", "extra_large"]
# These clusters were generated by sweeping over a number of different
# hyperparameter configurations and selecting the one with the highest
# cross-entropy between the clusters and the speaker IDs (meaning that
# the clusters should be more speaker-independent).
PretrainedHubertKmeansSize = Literal[
"base-l7-c100",
"base-l7-c200",
"base-l7-c500",
"base-l7-c1000",
"base-l8-c100",
"base-l8-c200",
"base-l8-c500",
"base-l8-c1000",
"base-l10-c100",
"base-l10-c200",
]
DEFAULT_CONV_DIM: tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512)
DEFAULT_CONV_STRIDE: tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2)
DEFAULT_CONV_KERNEL: tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2)
[docs]def cast_pretrained_hubert_size(s: str) -> PretrainedHubertSize:
if s not in get_args(PretrainedHubertSize):
raise KeyError(f"Invalid HuBERT key: {s} Expected one of: {get_args(PretrainedHubertSize)}")
return cast(PretrainedHubertSize, s)
[docs]def cast_pretrained_hubert_kmeans_size(s: str) -> PretrainedHubertKmeansSize:
if s not in get_args(PretrainedHubertKmeansSize):
raise KeyError(f"Invalid HuBERT key: {s} Expected one of: {get_args(PretrainedHubertKmeansSize)}")
return cast(PretrainedHubertKmeansSize, s)
[docs]def normalize_output_layer(output_layer: int | float | None, num_layers: int) -> int | None:
if output_layer is not None:
if isinstance(output_layer, float):
output_layer = round(output_layer * num_layers)
if output_layer < 0:
output_layer += num_layers
if not (0 <= output_layer < num_layers):
raise ValueError(f"output_layer={output_layer} is outside the range of available layers")
return output_layer
[docs]class HubertSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings: int = 128) -> None:
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
[docs] def forward(self, hidden_states: Tensor) -> Tensor:
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
[docs]class PositionalConvEmbedding(nn.Module):
def __init__(
self,
hidden_size: int,
num_conv_pos_embeddings: int = 128,
num_conv_pos_embedding_groups: int = 16,
feat_extract_activation: ActivationType = "gelu",
) -> None:
super().__init__()
conv = nn.Conv1d(
hidden_size,
hidden_size,
kernel_size=num_conv_pos_embeddings,
padding=num_conv_pos_embeddings // 2,
groups=num_conv_pos_embedding_groups,
)
self.conv = weight_norm(conv, dim=2)
self.padding = HubertSamePadLayer(num_conv_pos_embeddings)
self.activation = get_activation(feat_extract_activation)
[docs] def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
[docs]class Attention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"`embed_dim` must be divisible by num_heads (got {self.embed_dim=} and {num_heads=}).")
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: Tensor, seq_len: int, bsz: int) -> Tensor:
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
[docs] def forward(self, hidden_states: Tensor, causal: bool = False) -> Tensor:
"""Runs the HuBERT attention layer.
Args:
hidden_states: Input states for the attention layer.
causal: If set, use causal attention.
Returns:
The attention outputs.
"""
bsz, tgt_len, _ = hidden_states.size()
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=causal)
attn_output = attn_output.transpose(1, 2).flatten(2)
final_output = self.out_proj(attn_output)
return final_output
[docs]class FeedForward(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: ActivationType = "gelu",
hidden_dropout: float = 0.1,
activation_dropout: float = 0.1,
) -> None:
super().__init__()
self.intermediate_dropout = nn.Dropout(activation_dropout)
self.intermediate_dense = nn.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = get_activation(hidden_act)
self.output_dense = nn.Linear(intermediate_size, hidden_size)
self.output_dropout = nn.Dropout(hidden_dropout)
[docs] def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.intermediate_dropout(hidden_states)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.output_dropout(hidden_states)
return hidden_states
[docs]class HubertEncoderLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
hidden_act: ActivationType = "gelu",
layer_norm_eps: float = 1e-5,
hidden_dropout: float = 0.1,
activation_dropout: float = 0.1,
) -> None:
super().__init__()
self.attention = Attention(embed_dim=hidden_size, num_heads=num_attention_heads)
self.dropout = nn.Dropout(hidden_dropout)
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.feed_forward = FeedForward(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout=hidden_dropout,
activation_dropout=activation_dropout,
)
self.final_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
[docs] def forward(self, hidden_states: Tensor, causal: bool = False) -> Tensor:
attn_residual = hidden_states
hidden_states = self.attention.forward(hidden_states, causal=causal)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
[docs]class HubertEncoder(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_hidden_layers: int,
num_conv_pos_embeddings: int = 128,
num_conv_pos_embedding_groups: int = 16,
feat_extract_activation: ActivationType = "gelu",
hidden_act: ActivationType = "gelu",
layer_norm_eps: float = 1e-5,
hidden_dropout: float = 0.1,
activation_dropout: float = 0.1,
) -> None:
super().__init__()
self.pos_conv_embed = PositionalConvEmbedding(
hidden_size=hidden_size,
num_conv_pos_embeddings=num_conv_pos_embeddings,
num_conv_pos_embedding_groups=num_conv_pos_embedding_groups,
feat_extract_activation=feat_extract_activation,
)
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.dropout = nn.Dropout(hidden_dropout)
layers = nn.ModuleList(
[
HubertEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
hidden_dropout=hidden_dropout,
activation_dropout=activation_dropout,
)
for _ in range(num_hidden_layers)
]
)
self.layers = cast(list[HubertEncoderLayer], layers)
self.gradient_checkpointing = False
[docs] def forward(self, hidden_states: Tensor, causal: bool = False, output_layer: int | float | None = None) -> Tensor:
position_embeddings = self.pos_conv_embed.forward(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
output_layer = normalize_output_layer(output_layer, len(self.layers))
for i, layer in enumerate(self.layers):
hidden_states = layer.forward(hidden_states, causal=causal)
if output_layer is not None and i == output_layer:
break
return hidden_states
[docs]class GroupNormConvLayer(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
kernel: int,
bias: bool = True,
feat_extract_activation: ActivationType = "gelu",
) -> None:
super().__init__()
self.in_conv_dim = in_channels
self.out_conv_dim = out_channels
self.conv = nn.Conv1d(self.in_conv_dim, self.out_conv_dim, kernel_size=kernel, stride=stride, bias=bias)
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
self.activation = get_activation(feat_extract_activation)
[docs] def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
[docs]class NoLayerNormConvLayer(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
kernel: int,
bias: bool = True,
feat_extract_activation: ActivationType = "gelu",
) -> None:
super().__init__()
self.in_conv_dim = in_channels
self.out_conv_dim = out_channels
self.conv = nn.Conv1d(self.in_conv_dim, self.out_conv_dim, kernel_size=kernel, stride=stride, bias=bias)
self.activation = get_activation(feat_extract_activation)
[docs] def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
[docs]class LayerNormConvLayer(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
kernel: int,
bias: bool = True,
feat_extract_activation: ActivationType = "gelu",
) -> None:
super().__init__()
self.in_conv_dim = in_channels
self.out_conv_dim = out_channels
self.conv = nn.Conv1d(self.in_conv_dim, self.out_conv_dim, kernel_size=kernel, stride=stride, bias=bias)
self.layer_norm = ConvLayerNorm(self.out_conv_dim, dims=1, elementwise_affine=True)
self.activation = get_activation(feat_extract_activation)
[docs] def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
[docs]class HubertFeatureEncoder(nn.Module):
def __init__(
self,
conv_dim: tuple[int, ...] = DEFAULT_CONV_DIM,
conv_stride: tuple[int, ...] = DEFAULT_CONV_STRIDE,
conv_kernel: tuple[int, ...] = DEFAULT_CONV_KERNEL,
conv_bias: bool = True,
feat_extract_norm: Literal["group", "layer"] = "layer",
feat_extract_activation: ActivationType = "gelu",
) -> None:
super().__init__()
assert len(conv_dim) == len(conv_stride) == len(conv_kernel)
num_feat_extract_layers = len(conv_dim)
conv_layers: list[nn.Module] = []
if feat_extract_norm == "group":
conv_layers += [
GroupNormConvLayer(
in_channels=1,
out_channels=conv_dim[0],
stride=conv_stride[0],
kernel=conv_kernel[0],
bias=conv_bias,
feat_extract_activation=feat_extract_activation,
)
]
for i in range(num_feat_extract_layers - 1):
conv_layers += [
NoLayerNormConvLayer(
in_channels=conv_dim[i],
out_channels=conv_dim[i + 1],
stride=conv_stride[i + 1],
kernel=conv_kernel[i + 1],
bias=conv_bias,
feat_extract_activation=feat_extract_activation,
)
]
elif feat_extract_norm == "layer":
for i in range(num_feat_extract_layers):
conv_layers += [
LayerNormConvLayer(
in_channels=1 if i == 0 else conv_dim[i - 1],
out_channels=conv_dim[i],
stride=conv_stride[i],
kernel=conv_kernel[i],
bias=conv_bias,
feat_extract_activation=feat_extract_activation,
)
]
else:
raise ValueError(f"{feat_extract_norm=}, but has to be one of ['group', 'layer']")
self.conv_layers = nn.ModuleList(conv_layers)
def _freeze_parameters(self) -> None:
for param in self.parameters():
param.requires_grad = False
[docs] def forward(self, input_values: Tensor) -> Tensor:
hidden_states = input_values[:, None]
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
[docs]class HubertFeatureProjection(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
layer_norm_eps: float = 1e-5,
feat_proj_dropout: float = 0.0,
feat_proj_layer_norm: bool = True,
) -> None:
super().__init__()
self.feat_proj_layer_norm = feat_proj_layer_norm
if self.feat_proj_layer_norm:
self.layer_norm = nn.LayerNorm(input_size, eps=layer_norm_eps)
self.projection = nn.Linear(input_size, hidden_size)
self.dropout = nn.Dropout(feat_proj_dropout)
[docs] def forward(self, hidden_states: Tensor) -> Tensor:
if self.feat_proj_layer_norm:
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
[docs]class HubertEncoderLayerStableLayerNorm(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
layer_norm_eps: float = 1e-5,
hidden_act: ActivationType = "gelu",
hidden_dropout: float = 0.1,
activation_dropout: float = 0.1,
) -> None:
super().__init__()
self.attention = Attention(embed_dim=hidden_size, num_heads=num_attention_heads)
self.dropout = nn.Dropout(hidden_dropout)
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.feed_forward = FeedForward(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout=hidden_dropout,
activation_dropout=activation_dropout,
)
self.final_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
[docs] def forward(self, hidden_states: Tensor, causal: bool = False) -> Tensor:
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.attention.forward(hidden_states, causal=causal)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward.forward(self.final_layer_norm(hidden_states))
return hidden_states
[docs]class HubertEncoderStableLayerNorm(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_hidden_layers: int,
num_conv_pos_embeddings: int = 128,
num_conv_pos_embedding_groups: int = 16,
hidden_act: ActivationType = "gelu",
feat_extract_activation: ActivationType = "gelu",
layer_norm_eps: float = 1e-5,
hidden_dropout: float = 0.1,
activation_dropout: float = 0.1,
) -> None:
super().__init__()
self.pos_conv_embed = PositionalConvEmbedding(
hidden_size=hidden_size,
num_conv_pos_embeddings=num_conv_pos_embeddings,
num_conv_pos_embedding_groups=num_conv_pos_embedding_groups,
feat_extract_activation=feat_extract_activation,
)
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.dropout = nn.Dropout(hidden_dropout)
layers = [
HubertEncoderLayerStableLayerNorm(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
layer_norm_eps=layer_norm_eps,
hidden_act=hidden_act,
hidden_dropout=hidden_dropout,
activation_dropout=activation_dropout,
)
for _ in range(num_hidden_layers)
]
self.layers = cast(list[HubertEncoderLayerStableLayerNorm], nn.ModuleList(layers))
[docs] def forward(self, hidden_states: Tensor, causal: bool = False, output_layer: int | float | None = None) -> Tensor:
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
output_layer = normalize_output_layer(output_layer, len(self.layers))
for i, layer in enumerate(self.layers):
hidden_states = layer.forward(hidden_states, causal=causal)
if output_layer is not None and i == output_layer:
break
hidden_states = self.layer_norm(hidden_states)
return hidden_states
[docs]class Hubert(nn.Module):
__constants__ = ["conv_kernel", "conv_stride", "pre_normalize", "hidden_size"]
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
conv_dim: tuple[int, ...] = DEFAULT_CONV_DIM,
conv_stride: tuple[int, ...] = DEFAULT_CONV_STRIDE,
conv_kernel: tuple[int, ...] = DEFAULT_CONV_KERNEL,
conv_bias: bool = True,
num_conv_pos_embeddings: int = 128,
num_conv_pos_embedding_groups: int = 16,
do_stable_layer_norm: bool = True,
pre_normalize: bool = True,
feat_extract_norm: Literal["group", "layer"] = "layer",
feat_extract_activation: ActivationType = "gelu",
feat_proj_layer_norm: bool = True,
hidden_act: ActivationType = "gelu",
layer_norm_eps: float = 1e-5,
hidden_dropout: float = 0.1,
activation_dropout: float = 0.1,
feat_proj_dropout: float = 0.0,
) -> None:
super().__init__()
self.conv_kernel = conv_kernel
self.conv_stride = conv_stride
self.pre_normalize = pre_normalize
self.hidden_size = hidden_size
self.feature_extractor = HubertFeatureEncoder(
conv_dim=conv_dim,
conv_stride=conv_stride,
conv_kernel=conv_kernel,
conv_bias=conv_bias,
feat_extract_norm=feat_extract_norm,
feat_extract_activation=feat_extract_activation,
)
self.feature_projection = HubertFeatureProjection(
input_size=conv_dim[-1],
hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps,
feat_proj_dropout=feat_proj_dropout,
feat_proj_layer_norm=feat_proj_layer_norm,
)
self.encoder: HubertEncoderStableLayerNorm | HubertEncoder
if do_stable_layer_norm:
self.encoder = HubertEncoderStableLayerNorm(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
num_conv_pos_embeddings=num_conv_pos_embeddings,
num_conv_pos_embedding_groups=num_conv_pos_embedding_groups,
hidden_act=hidden_act,
feat_extract_activation=feat_extract_activation,
layer_norm_eps=layer_norm_eps,
hidden_dropout=hidden_dropout,
activation_dropout=activation_dropout,
)
else:
self.encoder = HubertEncoder(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
num_conv_pos_embeddings=num_conv_pos_embeddings,
num_conv_pos_embedding_groups=num_conv_pos_embedding_groups,
hidden_act=hidden_act,
feat_extract_activation=feat_extract_activation,
layer_norm_eps=layer_norm_eps,
hidden_dropout=hidden_dropout,
activation_dropout=activation_dropout,
)
[docs] def set_output_layer(self, output_layer: int | float) -> None:
output_layer = normalize_output_layer(output_layer, len(self.encoder.layers))
del self.encoder.layers[output_layer:]
[docs] def forward(
self,
input_values: Tensor,
sample_rate: int,
causal: bool = False,
output_layer: int | float | None = None,
) -> Tensor:
if sample_rate != 16_000:
raise RuntimeError("HuBERT only supports 16 kHz as input sampling rate")
if self.pre_normalize:
input_values = F.layer_norm(input_values, input_values.shape[1:])
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
hidden_states = self.feature_projection(extract_features)
return self.encoder.forward(hidden_states, causal=causal, output_layer=output_layer)
[docs] def predictor(
self,
kmeans: KMeans | None = None,
*,
device: base_device | None = None,
) -> "HubertPredictor":
return HubertPredictor(self, kmeans, device=device)
[docs]class HubertPredictor:
def __init__(
self,
hubert_model: Hubert,
kmeans: KMeans | None = None,
*,
device: base_device | None = None,
) -> None:
"""Provides an API for doing predictoins with a HuBERT model.
Note that this module is not an `nn.Module`, so you can use it in your
module without worrying about storing all the weights on accident.
Args:
hubert_model: The HuBERT model to use for predictions.
kmeans: The kmeans model to use for quantization. If `None`, don't
quantize.
device: The device to use for predictions. If `None`, will use the
device returned by detect_device().
"""
super().__init__()
self.device = detect_device() if device is None else device
self.model = hubert_model.eval()
self.kmeans = kmeans.eval() if kmeans is not None else None
self.device.module_to(self.model)
if self.kmeans is not None:
self.device.module_to(self.kmeans)
self.sample_rate = 16_000 # True for all HuBERT models.
[docs] def predict(
self,
waveform: np.ndarray | Tensor,
sample_rate: int,
output_layer: int | float | None = None,
causal: bool = False,
) -> Tensor:
"""Gets the hidden states for the given waveform.
Args:
waveform: The waveform to get hidden states for, with shape (B, T)
sample_rate: The waveform's sampling rate; this is only used to
verify that it is 16 kHz, since it is easy for downstream
applications to forget.
output_layer: The layer to get hidden states from. If `None`, will
return the hidden states from the last layer. If an `int`, will
return the hidden states from that layer. If a `float`, will
return the hidden states from the layer at that percentage of
the model. For example, `0.5` will return the hidden states
from the middle layer. Negative values will wrap around.
causal: If set, use a causal attention mask.
Returns:
The hidden states for the given waveform, with shape (B, T, D)
"""
waveform = self.device.tensor_to(waveform)
features = self.model.forward(waveform, sample_rate, causal=causal, output_layer=output_layer)
if self.kmeans is not None:
features = self.kmeans.forward(features)
return features
[docs] def predict_in_chunks(
self,
waveform: Tensor | np.ndarray,
sample_rate: int,
chunk_size: int = 16_000 * 10,
output_layer: int | float | None = None,
causal: bool = False,
) -> Tensor:
"""Gets the hidden states for the given waveform, in chunks.
This is useful for processing very long waveforms, as it allows you to
process the waveform in chunks, rather than loading the entire waveform
into memory at once.
Args:
waveform: The waveform to get hidden states for, with shape (B, T)
sample_rate: The waveform's sampling rate; this is only used to
verify that it is 16 kHz, since it is easy for downstream
applications to forget.
chunk_size: The size of each chunk to process, in frames.
output_layer: The layer to get hidden states from. If `None`, will
return the hidden states from the last layer. If an `int`, will
return the hidden states from that layer. If a `float`, will
return the hidden states from the layer at that percentage of
the model. For example, `0.5` will return the hidden states
from the middle layer. Negative values will wrap around.
causal: If set, use a causal attention mask.
Returns:
The hidden states for the given waveform, with shape (B, T, D)
"""
with torch.inference_mode(), self.device.autocast_context():
x = self.device.tensor_to(waveform) # Loads entire waveform into device memory.
if self.model.pre_normalize:
x = F.layer_norm(x, x.shape)
feat = []
for start in range(0, x.size(1), chunk_size):
x_chunk = x[:, start : start + chunk_size]
feat_chunk = self.model.forward(x_chunk, sample_rate, causal=causal, output_layer=output_layer)
if self.kmeans is not None:
feat_chunk = self.kmeans.forward(feat_chunk)
feat.append(feat_chunk.cpu())
return torch.cat(feat, 1).squeeze(0)
[docs] def predict_file(
self,
path: str | Path,
chunk_length_sec: float = 10.0,
output_layer: int | float | None = None,
causal: bool = False,
) -> Tensor:
"""Gets the hidden states for the given audio file, in chunks.
Args:
path: The path to the audio file to process.
sample_rate: The waveform's sampling rate; this is only used to
verify that it is 16 kHz, since it is easy for downstream
applications to forget.
chunk_length_sec: The length of each chunk to process, in seconds.
output_layer: The layer to get hidden states from. If `None`, will
return the hidden states from the last layer. If an `int`, will
return the hidden states from that layer. If a `float`, will
return the hidden states from the layer at that percentage of
the model. For example, `0.5` will return the hidden states
from the middle layer. Negative values will wrap around.
causal: If set, use a causal attention mask.
Returns:
The hidden states for the given waveform, with shape (B, T, D)
"""
props = get_audio_props(path)
effects: list[tuple[str, str]] = [("gain", "-n"), ("channels", "1")]
if props.sample_rate != self.sample_rate:
effects.append(("rate", str(self.sample_rate)))
chunk_length = round(chunk_length_sec * self.sample_rate)
with torch.inference_mode(), self.device.autocast_context():
feat = []
for waveform_chunk in read_audio(
path,
chunk_length=chunk_length,
sample_rate=self.sample_rate,
):
waveform_tensor = torch.from_numpy(waveform_chunk).to(torch.float32)
waveform_tensor, _ = ta_sox.apply_effects_tensor(waveform_tensor, props.sample_rate, effects)
chans, _ = waveform_tensor.shape
assert chans == 1, f"Expected mono-channel audio, got {chans} channels"
x = self.device.tensor_to(waveform_tensor)
if self.model.pre_normalize:
x = F.layer_norm(x, x.shape[1:])
feat_chunk = self.model.forward(x, self.sample_rate, causal=causal, output_layer=output_layer)
if self.kmeans is not None:
feat_chunk = self.kmeans.forward(feat_chunk)
feat.append(feat_chunk.cpu())
return torch.cat(feat, 1).squeeze(0)
EXCLUDE_KEYS = {"masked_spec_embed", ".weight", ".bias"}
def _load_pretrained_hubert(
size: PretrainedHubertSize,
ckpt_url: str,
sha256: str,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
remove_prefix: str | None = None,
load_weights: bool = True,
conv_dim: tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512),
conv_stride: tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2),
conv_kernel: tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2),
conv_bias: bool = True,
num_conv_pos_embeddings: int = 128,
num_conv_pos_embedding_groups: int = 16,
do_stable_layer_norm: bool = True,
pre_normalize: bool = True,
feat_extract_norm: Literal["group", "layer"] = "layer",
feat_extract_activation: ActivationType = "gelu",
feat_proj_layer_norm: bool = True,
hidden_act: ActivationType = "gelu",
layer_norm_eps: float = 1e-5,
hidden_dropout: float = 0.1,
activation_dropout: float = 0.1,
feat_proj_dropout: float = 0.0,
) -> Hubert:
with Timer("building empty model", spinner=True):
model = Hubert(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
conv_dim=conv_dim,
conv_stride=conv_stride,
conv_kernel=conv_kernel,
conv_bias=conv_bias,
num_conv_pos_embeddings=num_conv_pos_embeddings,
num_conv_pos_embedding_groups=num_conv_pos_embedding_groups,
do_stable_layer_norm=do_stable_layer_norm,
pre_normalize=pre_normalize,
feat_extract_norm=feat_extract_norm,
feat_extract_activation=feat_extract_activation,
feat_proj_layer_norm=feat_proj_layer_norm,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
hidden_dropout=hidden_dropout,
activation_dropout=activation_dropout,
feat_proj_dropout=feat_proj_dropout,
)
# Loads the model weights.
if load_weights:
model_fname = f"{size}.bin"
with Timer("downloading checkpoint"):
model_path = ensure_downloaded(ckpt_url, "hubert", model_fname, sha256=sha256)
with Timer("loading checkpoint", spinner=True):
ckpt = torch.load(model_path, map_location="cpu")
if remove_prefix:
ckpt = {k[len(remove_prefix) :]: v for k, v in ckpt.items()}
ckpt = {k: v for k, v in ckpt.items() if k not in EXCLUDE_KEYS}
model.load_state_dict(ckpt)
return model
def _load_pretrained_hubert_kmeans(
size: PretrainedHubertKmeansSize,
ckpt_url: str,
sha256: str,
use_triton_if_available: bool = True,
) -> KMeans:
centers_fname = f"{size}.npy"
with Timer("downloading cluster centers"):
centers_path = ensure_downloaded(ckpt_url, "hubert", centers_fname, sha256=sha256)
with Timer("loading K-means clusters", spinner=True):
centers = np.load(centers_path)
return KMeans(centers, use_triton_if_available=use_triton_if_available)
[docs]def pretrained_hubert(size: PretrainedHubertSize, load_weights: bool = True) -> Hubert:
match size:
case "base":
return _load_pretrained_hubert(
size,
ckpt_url="https://huggingface.co/facebook/hubert-base-ls960/resolve/main/pytorch_model.bin",
sha256="062249fffb353eab67547a2fbc129f7c31a2f459faf641b19e8fb007cc5c48ad",
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
load_weights=load_weights,
feat_extract_norm="group",
conv_bias=False,
do_stable_layer_norm=False,
pre_normalize=False,
)
case "large":
return _load_pretrained_hubert(
size,
ckpt_url="https://huggingface.co/facebook/hubert-large-ls960-ft/resolve/main/pytorch_model.bin",
sha256="9cf43abec3f0410ad6854afa4d376c69ccb364b48ddddfd25c4c5aa16398eab0",
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
remove_prefix="hubert.",
load_weights=load_weights,
)
case "extra_large":
return _load_pretrained_hubert(
size,
ckpt_url="https://huggingface.co/facebook/hubert-xlarge-ll60k/resolve/main/pytorch_model.bin",
sha256="6131dc27f4508595daa1a13fec4aa1f6b4a579b5d93550bae26c13a83221f8a7",
hidden_size=1280,
num_hidden_layers=48,
num_attention_heads=16,
intermediate_size=5120,
load_weights=load_weights,
)
case _:
raise NotImplementedError(f"Invalid size: {size}")
[docs]def pretrained_kmeans_clusters(size: PretrainedHubertKmeansSize) -> KMeans:
url_base = "https://huggingface.co/codekansas/hubert-quantization/resolve/main"
match size:
case "base-l7-c100":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_7_sklearn_100.npy",
sha256="e46d1e2a5d6f83805dd336cf22a4228a902e78c3377141b4aa8e8c946af160cb",
)
case "base-l7-c200":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_7_sklearn_200.npy",
sha256="5bce95ff25b8e3e07170f73bfcf7a5c72a432a9acd3382e833409a30a41ce062",
)
case "base-l7-c500":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_7_sklearn_500.npy",
sha256="ce9855b89955affbf8e939ff274a4938efee730d4fb4fab990070747744b9df0",
)
case "base-l7-c1000":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_7_sklearn_1000.npy",
sha256="6a10e5978bac1b84a3b0e03bb72e3015d0cdf6956e301a48971eb3a2493e37c5",
)
case "base-l8-c100":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_8_sklearn_100.npy",
sha256="3219a01b5ec21ca173605fe5b2d7b296db1a10ef24e5c593c8076b1b39f96865",
)
case "base-l8-c200":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_8_sklearn_200.npy",
sha256="0beab85b59604841da10b3327bedc710e0dbf8e4a2b24bc0d964bf345640e9d7",
)
case "base-l8-c500":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_8_sklearn_500.npy",
sha256="4a06731ef6d8aa116ae05ec309ad1ae47b7c030f05bc62137899b17d32fd294a",
)
case "base-l8-c1000":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_8_sklearn_1000.npy",
sha256="15be942383cf9e5afc3d6f0d615ab6dc8459364129dc1a02ee00f8c927783aae",
)
case "base-l10-c100":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_10_sklearn_100.npy",
sha256="22918f566c2308c14e9514830ddc669a26dfbed2668ce27eb8f57801b78d27b7",
)
case "base-l20-c200":
return _load_pretrained_hubert_kmeans(
size,
ckpt_url=f"{url_base}/kmeans_base_10_sklearn_200.npy",
sha256="92cff38a19df79303a8d18c98fb776cbed299b92944195a3793eda16a7f6da97",
)
case _:
raise NotImplementedError(f"Invalid size: {size}")
[docs]def pretrained_hubert_with_kmeans(
size: PretrainedHubertKmeansSize,
load_weights: bool = True,
) -> tuple[Hubert, KMeans]:
kmeans = pretrained_kmeans_clusters(size)
match size:
case "base-l7-c100":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(7)
return hubert, kmeans
case "base-l7-c200":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(7)
return hubert, kmeans
case "base-l7-c500":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(7)
return hubert, kmeans
case "base-l7-c1000":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(7)
return hubert, kmeans
case "base-l8-c100":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(8)
return hubert, kmeans
case "base-l8-c200":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(8)
return hubert, kmeans
case "base-l8-c500":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(8)
return hubert, kmeans
case "base-l8-c1000":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(8)
return hubert, kmeans
case "base-l10-c100":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(10)
return hubert, kmeans
case "base-l10-c200":
hubert = pretrained_hubert("base", load_weights=load_weights)
hubert.set_output_layer(10)
return hubert, kmeans
case _:
raise NotImplementedError(f"Invalid size: {size}")
[docs]def test_hubert_adhoc() -> None:
configure_logging()
parser = argparse.ArgumentParser()
size_choices = get_args(PretrainedHubertSize) + get_args(PretrainedHubertKmeansSize)
parser.add_argument("size", type=str, choices=size_choices)
parser.add_argument("-t", "--tsz", type=int, default=22400)
parser.add_argument("-n", "--no-load-weights", default=False, action="store_true")
parser.add_argument("-c", "--causal", default=False, action="store_true")
args = parser.parse_args()
# Loads the model and moves to the right device.
kmeans: KMeans | None
if args.size in get_args(PretrainedHubertSize):
model = pretrained_hubert(size=cast(PretrainedHubertSize, args.size), load_weights=not args.no_load_weights)
kmeans = None
else:
model, kmeans = pretrained_hubert_with_kmeans(size=cast(PretrainedHubertKmeansSize, args.size))
predictor = model.predictor(kmeans)
# Test the model on a random waveform.
y = predictor.predict(torch.randn(1, args.tsz), sample_rate=16000, causal=args.causal)
assert (args.tsz // 320) == y.shape[1] + 1
if __name__ == "__main__":
# python -m pretrained.hubert
test_hubert_adhoc()