pretrained.wav_codec

Defines a simple API for an audio quantizer model that runs on waveforms.

from pretrained.wav_codec import pretrained_wav_codec

model = pretrained_mel_codec("wav-codec-small")
quantizer, dequantizer = model.quantizer(), model.dequantizer()

# Convert some audio to a quantized representation.
quantized = quantizer(audio)

# Convert the quantized representation back to audio.
audio = dequantizer(quantized)
pretrained.wav_codec.cast_pretrained_mel_codec_type(s: str | Literal['base']) Literal['base'][source]
pretrained.wav_codec.split_waveform(waveform: Tensor, stride_length: int, waveform_prev: Tensor | None = None) tuple[torch.Tensor, torch.Tensor][source]
class pretrained.wav_codec.CBR(in_channels: int, out_channels: int, kernel_size: int)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, state: tuple[torch.Tensor, int] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, int]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class pretrained.wav_codec.Encoder(stride_length: int, d_model: int, kernel_size: int = 5)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(waveform: Tensor, state: tuple[torch.Tensor, list[tuple[torch.Tensor, int]]] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, list[tuple[torch.Tensor, int]]]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class pretrained.wav_codec.Decoder(stride_length: int, hidden_size: int, num_layers: int)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(toks: Tensor, waveform: Tensor, state: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

infer(toks: Tensor, state: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]
class pretrained.wav_codec.WavCodec(stride_length: int, hidden_size: int, num_layers: int, codebook_size: int, num_quantizers: int)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(waveform: Tensor) tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

encode(waveform: Tensor, waveform_prev: Tensor | None = None) tuple[torch.Tensor, torch.Tensor][source]
decode(toks: Tensor, state: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]
quantizer() WavCodecQuantizer[source]
dequantizer() WavCodecDequantizer[source]
class pretrained.wav_codec.WavCodecQuantizer(model: WavCodec)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

encode(waveform: Tensor, state: tuple[torch.Tensor, list[tuple[torch.Tensor, int]]] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, list[tuple[torch.Tensor, int]]]][source]

Converts a waveform into a set of tokens.

Parameters:
  • waveform – The single-channel input waveform, with shape (B, T) This should be at 16000 Hz.

  • state – The encoder state from the previous step.

Returns:

The quantized tokens, with shape (N, B, Tq), along with the encoder state to pass to the next step.

forward(waveform: Tensor, state: tuple[torch.Tensor, list[tuple[torch.Tensor, int]]] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, list[tuple[torch.Tensor, int]]]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class pretrained.wav_codec.WavCodecDequantizer(model: WavCodec)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

decode(toks: Tensor, state: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]

Converts a set of tokens into a waveform.

Parameters:
  • toks – The quantized tokens, with shape (N, B, Tq)

  • state – The previous state of the decoder.

Returns:

The single-channel output waveform, with shape (B, T), along with the new state of the decoder.

forward(toks: Tensor, state: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

pretrained.wav_codec.pretrained_wav_codec(key: str | Literal['base'], load_weights: bool = True) WavCodec[source]
pretrained.wav_codec.test_codec_adhoc() None[source]
pretrained.wav_codec.test_codec_training_adhoc() None[source]