# mypy: disable-error-code="import"
"""Defines a simple API for using OpenAI's pretrained CLIP model.
.. highlight:: python
.. code-block:: python
from pretrained.clip import pretrained_clip
full_model = pretrained_clip("RN50", mode="all")
visual_model = pretrained_clip("RN50", mode="visual")
linguistic_model = pretrained_clip("RN50", mode="linguistic")
image = PIL.Image.open(image_path)
image_tensorizer = visual_model.get_preprocess()
image_tensor = image_tensorizer(image) # (3, 224, 224)
tokenizer = linguistic_model.get_tokenizer()
token_tensor = tokenizer.tokenizer(["A photo of a cat", "A photo of a dog"])
visual_model.encode_image(imgs) # (N, C)
linguistic_model.encode_text(token_tensor) # (N, C)
The choices for the model key are:
- ``RN50``: ResNet50 + Transformer
- ``RN101``: ResNet101 + Transformer
- ``RN50x4``: 4x ResNet50 + Transformer
- ``RN50x16``: 16x ResNet50 + Transformer
- ``RN50x64``: 64x ResNet50 + Transformer
- ``ViT_B_32``: ViT-B/32 + Transformer
- ``ViT_B_16``: ViT-B/16 + Transformer
- ``ViT_L_14``: ViT-L/14 + Transformer
- ``ViT_L_14_336px``: ViT-L/14 + Transformer (336px)
"""
import argparse
import functools
import gzip
import html
import logging
import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Literal, cast, get_args, overload
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
import torchvision
from ml.utils.checkpoint import ensure_downloaded
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer
from torch import Tensor, nn
logger = logging.getLogger(__name__)
URL_PREFIX = "https://openaipublic.azureedge.net/clip/models"
PretrainedClipSize = Literal[
"RN50",
"RN101",
"RN50x4",
"RN50x16",
"RN50x64",
"ViT_B_32",
"ViT_B_16",
"ViT_L_14",
"ViT_L_14_336px",
]
[docs]def cast_pretrained_clip_key(s: str) -> PretrainedClipSize:
if s not in get_args(PretrainedClipSize):
raise KeyError(f"Invalid CLIP size: {s} Expected one of: {get_args(PretrainedClipSize)}")
return cast(PretrainedClipSize, s)
PRETRAINED_MODELS: dict[PretrainedClipSize, str] = {
"RN50": f"{URL_PREFIX}/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": f"{URL_PREFIX}/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": f"{URL_PREFIX}/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": f"{URL_PREFIX}/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": f"{URL_PREFIX}/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT_B_32": f"{URL_PREFIX}/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT_B_16": f"{URL_PREFIX}/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT_L_14": f"{URL_PREFIX}/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT_L_14_336px": f"{URL_PREFIX}/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", # noqa: E501
}
CLIP_VOCABULARY = "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz"
MEAN, STD = (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
def _convert_image_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
return image.convert("RGB")
[docs]def pil_preprocess(n_px: int) -> Callable[[PIL.Image.Image], Tensor]:
return torchvision.transforms.Compose(
[
torchvision.transforms.Resize(n_px, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
torchvision.transforms.CenterCrop(n_px),
_convert_image_to_rgb,
torchvision.transforms.PILToTensor(),
torchvision.transforms.ConvertImageDtype(torch.float),
torchvision.transforms.Normalize(MEAN, STD),
],
)
[docs]def tensor_preprocess(n_px: int) -> Callable[[Tensor], Tensor]:
return torchvision.transforms.Compose(
[
torchvision.transforms.ConvertImageDtype(torch.float),
torchvision.transforms.Normalize(MEAN, STD),
],
)
[docs]@functools.lru_cache()
def default_bpe() -> str:
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
[docs]@functools.lru_cache()
def bytes_to_unicode() -> dict[int, str]:
"""Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings. This means you need a
large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around
5K for decent coverage. This is a signficant percentage of your normal,
say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8
bytes and unicode strings. And avoids mapping to whitespace/control
characters the BPE code barfs on.
Returns:
Mapping from UTF-8 byte to unicode string.
"""
bs = list(range(ord("!"), ord("~") + 1))
bs += list(range(ord("¡"), ord("¬") + 1))
bs += list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
css = [chr(n) for n in cs]
return dict(zip(bs, css))
[docs]def get_pairs(word: tuple[str, ...]) -> set[tuple[str, str]]:
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
[docs]@functools.lru_cache()
def text_clean_func(lower: bool = True) -> Callable[[str], str]:
try:
import ftfy
ftfy_fix: Callable[[str], str] = ftfy.fix_text
except ImportError:
logger.warning("Please install ftfy: pip install ftfy")
def ftfy_fix(x: str) -> str:
return x
def _clean(text: str) -> str:
text = ftfy_fix(text)
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text)
text = text.strip()
if lower:
text = text.lower()
return text
return _clean
[docs]class ClipTokenizer:
def __init__(self) -> None:
bpe_path = ensure_downloaded(CLIP_VOCABULARY, "clip", "CLIP_vocabulary.txt.gz")
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges_unzipped = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges_unzipped = merges_unzipped[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges_unzipped]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
self.pat = re.compile(
r"<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\w+|\d|[^\s\w\d]+",
re.IGNORECASE,
)
[docs] def bpe(self, token: str) -> str:
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word_list: list[str] = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word_list.extend(word[i:j])
i = j
except Exception:
new_word_list.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word_list.append(first + second)
i += 2
else:
new_word_list.append(word[i])
i += 1
new_word = tuple(new_word_list)
word = new_word
if len(word) == 1:
break
pairs = get_pairs(word)
word_str = " ".join(word)
self.cache[token] = word_str
return word_str
[docs] def encode(self, text: str) -> list[int]:
bpe_tokens: list[int] = []
text = text_clean_func()(text)
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
[docs] def decode(self, tokens: list[int]) -> str:
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
return text
[docs] def tokenize(
self,
texts: str | list[str],
context_length: int = 77,
truncate: bool = False,
) -> Tensor:
if isinstance(texts, str):
texts = [texts]
sot_token = self.encoder["<|startoftext|>"]
eot_token = self.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, : len(tokens)] = torch.tensor(tokens)
return result
[docs]class Bottleneck(nn.Module):
expansion = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False, device=device, dtype=dtype)
self.bn1 = nn.BatchNorm2d(planes, device=device, dtype=dtype)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False, device=device, dtype=dtype)
self.bn2 = nn.BatchNorm2d(planes, device=device, dtype=dtype)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False, device=device, dtype=dtype)
self.bn3 = nn.BatchNorm2d(planes * self.expansion, device=device, dtype=dtype)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# Downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
pool = nn.AvgPool2d(stride)
conv = nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False, device=device, dtype=dtype)
bn = nn.BatchNorm2d(planes * self.expansion, device=device, dtype=dtype)
self.downsample = nn.Sequential(OrderedDict([("-1", pool), ("0", conv), ("1", bn)]))
[docs] def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
[docs]class AttentionPool2d(nn.Module):
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int | None = None,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
if dtype is None:
pos_emb = torch.randn(spacial_dim**2 + 1, embed_dim, device=device) / embed_dim**0.5
else:
pos_emb = torch.randn(spacial_dim**2 + 1, embed_dim, device=device, dtype=dtype) / embed_dim**0.5
self.positional_embedding = nn.Parameter(pos_emb)
self.k_proj = nn.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.q_proj = nn.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.v_proj = nn.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
self.num_heads = num_heads
[docs] def forward(self, x: Tensor) -> Tensor:
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x.squeeze(0)
[docs]class ModifiedResNet(nn.Module):
__constants__ = ["input_resolution", "output_dim"]
def __init__(
self,
layers: tuple[int, int, int, int],
output_dim: int,
heads: int,
input_resolution: int = 224,
width: int = 64,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""ResNet class that is similar to TorchVision's but with some changes.
- There are now 3 "stem" convolutions as opposed to 1, with an average
pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is
prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
Args:
layers: Layer counts for the four parts of the ResNet
output_dim: Number of final output dimensions
heads: Number of attention heads
input_resolution: Number of pixels in width and height directions
width: Hidden channel count
device: Default PyTorch device to use
dtype: Default PyTorch dtype to use
"""
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# The 3-layer stem
self.conv1 = nn.Conv2d(
3,
width // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False,
device=device,
dtype=dtype,
)
self.bn1 = nn.BatchNorm2d(width // 2, device=device, dtype=dtype)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
width // 2,
width // 2,
kernel_size=3,
padding=1,
bias=False,
device=device,
dtype=dtype,
)
self.bn2 = nn.BatchNorm2d(width // 2, device=device, dtype=dtype)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False, device=device, dtype=dtype)
self.bn3 = nn.BatchNorm2d(width, device=device, dtype=dtype)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# Residual layers
self._inplanes = width # This is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # The ResNet feature dimension
self.attnpool = AttentionPool2d(
input_resolution // 32,
embed_dim,
heads,
output_dim,
device=device,
dtype=dtype,
)
[docs] def initialize_parameters(self) -> None:
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def _make_layer(self, planes: int, blocks: int, stride: int = 1) -> nn.Module:
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
[docs] def stem(self, x: Tensor) -> Tensor:
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
[docs] def forward(self, x: Tensor) -> Tensor:
x = x.type(self.conv1.weight.dtype)
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
[docs]class QuickGELU(nn.Module):
[docs] def forward(self, x: Tensor) -> Tensor:
return x * torch.sigmoid(1.702 * x)
[docs]class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
attn_mask: Tensor | None = None,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head, device=device, dtype=dtype)
self.ln_1 = nn.LayerNorm(d_model, device=device, dtype=dtype)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4, device=device, dtype=dtype)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model, device=device, dtype=dtype)),
]
)
)
self.ln_2 = nn.LayerNorm(d_model, device=device, dtype=dtype)
self.attn_mask = attn_mask
[docs] def attention(self, x: Tensor) -> Tensor:
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
[docs] def forward(self, x: Tensor) -> Tensor:
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
[docs]class TextModel(nn.Module):
def __init__(
self,
embed_dim: int,
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.context_length = context_length
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
device=device,
dtype=dtype,
)
if dtype is None:
pos_emb = torch.empty(self.context_length, transformer_width, device=device)
text_proj = torch.empty(transformer_width, embed_dim, device=device)
else:
pos_emb = torch.empty(self.context_length, transformer_width, device=device, dtype=dtype)
text_proj = torch.empty(transformer_width, embed_dim, device=device, dtype=dtype)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width, device=device, dtype=dtype)
self.positional_embedding = nn.Parameter(pos_emb)
self.ln_final = nn.LayerNorm(transformer_width, device=device, dtype=dtype)
self.text_projection = nn.Parameter(text_proj)
[docs] def get_tokenizer(self) -> ClipTokenizer:
return ClipTokenizer()
[docs] def initialize_parameters(self) -> None:
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
[docs] def build_attention_mask(self) -> Tensor:
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
[docs] def forward(self, text: Tensor) -> Tensor:
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
[docs]class Clip(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: tuple[int, int, int, int] | int,
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.context_length = context_length
self.visual: ModifiedResNet | VisionTransformer
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width,
device=device,
dtype=dtype,
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
device=device,
dtype=dtype,
)
self.linguistic = TextModel(
embed_dim=embed_dim,
context_length=context_length,
vocab_size=vocab_size,
transformer_width=transformer_width,
transformer_heads=transformer_heads,
transformer_layers=transformer_layers,
device=device,
dtype=dtype,
)
if dtype is None:
logit_scale = torch.ones([], device=device) * np.log(1 / 0.07)
else:
logit_scale = torch.ones([], device=device, dtype=dtype) * np.log(1 / 0.07)
self.logit_scale = nn.Parameter(logit_scale)
self.initialize_parameters()
[docs] @torch.jit.ignore
def get_pil_preprocess(self) -> Callable[[PIL.Image.Image], Tensor]:
return pil_preprocess(self.visual.input_resolution)
[docs] @torch.jit.ignore
def get_tensor_preprocess(self) -> Callable[[Tensor], Tensor]:
return tensor_preprocess(self.visual.input_resolution)
[docs] def initialize_parameters(self) -> None:
if isinstance(self.visual, ModifiedResNet):
self.visual.initialize_parameters()
self.linguistic.initialize_parameters()
[docs] def build_attention_mask(self) -> Tensor:
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
[docs] def encode_image(self, image: Tensor) -> Tensor:
return self.visual(image)
[docs] def encode_text(self, text: Tensor) -> Tensor:
return self.linguistic(text)
[docs] def forward(self, image: Tensor, text: Tensor) -> tuple[Tensor, Tensor]:
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
[docs] def predictor(self, *, device: base_device | None = None) -> "ClipPredictor":
return ClipPredictor(self, device=device)
[docs]class ClipPredictor:
def __init__(self, clip_model: Clip, *, device: base_device | None = None) -> None:
"""Provides an API for doing predictions with a CLIP model.
Note that this module is not an `nn.Module`, so you can use it in your
module without worrying about storing all the weights on accident.
Args:
clip_model: The CLIP model to use for predictions
device: The device to use for predictions. If None, will use the
device returned by detect_device().
"""
super().__init__()
self.device = detect_device() if device is None else device
self.model = clip_model.eval()
self.device.module_to(self.model)
self.tokenizer = self.model.linguistic.get_tokenizer()
self.pil_preprocess = self.model.get_pil_preprocess()
self.tensor_preprocess = self.model.get_tensor_preprocess()
[docs] def predict_text(self, text: str | Tensor) -> Tensor:
tokens = text if isinstance(text, Tensor) else self.device.tensor_to(self.tokenizer.tokenize([text]))
return self.model.encode_text(tokens)
[docs] def predict_image(self, image: np.ndarray | PIL.Image.Image | Tensor) -> Tensor:
if isinstance(image, np.ndarray):
image_tensor = self.tensor_preprocess(self.device.tensor_to(torch.from_numpy(image)))
elif isinstance(image, PIL.Image.Image):
image_tensor = self.device.tensor_to(self.pil_preprocess(image))
elif isinstance(image, Tensor):
image_tensor = self.tensor_preprocess(self.device.tensor_to(image))
else:
raise NotImplementedError(f"Unsupported image type: {type(image)}")
return self.model.encode_image(image_tensor)
[docs]def convert_weights(model: nn.Module) -> None:
"""Convert applicable model parameters to fp16.
Args:
model: The model to convert
"""
def _convert_weights_to_fp16(mod: nn.Module) -> None:
if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Linear)):
mod.weight.data = mod.weight.data.half()
if mod.bias is not None:
mod.bias.data = mod.bias.data.half()
if isinstance(mod, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(mod, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(mod, name):
attr = getattr(mod, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
@overload
def pretrained_clip(
key: PretrainedClipSize | nn.Module,
mode: Literal["visual"],
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> ModifiedResNet | VisionTransformer:
...
@overload
def pretrained_clip(
key: PretrainedClipSize | nn.Module,
mode: Literal["linguistic"],
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> TextModel:
...
@overload
def pretrained_clip(
key: PretrainedClipSize | nn.Module,
mode: Literal["all"],
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Clip:
...
[docs]def pretrained_clip(
key: PretrainedClipSize | nn.Module,
mode: str,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Clip | ModifiedResNet | VisionTransformer | TextModel:
"""Builds the CLIP model from a state dictionary.
Args:
key: The model key to load, or another model to load weights from
mode: Default is to return all models, but can optionally return just
the visual or linguistic part of the model
device: The device for the model
dtype: The dtype for the model
Returns:
The constructed clip model, or just the visual or text branch
"""
assert mode in ("all", "visual", "linguistic")
if isinstance(key, nn.Module):
ckpt = key.state_dict()
else:
filepath = get_pretrained_path(key)
ckpt = torch.jit.load(filepath, map_location="cpu").state_dict()
vit = "visual.proj" in ckpt
vision_layers: tuple[int, int, int, int] | int
if vit:
vision_width = ckpt["visual.conv1.weight"].shape[0]
vision_layers = sum(1 for k in ckpt.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight"))
vision_patch_size = ckpt["visual.conv1.weight"].shape[-1]
grid_size = round((ckpt["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
vision_layers = cast(
tuple[int, int, int, int],
tuple(len(set(k.split(".")[2] for k in ckpt if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]),
)
vision_width = ckpt["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((ckpt["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width**2 + 1 == ckpt["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = ckpt["text_projection"].shape[1]
context_length = ckpt["positional_embedding"].shape[0]
vocab_size = ckpt["token_embedding.weight"].shape[0]
transformer_width = ckpt["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in ckpt if k.startswith("transformer.resblocks")))
model = Clip(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
device=device,
dtype=dtype,
)
for ckpt_key in ["input_resolution", "context_length", "vocab_size"]:
if ckpt_key in ckpt:
del ckpt[ckpt_key]
# Prepends `linguistic.` prefix to linguistic weights.
non_visual_keys = {k: v for k, v in ckpt.items() if not k.startswith("visual.")}
non_visual_keys.pop("logit_scale")
for k, v in non_visual_keys.items():
del ckpt[k]
ckpt[f"linguistic.{k}"] = v
convert_weights(model)
def get_ckpt_part(ckpt: dict[str, Any], prefix: str) -> dict[str, Any]:
return {k[len(prefix) :]: v for k, v in ckpt.items() if k.startswith(prefix)}
if mode == "visual":
model.visual.load_state_dict(get_ckpt_part(ckpt, "visual."))
return model.visual
elif mode == "linguistic":
model.linguistic.load_state_dict(get_ckpt_part(ckpt, "linguistic."))
return model.linguistic
else:
model.load_state_dict(ckpt)
return model
[docs]def get_pretrained_path(key: PretrainedClipSize) -> Path:
if key not in PRETRAINED_MODELS:
raise KeyError(f"Invalid CLIP model key {key}; choices are {list(PRETRAINED_MODELS.keys())}")
model_url = PRETRAINED_MODELS[key]
with Timer("downloading checkpoint"):
return ensure_downloaded(model_url, "clip", f"{key}_ckpt.pt")
[docs]def test_pretrained_model() -> None:
parser = argparse.ArgumentParser(description="Tests a pretraiend CLIP model")
parser.add_argument("key", type=str, choices=get_args(PretrainedClipSize))
args = parser.parse_args()
configure_logging()
# Gets an image of a peach from Wikipedia.
peach_url = "https://upload.wikimedia.org/wikipedia/commons/9/9e/Autumn_Red_peaches.jpg"
img_path = ensure_downloaded(peach_url, "peach.jpg", is_tmp=True)
peach_img = PIL.Image.open(img_path)
pos_desc = "A picture of an Autumn Red peach"
neg_desc = "An Instagram photo of a cute puppy"
# Loads the JIT'd model and the regular model.
auto_device = detect_device()
jit_model = cast(Clip, torch.jit.load(get_pretrained_path(cast(PretrainedClipSize, args.key)), map_location="cpu"))
model = pretrained_clip(jit_model, "all")
# Moves to the correct device.
auto_device.module_to(jit_model)
auto_device.module_to(model)
# Puts in eval mode.
jit_model.eval()
model.eval()
# Converts raw inputs to tensors.
img_tensorizer = model.get_pil_preprocess()
tokenizer = model.linguistic.get_tokenizer()
imgs = auto_device.tensor_to(img_tensorizer(peach_img)).unsqueeze(0).repeat_interleave(2, dim=0)
texts = auto_device.tensor_to(tokenizer.tokenize([pos_desc, neg_desc]))
with torch.no_grad():
img_out = model.encode_image(imgs)
text_out = model.encode_text(texts)
img_out_preds, text_out_preds = model(imgs, texts)
if auto_device.get_device().type == "cuda":
img_ref = jit_model.encode_image(imgs)
text_ref = jit_model.encode_text(texts)
img_ref_preds, text_ref_preds = jit_model(imgs, texts)
else:
logger.warning("JIT'd model requires CUDA to run")
img_ref, text_ref, img_ref_preds, text_ref_preds = img_out, text_out, img_out_preds, text_out_preds
# Checks the tensors against each other.
assert ((img_out - img_ref).abs().mean() / img_ref.abs().mean()).item() < 5e-2
assert ((text_out - text_ref).abs().mean() / text_ref.abs().mean()).item() < 5e-2
# Checks that the positive description scores more highly than the negative description.
assert (img_ref_preds[:, 0] > img_ref_preds[:, 1]).all().item()
# Checks model against JIT'd model.
assert ((img_out_preds - img_ref_preds).abs().max() / img_ref_preds.abs().mean()).item() < 5e-2
assert ((text_out_preds - text_ref_preds).abs().max() / text_ref_preds.abs().mean()).item() < 5e-2
logger.info("All adhoc tests passed!")
if __name__ == "__main__":
# python -m pretrained.clip
test_pretrained_model()