"""Defines a simple API for using the BLIP model.
This code provides a simple API for interacting with BLIP, referencing the
original repository `here <https://github.com/salesforce/BLIP>`_.
.. highlight:: python
.. code-block:: python
from pretrained.blip import pretrained_blip
model = pretrained_blip("ViT-B")
predictor = model.predictor()
image = PIL.Image.open(image_path)
embedding = predictor.predict(image)
The choices for the model key are:
- ``ViT-B``
- ``ViT-B-CapFilt``
- ``ViT-L``
"""
import argparse
import logging
from dataclasses import dataclass
from typing import Literal, Sequence, Union, cast, get_args
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as V
from ml.core.config import conf_field
from ml.models.activations import ActivationType, get_activation
from ml.models.init import init_
from ml.models.norms import NormType, get_norm_linear
from ml.utils.checkpoint import ensure_downloaded
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from ml.utils.large_models import init_empty_weights, meta_to_empty_func
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer
from omegaconf import MISSING
from torch import Tensor, nn
logger = logging.getLogger(__name__)
RawImage = Union[PIL.Image.Image, Tensor, np.ndarray]
PretrainedBlipKey = Literal[
# ViT-B models
"ViT-B",
"ViT-B-Coco",
"ViT-B-Flickr30k",
"ViT-B-VQA",
"ViT-B-NLVR2",
# ViT-B-CapFilt models
"ViT-B-CapFilt",
"ViT-B-CapFilt-Coco",
"ViT-B-CapFilt-VQA",
# ViT-L models
"ViT-L",
"ViT-L-Coco-Retrieval",
"ViT-L-Flickr30k",
"ViT-L-Coco-Captioning",
]
[docs]def cast_pretrained_blip_key(s: str) -> PretrainedBlipKey:
if s not in get_args(PretrainedBlipKey):
raise KeyError(f"Invalid BLIP size: {s} Expected one of: {get_args(PretrainedBlipKey)}")
return cast(PretrainedBlipKey, s)
[docs]@dataclass
class ViTParams:
img_size: int | tuple[int, int] = conf_field(224, help="Total image size")
patch_size: int = conf_field(16, help="Size of image patches")
in_chans: int = conf_field(3, help="Number of input channels")
embed_dim: int = conf_field(MISSING, help="Number of embedding dimensions")
depth: int = conf_field(MISSING, help="Number of transformer layers")
num_heads: int = conf_field(MISSING, help="Number of attention heads")
mlp_ratio: float = conf_field(4.0, help="Ratio of MLP hidden dim to embedding dim")
qkv_bias: bool = conf_field(True, help="If True, use bias for qkv projection")
drop_rate: float = conf_field(0.0, help="Dropout rate")
attn_drop_rate: float = conf_field(0.0, help="Attention dropout rate")
norm_type: NormType = conf_field("layer_affine", help="Normalization type")
[docs]@dataclass
class TextParams:
num_tokens: int = conf_field(MISSING, help="Number of tokens")
[docs]@dataclass
class ModelParams:
url: str = conf_field(MISSING, help="URL to pre-trained weights")
vit: ViTParams = conf_field(MISSING, help="ViT parameters")
VIT_BASE = ViTParams(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
)
VIT_BASE_CAPFILT = ViTParams(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
)
VIT_LARGE = ViTParams(
img_size=224,
patch_size=16,
embed_dim=1024,
depth=16,
num_heads=24,
)
PRETRAINED_MODEL_SIZES: dict[PretrainedBlipKey, ModelParams] = {
# ViT-B models
"ViT-B": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth",
vit=VIT_BASE,
),
"ViT-B-Coco": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth",
vit=VIT_BASE,
),
"ViT-B-Flickr30k": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth",
vit=VIT_BASE,
),
"ViT-B-VQA": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth",
vit=VIT_BASE,
),
"ViT-B-NLVR2": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth",
vit=VIT_BASE,
),
# ViT-B-CapFilt models
"ViT-B-CapFilt": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth",
vit=VIT_BASE_CAPFILT,
),
"ViT-B-CapFilt-Coco": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth",
vit=VIT_BASE_CAPFILT,
),
"ViT-B-CapFilt-VQA": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth",
vit=VIT_BASE_CAPFILT,
),
# ViT-L models
"ViT-L": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth",
vit=VIT_LARGE,
),
"ViT-L-Coco-Retrieval": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth",
vit=VIT_LARGE,
),
"ViT-L-Flickr30k": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth",
vit=VIT_LARGE,
),
"ViT-L-Coco-Captioning": ModelParams(
url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
vit=VIT_LARGE,
),
}
[docs]class PatchEmbed(nn.Module):
"""Gets patch embeddings for an input image.
Parameters:
img_size: The size of the input image.
patch_size: The size of the patches to extract.
flatten: Whether to flatten the output.
in_chans: The number of input channels.
embed_dim: The embedding dimension.
norm_type: The type of normalization to use.
bias: Whether to use a bias term.
Inputs:
x: The input image, of shape ``(B, C, H, W)``.
Outputs:
The patch embeddings, of shape ``(B, num_patches, embed_dim)``
if ``flatten=True``, or ``(B, embed_dim, H // patch_size, W // patch_size)``
otherwise.
"""
__constants__ = ["patch_size", "img_size", "num_patches", "flatten", "in_chans"]
def __init__(
self,
img_size: int | tuple[int, int] = 224,
patch_size: int | tuple[int, int] = 16,
flatten: bool = True,
in_chans: int = 3,
embed_dim: int = 768,
norm_type: NormType = "no_norm",
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size
assert self.img_size[0] % self.patch_size[0] == 0 and self.img_size[1] % self.patch_size[1] == 0
self.num_patches = (self.img_size[0] // self.patch_size[0]) * (self.img_size[1] // self.patch_size[1])
self.flatten = flatten
self.in_chans = in_chans
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = get_norm_linear(norm_type, dim=embed_dim)
[docs] def forward(self, x: Tensor) -> Tensor:
_, _, hei, wid = x.shape
assert hei == self.img_size[0] and wid == self.img_size[1]
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # (B, C, H, W) -> (B, H * W, C)
x = self.norm(x)
return x
[docs]class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_type: ActivationType = "gelu",
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = get_activation(act_type)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
[docs] def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
[docs]class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_gradients = None
self.attention_map = None
[docs] def forward(self, x: Tensor) -> Tensor:
bsz, tsz, chans = x.shape
qkv = self.qkv(x).reshape(bsz, tsz, 3, self.num_heads, chans // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
x = F.scaled_dot_product_attention(q, k, v, is_causal=False).transpose(1, 2).reshape(bsz, tsz, chans)
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop: float = 0.0,
attn_drop: float = 0.0,
act_type: ActivationType = "gelu",
norm_type: NormType = "layer_affine",
) -> None:
super().__init__()
self.norm1 = get_norm_linear(norm_type, dim=dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.norm2 = get_norm_linear(norm_type, dim=dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_type=act_type, drop=drop)
[docs] def forward(self, x: Tensor) -> Tensor:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
[docs]class Blip(nn.Module):
def __init__(self, model_args: ModelParams) -> None:
super().__init__()
self.visual_encoder = VisionTransformer(model_args.vit)
[docs] def forward(self, img: Tensor) -> Tensor:
return self.visual_encoder(img)
[docs] def predictor(self, device: base_device | None = None) -> "BlipPredictor":
return BlipPredictor(self, device=device)
[docs]class BlipPredictor:
def __init__(self, blip_model: Blip, *, device: base_device | None = None) -> None:
"""Provides an API for sampling from the BLIP model.
Args:
blip_model: The BLIP model to use for sampling.
device: The device to use for sampling. If None, the device will be
automatically detected.
"""
super().__init__()
self.device = detect_device() if device is None else device
self.device.module_to(blip_model)
self.model = blip_model
[docs] @torch.no_grad()
def predict(self, image: RawImage | Sequence[RawImage]) -> Tensor:
"""Gets the embedding for a given image.
Args:
image: The image to get an embedding for.
Returns:
The embedding tensor, with shape ``(embed_dim)``,
or ``(bsz, embed_dim)`` if ``image`` is a sequence.
"""
if isinstance(image, Sequence):
x = torch.stack([self.get_input_image(img) for img in image])
x = self.model(x)[:, 0]
else:
x = self.get_input_image(image)
x = self.model(x.unsqueeze(0)).squeeze(0)[0]
return x
[docs]def pretrained_blip(key: PretrainedBlipKey, *, device: base_device | None = None) -> Blip:
device = detect_device() if device is None else device
model_args = PRETRAINED_MODEL_SIZES[key]
with Timer("downloading checkpoint"):
ckpt_path = ensure_downloaded(model_args.url, "blip", f"{key}.pth")
with Timer("loading model checkpoint", spinner=True):
ckpt = torch.load(ckpt_path, map_location="cpu")["model"]
# Filters jut the visual encoder weights.
ckpt = {k: v for k, v in ckpt.items() if k.startswith("visual_encoder.")}
with Timer("building model skeleton", spinner=True), init_empty_weights():
model = Blip(model_args)
# Logs model summary.
total_params = sum(p.numel() for p in model.parameters())
logger.info("Model %s has %s parameters", key, f"{total_params:,}")
# Build the transformer and loads the checkpoint.
with Timer("loading state dict", spinner=True):
model._apply(meta_to_empty_func(device.get_device(), torch.half))
model.load_state_dict(ckpt, strict=True)
return model
[docs]def test_blip_adhoc() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("key", type=str, choices=get_args(PretrainedBlipKey))
args = parser.parse_args()
configure_logging()
# Gets an image of a peach from Wikipedia.
peach_url = "https://upload.wikimedia.org/wikipedia/commons/9/9e/Autumn_Red_peaches.jpg"
img_path = ensure_downloaded(peach_url, "peach.jpg", is_tmp=True)
image = PIL.Image.open(img_path)
model = pretrained_blip(args.key)
predictor = model.predictor()
# Outputs the embedding for a given image.
embedding = predictor.predict(image)
logger.info("Embedding shape: %s", embedding.shape)
if __name__ == "__main__":
# python -m pretrained.blip
test_blip_adhoc()