pretrained.llama

Defines a simple API for using Meta’s pretrained LLaMa model.

This code is adapted from the original implementation here, adapted to use the parallelism primitives in this codebase.

from pretrained.llama import pretrained_llama

model = pretrained_llama("7B")
predictor = model.predictor()

predictor.predict("The quick brown fox jumps over the lazy dog.")

Using the tokenizer requires installing the sentencepiece library:

pip install sentencepiece

The choices for the model key are:

  • "7B"

  • "13B"

  • "30B"

  • "65B"

pretrained.llama.cast_pretrained_llama_key(s: str) Literal['7B', '13B', '30B', '65B'][source]
class pretrained.llama.ModelArgs(dim: int = '???', n_layers: int = '???', n_heads: int = '???', mp_size: int = '???', vocab_size: int = '???', multiple_of: int = 256, norm_eps: float = 0.0001, max_seq_len: int = 2048, use_checkpointing: bool = True)[source]

Bases: object

dim: int = '???'
n_layers: int = '???'
n_heads: int = '???'
mp_size: int = '???'
vocab_size: int = '???'
multiple_of: int = 256
norm_eps: float = 0.0001
max_seq_len: int = 2048
use_checkpointing: bool = True
class pretrained.llama.RMSNorm(dim: int, eps: float = 1e-06)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters() None[source]
forward(x: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

pretrained.llama.precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) Tensor[source]
pretrained.llama.reshape_for_broadcast(freqs_cis: Tensor, x: Tensor) Tensor[source]
pretrained.llama.apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) tuple[torch.Tensor, torch.Tensor][source]
class pretrained.llama.Attention(args: ModelArgs, lora_rank: int | None = None)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, freqs_cis: Tensor, is_causal: bool, cache: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class pretrained.llama.FeedForward(dim: int, hidden_dim: int, multiple_of: int, lora_rank: int | None = None)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class pretrained.llama.TransformerBlock(layer_id: int, args: ModelArgs, lora_rank: int | None = None)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

run_attn(x: Tensor, freqs_cis: Tensor, is_causal: bool, cache: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]
run_ffn(x: Tensor) Tensor[source]
forward(x: Tensor, freqs_cis: Tensor, is_causal: bool, cache: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class pretrained.llama.Tokenizer(model_path: str | Path)[source]

Bases: object

encode(s: str, bos: bool, eos: bool) list[int][source]
decode(t: list[int]) str[source]
class pretrained.llama.Llama(params: ModelArgs, tokenizer: Tokenizer | None = None, lora_rank: int | None = None)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

freqs_cis: Tensor
get_mask(seqlen: int, ref_tensor: Tensor) Tensor | None[source]
forward(tokens: Tensor) tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

infer(tokens: Tensor, max_gen_len: int, temperature: float, top_p: float, eos_id: int | None = None) Iterator[tuple[torch.Tensor, torch.Tensor]][source]

Runs model inference for a token sequence.

Parameters:
  • tokens – The input tokens, with shape (T).

  • max_gen_len – The maximum number of tokens to generate.

  • temperature – The softmax temperature.

  • top_p – The top-p sampling threshold.

  • eos_id – The EOS token ID; if not provided, generate as many tokens as possible.

Yields:

The generated token sequence, with shape (T + N), along with the associated logits, with shape (N, V).

predictor() LlamaPredictor[source]
class pretrained.llama.LlamaPredictor(llama_model: Llama, *, device: base_device | None = None)[source]

Bases: object

Provides an API for sampling from the LLaMa model.

Parameters:
  • llama_model – The LLaMa model.

  • device – The device to use for sampling. If None, the device will be automatically detected.

Raises:

ValueError – If the tokenizer is not set.

tokenize(prompt: str | None) Tensor[source]
generate_for_tokens(prompt_tokens: Tensor, max_gen_len: int = 256, temperature: float = 0.8, top_p: float = 0.95) Iterator[str][source]
generate(prompt: str | None = None, max_gen_len: int = 256, temperature: float = 0.8, top_p: float = 0.95) Iterator[str][source]
unit_test_forward_matches_infer(prompt: str) bool[source]

Ensures that the forward pass matches the inference pass.

This is a simple unit test which does argmax decoding for the inference pass to get a sequence, then passes the sequence to the forward pass. The output of the forward pass should match.

Parameters:

prompt – The prompt to use for the unit test.

Returns:

Whether the forward pass matches the inference pass.

pretrained.llama.sample_top_p(probs: Tensor, p: float) Tensor[source]
pretrained.llama.get_ckpt_and_tokenizer_path(key: Literal['7B', '13B', '30B', '65B']) tuple[pathlib.Path, pathlib.Path][source]
pretrained.llama.empty_llama(key: Literal['7B', '13B', '30B', '65B']) Llama[source]
pretrained.llama.pretrained_llama(key: Literal['7B', '13B', '30B', '65B'], *, lora_rank: int | None = None) Llama[source]
pretrained.llama.test_worker(key: Literal['7B', '13B', '30B', '65B'], prompt: str | None, max_gen_len: int, temperature: float, top_p: float, pretrained: bool) None[source]
pretrained.llama.setup() None[source]
pretrained.llama.test_pretrained_model() None[source]