Source code for pretrained.hubert

"""Defines a simple API for using Meta's pretrained Hubert model.

.. highlight:: python
.. code-block:: python

    from pretrained.hubert import pretrained_hubert

    model = pretrained_hubert("base")
    predictor = model.predictor()

    # Gets HuBERT embeddings for a waveform.
    predictor.predict(torch.randn(1, 16_000), output_layer=None)

    # Gets HuBERT embeddings for a long waveform, in batches.
    predictor.predict_in_chunks(torch.randn(1, 160_000), 16_000, output_layer=None)

In order to get HuBERT clusters, you can use:

.. highlight:: python
.. code-block:: python

    from pretrained.hubert import pretrained_hubert_with_kmeans

    model, kmeans = pretrained_hubert_with_kmeans("base-l7-c100")
    predictor = model.predictor(kmeans)

    # Get the HuBERT tokens for a waveform.
    predictor.predict(torch.randn(1, 16_000))

The choices for the model key are:

- ``"base"`` - 12 layers, 768 hidden size, 12 attention heads.
- ``"large"`` - 24 layers, 1024 hidden size, 16 attention heads.
- ``"extra_large"`` - 48 layers, 1280 hidden size, 16 attention heads.
"""

import argparse
from pathlib import Path
from typing import Literal, cast, get_args

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.sox_effects as ta_sox
from ml.models.activations import ActivationType, get_activation
from ml.models.kmeans import KMeans
from ml.models.norms import ConvLayerNorm
from ml.utils.audio import get_audio_props, read_audio
from ml.utils.checkpoint import ensure_downloaded
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer
from torch import Tensor, nn
from torch.nn.utils.parametrizations import weight_norm

PretrainedHubertSize = Literal["base", "large", "extra_large"]

# These clusters were generated by sweeping over a number of different
# hyperparameter configurations and selecting the one with the highest
# cross-entropy between the clusters and the speaker IDs (meaning that
# the clusters should be more speaker-independent).
PretrainedHubertKmeansSize = Literal[
    "base-l7-c100",
    "base-l7-c200",
    "base-l7-c500",
    "base-l7-c1000",
    "base-l8-c100",
    "base-l8-c200",
    "base-l8-c500",
    "base-l8-c1000",
    "base-l10-c100",
    "base-l10-c200",
]

DEFAULT_CONV_DIM: tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512)
DEFAULT_CONV_STRIDE: tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2)
DEFAULT_CONV_KERNEL: tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2)


