pretrained.llama
Defines a simple API for using Meta’s pretrained LLaMa model.
This code is adapted from the original implementation here, adapted to use the parallelism primitives in this codebase.
from pretrained.llama import pretrained_llama
model = pretrained_llama("7B")
predictor = model.predictor()
predictor.predict("The quick brown fox jumps over the lazy dog.")
Using the tokenizer requires installing the sentencepiece
library:
pip install sentencepiece
The choices for the model key are:
"7B"
"13B"
"30B"
"65B"
- class pretrained.llama.ModelArgs(dim: int = '???', n_layers: int = '???', n_heads: int = '???', mp_size: int = '???', vocab_size: int = '???', multiple_of: int = 256, norm_eps: float = 0.0001, max_seq_len: int = 2048, use_checkpointing: bool = True)[source]
Bases:
object
- dim: int = '???'
- n_layers: int = '???'
- n_heads: int = '???'
- mp_size: int = '???'
- vocab_size: int = '???'
- multiple_of: int = 256
- norm_eps: float = 0.0001
- max_seq_len: int = 2048
- use_checkpointing: bool = True
- class pretrained.llama.RMSNorm(dim: int, eps: float = 1e-06)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- pretrained.llama.apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) tuple[torch.Tensor, torch.Tensor] [source]
- class pretrained.llama.Attention(args: ModelArgs, lora_rank: int | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, freqs_cis: Tensor, is_causal: bool, cache: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.llama.FeedForward(dim: int, hidden_dim: int, multiple_of: int, lora_rank: int | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.llama.TransformerBlock(layer_id: int, args: ModelArgs, lora_rank: int | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- run_attn(x: Tensor, freqs_cis: Tensor, is_causal: bool, cache: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] [source]
- forward(x: Tensor, freqs_cis: Tensor, is_causal: bool, cache: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.llama.Llama(params: ModelArgs, tokenizer: Tokenizer | None = None, lora_rank: int | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- freqs_cis: Tensor
- forward(tokens: Tensor) tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]] [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- infer(tokens: Tensor, max_gen_len: int, temperature: float, top_p: float, eos_id: int | None = None) Iterator[tuple[torch.Tensor, torch.Tensor]] [source]
Runs model inference for a token sequence.
- Parameters:
tokens – The input tokens, with shape (T).
max_gen_len – The maximum number of tokens to generate.
temperature – The softmax temperature.
top_p – The top-p sampling threshold.
eos_id – The EOS token ID; if not provided, generate as many tokens as possible.
- Yields:
The generated token sequence, with shape (T + N), along with the associated logits, with shape (N, V).
- predictor() LlamaPredictor [source]
- class pretrained.llama.LlamaPredictor(llama_model: Llama, *, device: base_device | None = None)[source]
Bases:
object
Provides an API for sampling from the LLaMa model.
- Parameters:
llama_model – The LLaMa model.
device – The device to use for sampling. If None, the device will be automatically detected.
- Raises:
ValueError – If the tokenizer is not set.
- generate_for_tokens(prompt_tokens: Tensor, max_gen_len: int = 256, temperature: float = 0.8, top_p: float = 0.95) Iterator[str] [source]
- generate(prompt: str | None = None, max_gen_len: int = 256, temperature: float = 0.8, top_p: float = 0.95) Iterator[str] [source]
- unit_test_forward_matches_infer(prompt: str) bool [source]
Ensures that the forward pass matches the inference pass.
This is a simple unit test which does argmax decoding for the inference pass to get a sequence, then passes the sequence to the forward pass. The output of the forward pass should match.
- Parameters:
prompt – The prompt to use for the unit test.
- Returns:
Whether the forward pass matches the inference pass.
- pretrained.llama.get_ckpt_and_tokenizer_path(key: Literal['7B', '13B', '30B', '65B']) tuple[pathlib.Path, pathlib.Path] [source]
- pretrained.llama.pretrained_llama(key: Literal['7B', '13B', '30B', '65B'], *, lora_rank: int | None = None) Llama [source]