"""Defines a simple API for using Meta's pretrained Segment Anything model.
.. highlight:: python
.. code-block:: python
from pretrained.sam import pretrained_sam
model = pretrained_sam("ViT-B")
predictor = model.predictor()
image = PIL.Image.open(img_path)
predictor.set_image(np.array(image))
predictions, _, _ = predictor.predict()
single_mask = predictions[0] # Same shape as the original image.
Alternatively, you can run the script directly on an image:
.. code-block:: bash
python -m pretrained.sam ViT-B /path/to/image.jpg
The choices for the model key are:
- ``ViT-H``: ViT with 32 layers and 16 attention heads.
- ``ViT-L``: ViT with 24 layers and 16 attention heads.
- ``ViT-B``: ViT with 12 layers and 12 attention heads.
"""
import argparse
import copy
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Type, cast, get_args
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from ml.utils.checkpoint import ensure_downloaded
from ml.utils.device.auto import detect_device
from ml.utils.device.base import base_device
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer
from torch import Tensor, nn
from torchvision.transforms.functional import resize, to_pil_image
PretrainedSamSize = Literal["ViT-H", "ViT-L", "ViT-B"]
ImageFormat = Literal["RGB", "BGR"]
DEFAULT_PIXEL_MEAN = (123.675, 116.28, 103.53)
DEFAULT_PIXEL_STD = (58.395, 57.12, 57.375)
[docs]def cast_pretrained_sam_size(s: str) -> PretrainedSamSize:
if s not in get_args(PretrainedSamSize):
raise KeyError(f"Invalid SAM size: {s} Expected one of: {get_args(PretrainedSamSize)}")
return cast(PretrainedSamSize, s)
[docs]@dataclass
class PretrainedModelConfig:
url: str
encoder_embed_dim: int
encoder_depth: int
encoder_num_heads: int
encoder_global_attn_indices: tuple[int, int, int, int]
PRETRAINED_MODELS: dict[PretrainedSamSize, PretrainedModelConfig] = {
"ViT-H": PretrainedModelConfig(
url="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indices=(7, 15, 23, 31),
),
"ViT-L": PretrainedModelConfig(
url="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indices=(5, 11, 17, 23),
),
"ViT-B": PretrainedModelConfig(
url="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indices=(2, 5, 8, 11),
),
}
[docs]class MLPBlock(nn.Module):
def __init__(self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
[docs] def forward(self, x: Tensor) -> Tensor:
return self.lin2(self.act(self.lin1(x)))
[docs]class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
[docs] def forward(self, x: Tensor) -> Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
[docs]class LayerNormHigherEps(nn.LayerNorm):
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["eps"] = kwargs.pop("eps", 1e-6)
super().__init__(*args, **kwargs)
[docs]class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
window_size: int = 0,
global_attn_indexes: tuple[int, ...] = (),
) -> None:
"""Image encoder based on Vision Transformer.
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of input image channels.
embed_dim: Patch embedding dimension.
depth: Depth of ViT.
num_heads: Number of attention heads in each ViT block.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
out_chans: Number of output channels.
qkv_bias: If True, add a learnable bias to query, key, value.
norm_layer: Normalization layer.
act_layer: Activation layer.
use_abs_pos: If True, use absolute positional embeddings.
use_rel_pos: If True, add relative positional embeddings to the attention map.
window_size: Window size for window attention blocks.
global_attn_indexes: Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: nn.Parameter | None = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
[docs] def forward(self, x: Tensor) -> Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
[docs]class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
window_size: int = 0,
input_size: tuple[int, int] | None = None,
) -> None:
"""Transformer blocks, which support window attention and residual propagation.
Args:
dim: Number of input channels.
num_heads: Number of attention heads in each ViT block.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
norm_layer: Normalization layer.
act_layer: Activation layer.
use_rel_pos: If True, add relative positional embeddings to the
attention map.
window_size: Window size for window attention blocks. If it equals
0, then use global attention.
input_size: Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
[docs] def forward(self, x: Tensor) -> Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
hei, wid = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (hei, wid))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
[docs]class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
input_size: tuple[int, int] | None = None,
) -> None:
"""Multi-head attention block with relative position embeddings.
Args:
dim: Number of input channels.
num_heads: Number of attention heads.
qkv_bias: If True, add a learnable bias to query, key, value.
use_rel_pos: If True, add relative positional embeddings to the
attention map.
input_size: Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert input_size is not None, "Input size must be provided if using relative positional encoding."
# Initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
[docs] def forward(self, x: Tensor) -> Tensor:
bsz, hei, wid, _ = x.shape
qkv = self.qkv(x).reshape(bsz, hei * wid, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.reshape(3, bsz * self.num_heads, hei * wid, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (hei, wid), (hei, wid))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(bsz, self.num_heads, hei, wid, -1).permute(0, 2, 3, 1, 4).reshape(bsz, hei, wid, -1)
x = self.proj(x)
return x
[docs]def window_partition(x: Tensor, window_size: int) -> tuple[Tensor, tuple[int, int]]:
"""Partition into non-overlapping windows with padding if needed.
Args:
x: Input tokens with shape (B, H, W, C).
window_size: Window size.
Returns:
Windows after partition with shape (B * n_win, win_size, win_size, C),
and the shape.
"""
bsz, hei, wid, chans = x.shape
pad_h = (window_size - hei % window_size) % window_size
pad_w = (window_size - wid % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
hei_p, wid_p = hei + pad_h, wid + pad_w
x = x.view(bsz, hei_p // window_size, window_size, wid_p // window_size, window_size, chans)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, chans)
return windows, (hei_p, wid_p)
[docs]def window_unpartition(windows: Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]) -> Tensor:
"""Window unpartition into original sequences and removing padding.
Args:
windows: Input tokens with (B * n_win, win_size, win_size, C).
window_size: Window size.
pad_hw: Padded height and width (Hp, Wp).
hw: Original height and width (H, W) before padding.
Returns:
The unpartitioned sequences with shape (B, H, W, C).
"""
hei_p, wid_p = pad_hw
hei, wid = hw
bsz = windows.shape[0] // (hei_p * wid_p // window_size // window_size)
x = windows.view(bsz, hei_p // window_size, wid_p // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(bsz, hei_p, wid_p, -1)
if hei_p > hei or wid_p > wid:
x = x[:, :hei, :wid, :].contiguous()
return x
[docs]def get_rel_pos(q_size: int, k_size: int, rel_pos: Tensor) -> Tensor:
"""Get relative positional embeddings.
Args:
q_size: Size of query q.
k_size: Size of key k.
rel_pos: Relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
[docs]def add_decomposed_rel_pos(
attn: Tensor,
q: Tensor,
rel_pos_h: Tensor,
rel_pos_w: Tensor,
q_size: tuple[int, int],
k_size: tuple[int, int],
) -> Tensor:
"""Calculate decomposed Relative Positional Embeddings.
https://github.com/facebookresearch/mvit/mvit/models/attention.py
Args:
attn: Attention map.
q: Query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h: Relative position embeddings (Lh, C) for height axis.
rel_pos_w: Relative position embeddings (Lw, C) for width axis.
q_size: Spatial sequence size of query q with (q_h, q_w).
k_size: Spatial sequence size of key k with (k_h, k_w).
Returns:
Attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
rel_h = get_rel_pos(q_h, k_h, rel_pos_h)
rel_w = get_rel_pos(q_w, k_w, rel_pos_w)
bsz, _, dim = q.shape
r_q = q.reshape(bsz, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rel_h)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rel_w)
attn = attn.view(bsz, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn = attn.view(bsz, q_h * q_w, k_h * k_w)
return attn
[docs]class PatchEmbed(nn.Module):
def __init__(
self,
kernel_size: tuple[int, int] = (16, 16),
stride: tuple[int, int] = (16, 16),
padding: tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""Image to Patch Embedding.
Args:
kernel_size: Kernel size of the projection layer.
stride: Stride of the projection layer.
padding: Padding size of the projection layer.
in_chans: Number of input image channels.
embed_dim: Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
[docs] def forward(self, x: Tensor) -> Tensor:
x = self.proj(x)
x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
return x
[docs]class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""Predicts masks given an image and prompt embeddings.
Args:
transformer_dim: The channel dimension of the transformer
transformer: The transformer used to predict masks
num_multimask_outputs: The number of masks to predict when
disambiguating masks
activation: The type of activation to use when upscaling masks
iou_head_depth: The depth of the MLP used to predict mask quality
iou_head_hidden_dim: The hidden dimension of the MLP used to
predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList(
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)]
)
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
[docs] def forward(
self,
image_embeddings: Tensor,
image_pe: Tensor,
sparse_prompt_embeddings: Tensor,
dense_prompt_embeddings: Tensor,
multimask_output: bool,
) -> tuple[Tensor, Tensor]:
"""Predict masks given image and prompt embeddings.
Args:
image_embeddings: The embeddings from the image encoder
image_pe: Positional encoding with the shape of image_embeddings
sparse_prompt_embeddings: The embeddings of the points and boxes
dense_prompt_embeddings: The embeddings of the mask inputs
multimask_output: Whether to return multiple masks or a single
mask.
Returns:
The batched predicted masks and batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
[docs] def predict_masks(
self,
image_embeddings: Tensor,
image_pe: Tensor,
sparse_prompt_embeddings: Tensor,
dense_prompt_embeddings: Tensor,
) -> tuple[Tensor, Tensor]:
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: list[Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
[docs]class MLP(nn.Module):
__constants__ = ["num_layers", "sigmoid_output"]
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.sigmoid_output = sigmoid_output
[docs] def forward(self, x: Tensor) -> Tensor:
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x
[docs]class PromptEncoder(nn.Module):
__constants__ = ["embed_dim", "input_image_size", "input_image_size"]
def __init__(
self,
embed_dim: int,
image_embedding_size: tuple[int, int],
input_image_size: tuple[int, int],
mask_in_chans: int,
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""Encodes prompts for input to SAM's mask decoder.
Args:
embed_dim: The prompts' embedding dimension
image_embedding_size: The spatial size of the image embedding,
as (H, W).
input_image_size: The padded size of the image as input to the
image encoder, as (H, W).
mask_in_chans: The number of hidden channels used for encoding
input masks.
activation: The activation to use when encoding input masks.
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
[docs] def get_dense_pe(self) -> Tensor:
"""Returns the positional encoding used to encode point prompts.
The embedding is applied to a dense set of points the shape of the
image encoding.
Returns:
Positional encoding with shape (1, emb_dim, emb_h, emb_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(
self,
points: Tensor,
labels: Tensor,
pad: bool,
) -> Tensor:
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
def _embed_boxes(self, boxes: Tensor) -> Tensor:
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: Tensor) -> Tensor:
mask_embedding = self.mask_downscaling(masks)
return mask_embedding
def _get_batch_size(
self,
points: tuple[Tensor, Tensor] | None,
boxes: Tensor | None,
masks: Tensor | None,
) -> int:
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
[docs] def forward(
self,
points: tuple[Tensor, Tensor] | None,
boxes: Tensor | None,
masks: Tensor | None,
) -> tuple[Tensor, Tensor]:
"""Embeds different types of prompts.
This function returns both sparse and dense embeddings.
Args:
points: Point coordinates and labels to embed.
boxes: Boxes to embed
masks: Masks to embed
Returns:
Sparse embeddings for the points and boxes, with shape
(B, N, embed_dim), where N is determined by the number of
input points and boxes, and the dense embeddings for the masks,
with shape (B, emb_dim, emb_h, emb_w)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
[docs]class PositionEmbeddingRandom(nn.Module):
positional_encoding_gaussian_matrix: Tensor
def __init__(self, num_pos_feats: int = 64, scale: float | None = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
def _pe_encoding(self, coords: Tensor) -> Tensor:
# Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) # (d_1, ..., d_n, C)
[docs] def forward(self, size: tuple[int, int]) -> Tensor:
h, w = size
grid = self.positional_encoding_gaussian_matrix.new_ones((h, w))
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # (C, H, W)
[docs] def forward_with_coords(self, coords_input: Tensor, image_size: tuple[int, int]) -> Tensor:
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # (B, N, C)
[docs]class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""Defines a mutual cross attention block.
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to
sparse inputs.
Args:
embedding_dim: The channel dimension of the embeddings
num_heads: The number of heads in the attention layers
mlp_dim: The hidden dimension of the mlp block
activation: The activation of the mlp block
attention_downsample_rate: The downsample rate for the attention
blocks. The attention blocks will downsample the input image
by this factor.
skip_first_layer_pe: Skip the PE on the first layer
"""
super().__init__()
self.self_attn = TwoWayAttentionFunction(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = TwoWayAttentionFunction(
embedding_dim,
num_heads,
downsample_rate=attention_downsample_rate,
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = TwoWayAttentionFunction(
embedding_dim,
num_heads,
downsample_rate=attention_downsample_rate,
)
self.skip_first_layer_pe = skip_first_layer_pe
[docs] def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
[docs]class TwoWayAttentionFunction(nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # (B, N_heads, N_tokens, C_per_head)
def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # (B, N_tokens, C)
[docs] def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Input projections.
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads.
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Applies attention.
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # (B, N_heads, N_tokens, N_tokens)
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Gets the output.
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
[docs]class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = "RGB"
pixel_mean: Tensor
pixel_std: Tensor
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: tuple[float, float, float] = DEFAULT_PIXEL_MEAN,
pixel_std: tuple[float, float, float] = DEFAULT_PIXEL_STD,
) -> None:
"""SAM predicts object masks from an image and input prompts.
Args:
image_encoder: The backbone used to encode the image.
prompt_encoder: Encodes various types of input prompts.
mask_decoder: Predicts masks from the image embeddings
and encoded prompts.
pixel_mean: Mean values for normalizing pixels in the input image.
pixel_std: Std values for normalizing pixels in the input image.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> torch.device:
return self.pixel_mean.device
[docs] @torch.no_grad()
def forward(
self,
batched_input: list[dict[str, Any]],
multimask_output: bool,
) -> list[dict[str, Tensor]]:
"""Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Args:
batched_input: A list over input images, each a dictionary with
the following keys. A prompt key can be excluded if it is not
present. The 'image' key expects a tensor with shape (3, H, W),
already transformed for input to the model. The 'original_size'
key expects a tuple of (H, W), the original size of the image
before transformation. The 'point_coords' key expects a tensor
with shape (N, 2), the coordinates of N point prompts in the
image. The 'point_labels' key expects a tensor with shape (N,),
the labels of the N point prompts. The 'boxes' key expects a
tensor with shape (4,), the coordinates of a box prompt in
the image. The 'mask_inputs' key expects a tensor with shape
(1, H, W), the mask input to the model.
multimask_output: Whether the model should predict multiple
disambiguating masks, or return a single mask.
Returns:
A list over input images, where each element is as dictionary with
the following keys. The 'masks' key is the batched binary mask
predictions with shape (B, C, H, W), where B is the number of input
prompts, C is determined by multimask_output, and (H, W) is the
original size of the image. The 'iou_predictions' key is the
model's predictions of mask quality, with shape (B, C). The
'low_res_logits' key is the low resolution logits with shape
(B, C, 256, 256). This can be passed as mask input to subsequent
iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get("boxes", None),
masks=image_record.get("mask_inputs", None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record["image"].shape[-2:],
original_size=image_record["original_size"],
)
masks = masks > self.mask_threshold
outputs.append(
{
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
}
)
return outputs
[docs] def postprocess_masks(
self,
masks: Tensor,
input_size: tuple[int, ...],
original_size: tuple[int, ...],
) -> Tensor:
"""Removes padding and upscale masks to the original image size.
Args:
masks: Batched masks from the mask_decoder, with shape (B, C, H, W)
input_size: The size of the image input to the model, in (H, W)
format. Used to remove padding.
original_size: The original size of the image before resizing for
input to the model, in (H, W) format.
Returns:
Batched masks with shape (B, C, H, W), where (H, W) matches
the original size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
[docs] def preprocess(self, x: Tensor) -> Tensor:
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
[docs] def predictor(self) -> "SamPredictor":
return SamPredictor(self)
[docs]class ResizeLongestSide:
def __init__(self, target_length: int) -> None:
self.target_length = target_length
[docs] def apply_image(self, image: np.ndarray) -> np.ndarray:
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return np.array(resize(to_pil_image(image), target_size))
[docs] def apply_coords(self, coords: np.ndarray, original_size: tuple[int, ...]) -> np.ndarray:
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = copy.deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
[docs] def apply_boxes(self, boxes: np.ndarray, original_size: tuple[int, ...]) -> np.ndarray:
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
[docs] def apply_image_torch(self, image: Tensor) -> Tensor:
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
[docs] def apply_coords_torch(self, coords: Tensor, original_size: tuple[int, ...]) -> Tensor:
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = copy.deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
[docs] def apply_boxes_torch(self, boxes: Tensor, original_size: tuple[int, ...]) -> Tensor:
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
[docs] @staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> tuple[int, int]:
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
[docs]class SamPredictor:
def __init__(self, sam_model: Sam, *, device: base_device | None = None) -> None:
"""Provides an API to do repeated mask predictions on an image.
This predictor uses SAM to calculate the image embedding for an image,
and then allow repeated, efficient mask prediction given prompts.
Args:
sam_model: The model to use for mask prediction.
device: The device to use for prediction. If None, will use the
device returned by detect_device().
"""
super().__init__()
self.device = detect_device() if device is None else device
self.model = sam_model.eval()
self.device.module_to(self.model)
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
self.reset_image()
[docs] def set_image(self, image: np.ndarray, image_format: ImageFormat = "RGB") -> None:
"""Sets a given image for mask prediction.
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method.
Args:
image: The image for calculating masks. Expects an image in HWC
uint8 format, with pixel values in [0, 255].
image_format: The color format of the image, in ['RGB', 'BGR'].
"""
assert image_format in get_args(ImageFormat)
if image_format != self.model.image_format:
image = image[..., ::-1]
# Transform the image to the form expected by the model
input_image = self.transform.apply_image(image)
input_image_torch = self.device.tensor_to(torch.as_tensor(input_image))
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
self.set_torch_image(input_image_torch, image.shape[:2])
[docs] @torch.no_grad()
def set_torch_image(
self,
transformed_image: Tensor,
original_image_size: tuple[int, ...],
) -> None:
"""Sets a given image for mask prediction.
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method. Expects the input
image to be already transformed to the format expected by the model.
Args:
transformed_image: The input image, with shape (1, 3, H, W), which
has been transformed with ResizeLongestSide.
original_image_size: The size of the image before transformation,
in (H, W) format.
"""
assert (
len(transformed_image.shape) == 4
and transformed_image.shape[1] == 3
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
self.reset_image()
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
self.features = self.model.image_encoder(input_image)
self.is_image_set = True
[docs] def predict(
self,
point_coords: np.ndarray | None = None,
point_labels: np.ndarray | None = None,
box: np.ndarray | None = None,
mask_input: np.ndarray | None = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Predict masks for the given input prompts for the current image.
Args:
point_coords: A (N, 2) array of point prompts to the model. Each
point is in (X, Y) in pixels.
point_labels: A length N array of labels for the point prompts. 1
indicates a foreground point and 0 indicates a background point.
box: A length 4 array given a box prompt to the model, in
XYXY format.
mask_input: A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form
(1, H, W), where for SAM, H=W=256.
multimask_output: If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will
often produce better masks than a single prediction. If only a
single mask is needed, the model's predicted quality score can
be used to select the best mask. For non-ambiguous prompts, such
as multiple input prompts, multimask_output=False can give
better results.
return_logits: If true, returns un-thresholded masks logits instead
of a binary mask.
Returns:
The output masks with shape (C, H, W), where C is the number of
masks, and (H, W) is the original image size; an array of length C
containing the model's predictions for the quality of each mask;
and an array of shape (C, H, W), where C is the number of masks and
H=W=256. These low resolution logits can be passed to a subsequent
iteration as mask input.
Raises:
RuntimeError: If an image has not been set yet.
"""
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
# Transform input prompts
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
if point_coords is not None:
assert point_labels is not None, "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = self.device.tensor_to(torch.as_tensor(point_coords, dtype=torch.float))
labels_torch = self.device.tensor_to(torch.as_tensor(point_labels, dtype=torch.int))
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = self.device.tensor_to(torch.as_tensor(box, dtype=torch.float))
box_torch = box_torch[None, :]
if mask_input is not None:
mask_input_torch = self.device.tensor_to(torch.as_tensor(mask_input, dtype=torch.float))
mask_input_torch = mask_input_torch[None, :, :, :]
masks, iou_predictions, low_res_masks = self.predict_torch(
coords_torch,
labels_torch,
box_torch,
mask_input_torch,
multimask_output,
return_logits=return_logits,
)
masks_np = masks[0].detach().cpu().numpy()
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
return masks_np, iou_predictions_np, low_res_masks_np
[docs] @torch.no_grad()
def predict_torch(
self,
point_coords: Tensor | None,
point_labels: Tensor | None,
boxes: Tensor | None = None,
mask_input: Tensor | None = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> tuple[Tensor, Tensor, Tensor]:
"""Predicts masks for the given input prompts.
Predicts masks for the given input prompts, using the currently set
image. Input prompts are batched Tensors and are expected to
already be transformed to the input frame using ResizeLongestSide.
Args:
point_coords: A (B, N, 2) array of point prompts to the
model. Each point is in (X, Y) in pixels.
point_labels: A (B, N) array of labels for the point prompts. 1
indicates a foreground point and 0 indicates a background
point.
boxes: A (B, 4) array given a box prompt to the model, in
XYXY format.
mask_input: A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form
(B, 1, H, W), where for SAM, H=W=256. Masks returned by a
previous iteration of the predict method do not need further
transformation.
multimask_output: If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will
often produce better masks than a single prediction. If only a
single mask is needed, the model's predicted quality score can
be used to select the best mask. For non-ambiguous prompts,
such as multiple input prompts, multimask_output=False can give
better results.
return_logits: If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
The output masks with shape (C, H, W), where C is the number of
masks, and (H, W) is the original image size; an array of length C
containing the model's predictions for the quality of each mask;
and an array of shape (C, H, W), where C is the number of masks and
H=W=256. These low resolution logits can be passed to a subsequent
iteration as mask input.
Raises:
RuntimeError: If an image has not been set yet.
"""
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
if point_coords is not None:
points = (point_coords, point_labels)
else:
points = None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=boxes,
masks=mask_input,
)
# Predict masks
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
[docs] def get_image_embedding(self) -> Tensor:
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) to generate an embedding.")
assert self.features is not None, "Features must exist if an image has been set."
return self.features
[docs] def reset_image(self) -> None:
self.is_image_set = False
self.features = None
self.orig_h = None
self.orig_w = None
self.input_h = None
self.input_w = None
[docs]def get_pretrained_path(key: PretrainedSamSize) -> Path:
if key not in PRETRAINED_MODELS:
raise KeyError(f"Invalid CLIP model key {key}; choices are {list(PRETRAINED_MODELS.keys())}")
model_url = PRETRAINED_MODELS[key].url
with Timer("downloading checkpoint"):
filepath = ensure_downloaded(model_url, "sam", f"{key}_ckpt.pt")
return filepath
[docs]def pretrained_sam(key: PretrainedSamSize, *, skip_weights: bool = False) -> Sam:
config = PRETRAINED_MODELS[key]
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
model = Sam(
image_encoder=ImageEncoderViT(
depth=config.encoder_depth,
embed_dim=config.encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=LayerNormHigherEps,
num_heads=config.encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=config.encoder_global_attn_indices,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=DEFAULT_PIXEL_MEAN,
pixel_std=DEFAULT_PIXEL_STD,
)
if not skip_weights:
with open(get_pretrained_path(key), "rb") as f:
state_dict = torch.load(f)
model.load_state_dict(state_dict)
return model
[docs]def test_pretrained_model() -> None:
parser = argparse.ArgumentParser(description="Tests a pretrained SAM model")
parser.add_argument("key", type=str, choices=get_args(PretrainedSamSize))
parser.add_argument("image_path", type=str, nargs="?", default=None)
parser.add_argument("-o", "--output-path", type=str, default=None)
args = parser.parse_args()
configure_logging()
# Gets an image of a peach from Wikipedia.
if args.image_path is None:
peach_url = "https://upload.wikimedia.org/wikipedia/commons/9/9e/Autumn_Red_peaches.jpg"
img_path = ensure_downloaded(peach_url, "peach.jpg", is_tmp=True)
else:
img_path = Path(args.image_path)
model = pretrained_sam(cast(PretrainedSamSize, args.key))
predictor = model.predictor()
peach_img = PIL.Image.open(img_path)
predictor.set_image(np.array(peach_img))
predictions, _, _ = predictor.predict()
single_mask = predictions[0]
mask = PIL.Image.fromarray(single_mask.astype(np.uint8) * 255)
# Overlays the mask on the original image.
mask.putalpha(128)
peach_img.paste(mask, (0, 0), mask)
if args.output_path is None:
peach_img.show()
else:
peach_img.save(args.output_path)
if __name__ == "__main__":
# python -m pretrained.sam
test_pretrained_model()