[docs]def cast_pretrained_hubert_size(s: str) -> PretrainedHubertSize: if s not in get_args(PretrainedHubertSize): raise KeyError(f"Invalid HuBERT key: {s} Expected one of: {get_args(PretrainedHubertSize)}") return cast(PretrainedHubertSize, s)
[docs]def cast_pretrained_hubert_kmeans_size(s: str) -> PretrainedHubertKmeansSize: if s not in get_args(PretrainedHubertKmeansSize): raise KeyError(f"Invalid HuBERT key: {s} Expected one of: {get_args(PretrainedHubertKmeansSize)}") return cast(PretrainedHubertKmeansSize, s)
[docs]def normalize_output_layer(output_layer: int | float | None, num_layers: int) -> int | None: if output_layer is not None: if isinstance(output_layer, float): output_layer = round(output_layer * num_layers) if output_layer < 0: output_layer += num_layers if not (0 <= output_layer < num_layers): raise ValueError(f"output_layer={output_layer} is outside the range of available layers") return output_layer
[docs]class HubertSamePadLayer(nn.Module): def __init__(self, num_conv_pos_embeddings: int = 128) -> None: super().__init__() self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
[docs] def forward(self, hidden_states: Tensor) -> Tensor: if self.num_pad_remove > 0: hidden_states = hidden_states[:, :, : -self.num_pad_remove] return hidden_states
[docs]class PositionalConvEmbedding(nn.Module): def __init__( self, hidden_size: int, num_conv_pos_embeddings: int = 128, num_conv_pos_embedding_groups: int = 16, feat_extract_activation: ActivationType = "gelu", ) -> None: super().__init__() conv = nn.Conv1d( hidden_size, hidden_size, kernel_size=num_conv_pos_embeddings, padding=num_conv_pos_embeddings // 2, groups=num_conv_pos_embedding_groups, ) self.conv = weight_norm(conv, dim=2) self.padding = HubertSamePadLayer(num_conv_pos_embeddings) self.activation = get_activation(feat_extract_activation)
[docs] def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = hidden_states.transpose(1, 2) hidden_states = self.conv(hidden_states) hidden_states = self.padding(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = hidden_states.transpose(1, 2) return hidden_states
[docs]class Attention(nn.Module): def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None: super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads if (self.head_dim * num_heads) != self.embed_dim: raise ValueError(f"`embed_dim` must be divisible by num_heads (got {self.embed_dim=} and {num_heads=}).") self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: Tensor, seq_len: int, bsz: int) -> Tensor: return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
[docs] def forward(self, hidden_states: Tensor, causal: bool = False) -> Tensor: """Runs the HuBERT attention layer. Args: hidden_states: Input states for the attention layer. causal: If set, use causal attention. Returns: The attention outputs. """ bsz, tgt_len, _ = hidden_states.size() query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=causal) attn_output = attn_output.transpose(1, 2).flatten(2) final_output = self.out_proj(attn_output) return final_output
[docs]class FeedForward(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: ActivationType = "gelu", hidden_dropout: float = 0.1, activation_dropout: float = 0.1, ) -> None: super().__init__() self.intermediate_dropout = nn.Dropout(activation_dropout) self.intermediate_dense = nn.Linear(hidden_size, intermediate_size) self.intermediate_act_fn = get_activation(hidden_act) self.output_dense = nn.Linear(intermediate_size, hidden_size) self.output_dropout = nn.Dropout(hidden_dropout)
[docs] def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.intermediate_dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_dropout(hidden_states) hidden_states = self.output_dense(hidden_states) hidden_states = self.output_dropout(hidden_states) return hidden_states
[docs]class HubertEncoderLayer(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, num_attention_heads: int, hidden_act: ActivationType = "gelu", layer_norm_eps: float = 1e-5, hidden_dropout: float = 0.1, activation_dropout: float = 0.1, ) -> None: super().__init__() self.attention = Attention(embed_dim=hidden_size, num_heads=num_attention_heads) self.dropout = nn.Dropout(hidden_dropout) self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.feed_forward = FeedForward( hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act, hidden_dropout=hidden_dropout, activation_dropout=activation_dropout, ) self.final_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
[docs] def forward(self, hidden_states: Tensor, causal: bool = False) -> Tensor: attn_residual = hidden_states hidden_states = self.attention.forward(hidden_states, causal=causal) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) return hidden_states
[docs]class HubertEncoder(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, num_attention_heads: int, num_hidden_layers: int, num_conv_pos_embeddings: int = 128, num_conv_pos_embedding_groups: int = 16, feat_extract_activation: ActivationType = "gelu", hidden_act: ActivationType = "gelu", layer_norm_eps: float = 1e-5, hidden_dropout: float = 0.1, activation_dropout: float = 0.1, ) -> None: super().__init__() self.pos_conv_embed = PositionalConvEmbedding( hidden_size=hidden_size, num_conv_pos_embeddings=num_conv_pos_embeddings, num_conv_pos_embedding_groups=num_conv_pos_embedding_groups, feat_extract_activation=feat_extract_activation, ) self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.dropout = nn.Dropout(hidden_dropout) layers = nn.ModuleList( [ HubertEncoderLayer( hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, hidden_act=hidden_act, layer_norm_eps=layer_norm_eps, hidden_dropout=hidden_dropout, activation_dropout=activation_dropout, ) for _ in range(num_hidden_layers) ] ) self.layers = cast(list[HubertEncoderLayer], layers) self.gradient_checkpointing = False
[docs] def forward(self, hidden_states: Tensor, causal: bool = False, output_layer: int | float | None = None) -> Tensor: position_embeddings = self.pos_conv_embed.forward(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) output_layer = normalize_output_layer(output_layer, len(self.layers)) for i, layer in enumerate(self.layers): hidden_states = layer.forward(hidden_states, causal=causal) if output_layer is not None and i == output_layer: break return hidden_states
[docs] def extract_all_features( self, hidden_states: Tensor, causal: bool = False, output_layer: int | float | None = None, ) -> list[Tensor]: position_embeddings = self.pos_conv_embed.forward(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) output_layer = normalize_output_layer(output_layer, len(self.layers)) all_layer_hidden_states = [] for i, layer in enumerate(self.layers): hidden_states = layer.forward(hidden_states, causal=causal) all_layer_hidden_states.append(hidden_states) if output_layer is not None and i == output_layer: break return all_layer_hidden_states
[docs]class GroupNormConvLayer(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: int, kernel: int, bias: bool = True, feat_extract_activation: ActivationType = "gelu", ) -> None: super().__init__() self.in_conv_dim = in_channels self.out_conv_dim = out_channels self.conv = nn.Conv1d(self.in_conv_dim, self.out_conv_dim, kernel_size=kernel, stride=stride, bias=bias) self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) self.activation = get_activation(feat_extract_activation)
[docs] def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.conv(hidden_states) hidden_states = self.layer_norm(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states
[docs]class NoLayerNormConvLayer(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: int, kernel: int, bias: bool = True, feat_extract_activation: ActivationType = "gelu", ) -> None: super().__init__() self.in_conv_dim = in_channels self.out_conv_dim = out_channels self.conv = nn.Conv1d(self.in_conv_dim, self.out_conv_dim, kernel_size=kernel, stride=stride, bias=bias) self.activation = get_activation(feat_extract_activation)
[docs] def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.conv(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states
[docs]class LayerNormConvLayer(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: int, kernel: int, bias: bool = True, feat_extract_activation: ActivationType = "gelu", ) -> None: super().__init__() self.in_conv_dim = in_channels self.out_conv_dim = out_channels self.conv = nn.Conv1d(self.in_conv_dim, self.out_conv_dim, kernel_size=kernel, stride=stride, bias=bias) self.layer_norm = ConvLayerNorm(self.out_conv_dim, dims=1, elementwise_affine=True) self.activation = get_activation(feat_extract_activation)
[docs] def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.conv(hidden_states) hidden_states = self.layer_norm(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states
[docs]class HubertFeatureEncoder(nn.Module): def __init__( self, conv_dim: tuple[int, ...] = DEFAULT_CONV_DIM, conv_stride: tuple[int, ...] = DEFAULT_CONV_STRIDE, conv_kernel: tuple[int, ...] = DEFAULT_CONV_KERNEL, conv_bias: bool = True, feat_extract_norm: Literal["group", "layer"] = "layer", feat_extract_activation: ActivationType = "gelu", ) -> None: super().__init__() assert len(conv_dim) == len(conv_stride) == len(conv_kernel) num_feat_extract_layers = len(conv_dim) conv_layers: list[nn.Module] = [] if feat_extract_norm == "group": conv_layers += [ GroupNormConvLayer( in_channels=1, out_channels=conv_dim[0], stride=conv_stride[0], kernel=conv_kernel[0], bias=conv_bias, feat_extract_activation=feat_extract_activation, ) ] for i in range(num_feat_extract_layers - 1): conv_layers += [ NoLayerNormConvLayer( in_channels=conv_dim[i], out_channels=conv_dim[i + 1], stride=conv_stride[i + 1], kernel=conv_kernel[i + 1], bias=conv_bias, feat_extract_activation=feat_extract_activation, ) ] elif feat_extract_norm == "layer": for i in range(num_feat_extract_layers): conv_layers += [ LayerNormConvLayer( in_channels=1 if i == 0 else conv_dim[i - 1], out_channels=conv_dim[i], stride=conv_stride[i], kernel=conv_kernel[i], bias=conv_bias, feat_extract_activation=feat_extract_activation, ) ] else: raise ValueError(f"{feat_extract_norm=}, but has to be one of ['group', 'layer']") self.conv_layers = nn.ModuleList(conv_layers) def _freeze_parameters(self) -> None: for param in self.parameters(): param.requires_grad = False
[docs] def forward(self, input_values: Tensor) -> Tensor: hidden_states = input_values[:, None] for conv_layer in self.conv_layers: hidden_states = conv_layer(hidden_states) return hidden_states
[docs]class HubertFeatureProjection(nn.Module): def __init__( self, input_size: int, hidden_size: int, layer_norm_eps: float = 1e-5, feat_proj_dropout: float = 0.0, feat_proj_layer_norm: bool = True, ) -> None: super().__init__() self.feat_proj_layer_norm = feat_proj_layer_norm if self.feat_proj_layer_norm: self.layer_norm = nn.LayerNorm(input_size, eps=layer_norm_eps) self.projection = nn.Linear(input_size, hidden_size) self.dropout = nn.Dropout(feat_proj_dropout)
[docs] def forward(self, hidden_states: Tensor) -> Tensor: if self.feat_proj_layer_norm: hidden_states = self.layer_norm(hidden_states) hidden_states = self.projection(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states
[docs]class HubertEncoderLayerStableLayerNorm(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, num_attention_heads: int, layer_norm_eps: float = 1e-5, hidden_act: ActivationType = "gelu", hidden_dropout: float = 0.1, activation_dropout: float = 0.1, ) -> None: super().__init__() self.attention = Attention(embed_dim=hidden_size, num_heads=num_attention_heads) self.dropout = nn.Dropout(hidden_dropout) self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.feed_forward = FeedForward( hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act, hidden_dropout=hidden_dropout, activation_dropout=activation_dropout, ) self.final_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
[docs] def forward(self, hidden_states: Tensor, causal: bool = False) -> Tensor: attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = self.attention.forward(hidden_states, causal=causal) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward.forward(self.final_layer_norm(hidden_states)) return hidden_states
[docs]class HubertEncoderStableLayerNorm(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, num_attention_heads: int, num_hidden_layers: int, num_conv_pos_embeddings: int = 128, num_conv_pos_embedding_groups: int = 16, hidden_act: ActivationType = "gelu", feat_extract_activation: ActivationType = "gelu", layer_norm_eps: float = 1e-5, hidden_dropout: float = 0.1, activation_dropout: float = 0.1, ) -> None: super().__init__() self.pos_conv_embed = PositionalConvEmbedding( hidden_size=hidden_size, num_conv_pos_embeddings=num_conv_pos_embeddings, num_conv_pos_embedding_groups=num_conv_pos_embedding_groups, feat_extract_activation=feat_extract_activation, ) self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.dropout = nn.Dropout(hidden_dropout) layers = [ HubertEncoderLayerStableLayerNorm( hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, layer_norm_eps=layer_norm_eps, hidden_act=hidden_act, hidden_dropout=hidden_dropout, activation_dropout=activation_dropout, ) for _ in range(num_hidden_layers) ] self.layers = cast(list[HubertEncoderLayerStableLayerNorm], nn.ModuleList(layers))
[docs] def forward(self, hidden_states: Tensor, causal: bool = False, output_layer: int | float | None = None) -> Tensor: position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) output_layer = normalize_output_layer(output_layer, len(self.layers)) for i, layer in enumerate(self.layers): hidden_states = layer.forward(hidden_states, causal=causal) if output_layer is not None and i == output_layer: break hidden_states = self.layer_norm(hidden_states) return hidden_states
[docs] def extract_all_features( self, hidden_states: Tensor, causal: bool = False, output_layer: int | float | None = None, ) -> list[Tensor]: position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) output_layer = normalize_output_layer(output_layer, len(self.layers)) all_layers_hidden_states = [] for i, layer in enumerate(self.layers): hidden_states = layer.forward(hidden_states, causal=causal) all_layers_hidden_states.append(hidden_states) if output_layer is not None and i == output_layer: break return all_layers_hidden_states
[docs]class Hubert(nn.Module): __constants__ = ["conv_kernel", "conv_stride", "pre_normalize", "hidden_size"] def __init__( self, hidden_size: int, intermediate_size: int, num_hidden_layers: int, num_attention_heads: int, conv_dim: tuple[int, ...] = DEFAULT_CONV_DIM, conv_stride: tuple[int, ...] = DEFAULT_CONV_STRIDE, conv_kernel: tuple[int, ...] = DEFAULT_CONV_KERNEL, conv_bias: bool = True, num_conv_pos_embeddings: int = 128, num_conv_pos_embedding_groups: int = 16, do_stable_layer_norm: bool = True, pre_normalize: bool = True, feat_extract_norm: Literal["group", "layer"] = "layer", feat_extract_activation: ActivationType = "gelu", feat_proj_layer_norm: bool = True, hidden_act: ActivationType = "gelu", layer_norm_eps: float = 1e-5, hidden_dropout: float = 0.1, activation_dropout: float = 0.1, feat_proj_dropout: float = 0.0, ) -> None: super().__init__() self.conv_kernel = conv_kernel self.conv_stride = conv_stride self.pre_normalize = pre_normalize self.hidden_size = hidden_size self.feature_extractor = HubertFeatureEncoder( conv_dim=conv_dim, conv_stride=conv_stride, conv_kernel=conv_kernel, conv_bias=conv_bias, feat_extract_norm=feat_extract_norm, feat_extract_activation=feat_extract_activation, ) self.feature_projection = HubertFeatureProjection( input_size=conv_dim[-1], hidden_size=hidden_size, layer_norm_eps=layer_norm_eps, feat_proj_dropout=feat_proj_dropout, feat_proj_layer_norm=feat_proj_layer_norm, ) self.encoder: HubertEncoderStableLayerNorm | HubertEncoder if do_stable_layer_norm: self.encoder = HubertEncoderStableLayerNorm( hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers, num_conv_pos_embeddings=num_conv_pos_embeddings, num_conv_pos_embedding_groups=num_conv_pos_embedding_groups, hidden_act=hidden_act, feat_extract_activation=feat_extract_activation, layer_norm_eps=layer_norm_eps, hidden_dropout=hidden_dropout, activation_dropout=activation_dropout, ) else: self.encoder = HubertEncoder( hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers, num_conv_pos_embeddings=num_conv_pos_embeddings, num_conv_pos_embedding_groups=num_conv_pos_embedding_groups, hidden_act=hidden_act, feat_extract_activation=feat_extract_activation, layer_norm_eps=layer_norm_eps, hidden_dropout=hidden_dropout, activation_dropout=activation_dropout, )
[docs] def set_output_layer(self, output_layer: int | float) -> None: output_layer = normalize_output_layer(output_layer, len(self.encoder.layers)) del self.encoder.layers[output_layer:]
[docs] def forward( self, input_values: Tensor, sample_rate: int, causal: bool = False, output_layer: int | float | None = None, ) -> Tensor: if sample_rate != 16_000: raise RuntimeError("HuBERT only supports 16 kHz as input sampling rate") if self.pre_normalize: input_values = F.layer_norm(input_values, input_values.shape[1:]) extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) hidden_states = self.feature_projection(extract_features) return self.encoder.forward(hidden_states, causal=causal, output_layer=output_layer)
[docs] def extract_all_features( self, input_values: Tensor, sample_rate: int, causal: bool = False, output_layer: int | float | None = None, ) -> list[Tensor]: if sample_rate != 16_000: raise RuntimeError("HuBERT only supports 16 kHz as input sampling rate") if self.pre_normalize: input_values = F.layer_norm(input_values, input_values.shape[1:]) extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) hidden_states = self.feature_projection(extract_features) return self.encoder.extract_all_features(hidden_states, causal=causal, output_layer=output_layer)
[docs] def predictor( self, kmeans: KMeans | None = None, *, device: base_device | None = None, ) -> "HubertPredictor": return HubertPredictor(self, kmeans, device=device)
[docs]class HubertPredictor: def __init__( self, hubert_model: Hubert, kmeans: KMeans | None = None, *, device: base_device | None = None, ) -> None: """Provides an API for doing predictoins with a HuBERT model. Note that this module is not an `nn.Module`, so you can use it in your module without worrying about storing all the weights on accident. Args: hubert_model: The HuBERT model to use for predictions. kmeans: The kmeans model to use for quantization. If `None`, don't quantize. device: The device to use for predictions. If `None`, will use the device returned by detect_device(). """ super().__init__() self.device = detect_device() if device is None else device self.model = hubert_model.eval() self.kmeans = kmeans.eval() if kmeans is not None else None self.device.module_to(self.model) if self.kmeans is not None: self.device.module_to(self.kmeans) self.sample_rate = 16_000 # True for all HuBERT models.
[docs] def predict( self, waveform: np.ndarray | Tensor, sample_rate: int, output_layer: int | float | None = None, causal: bool = False, ) -> Tensor: """Gets the hidden states for the given waveform. Args: waveform: The waveform to get hidden states for, with shape (B, T) sample_rate: The waveform's sampling rate; this is only used to verify that it is 16 kHz, since it is easy for downstream applications to forget. output_layer: The layer to get hidden states from. If `None`, will return the hidden states from the last layer. If an `int`, will return the hidden states from that layer. If a `float`, will return the hidden states from the layer at that percentage of the model. For example, `0.5` will return the hidden states from the middle layer. Negative values will wrap around. causal: If set, use a causal attention mask. Returns: The hidden states for the given waveform, with shape (B, T, D) """ waveform = self.device.tensor_to(waveform) features = self.model.forward(waveform, sample_rate, causal=causal, output_layer=output_layer) if self.kmeans is not None: features = self.kmeans.forward(features) return features
[docs] def predict_in_chunks( self, waveform: Tensor | np.ndarray, sample_rate: int, chunk_size: int = 16_000 * 10, output_layer: int | float | None = None, causal: bool = False, ) -> Tensor: """Gets the hidden states for the given waveform, in chunks. This is useful for processing very long waveforms, as it allows you to process the waveform in chunks, rather than loading the entire waveform into memory at once. Args: waveform: The waveform to get hidden states for, with shape (B, T) sample_rate: The waveform's sampling rate; this is only used to verify that it is 16 kHz, since it is easy for downstream applications to forget. chunk_size: The size of each chunk to process, in frames. output_layer: The layer to get hidden states from. If `None`, will return the hidden states from the last layer. If an `int`, will return the hidden states from that layer. If a `float`, will return the hidden states from the layer at that percentage of the model. For example, `0.5` will return the hidden states from the middle layer. Negative values will wrap around. causal: If set, use a causal attention mask. Returns: The hidden states for the given waveform, with shape (B, T, D) """ with torch.inference_mode(), self.device.autocast_context(): x = self.device.tensor_to(waveform) # Loads entire waveform into device memory. if self.model.pre_normalize: x = F.layer_norm(x, x.shape) feat = [] for start in range(0, x.size(1), chunk_size): x_chunk = x[:, start : start + chunk_size] feat_chunk = self.model.forward(x_chunk, sample_rate, causal=causal, output_layer=output_layer) if self.kmeans is not None: feat_chunk = self.kmeans.forward(feat_chunk) feat.append(feat_chunk.cpu()) return torch.cat(feat, 1).squeeze(0)
[docs] def predict_file( self, path: str | Path, chunk_length_sec: float = 10.0, output_layer: int | float | None = None, causal: bool = False, ) -> Tensor: """Gets the hidden states for the given audio file, in chunks. Args: path: The path to the audio file to process. sample_rate: The waveform's sampling rate; this is only used to verify that it is 16 kHz, since it is easy for downstream applications to forget. chunk_length_sec: The length of each chunk to process, in seconds. output_layer: The layer to get hidden states from. If `None`, will return the hidden states from the last layer. If an `int`, will return the hidden states from that layer. If a `float`, will return the hidden states from the layer at that percentage of the model. For example, `0.5` will return the hidden states from the middle layer. Negative values will wrap around. causal: If set, use a causal attention mask. Returns: The hidden states for the given waveform, with shape (B, T, D) """ props = get_audio_props(path) effects: list[tuple[str, str]] = [("gain", "-n"), ("channels", "1")] if props.sample_rate != self.sample_rate: effects.append(("rate", str(self.sample_rate))) chunk_length = round(chunk_length_sec * self.sample_rate) with torch.inference_mode(), self.device.autocast_context(): feat = [] for waveform_chunk in read_audio( path, chunk_length=chunk_length, sample_rate=self.sample_rate, ): waveform_tensor = torch.from_numpy(waveform_chunk).to(torch.float32) waveform_tensor, _ = ta_sox.apply_effects_tensor(waveform_tensor, props.sample_rate, effects) chans, _ = waveform_tensor.shape assert chans == 1, f"Expected mono-channel audio, got {chans} channels" x = self.device.tensor_to(waveform_tensor) if self.model.pre_normalize: x = F.layer_norm(x, x.shape[1:]) feat_chunk = self.model.forward(x, self.sample_rate, causal=causal, output_layer=output_layer) if self.kmeans is not None: feat_chunk = self.kmeans.forward(feat_chunk) feat.append(feat_chunk.cpu()) return torch.cat(feat, 1).squeeze(0)
EXCLUDE_KEYS = {"masked_spec_embed", ".weight", ".bias"} def _load_pretrained_hubert( size: PretrainedHubertSize, ckpt_url: str, sha256: str, hidden_size: int, intermediate_size: int, num_hidden_layers: int, num_attention_heads: int, remove_prefix: str | None = None, load_weights: bool = True, conv_dim: tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512), conv_stride: tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2), conv_kernel: tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2), conv_bias: bool = True, num_conv_pos_embeddings: int = 128, num_conv_pos_embedding_groups: int = 16, do_stable_layer_norm: bool = True, pre_normalize: bool = True, feat_extract_norm: Literal["group", "layer"] = "layer", feat_extract_activation: ActivationType = "gelu", feat_proj_layer_norm: bool = True, hidden_act: ActivationType = "gelu", layer_norm_eps: float = 1e-5, hidden_dropout: float = 0.1, activation_dropout: float = 0.1, feat_proj_dropout: float = 0.0, ) -> Hubert: with Timer("building empty model", spinner=True): model = Hubert( hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, conv_dim=conv_dim, conv_stride=conv_stride, conv_kernel=conv_kernel, conv_bias=conv_bias, num_conv_pos_embeddings=num_conv_pos_embeddings, num_conv_pos_embedding_groups=num_conv_pos_embedding_groups, do_stable_layer_norm=do_stable_layer_norm, pre_normalize=pre_normalize, feat_extract_norm=feat_extract_norm, feat_extract_activation=feat_extract_activation, feat_proj_layer_norm=feat_proj_layer_norm, hidden_act=hidden_act, layer_norm_eps=layer_norm_eps, hidden_dropout=hidden_dropout, activation_dropout=activation_dropout, feat_proj_dropout=feat_proj_dropout, ) # Loads the model weights. if load_weights: model_fname = f"{size}.bin" with Timer("downloading checkpoint"): model_path = ensure_downloaded(ckpt_url, "hubert", model_fname, sha256=sha256) with Timer("loading checkpoint", spinner=True): ckpt = torch.load(model_path, map_location="cpu") if remove_prefix: ckpt = {k[len(remove_prefix) :]: v for k, v in ckpt.items()} ckpt = {k: v for k, v in ckpt.items() if k not in EXCLUDE_KEYS} model.load_state_dict(ckpt) return model def _load_pretrained_hubert_kmeans( size: PretrainedHubertKmeansSize, ckpt_url: str, sha256: str, use_triton_if_available: bool = True, ) -> KMeans: centers_fname = f"{size}.npy" with Timer("downloading cluster centers"): centers_path = ensure_downloaded(ckpt_url, "hubert", centers_fname, sha256=sha256) with Timer("loading K-means clusters", spinner=True): centers = np.load(centers_path) return KMeans(centers, use_triton_if_available=use_triton_if_available)
[docs]def pretrained_hubert(size: PretrainedHubertSize, load_weights: bool = True) -> Hubert: match size: case "base": return _load_pretrained_hubert( size, ckpt_url="https://huggingface.co/facebook/hubert-base-ls960/resolve/main/pytorch_model.bin", sha256="062249fffb353eab67547a2fbc129f7c31a2f459faf641b19e8fb007cc5c48ad", hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, load_weights=load_weights, feat_extract_norm="group", conv_bias=False, do_stable_layer_norm=False, pre_normalize=False, ) case "large": return _load_pretrained_hubert( size, ckpt_url="https://huggingface.co/facebook/hubert-large-ls960-ft/resolve/main/pytorch_model.bin", sha256="9cf43abec3f0410ad6854afa4d376c69ccb364b48ddddfd25c4c5aa16398eab0", hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, remove_prefix="hubert.", load_weights=load_weights, ) case "extra_large": return _load_pretrained_hubert( size, ckpt_url="https://huggingface.co/facebook/hubert-xlarge-ll60k/resolve/main/pytorch_model.bin", sha256="6131dc27f4508595daa1a13fec4aa1f6b4a579b5d93550bae26c13a83221f8a7", hidden_size=1280, num_hidden_layers=48, num_attention_heads=16, intermediate_size=5120, load_weights=load_weights, ) case _: raise NotImplementedError(f"Invalid size: {size}")
[docs]def pretrained_kmeans_clusters(size: PretrainedHubertKmeansSize) -> KMeans: url_base = "https://huggingface.co/codekansas/hubert-quantization/resolve/main" match size: case "base-l7-c100": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_7_sklearn_100.npy", sha256="e46d1e2a5d6f83805dd336cf22a4228a902e78c3377141b4aa8e8c946af160cb", ) case "base-l7-c200": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_7_sklearn_200.npy", sha256="5bce95ff25b8e3e07170f73bfcf7a5c72a432a9acd3382e833409a30a41ce062", ) case "base-l7-c500": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_7_sklearn_500.npy", sha256="ce9855b89955affbf8e939ff274a4938efee730d4fb4fab990070747744b9df0", ) case "base-l7-c1000": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_7_sklearn_1000.npy", sha256="6a10e5978bac1b84a3b0e03bb72e3015d0cdf6956e301a48971eb3a2493e37c5", ) case "base-l8-c100": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_8_sklearn_100.npy", sha256="3219a01b5ec21ca173605fe5b2d7b296db1a10ef24e5c593c8076b1b39f96865", ) case "base-l8-c200": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_8_sklearn_200.npy", sha256="0beab85b59604841da10b3327bedc710e0dbf8e4a2b24bc0d964bf345640e9d7", ) case "base-l8-c500": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_8_sklearn_500.npy", sha256="4a06731ef6d8aa116ae05ec309ad1ae47b7c030f05bc62137899b17d32fd294a", ) case "base-l8-c1000": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_8_sklearn_1000.npy", sha256="15be942383cf9e5afc3d6f0d615ab6dc8459364129dc1a02ee00f8c927783aae", ) case "base-l10-c100": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_10_sklearn_100.npy", sha256="22918f566c2308c14e9514830ddc669a26dfbed2668ce27eb8f57801b78d27b7", ) case "base-l20-c200": return _load_pretrained_hubert_kmeans( size, ckpt_url=f"{url_base}/kmeans_base_10_sklearn_200.npy", sha256="92cff38a19df79303a8d18c98fb776cbed299b92944195a3793eda16a7f6da97", ) case _: raise NotImplementedError(f"Invalid size: {size}")
[docs]def pretrained_hubert_with_kmeans( size: PretrainedHubertKmeansSize, load_weights: bool = True, ) -> tuple[Hubert, KMeans]: kmeans = pretrained_kmeans_clusters(size) match size: case "base-l7-c100": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(7) return hubert, kmeans case "base-l7-c200": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(7) return hubert, kmeans case "base-l7-c500": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(7) return hubert, kmeans case "base-l7-c1000": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(7) return hubert, kmeans case "base-l8-c100": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(8) return hubert, kmeans case "base-l8-c200": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(8) return hubert, kmeans case "base-l8-c500": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(8) return hubert, kmeans case "base-l8-c1000": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(8) return hubert, kmeans case "base-l10-c100": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(10) return hubert, kmeans case "base-l10-c200": hubert = pretrained_hubert("base", load_weights=load_weights) hubert.set_output_layer(10) return hubert, kmeans case _: raise NotImplementedError(f"Invalid size: {size}")
[docs]def test_hubert_adhoc() -> None: configure_logging() parser = argparse.ArgumentParser() size_choices = get_args(PretrainedHubertSize) + get_args(PretrainedHubertKmeansSize) parser.add_argument("size", type=str, choices=size_choices) parser.add_argument("-t", "--tsz", type=int, default=22400) parser.add_argument("-n", "--no-load-weights", default=False, action="store_true") parser.add_argument("-c", "--causal", default=False, action="store_true") args = parser.parse_args() # Loads the model and moves to the right device. kmeans: KMeans | None if args.size in get_args(PretrainedHubertSize): model = pretrained_hubert(size=cast(PretrainedHubertSize, args.size), load_weights=not args.no_load_weights) kmeans = None else: model, kmeans = pretrained_hubert_with_kmeans(size=cast(PretrainedHubertKmeansSize, args.size)) predictor = model.predictor(kmeans) # Test the model on a random waveform. y = predictor.predict(torch.randn(1, args.tsz), sample_rate=16000, causal=args.causal) assert (args.tsz // 320) == y.shape[1] + 1
if __name__ == "__main__": # python -m pretrained.hubert test_hubert_adhoc()