# mypy: disable-error-code="import, override"
r"""Defines a simple API for using the RWKV model.
This code is adapted from the minimimal implementation
`here <https://johanwind.github.io/2023/03/23/rwkv_details.html>`_, adapted
to be fine-tunable.
.. highlight:: python
.. code-block:: python
from rwkv.model import pretrained_rwkv
model = pretrained_rwkv("7B")
predictor = model.predictor()
for token in predictor.generate("The quick brown fox jumped over the"):
print(token)
Using the tokenizer requires installing the ``tokenizers`` library:
.. code-block:: bash
pip install tokenizers
Additionally, using the training mode CUDA kernel requires installing ``triton``:
.. code-block:: bash
pip install triton
The choices for the model key are:
- ``"169m"``
- ``"430m"``
- ``"1.5b"``
- ``"3b"``
- ``"7b"``
- ``"14b"``
"""
import argparse
import functools
import logging
import math
import os
import time
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Iterator, Literal, Sequence, cast, get_args
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from ml.models.lora import maybe_lora, reset_lora_weights_
from ml.utils.checkpoint import ensure_downloaded
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from ml.utils.large_models import init_empty_weights, meta_to_empty_func
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer
from torch import Tensor, nn
from torch.autograd.function import Function, FunctionCtx, once_differentiable
logger = logging.getLogger(__name__)
PretrainedRwkvKey = Literal["169m", "430m", "1.5b", "3b", "7b", "14b"]
WkvFnKey = Literal["eps", "log"]
AttentionState = tuple[Tensor, Tensor]
FeedForwardState = Tensor
State = tuple[AttentionState, FeedForwardState]
EPS = 1e-4
[docs]def cast_pretrained_rwkv_key(s: str) -> PretrainedRwkvKey:
if s not in get_args(PretrainedRwkvKey):
raise KeyError(f"Invalid RWKV size: {s} Expected one of: {get_args(PretrainedRwkvKey)}")
return cast(PretrainedRwkvKey, s)
[docs]@dataclass
class ModelArgs:
url: str
sha256: str
emb_dim: int
num_layers: int
PRETRAINED_MODEL_SIZES: dict[PretrainedRwkvKey, ModelArgs] = {
"169m": ModelArgs(
url="https://huggingface.co/BlinkDL/rwkv-4-pile-169m/resolve/main/RWKV-4-Pile-169M-20220807-8023.pth",
sha256="713c6f6137a08d3a86ab57df4f09ea03563329beb3bbabc23509d6c57aa0f9e2",
emb_dim=768,
num_layers=12,
),
"430m": ModelArgs(
url="https://huggingface.co/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth",
sha256="261e6b8fef1c7c9e08a4dde31bf5caf8e79c4da38126d77977a4707de82a7f64",
emb_dim=1024,
num_layers=24,
),
"1.5b": ModelArgs(
url="https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-20220929-ctx4096.pth",
sha256="6c97043e1bb0867368249290c97a2fe8ffc5ec12ceb1b5251f4ee911f9982c23",
emb_dim=2048,
num_layers=24,
),
"3b": ModelArgs(
url="https://huggingface.co/BlinkDL/rwkv-4-pile-3b/resolve/main/RWKV-4-Pile-3B-20221110-ctx4096.pth",
sha256="9500633f23d86fbae3cb3cbe7908b97b971e9561edf583c2c5c60b10b02bcc27",
emb_dim=2560,
num_layers=32,
),
"7b": ModelArgs(
url="https://huggingface.co/BlinkDL/rwkv-4-pile-7b/resolve/main/RWKV-4-Pile-7B-20230109-ctx4096.pth",
sha256="9ea1271b25deb6c72bd29f629147d5013cc7d7c69f9715192f6b6b92fca08f64",
emb_dim=4096,
num_layers=32,
),
"14b": ModelArgs(
url="https://huggingface.co/BlinkDL/rwkv-4-pile-14b/resolve/main/RWKV-4-Pile-14B-20230313-ctx8192-test1050.pth",
sha256="9e1b9b44f2a98124d86fe35e298f230e3a4fa7b60431962da282817ae1b0bf32",
emb_dim=5120,
num_layers=40,
),
}
TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json"
[docs]@functools.lru_cache
def supports_triton() -> bool:
if "USE_TRITON" in os.environ:
return os.environ["USE_TRITON"] == "1"
if not torch.cuda.is_available():
return False
try:
import triton
assert triton is not None
return True
except (ImportError, ModuleNotFoundError):
if torch.cuda.is_available():
warnings.warn("Triton is not installed, but CUDA is available; install with `pip install triton`")
return False
@torch.jit.script
def wkv_with_eps_forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
bsz, tsz, chans = k.shape
assert w.shape == u.shape == (chans,)
assert v.shape == (bsz, tsz, chans)
assert state.shape == (bsz, 3, 1, chans)
alpha, beta, eps = state[:, :, -1].chunk(3, dim=1) # (B, 1, D), (B, 1, D), (B, 1, D)
_, tsz, _ = k.shape
wkvs = []
alphas = [alpha]
betas = [beta]
epss = [eps]
for t in range(tsz):
kt, vt = k[:, t : t + 1], v[:, t : t + 1]
ukt = u + kt
tau = torch.maximum(ukt, eps)
e1 = torch.exp(eps - tau)
e2 = torch.exp(ukt - tau)
wkv = (e1 * alpha + e2 * vt) / (e1 * beta + e2)
wkvs.append(wkv)
w_eps = eps - w
eps = torch.maximum(w_eps, kt)
e1 = torch.exp(w_eps - eps)
e2 = torch.exp(kt - eps)
alpha = e1 * alpha + e2 * vt
beta = e1 * beta + e2
alphas.append(alpha)
betas.append(beta)
epss.append(eps)
alpha = torch.stack(alphas, dim=2)
beta = torch.stack(betas, dim=2)
eps = torch.stack(epss, dim=2)
return torch.cat(wkvs, 1), torch.cat((alpha, beta, eps), dim=1)
@torch.jit.script
def wkv_with_eps_backward(
w: Tensor,
u: Tensor,
k: Tensor,
v: Tensor,
state: Tensor,
grad_wkv: Tensor,
grad_state: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
bsz, tsz, chans = k.shape
assert w.shape == u.shape == (chans,)
assert v.shape == (bsz, tsz, chans)
assert state.shape == (bsz, 3, tsz + 1, chans)
assert grad_wkv.shape == (bsz, tsz, chans)
assert grad_state.shape == (bsz, 3, 1, chans)
alpha, beta, eps = state.chunk(3, dim=1) # (B, 1, T + 1, D), (B, 1, T + 1, D), (B, 1, T + 1, D)
grad_alpha, grad_beta, grad_eps = grad_state[:, :, 0].chunk(3, dim=1) # (B, 1, D), (B, 1, D), (B, 1, D)
grad_eps = grad_eps.clone()
grad_w = torch.zeros_like(w)
grad_u = torch.zeros_like(u)
grad_k = torch.zeros_like(k)
grad_v = torch.zeros_like(v)
for t in range(tsz - 1, -1, -1):
kt, vt = k[:, t : t + 1], v[:, t : t + 1]
alpha_prev, beta_prev, eps_prev = alpha[:, :, t], beta[:, :, t], eps[:, :, t]
alpha_curr, beta_curr, eps_curr = alpha[:, :, t + 1], beta[:, :, t + 1], eps[:, :, t + 1]
ukt = u + kt
tau = torch.maximum(ukt, eps_prev)
e1 = torch.exp(eps_prev - tau)
e2 = torch.exp(ukt - tau)
euke = torch.exp(ukt + eps_prev - 2 * tau)
denom = e1 * beta_prev + e2
denom_sq = denom * denom
grad_wkvt = grad_wkv[:, t : t + 1]
# Backpropagates wkv gradients.
grad_uk = grad_wkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq
grad_u += grad_uk.flatten(0, -2).sum(0)
grad_k[:, t : t + 1] += grad_uk
grad_v[:, t : t + 1] += grad_wkvt * e2 / denom
grad_alpha_wkv = grad_wkvt * e1 / denom
grad_beta_wkv = -grad_wkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq
grad_eps_wkv = grad_wkvt * euke * (alpha_prev - vt * beta_prev) / (e1 * beta_prev + e2) ** 2
e1 = torch.exp(eps_prev - eps_curr - w)
e2 = torch.exp(kt - eps_curr)
# Backpropagates alpha gradients.
grad_alpha_we = grad_alpha * e1 * alpha_prev
grad_w -= grad_alpha_we.flatten(0, -2).sum(0)
grad_k[:, t : t + 1] += grad_alpha * e2 * vt
grad_v[:, t : t + 1] += grad_alpha * e2
grad_eps += grad_alpha * -alpha_curr
# Backpropagates beta gradients.
grad_beta_we = grad_beta * e1 * beta_prev
grad_w -= grad_beta_we.flatten(0, -2).sum(0)
grad_k[:, t : t + 1] += grad_beta * e2
grad_eps += grad_beta * -beta_curr
# Backpropagates epsilon gradients.
eps_grad_mask = eps_prev - w > kt
grad_eps_we = torch.where(eps_grad_mask, grad_eps, torch.zeros_like(grad_eps))
grad_w -= grad_eps_we.flatten(0, -2).sum(0)
grad_k[:, t : t + 1] += torch.where(eps_grad_mask, torch.zeros_like(grad_eps), grad_eps)
# Computes gradients for alpha, beta and epsilon.
grad_alpha = grad_alpha * e1 + grad_alpha_wkv
grad_beta = grad_beta * e1 + grad_beta_wkv
grad_eps = grad_alpha_we + grad_beta_we + grad_eps_we + grad_eps_wkv
return grad_w, grad_u, grad_k, grad_v, torch.stack((grad_alpha, grad_beta, grad_eps), dim=1)
[docs]class WkvWithEps(Function):
[docs] @staticmethod
def forward(
ctx: FunctionCtx,
w: Tensor,
u: Tensor,
k: Tensor,
v: Tensor,
state: Tensor,
) -> tuple[Tensor, Tensor]:
wkv, state_out = wkv_with_eps_forward(w, u, k, v, state)
ctx.save_for_backward(w, u, k, v, state_out)
return wkv, state_out[:, :, -1:]
[docs] @staticmethod
@once_differentiable
def backward(
ctx: FunctionCtx,
grad_wkv: Tensor,
grad_state: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors)
return wkv_with_eps_backward(w, u, k, v, state, grad_wkv, grad_state)
[docs]def initial_state_with_eps(emb_dim: int) -> Tensor:
return torch.zeros(1, 3, 1, emb_dim)
[docs]def wkv_with_eps(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
"""Runs the core WKV computation.
Args:
w: The decay tensor, with shape (D)
u: The output multiplier tensor, with shape (D)
k: The K tensor, with shape (B, T, D)
v: The V tensor, with shape (B, T, D)
state: The state tensor, with shape (B, 3, T, D), consisting of the
alpha, beta and eps tensors, each with shape (B, 1, T, D)
Returns:
The WKV tensor, with shape (B, T, D), and the next state, with shape
(B, 3, 1, D), consisting of the next alpha, beta and eps tensors, each
with shape (B, 1, 1, D)
"""
return WkvWithEps.apply(w, u, k, v, state)
@torch.jit.script
def logaddexp(a: Tensor, b: Tensor) -> Tensor:
max_ab = torch.maximum(a, b)
return max_ab + torch.log(torch.exp(a - max_ab) + torch.exp(b - max_ab))
@torch.jit.script
def logsubexp(a: Tensor, b: Tensor, log_eps: float) -> Tensor:
max_ab = torch.clamp_min(torch.maximum(a, b), log_eps)
return max_ab + torch.log(torch.exp(a - max_ab) - torch.exp(b - max_ab))
@torch.jit.script
def wkv_log_space_forward(
w: Tensor,
u: Tensor,
k: Tensor,
v: Tensor,
state: Tensor,
eps: float = EPS,
normalize: bool = False,
) -> tuple[Tensor, Tensor]:
bsz, tsz, chans = k.shape
assert w.shape == u.shape == (chans,)
assert v.shape == (bsz, tsz, chans)
assert state.shape == (bsz, 3, 1, chans)
ln_alpha_p, ln_alpha_m, ln_beta = state[:, :, -1].chunk(3, dim=1)
log_eps = math.log(eps)
wkvs = []
ln_alpha_ps = [ln_alpha_p]
ln_alpha_ms = [ln_alpha_m]
ln_betas = [ln_beta]
for t in range(tsz):
kt, vt = k[:, t : t + 1], v[:, t : t + 1]
vt_p, vt_m = torch.clamp_min(vt, 0) + eps, torch.clamp_min(-vt, 0) + eps
ln_v_p, ln_v_m = torch.log(vt_p), torch.log(vt_m)
if normalize:
ln_alpha_pm = torch.minimum(ln_alpha_p, ln_alpha_m) - eps
ln_alpha_p = logsubexp(ln_alpha_p, ln_alpha_pm, log_eps)
ln_alpha_m = logsubexp(ln_alpha_m, ln_alpha_pm, log_eps)
ln_wkv_p = logaddexp(u + kt + ln_v_p, ln_alpha_p) - logaddexp(u + kt, ln_beta)
ln_wkv_m = logaddexp(u + kt + ln_v_m, ln_alpha_m) - logaddexp(u + kt, ln_beta)
wkv = torch.exp(ln_wkv_p) - torch.exp(ln_wkv_m)
wkvs.append(wkv)
ln_alpha_p = logaddexp(ln_alpha_p - w, kt + ln_v_p)
ln_alpha_m = logaddexp(ln_alpha_m - w, kt + ln_v_m)
ln_beta = logaddexp(ln_beta - w, kt)
ln_alpha_ps.append(ln_alpha_p)
ln_alpha_ms.append(ln_alpha_m)
ln_betas.append(ln_beta)
ln_alpha_p = torch.stack(ln_alpha_ps, dim=2)
ln_alpha_m = torch.stack(ln_alpha_ms, dim=2)
ln_beta = torch.stack(ln_betas, dim=2)
return torch.cat(wkvs, 1), torch.cat((ln_alpha_p, ln_alpha_m, ln_beta), dim=1)
@torch.jit.script
def wkv_log_space_backward(
w: Tensor,
u: Tensor,
k: Tensor,
v: Tensor,
state: Tensor,
grad_wkv: Tensor,
grad_state: Tensor,
eps: float = EPS,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
bsz, tsz, chans = k.shape
assert w.shape == u.shape == (chans,)
assert v.shape == (bsz, tsz, chans)
assert state.shape == (bsz, 3, tsz, chans)
assert grad_wkv.shape == (bsz, tsz, chans)
assert grad_state.shape == (bsz, 3, 1, chans)
grad_ln_alpha_p, grad_ln_alpha_m, grad_ln_beta = grad_state[:, :, 0].chunk(3, dim=1)
grad_w = torch.zeros_like(w)
grad_u = torch.zeros_like(u)
grad_k = torch.zeros_like(k)
grad_v = torch.zeros_like(v)
for t in range(tsz - 1, -1, -1):
kt, vt = k[:, t : t + 1], v[:, t : t + 1]
vt_p, vt_m = torch.clamp_min(vt, 0) + eps, torch.clamp_min(-vt, 0) + eps
ln_v_p, ln_v_m = torch.log(vt_p), torch.log(vt_m)
ln_alpha_p_prev, ln_alpha_m_prev, ln_beta_prev = state[:, :, t].chunk(3, dim=1)
uk = u + kt
ukv_p, ukv_m = uk + ln_v_p, uk + ln_v_m
ukb = logaddexp(uk, ln_beta_prev)
wkv_p = torch.exp(logaddexp(ukv_p, ln_alpha_p_prev) - ukb)
wkv_m = torch.exp(logaddexp(ukv_m, ln_alpha_m_prev) - ukb)
grad_wkvt = grad_wkv[:, t : t + 1]
grad_ln_wkv_p, grad_ln_wkv_m = grad_wkvt * wkv_p, grad_wkvt * -wkv_m
# Backpropagates wkv gradients.
e_num_p = torch.exp(ln_alpha_p_prev - ukv_p)
e_num_m = torch.exp(ln_alpha_m_prev - ukv_m)
e_den = torch.exp(ln_beta_prev - uk)
grad_wkv_den_p = grad_ln_wkv_p / (1 + e_den)
grad_wkv_den_m = grad_ln_wkv_m / (1 + e_den)
grad_kv_p = grad_ln_wkv_p / (1 + e_num_p)
grad_kv_m = grad_ln_wkv_m / (1 + e_num_m)
grad_uk = grad_kv_p + grad_kv_m - grad_wkv_den_p - grad_wkv_den_m
grad_u += grad_uk.flatten(0, -2).sum(0)
grad_k[:, t : t + 1] += grad_uk
grad_v[:, t : t + 1] += torch.where(vt > 0, grad_kv_p / vt_p, grad_kv_m / -vt_m)
grad_ln_alpha_wkv_p = grad_ln_wkv_p / (1 + (1 / e_num_p))
grad_ln_alpha_wkv_m = grad_ln_wkv_m / (1 + (1 / e_num_m))
grad_ln_beta_wkv = -grad_ln_wkv_p / (1 + (1 / e_den)) - grad_ln_wkv_m / (1 + (1 / e_den))
# Backpropagates alpha gradients.
e_alpha_p = torch.exp(kt + ln_v_p + w - ln_alpha_p_prev)
e_alpha_m = torch.exp(kt + ln_v_m + w - ln_alpha_m_prev)
grad_wa_p = grad_ln_alpha_p / (1 + e_alpha_p)
grad_wa_m = grad_ln_alpha_m / (1 + e_alpha_m)
grad_w -= (grad_wa_p + grad_wa_m).flatten(0, -2).sum(0)
grad_kv_p = grad_ln_alpha_p / (1 + (1 / e_alpha_p))
grad_kv_m = grad_ln_alpha_m / (1 + (1 / e_alpha_m))
grad_k[:, t : t + 1] += grad_kv_p + grad_kv_m
grad_v[:, t : t + 1] += torch.where(vt > 0, grad_kv_p / vt_p, -grad_kv_m / vt_m)
# Backpropagates beta gradients.
e_beta = torch.exp(kt + w - ln_beta_prev)
grad_wb = grad_ln_beta / (1 + e_beta)
grad_w -= grad_wb.flatten(0, -2).sum(0)
grad_k[:, t : t + 1] += grad_ln_beta / (1 + (1 / e_beta))
# Compute gradients for log alpha and log beta.
grad_ln_alpha_p = grad_wa_p + grad_ln_alpha_wkv_p
grad_ln_alpha_m = grad_wa_m + grad_ln_alpha_wkv_m
grad_ln_beta = grad_wb + grad_ln_beta_wkv
return grad_w, grad_u, grad_k, grad_v, torch.stack((grad_ln_alpha_p, grad_ln_alpha_m, grad_ln_beta), dim=1)
[docs]class WkvLogSpace(Function):
[docs] @staticmethod
def forward(
ctx: FunctionCtx,
w: Tensor,
u: Tensor,
k: Tensor,
v: Tensor,
state: Tensor,
) -> tuple[Tensor, Tensor]:
wkv, state_out = wkv_log_space_forward(w, u, k, v, state)
ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1])
return wkv, state_out[:, :, -1:]
[docs] @staticmethod
@once_differentiable
def backward(
ctx: FunctionCtx,
grad_wkv: Tensor,
grad_state: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors)
return wkv_log_space_backward(w, u, k, v, state, grad_wkv, grad_state)
[docs]def initial_state_log_space(emb_dim: int) -> Tensor:
return torch.full((1, 3, 1, emb_dim), float("-inf"))
[docs]def wkv_log_space(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
"""Runs the core WKV computation.
Args:
w: The decay tensor, with shape (D)
u: The output multiplier tensor, with shape (D)
k: The K tensor, with shape (B, T, D)
v: The V tensor, with shape (B, T, D)
state: The state tensor, with shape (B, 3, D), consisting of the
alpha plus, alpha minus and beta tensors, each with shape (B, 1, D)
Returns:
The WKV tensor, with shape (B, T, D), and the next state, with shape
(B, 2, D), consisting of the next alpha plus, alpha minus and beta
tensors, each with shape (B, 1, D)
"""
return WkvLogSpace.apply(w, u, k, v, state)
[docs]def get_wkv_fn(key: WkvFnKey) -> Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], tuple[Tensor, Tensor]]:
match key:
case "eps":
return wkv_with_eps
case "log":
return wkv_log_space
case _:
raise ValueError(f"Unsupported key: {key}")
[docs]def get_wkv_fn_cuda(key: WkvFnKey) -> Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], tuple[Tensor, Tensor]]:
if not supports_triton():
return get_wkv_fn(key)
from pretrained.triton.rwkv_kernel import wkv_triton_log_space, wkv_triton_with_eps
match key:
case "eps":
return wkv_triton_with_eps
case "log":
return wkv_triton_log_space
case _:
raise ValueError(f"Unsupported key: {key}")
[docs]def get_default_wkv_fn_key() -> WkvFnKey:
if "WKV_FN" in os.environ:
assert (wkv_fn_str := os.environ["WKV_FN"]) in get_args(WkvFnKey), f"Unsupported WKV_FN: {wkv_fn_str}"
return cast(WkvFnKey, wkv_fn_str)
warnings.warn("Using default WKV_FN: eps")
return "eps"
[docs]class Attention(nn.Module):
init_x: Tensor
init_state: Tensor
def __init__(
self,
dim: int,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
freeze: bool = False,
wkv_key: WkvFnKey | None = None,
) -> None:
super().__init__()
self.time_decay = nn.Parameter(torch.ones(dim))
self.time_first = nn.Parameter(torch.ones(dim))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, dim))
self.time_mix_v = nn.Parameter(torch.ones(1, 1, dim))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, dim))
if freeze:
self.time_decay.requires_grad_(False)
self.time_first.requires_grad_(False)
self.time_mix_k.requires_grad_(False)
self.time_mix_v.requires_grad_(False)
self.time_mix_r.requires_grad_(False)
self.key = maybe_lora(nn.Linear(dim, dim, False), lora_rank, lora_alpha, lora_dropout, freeze=freeze)
self.value = maybe_lora(nn.Linear(dim, dim, False), lora_rank, lora_alpha, lora_dropout, freeze=freeze)
self.receptance = maybe_lora(nn.Linear(dim, dim, False), lora_rank, lora_alpha, lora_dropout, freeze=freeze)
self.output = maybe_lora(nn.Linear(dim, dim, False), lora_rank, lora_alpha, lora_dropout, freeze=freeze)
if wkv_key is None:
wkv_key = get_default_wkv_fn_key()
self.wkv_fn = get_wkv_fn(wkv_key)
self.wkv_fn_cuda = get_wkv_fn_cuda(wkv_key)
self.register_buffer("init_x", torch.zeros(1, 1, dim), persistent=False)
self.register_buffer("init_state", initial_state_with_eps(dim), persistent=False)
[docs] def time_shift(self, last_x: Tensor, x: Tensor) -> Tensor:
_, tsz, _ = x.shape
if tsz > 1:
last_x = torch.cat((last_x, x[..., :-1, :]), dim=-2)
return last_x
[docs] def forward(self, x: Tensor, state: AttentionState | None) -> tuple[Tensor, AttentionState]:
bsz, _, _ = x.shape
if state is None:
last_x = self.init_x.repeat_interleave(bsz, dim=0)
last_state = self.init_state.repeat_interleave(bsz, dim=0)
else:
last_x, last_state = state
last_x = self.time_shift(last_x, x)
k = self.key(x * self.time_mix_k + last_x * (1 - self.time_mix_k))
v = self.value(x * self.time_mix_v + last_x * (1 - self.time_mix_v))
r = self.receptance(x * self.time_mix_r + last_x * (1 - self.time_mix_r))
sr = torch.sigmoid(r)
w, u = self.time_decay, self.time_first
w = torch.exp(w)
wkv_fn = self.wkv_fn_cuda if x.is_cuda else self.wkv_fn
wkv, next_state = wkv_fn(w, u, k, v, last_state)
rwkv = wkv * sr
return self.output(rwkv), (x[..., -1:, :], next_state)
[docs]class FeedForward(nn.Module):
init_state: Tensor
def __init__(
self,
dim: int,
ffn_dim: int,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
freeze: bool = False,
) -> None:
super().__init__()
self.time_mix_k = nn.Parameter(torch.ones(1, 1, dim))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, dim))
if freeze:
self.time_mix_k.requires_grad_(False)
self.time_mix_r.requires_grad_(False)
self.key = maybe_lora(nn.Linear(dim, ffn_dim, False), lora_rank, lora_alpha, lora_dropout, freeze=freeze)
self.receptance = maybe_lora(nn.Linear(dim, dim, False), lora_rank, lora_alpha, lora_dropout, freeze=freeze)
self.value = maybe_lora(nn.Linear(ffn_dim, dim, False), lora_rank, lora_alpha, lora_dropout, freeze=freeze)
self.register_buffer("init_state", torch.zeros(1, 1, dim), persistent=False)
[docs] def time_shift(self, last_x: Tensor, x: Tensor) -> Tensor:
_, tsz, _ = x.shape
if tsz > 1:
last_x = torch.cat((last_x, x[..., :-1, :]), dim=-2)
return last_x
[docs] def forward(self, x: Tensor, state: FeedForwardState | None = None) -> tuple[Tensor, FeedForwardState]:
bsz = x.shape[0]
last_x = self.time_shift(self.init_state.repeat(bsz, 1, 1) if state is None else state, x)
k = self.key(x * self.time_mix_k + last_x * (1 - self.time_mix_k))
r = self.receptance(x * self.time_mix_r + last_x * (1 - self.time_mix_r))
vk = self.value(F.relu(k) ** 2)
return torch.sigmoid(r) * vk, x[..., -1:, :]
[docs]class Block(nn.Module):
def __init__(
self,
emb_dim: int,
pre_norm: bool,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
lora_attn: bool = True,
lora_ffn: bool = True,
freeze_layer_norm: bool = False,
freeze_attn: bool = False,
freeze_ffn: bool = False,
use_checkpointing: bool = False,
wkv_key: WkvFnKey | None = None,
) -> None:
super().__init__()
self.ln0 = nn.LayerNorm(emb_dim) if pre_norm else None
self.ln1 = nn.LayerNorm(emb_dim)
self.ln2 = nn.LayerNorm(emb_dim)
self.use_checkpointing = use_checkpointing
if freeze_layer_norm:
if self.ln0 is not None:
self.ln0.requires_grad_(False)
self.ln1.requires_grad_(False)
self.ln2.requires_grad_(False)
self.att = Attention(
emb_dim,
lora_rank=lora_rank if lora_attn else None,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
freeze=freeze_attn,
wkv_key=wkv_key,
)
self.ffn = FeedForward(
emb_dim,
emb_dim * 4,
lora_rank=lora_rank if lora_ffn else None,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
freeze=freeze_ffn,
)
[docs] def run_attn(self, x: Tensor, state: State | None = None) -> tuple[Tensor, AttentionState]:
return self.att.forward(self.ln1(x), None if state is None else state[0])
[docs] def run_ffn(self, x: Tensor, state: State | None = None) -> tuple[Tensor, FeedForwardState]:
return self.ffn.forward(self.ln2(x), None if state is None else state[1])
[docs] def forward(self, x: Tensor, state: State | None = None) -> tuple[Tensor, State]:
if self.ln0 is not None:
x = self.ln0(x)
if self.use_checkpointing:
dx, att_state_out = torch.utils.checkpoint.checkpoint(self.run_attn, x, state)
x = x + dx
dx, ffn_state_out = torch.utils.checkpoint.checkpoint(self.run_ffn, x, state)
x = x + dx
else:
dx, att_state_out = self.run_attn(x, state)
x = x + dx
dx, ffn_state_out = self.run_ffn(x, state)
x = x + dx
return x, (att_state_out, ffn_state_out)
[docs]class RwkvStack(nn.Module):
"""Defines a stack of RWKV modules.
Parameters:
emb_dim: The number of embedding dimensions in each block
num_layers: The number of layers in the stack
use_checkpointing: Whether to use checkpointing
wkv_key: The WKV algorithm to use
Inputs:
x: The input tensor, with shape ``(B, T, D)``
state: The previous state
Outputs:
The output tensor, with shape ``(B, T, D)``, and the next state
"""
def __init__(
self,
emb_dim: int,
num_layers: int,
use_checkpointing: bool = False,
wkv_key: WkvFnKey | None = None,
) -> None:
super().__init__()
self.blocks = nn.ModuleList(
[
Block(
emb_dim,
pre_norm=i == 0,
use_checkpointing=use_checkpointing,
wkv_key=wkv_key,
)
for i in range(num_layers)
]
)
[docs] def forward(self, x: Tensor, state: list[State] | None = None) -> tuple[Tensor, list[State]]:
state_out: list[State] = []
for i, block in enumerate(self.blocks):
x, state_out_i = block(x, None if state is None else state[i])
state_out.append(state_out_i)
return x, state_out
[docs]class Rwkv(nn.Module):
def __init__(
self,
emb_dim: int,
num_tokens: int,
num_layers: int,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
lora_embeddings: bool = True,
lora_linear: bool = True,
lora_top_k_blocks: int | None = None,
lora_attn: bool = True,
lora_ffn: bool = True,
freeze_non_lora: bool = False,
freeze_layer_norm: bool | None = None,
freeze_attn: bool | None = None,
freeze_ffn: bool | None = None,
use_checkpointing: bool = False,
wkv_key: WkvFnKey | None = None,
) -> None:
super().__init__()
if lora_rank is None:
freeze_non_lora = False
if freeze_layer_norm is None:
freeze_layer_norm = freeze_non_lora
if freeze_attn is None:
freeze_attn = freeze_non_lora
if freeze_ffn is None:
freeze_ffn = freeze_non_lora
if lora_top_k_blocks is None:
min_block = 0
else:
min_block = num_layers - lora_top_k_blocks
self.emb = maybe_lora(
nn.Embedding(num_tokens, emb_dim),
lora_rank if lora_embeddings else None,
lora_alpha,
lora_dropout,
freeze=freeze_non_lora,
)
blocks = [
Block(
emb_dim,
i == 0,
lora_rank=lora_rank if i >= min_block else None,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_attn=lora_attn,
lora_ffn=lora_ffn,
freeze_layer_norm=freeze_layer_norm,
freeze_attn=freeze_attn,
freeze_ffn=freeze_ffn,
use_checkpointing=use_checkpointing,
wkv_key=wkv_key,
)
for i in range(num_layers)
]
self.blocks = nn.ModuleList(blocks)
self.ln_out = nn.LayerNorm(emb_dim)
if freeze_layer_norm:
self.ln_out.requires_grad_(False)
self.head = maybe_lora(
nn.Linear(emb_dim, num_tokens, bias=False),
lora_rank if lora_linear else None,
lora_alpha,
lora_dropout,
freeze=freeze_non_lora,
)
[docs] def tensor_to(self, x: Tensor) -> Tensor:
ref_tensor = self.head.weight
if x.is_floating_point():
return x.to(ref_tensor)
return x.to(ref_tensor.device)
[docs] def forward(
self,
tokens: Tensor,
states_in: list[State] | None = None,
return_logits: bool = False,
) -> tuple[Tensor, list[State]]:
x = self.emb(tokens)
states_out: list[State] = []
for i, block in enumerate(self.blocks):
x, state_out = block(x, None if states_in is None else states_in[i])
states_out.append(state_out)
x = self.head(self.ln_out(x))
if return_logits:
return x, states_out
e_x = torch.exp(x - torch.max(x))
probs = e_x / e_x.sum()
return probs, states_out
[docs] def predictor(self) -> "RwkvPredictor":
return RwkvPredictor(self)
[docs]def get_tokenizer() -> Any:
try:
from tokenizers import Tokenizer
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError("Install the `tokenizers` package: `pip install tokenizers`")
with Timer("downloading tokenizer"):
tokenizer_path = ensure_downloaded(TOKENIZER_URL, "rwkv", "tokenizer.json")
return Tokenizer.from_file(str(tokenizer_path))
[docs]class RwkvPredictor:
def __init__(self, rwkv_model: Rwkv) -> None:
"""Provides an API for sampling from the RWKV model.
Args:
rwkv_model: The RWKV model to use for sampling.
"""
super().__init__()
self.tokenizer = get_tokenizer()
self.model = rwkv_model
[docs] def sample_probs(self, probs: Tensor, temperature: float = 1.0, top_p: float = 0.85) -> Tensor:
try:
probs = probs ** (1 / temperature)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True) + 1e-6)
next_token = torch.multinomial(probs_sort.squeeze(-3), num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token[..., None, :, :]).squeeze(-1)
return next_token
except Exception:
logger.exception("Error sampling from probabilities.")
return probs.new_zeros(probs.shape[:-1], dtype=torch.long)
[docs] @torch.no_grad()
def generate(
self,
prompt: str | Tensor,
max_len: int = 256,
temperature: float = 1.0,
top_p: float = 0.85,
end_toks: Sequence[int] | None = None,
end_strs: Sequence[str] | None = None,
) -> Iterator[str]:
if isinstance(prompt, str):
prompt = torch.tensor([self.tokenizer.encode(prompt).ids])
assert prompt.dim() == 2 and prompt.shape[0] == 1
probs, state = self.model.forward(self.model.tensor_to(prompt))
probs = probs[:, -1:]
end_toks_set = set() if end_toks is None else set(end_toks)
end_strs_set = [] if end_strs is None else list(end_strs)
for i in range(max_len):
token = self.sample_probs(probs, temperature=temperature, top_p=top_p)
if token in end_toks_set:
break
token_str = self.tokenizer.decode([token.item()])
yield token_str
if any(e in token_str for e in end_strs_set):
break
if i < max_len - 1:
probs, state = self.model(self.model.tensor_to(torch.tensor([[token]])), state)
[docs]def pretrained_rwkv(
key: PretrainedRwkvKey,
*,
device: base_device | None = None,
lora_rank: int | None = None,
lora_alpha: float = 1.0,
lora_dropout: float = 0.0,
lora_embeddings: bool = True,
lora_linear: bool = True,
lora_top_k_blocks: int | None = None,
lora_attn: bool = True,
lora_ffn: bool = True,
freeze_non_lora: bool = False,
freeze_layer_norm: bool | None = None,
freeze_attn: bool | None = None,
freeze_ffn: bool | None = None,
use_checkpointing: bool = False,
empty: bool = False,
wkv_key: WkvFnKey | None = None,
) -> Rwkv:
"""Returns a pretrained RWKV model.
Args:
key: The key of the pretrained model to load.
device: The device to load the model onto. If None, the model will be
loaded onto the device returned by ``detect_device()``.
lora_rank: The rank of the LoRA decomposition to use.
lora_alpha: The alpha parameter of the LoRA decomposition.
lora_dropout: The dropout rate to use in the LoRA decomposition.
lora_embeddings: Whether to use LoRA for the embedding matrices.
lora_linear: Whether to use LoRA for the linear layers.
lora_top_k_blocks: The number of top-k blocks to use in the LoRA
decomposition.
lora_attn: Whether to use LoRA for the attention layers.
lora_ffn: Whether to use LoRA for the feed-forward layers.
freeze_non_lora: Whether to freeze the non-LoRA parameters. This value
will override the other freeze parameters if they are None.
freeze_layer_norm: Whether to freeze the layer normalization parameters.
freeze_attn: Whether to freeze the attention parameters.
freeze_ffn: Whether to freeze the feed-forward parameters.
use_checkpointing: Whether to use checkpointing to reduce memory usage.
empty: Whether to return an empty model with the same structure as the
pretrained model.
wkv_key: The choice of WKV function to use. They are mathematically
equivalent, but with different behaviors regarding numerical
stability.
Returns:
The pretrained RWKV model.
"""
device = detect_device() if device is None else device
model_args = PRETRAINED_MODEL_SIZES[key]
with Timer("building model skeleton", spinner=True), init_empty_weights():
model = Rwkv(
emb_dim=model_args.emb_dim,
num_tokens=50277,
num_layers=model_args.num_layers,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_embeddings=lora_embeddings,
lora_linear=lora_linear,
lora_top_k_blocks=lora_top_k_blocks,
lora_attn=lora_attn,
lora_ffn=lora_ffn,
freeze_non_lora=freeze_non_lora,
freeze_layer_norm=freeze_layer_norm,
freeze_attn=freeze_attn,
freeze_ffn=freeze_ffn,
use_checkpointing=use_checkpointing,
wkv_key=wkv_key,
)
if empty:
model._apply(meta_to_empty_func(torch.device("cpu"), torch.bfloat16))
device.module_to(model)
reset_lora_weights_(model)
return model
with Timer("downloading checkpoint"):
ckpt_path = ensure_downloaded(model_args.url, "rwkv", f"{key}.pth", sha256=model_args.sha256)
with Timer("loading model checkpoint", spinner=True):
ckpt = torch.load(ckpt_path, map_location="cpu")
# Build the transformer and loads the checkpoint.
with Timer("loading state dict", spinner=True):
model._apply(meta_to_empty_func(torch.device("cpu"), torch.bfloat16))
model.load_state_dict(ckpt)
device.module_to(model)
reset_lora_weights_(model)
return model
[docs]def test_rwkv_adhoc() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("size", type=str, choices=get_args(PretrainedRwkvKey))
parser.add_argument("prompt", type=str, nargs="?")
parser.add_argument("-t", "--tsz", type=int, default=128)
parser.add_argument("-m", "--temperature", type=float, default=1.0)
parser.add_argument("-p", "--top-p", type=float, default=0.85)
parser.add_argument("-e", "--end-tok", type=str, nargs="+", default=[])
parser.add_argument("-s", "--sep", type=str, default="")
parser.add_argument("-y", "--empty", action="store_true")
args = parser.parse_args()
configure_logging()
model = pretrained_rwkv(args.size, empty=args.empty)
predictor = model.predictor()
def generate_for_prompt(prompt: str) -> None:
print(prompt, end="")
start_time: float | None = None
num_tokens = 0
for token in predictor.generate(
prompt,
max_len=args.tsz,
temperature=args.temperature,
top_p=args.top_p,
end_strs=args.end_tok,
):
print(token, end=args.sep, flush=True)
if start_time is None:
start_time = time.time()
num_tokens += 1
print()
end_time = time.time()
if start_time is not None:
time_delta = end_time - start_time
print(f"Time taken: {num_tokens} / {time_delta:.2f}s = {num_tokens / time_delta:.2f} tokens per second")
if args.prompt:
if Path(args.prompt).exists():
with open(args.prompt, "r") as f:
generate_for_prompt(f.read().strip())
else:
generate_for_prompt(args.prompt)
else:
prompt = input("Prompt: ")
while prompt:
generate_for_prompt(prompt)
prompt = input("Prompt: ")
if __name__ == "__main__":
# python -m pretrained.rwkv
test_rwkv_adhoc()