# mypy: disable-error-code="import-not-found"
"""Defines a simple API for using Meta's pretrained LLaMa model.
This code is adapted from the original implementation
`here <https://github.com/facebookresearch/llama>`_, adapted to use
the parallelism primitives in this codebase.
.. highlight:: python
.. code-block:: python
from pretrained.llama import pretrained_llama
model = pretrained_llama("7B")
predictor = model.predictor()
predictor.predict("The quick brown fox jumps over the lazy dog.")
Using the tokenizer requires installing the ``sentencepiece`` library:
.. code-block:: bash
pip install sentencepiece
The choices for the model key are:
- ``"7B"``
- ``"13B"``
- ``"30B"``
- ``"65B"``
"""
import argparse
import functools
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator, Literal, cast, get_args
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from ml.core.config import conf_field
from ml.core.env import get_model_dir
from ml.models.lora import maybe_lora
from ml.models.parallel import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from ml.utils.large_models import init_empty_weights, meta_to_empty_func
from ml.utils.logging import configure_logging
from ml.utils.parallel import parallel_group_info
from ml.utils.timer import Timer
from ml.utils.torch_distributed import MultiprocessConfig, launch_subprocesses
from omegaconf import MISSING
from torch import Tensor, nn
logger = logging.getLogger(__name__)
PretrainedLlamaKey = Literal["7B", "13B", "30B", "65B"]
[docs]def cast_pretrained_llama_key(s: str) -> PretrainedLlamaKey:
if s not in get_args(PretrainedLlamaKey):
raise KeyError(f"Invalid LLaMa key: {s} Expected one of: {get_args(PretrainedLlamaKey)}")
return cast(PretrainedLlamaKey, s)
[docs]@dataclass
class ModelArgs:
dim: int = conf_field(MISSING, help="The model inner dimension size")
n_layers: int = conf_field(MISSING, help="The number of layers in the model")
n_heads: int = conf_field(MISSING, help="The number of transformer heads")
mp_size: int = conf_field(MISSING, help="The expected model parallelism size")
vocab_size: int = conf_field(MISSING, help="The vocabulary size")
multiple_of: int = conf_field(256, help="Make SwiGLU hidden layer size a multiple of large power of two")
norm_eps: float = conf_field(1e-4, help="The normalization epsilon value")
max_seq_len: int = conf_field(2048, help="The maximum sequence length")
use_checkpointing: bool = conf_field(True, help="Whether to use checkpointing")
PRETRAINED_MODEL_SIZES: dict[PretrainedLlamaKey, ModelArgs] = {
"7B": ModelArgs(dim=4096, n_heads=32, n_layers=32, mp_size=1),
"13B": ModelArgs(dim=5120, n_heads=40, n_layers=40, mp_size=2),
"30B": ModelArgs(dim=6656, n_heads=52, n_layers=60, mp_size=4),
"65B": ModelArgs(dim=8192, n_heads=64, n_layers=80, mp_size=8),
}
[docs]class RMSNorm(nn.Module):
__constants__ = ["eps"]
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))
self.reset_parameters()
[docs] def reset_parameters(self) -> None:
if not self.weight.data.is_meta:
nn.init.ones_(self.weight.data)
def _norm(self, x: Tensor) -> Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[docs] def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
[docs]def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
[docs]def reshape_for_broadcast(freqs_cis: Tensor, x: Tensor) -> Tensor:
ndim = x.ndim
assert 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i in (1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
[docs]def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
[docs]class Attention(nn.Module):
def __init__(self, args: ModelArgs, lora_rank: int | None = None) -> None:
super().__init__()
self.n_local_heads = args.n_heads // parallel_group_info().mp.world_size
self.head_dim = args.dim // args.n_heads
wq = ColumnParallelLinear(args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False)
wk = ColumnParallelLinear(args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False)
wv = ColumnParallelLinear(args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False)
wo = RowParallelLinear(args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True)
self.wq = maybe_lora(wq, lora_rank)
self.wk = maybe_lora(wk, lora_rank)
self.wv = maybe_lora(wv, lora_rank)
self.wo = maybe_lora(wo, lora_rank)
[docs] def forward(
self,
x: Tensor,
freqs_cis: Tensor,
is_causal: bool,
cache: tuple[Tensor, Tensor] | None = None,
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# Gets the cached keys.
if cache is None:
keys, values = xk, xv
else:
cache_k, cache_v = cache
keys, values = torch.cat((cache_k, xk), dim=1), torch.cat((cache_v, xv), dim=1)
# (B, T, H, D) -> (B, H, T, D)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# Attention.
output = F.scaled_dot_product_attention(xq, keys, values, is_causal=is_causal).transpose(1, 2).flatten(2)
return self.wo(output), (keys.transpose(1, 2), values.transpose(1, 2))
[docs]class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
lora_rank: int | None = None,
) -> None:
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False)
w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True)
w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False)
self.w1 = maybe_lora(w1, lora_rank)
self.w2 = maybe_lora(w2, lora_rank)
self.w3 = maybe_lora(w3, lora_rank)
[docs] def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
[docs]class Tokenizer:
def __init__(self, model_path: str | Path) -> None:
model_path = Path(model_path).resolve()
assert model_path.is_file(), f"Tokenizer model file not found at {model_path}"
try:
from sentencepiece import SentencePieceProcessor
except ImportError:
raise ImportError("Please install sentencepiece with: `pip install sentencepiece`")
self.sp_model = SentencePieceProcessor(model_file=str(model_path))
logger.info("Loaded sentence piece model from %s", model_path)
# Gets the sentence statistics.
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
logger.info("Number of words: %d, BOS ID: %d, EOS ID: %d", self.n_words, self.bos_id, self.eos_id)
[docs] def encode(self, s: str, bos: bool, eos: bool) -> list[int]:
assert isinstance(s, str)
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
[docs] def decode(self, t: list[int]) -> str:
return self.sp_model.decode(t)
[docs]class Llama(nn.Module):
freqs_cis: Tensor
def __init__(self, params: ModelArgs, tokenizer: Tokenizer | None = None, lora_rank: int | None = None) -> None:
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tokenizer = tokenizer
tok_embeddings = ParallelEmbedding(params.vocab_size, params.dim)
self.tok_embeddings = maybe_lora(tok_embeddings, lora_rank)
self.layers = nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params, lora_rank=lora_rank))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False)
self.output = maybe_lora(output, lora_rank)
freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
if lora_rank is not None:
self.tok_embeddings.requires_grad_(False)
self.norm.requires_grad_(False)
[docs] @torch.no_grad()
def get_mask(self, seqlen: int, ref_tensor: Tensor) -> Tensor | None:
mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=ref_tensor.device)
mask = torch.triu(mask, diagonal=1).type_as(ref_tensor)
return mask
[docs] def forward(self, tokens: Tensor) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]:
_, seqlen = tokens.shape
x = self.tok_embeddings(tokens)
freqs_cis = self.freqs_cis[:seqlen]
# Runs the transformer.
caches: list[tuple[Tensor, Tensor]] = []
for layer in self.layers:
x, cache = layer(x, freqs_cis, seqlen > 1)
caches.append(cache)
x = self.norm(x)
logits = self.output(x)
return logits, caches
[docs] @torch.no_grad()
def infer(
self,
tokens: Tensor,
max_gen_len: int,
temperature: float,
top_p: float,
eos_id: int | None = None,
) -> Iterator[tuple[Tensor, Tensor]]:
"""Runs model inference for a token sequence.
Args:
tokens: The input tokens, with shape (T).
max_gen_len: The maximum number of tokens to generate.
temperature: The softmax temperature.
top_p: The top-p sampling threshold.
eos_id: The EOS token ID; if not provided, generate as many tokens
as possible.
Yields:
The generated token sequence, with shape (T + N), along with the
associated logits, with shape (N, V).
"""
tokens = tokens[None] # (T) -> (B, T)
_, seqlen = tokens.shape
x = self.tok_embeddings(tokens)
freqs_cis = self.freqs_cis[:seqlen]
# Runs the first step of the transformer.
caches: list[tuple[Tensor, Tensor]] = []
for layer in self.layers:
x, cache = layer(x, freqs_cis, seqlen > 1)
caches.append(cache)
x = self.norm(x)
logits = self.output(x[:, -1:])
for i in range(max_gen_len):
# Samples the next token from the sequence.
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs.flatten(), top_p)[None]
else:
next_token = torch.argmax(logits, dim=-1)
tokens = torch.cat((tokens, next_token), dim=1)
if (eos_id is not None and (next_token == eos_id).item()) or i == max_gen_len - 1:
break
# Runs the next step of the transformer.
x = self.tok_embeddings(next_token)
freqs_cis = self.freqs_cis[seqlen + i : seqlen + i + 1]
next_caches: list[tuple[Tensor, Tensor]] = []
for layer, cache in zip(self.layers, caches):
x, cache = layer(x, freqs_cis, False, cache)
next_caches.append(cache)
x = self.norm(x)
logits = self.output(x[:, -1:, :])
caches = next_caches
yield next_token, logits
[docs] def predictor(self) -> "LlamaPredictor":
return LlamaPredictor(self)
[docs]class LlamaPredictor:
def __init__(self, llama_model: Llama, *, device: base_device | None = None) -> None:
"""Provides an API for sampling from the LLaMa model.
Args:
llama_model: The LLaMa model.
device: The device to use for sampling. If None, the device will be
automatically detected.
Raises:
ValueError: If the tokenizer is not set.
"""
super().__init__()
self.device = detect_device() if device is None else device
self.device.module_to(llama_model)
tokenizer = llama_model.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer must be set to use predictor")
self.tokenizer = tokenizer
self.model = llama_model
[docs] def tokenize(self, prompt: str | None) -> Tensor:
if prompt is None:
prompt_tokens = torch.full((1, 1), self.tokenizer.bos_id, dtype=torch.long)
else:
prompt_tokens = torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False))
prompt_tokens = self.device.tensor_to(prompt_tokens)
return prompt_tokens
[docs] @torch.inference_mode()
def generate_for_tokens(
self,
prompt_tokens: Tensor,
max_gen_len: int = 256,
temperature: float = 0.8,
top_p: float = 0.95,
) -> Iterator[str]:
for pred_token, _ in self.model.infer(
prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
eos_id=self.tokenizer.eos_id,
):
yield self.tokenizer.decode(pred_token.tolist())[0]
[docs] @torch.inference_mode()
def generate(
self,
prompt: str | None = None,
max_gen_len: int = 256,
temperature: float = 0.8,
top_p: float = 0.95,
) -> Iterator[str]:
prompt_tokens = self.tokenize(prompt)
yield from self.generate_for_tokens(
prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
[docs] @torch.no_grad()
def unit_test_forward_matches_infer(self, prompt: str) -> bool:
"""Ensures that the forward pass matches the inference pass.
This is a simple unit test which does argmax decoding for the inference
pass to get a sequence, then passes the sequence to the forward pass.
The output of the forward pass should match.
Args:
prompt: The prompt to use for the unit test.
Returns:
Whether the forward pass matches the inference pass.
"""
test_tokens = torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False))
test_tokens = self.device.tensor_to(test_tokens)
seqlen = test_tokens.shape[0]
inferred_tokens = self.model.infer(
test_tokens,
max_gen_len=16,
temperature=0.0,
top_p=0.0,
eos_id=None,
)
pred_tokens_list, pred_logits_list = zip(*inferred_tokens)
pred_tokens, pred_logits = torch.cat(pred_tokens_list, dim=1), torch.cat(pred_logits_list, dim=1)
fwd_logits, _ = self.model.forward(pred_tokens.unsqueeze(0))
# Need to check with a high atol because models are in FP16.
return torch.allclose(pred_logits[:, 1:], fwd_logits[:, seqlen:-1], atol=1e-2)
[docs]def sample_top_p(probs: Tensor, p: float) -> Tensor:
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
[docs]def get_ckpt_and_tokenizer_path(key: PretrainedLlamaKey) -> tuple[Path, Path]:
root_dir = get_model_dir() / "llama"
ckpt_dir = root_dir / key
tokenizer_path = root_dir / "tokenizer.model"
return ckpt_dir, tokenizer_path
[docs]def empty_llama(key: PretrainedLlamaKey) -> Llama:
_, tokenizer_path = get_ckpt_and_tokenizer_path(key)
if not tokenizer_path.exists():
raise ValueError(f"LLaMa tokenizer not found at {tokenizer_path}; download it first")
model_args = PRETRAINED_MODEL_SIZES[key]
with Timer("loading tokenizer", spinner=True):
tokenizer = Tokenizer(tokenizer_path)
model_args.vocab_size = tokenizer.n_words
with Timer("building empty model", spinner=True), init_empty_weights():
model = Llama(model_args, tokenizer)
with Timer("moving model to device", spinner=True):
device = detect_device().get_device()
model._apply(meta_to_empty_func(device, torch.half))
def reset_params(module: nn.Module) -> None:
if hasattr(module, "reset_parameters") and callable(module.reset_parameters):
module.reset_parameters()
with Timer("resetting parameters", spinner=True):
model.apply(reset_params)
return model
[docs]def pretrained_llama(key: PretrainedLlamaKey, *, lora_rank: int | None = None) -> Llama:
rank, world_size = parallel_group_info().mp.rank, parallel_group_info().mp.world_size
ckpt_dir, tokenizer_path = get_ckpt_and_tokenizer_path(key)
if not ckpt_dir.exists():
raise ValueError(f"LLaMa model {key} not found at {ckpt_dir}; download it first")
if not tokenizer_path.exists():
raise ValueError(f"LLaMa tokenizer not found at {tokenizer_path}; download it first")
# Loads the checkpoint for the current rank.
with Timer("loading checkpoint", spinner=True):
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
if world_size != (num_ckpts := len(checkpoints)):
raise ValueError(f"Loading a checkpoint for {num_ckpts=} but {world_size=}")
ckpt_path = checkpoints[rank]
checkpoint = torch.load(ckpt_path, map_location="cpu")
# Loads the checkpoint parameters from the JSON file.
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args = ModelArgs(**params)
# Builds the tokenizer and updates the vocab size.
with Timer("loading tokenizer", spinner=True):
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
# Builds the model with empty weights.
with Timer("building model skeleton", spinner=True), init_empty_weights():
model = Llama(model_args, tokenizer, lora_rank=lora_rank)
# Logs model summary.
total_params = sum(p.numel() for p in model.parameters())
logger.info("Model %s has %s parameters", key, f"{total_params:,}")
# Build the transformer and loads the checkpoint.
with Timer("loading state dict", spinner=True):
model._apply(meta_to_empty_func(torch.device("cuda"), torch.half))
model.load_state_dict(checkpoint)
return model
[docs]def test_worker(
key: PretrainedLlamaKey,
prompt: str | None,
max_gen_len: int,
temperature: float,
top_p: float,
pretrained: bool,
) -> None:
model = pretrained_llama(key) if pretrained else empty_llama(key)
predictor = model.predictor()
# Setting the seed across all processes to make sure that the weights
# initialize to the same values (needed to make the test pass).
torch.manual_seed(1337)
def generate_for_prompt(prompt: str) -> None:
print(prompt, end="")
for token in predictor.generate(
prompt=prompt,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
):
print(token, end="", flush=True)
print()
if prompt:
generate_for_prompt(prompt)
else:
prompt = input("Prompt: ")
while prompt:
generate_for_prompt(prompt)
prompt = input("Prompt: ")
[docs]def setup() -> None:
# Hides some nuisance logs.
logging.getLogger("torch.distributed").setLevel(logging.ERROR)
logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.ERROR)
logging.getLogger("ml.utils.torch_distributed").setLevel(logging.ERROR)
# Setting the seed across all processes to make sure that the weights
# initialize to the same values (needed to make the test pass).
torch.manual_seed(1337)
[docs]def test_pretrained_model() -> None:
parser = argparse.ArgumentParser(description="Tests a pretrained LLaMA model")
parser.add_argument("key", type=str, choices=get_args(PretrainedLlamaKey))
parser.add_argument("prompt", type=str)
parser.add_argument("-m", "--max-gen-len", type=int, default=256)
parser.add_argument("-t", "--temperature", type=float, default=0.8)
parser.add_argument("-p", "--top-p", type=float, default=0.95)
parser.add_argument("-e", "--empty", default=False, action="store_true")
args = parser.parse_args()
configure_logging()
key = args.key
all_args = args.prompt, args.max_gen_len, args.temperature, args.top_p, not args.empty
world_size = PRETRAINED_MODEL_SIZES[key].mp_size
launch_subprocesses(
functools.partial(test_worker, key, *all_args),
MultiprocessConfig(
world_size=world_size,
model_parallelism=world_size,
),
setup=setup,
)
if __name__ == "__main__":
# python -m pretrained.llama 7B 'The meaning of life is'
test_pretrained_model()