pretrained.sam
Defines a simple API for using Meta’s pretrained Segment Anything model.
from pretrained.sam import pretrained_sam
model = pretrained_sam("ViT-B")
predictor = model.predictor()
image = PIL.Image.open(img_path)
predictor.set_image(np.array(image))
predictions, _, _ = predictor.predict()
single_mask = predictions[0] # Same shape as the original image.
Alternatively, you can run the script directly on an image:
python -m pretrained.sam ViT-B /path/to/image.jpg
The choices for the model key are:
ViT-H
: ViT with 32 layers and 16 attention heads.ViT-L
: ViT with 24 layers and 16 attention heads.ViT-B
: ViT with 12 layers and 12 attention heads.
- class pretrained.sam.PretrainedModelConfig(url: str, encoder_embed_dim: int, encoder_depth: int, encoder_num_heads: int, encoder_global_attn_indices: tuple[int, int, int, int])[source]
Bases:
object
- url: str
- encoder_embed_dim: int
- encoder_depth: int
- encoder_num_heads: int
- encoder_global_attn_indices: tuple[int, int, int, int]
- class pretrained.sam.MLPBlock(embedding_dim: int, mlp_dim: int, act: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.LayerNorm2d(num_channels: int, eps: float = 1e-06)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.LayerNormHigherEps(*args: Any, **kwargs: Any)[source]
Bases:
LayerNorm
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class pretrained.sam.ImageEncoderViT(img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.normalization.LayerNorm'>, act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>, use_abs_pos: bool = True, use_rel_pos: bool = False, window_size: int = 0, global_attn_indexes: tuple[int, ...] = ())[source]
Bases:
Module
Image encoder based on Vision Transformer.
- Parameters:
img_size – Input image size.
patch_size – Patch size.
in_chans – Number of input image channels.
embed_dim – Patch embedding dimension.
depth – Depth of ViT.
num_heads – Number of attention heads in each ViT block.
mlp_ratio – Ratio of mlp hidden dim to embedding dim.
out_chans – Number of output channels.
qkv_bias – If True, add a learnable bias to query, key, value.
norm_layer – Normalization layer.
act_layer – Activation layer.
use_abs_pos – If True, use absolute positional embeddings.
use_rel_pos – If True, add relative positional embeddings to the attention map.
window_size – Window size for window attention blocks.
global_attn_indexes – Indexes for blocks using global attention.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.Block(dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.normalization.LayerNorm'>, act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>, use_rel_pos: bool = False, window_size: int = 0, input_size: tuple[int, int] | None = None)[source]
Bases:
Module
Transformer blocks, which support window attention and residual propagation.
- Parameters:
dim – Number of input channels.
num_heads – Number of attention heads in each ViT block.
mlp_ratio – Ratio of mlp hidden dim to embedding dim.
qkv_bias – If True, add a learnable bias to query, key, value.
norm_layer – Normalization layer.
act_layer – Activation layer.
use_rel_pos – If True, add relative positional embeddings to the attention map.
window_size – Window size for window attention blocks. If it equals 0, then use global attention.
input_size – Input resolution for calculating the relative positional parameter size.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.Attention(dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, input_size: tuple[int, int] | None = None)[source]
Bases:
Module
Multi-head attention block with relative position embeddings.
- Parameters:
dim – Number of input channels.
num_heads – Number of attention heads.
qkv_bias – If True, add a learnable bias to query, key, value.
use_rel_pos – If True, add relative positional embeddings to the attention map.
input_size – Input resolution for calculating the relative positional parameter size.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- pretrained.sam.window_partition(x: Tensor, window_size: int) tuple[torch.Tensor, tuple[int, int]] [source]
Partition into non-overlapping windows with padding if needed.
- Parameters:
x – Input tokens with shape (B, H, W, C).
window_size – Window size.
- Returns:
Windows after partition with shape (B * n_win, win_size, win_size, C), and the shape.
- pretrained.sam.window_unpartition(windows: Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]) Tensor [source]
Window unpartition into original sequences and removing padding.
- Parameters:
windows – Input tokens with (B * n_win, win_size, win_size, C).
window_size – Window size.
pad_hw – Padded height and width (Hp, Wp).
hw – Original height and width (H, W) before padding.
- Returns:
The unpartitioned sequences with shape (B, H, W, C).
- pretrained.sam.get_rel_pos(q_size: int, k_size: int, rel_pos: Tensor) Tensor [source]
Get relative positional embeddings.
- Parameters:
q_size – Size of query q.
k_size – Size of key k.
rel_pos – Relative position embeddings (L, C).
- Returns:
Extracted positional embeddings according to relative positions.
- pretrained.sam.add_decomposed_rel_pos(attn: Tensor, q: Tensor, rel_pos_h: Tensor, rel_pos_w: Tensor, q_size: tuple[int, int], k_size: tuple[int, int]) Tensor [source]
Calculate decomposed Relative Positional Embeddings.
https://github.com/facebookresearch/mvit/mvit/models/attention.py
- Parameters:
attn – Attention map.
q – Query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h – Relative position embeddings (Lh, C) for height axis.
rel_pos_w – Relative position embeddings (Lw, C) for width axis.
q_size – Spatial sequence size of query q with (q_h, q_w).
k_size – Spatial sequence size of key k with (k_h, k_w).
- Returns:
Attention map with added relative positional embeddings.
- class pretrained.sam.PatchEmbed(kernel_size: tuple[int, int] = (16, 16), stride: tuple[int, int] = (16, 16), padding: tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768)[source]
Bases:
Module
Image to Patch Embedding.
- Parameters:
kernel_size – Kernel size of the projection layer.
stride – Stride of the projection layer.
padding – Padding size of the projection layer.
in_chans – Number of input image channels.
embed_dim – Patch embedding dimension.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.MaskDecoder(*, transformer_dim: int, transformer: ~torch.nn.modules.module.Module, num_multimask_outputs: int = 3, activation: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>, iou_head_depth: int = 3, iou_head_hidden_dim: int = 256)[source]
Bases:
Module
Predicts masks given an image and prompt embeddings.
- Parameters:
transformer_dim – The channel dimension of the transformer
transformer – The transformer used to predict masks
num_multimask_outputs – The number of masks to predict when disambiguating masks
activation – The type of activation to use when upscaling masks
iou_head_depth – The depth of the MLP used to predict mask quality
iou_head_hidden_dim – The hidden dimension of the MLP used to predict mask quality
- forward(image_embeddings: Tensor, image_pe: Tensor, sparse_prompt_embeddings: Tensor, dense_prompt_embeddings: Tensor, multimask_output: bool) tuple[torch.Tensor, torch.Tensor] [source]
Predict masks given image and prompt embeddings.
- Parameters:
image_embeddings – The embeddings from the image encoder
image_pe – Positional encoding with the shape of image_embeddings
sparse_prompt_embeddings – The embeddings of the points and boxes
dense_prompt_embeddings – The embeddings of the mask inputs
multimask_output – Whether to return multiple masks or a single mask.
- Returns:
The batched predicted masks and batched predictions of mask quality
- class pretrained.sam.MLP(input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.PromptEncoder(embed_dim: int, image_embedding_size: tuple[int, int], input_image_size: tuple[int, int], mask_in_chans: int, activation: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>)[source]
Bases:
Module
Encodes prompts for input to SAM’s mask decoder.
- Parameters:
embed_dim – The prompts’ embedding dimension
image_embedding_size – The spatial size of the image embedding, as (H, W).
input_image_size – The padded size of the image as input to the image encoder, as (H, W).
mask_in_chans – The number of hidden channels used for encoding input masks.
activation – The activation to use when encoding input masks.
- get_dense_pe() Tensor [source]
Returns the positional encoding used to encode point prompts.
The embedding is applied to a dense set of points the shape of the image encoding.
- Returns:
Positional encoding with shape (1, emb_dim, emb_h, emb_w)
- forward(points: tuple[torch.Tensor, torch.Tensor] | None, boxes: Tensor | None, masks: Tensor | None) tuple[torch.Tensor, torch.Tensor] [source]
Embeds different types of prompts.
This function returns both sparse and dense embeddings.
- Parameters:
points – Point coordinates and labels to embed.
boxes – Boxes to embed
masks – Masks to embed
- Returns:
Sparse embeddings for the points and boxes, with shape (B, N, embed_dim), where N is determined by the number of input points and boxes, and the dense embeddings for the masks, with shape (B, emb_dim, emb_h, emb_w)
- class pretrained.sam.PositionEmbeddingRandom(num_pos_feats: int = 64, scale: float | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- positional_encoding_gaussian_matrix: Tensor
- forward(size: tuple[int, int]) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.TwoWayTransformer(depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, activation: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, attention_downsample_rate: int = 2)[source]
Bases:
Module
Transformer which does cross-attention in two directions.
This transformer decoder attends to an input image using queries whose positional embedding is supplied.
- Parameters:
depth – Number of layers in the transformer
embedding_dim – The channel dimension for the input embeddings
num_heads – The number of heads for multihead attention. Must divide embedding_dim
mlp_dim – The channel dimension internal to the MLP block
activation – The activation to use in the MLP block
attention_downsample_rate – The downsample rate for the attention blocks. The attention blocks will downsample the input image by this factor.
- forward(image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor) tuple[torch.Tensor, torch.Tensor] [source]
Runs the transformer forward pass.
- Parameters:
image_embedding – Image to attend to. Should be shape (B, embedding_dim, h, w).
image_pe – The positional encoding to add to the image. Must have the same shape as image_embedding.
point_embedding – The embedding to add to the query points. Must have shape (B, N_points, embedding_dim).
- Returns:
The processed point and image embeddings.
- class pretrained.sam.TwoWayAttentionBlock(embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False)[source]
Bases:
Module
Defines a mutual cross attention block.
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs.
- Parameters:
embedding_dim – The channel dimension of the embeddings
num_heads – The number of heads in the attention layers
mlp_dim – The hidden dimension of the mlp block
activation – The activation of the mlp block
attention_downsample_rate – The downsample rate for the attention blocks. The attention blocks will downsample the input image by this factor.
skip_first_layer_pe – Skip the PE on the first layer
- forward(queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) tuple[torch.Tensor, torch.Tensor] [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.TwoWayAttentionFunction(embedding_dim: int, num_heads: int, downsample_rate: int = 1)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(q: Tensor, k: Tensor, v: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.sam.Sam(image_encoder: ImageEncoderViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: tuple[float, float, float] = (123.675, 116.28, 103.53), pixel_std: tuple[float, float, float] = (58.395, 57.12, 57.375))[source]
Bases:
Module
SAM predicts object masks from an image and input prompts.
- Parameters:
image_encoder – The backbone used to encode the image.
prompt_encoder – Encodes various types of input prompts.
mask_decoder – Predicts masks from the image embeddings and encoded prompts.
pixel_mean – Mean values for normalizing pixels in the input image.
pixel_std – Std values for normalizing pixels in the input image.
- mask_threshold: float = 0.0
- image_format: str = 'RGB'
- pixel_mean: Tensor
- pixel_std: Tensor
- property device: device
- forward(batched_input: list[dict[str, Any]], multimask_output: bool) list[dict[str, torch.Tensor]] [source]
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is recommended over calling the model directly.
- Parameters:
batched_input – A list over input images, each a dictionary with the following keys. A prompt key can be excluded if it is not present. The ‘image’ key expects a tensor with shape (3, H, W), already transformed for input to the model. The ‘original_size’ key expects a tuple of (H, W), the original size of the image before transformation. The ‘point_coords’ key expects a tensor with shape (N, 2), the coordinates of N point prompts in the image. The ‘point_labels’ key expects a tensor with shape (N,), the labels of the N point prompts. The ‘boxes’ key expects a tensor with shape (4,), the coordinates of a box prompt in the image. The ‘mask_inputs’ key expects a tensor with shape (1, H, W), the mask input to the model.
multimask_output – Whether the model should predict multiple disambiguating masks, or return a single mask.
- Returns:
A list over input images, where each element is as dictionary with the following keys. The ‘masks’ key is the batched binary mask predictions with shape (B, C, H, W), where B is the number of input prompts, C is determined by multimask_output, and (H, W) is the original size of the image. The ‘iou_predictions’ key is the model’s predictions of mask quality, with shape (B, C). The ‘low_res_logits’ key is the low resolution logits with shape (B, C, 256, 256). This can be passed as mask input to subsequent iterations of prediction.
- postprocess_masks(masks: Tensor, input_size: tuple[int, ...], original_size: tuple[int, ...]) Tensor [source]
Removes padding and upscale masks to the original image size.
- Parameters:
masks – Batched masks from the mask_decoder, with shape (B, C, H, W)
input_size – The size of the image input to the model, in (H, W) format. Used to remove padding.
original_size – The original size of the image before resizing for input to the model, in (H, W) format.
- Returns:
Batched masks with shape (B, C, H, W), where (H, W) matches the original size.
- predictor() SamPredictor [source]
- class pretrained.sam.SamPredictor(sam_model: Sam, *, device: base_device | None = None)[source]
Bases:
object
Provides an API to do repeated mask predictions on an image.
This predictor uses SAM to calculate the image embedding for an image, and then allow repeated, efficient mask prediction given prompts.
- Parameters:
sam_model – The model to use for mask prediction.
device – The device to use for prediction. If None, will use the device returned by detect_device().
- set_image(image: ndarray, image_format: Literal['RGB', 'BGR'] = 'RGB') None [source]
Sets a given image for mask prediction.
Calculates the image embeddings for the provided image, allowing masks to be predicted with the ‘predict’ method.
- Parameters:
image – The image for calculating masks. Expects an image in HWC uint8 format, with pixel values in [0, 255].
image_format – The color format of the image, in [‘RGB’, ‘BGR’].
- set_torch_image(transformed_image: Tensor, original_image_size: tuple[int, ...]) None [source]
Sets a given image for mask prediction.
Calculates the image embeddings for the provided image, allowing masks to be predicted with the ‘predict’ method. Expects the input image to be already transformed to the format expected by the model.
- Parameters:
transformed_image – The input image, with shape (1, 3, H, W), which has been transformed with ResizeLongestSide.
original_image_size – The size of the image before transformation, in (H, W) format.
- predict(point_coords: ndarray | None = None, point_labels: ndarray | None = None, box: ndarray | None = None, mask_input: ndarray | None = None, multimask_output: bool = True, return_logits: bool = False) tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray] [source]
Predict masks for the given input prompts for the current image.
- Parameters:
point_coords – A (N, 2) array of point prompts to the model. Each point is in (X, Y) in pixels.
point_labels – A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.
box – A length 4 array given a box prompt to the model, in XYXY format.
mask_input – A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form (1, H, W), where for SAM, H=W=256.
multimask_output – If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model’s predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.
return_logits – If true, returns un-thresholded masks logits instead of a binary mask.
- Returns:
The output masks with shape (C, H, W), where C is the number of masks, and (H, W) is the original image size; an array of length C containing the model’s predictions for the quality of each mask; and an array of shape (C, H, W), where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input.
- Raises:
RuntimeError – If an image has not been set yet.
- predict_torch(point_coords: Tensor | None, point_labels: Tensor | None, boxes: Tensor | None = None, mask_input: Tensor | None = None, multimask_output: bool = True, return_logits: bool = False) tuple[torch.Tensor, torch.Tensor, torch.Tensor] [source]
Predicts masks for the given input prompts.
Predicts masks for the given input prompts, using the currently set image. Input prompts are batched Tensors and are expected to already be transformed to the input frame using ResizeLongestSide.
- Parameters:
point_coords – A (B, N, 2) array of point prompts to the model. Each point is in (X, Y) in pixels.
point_labels – A (B, N) array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.
boxes – A (B, 4) array given a box prompt to the model, in XYXY format.
mask_input – A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form (B, 1, H, W), where for SAM, H=W=256. Masks returned by a previous iteration of the predict method do not need further transformation.
multimask_output – If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model’s predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.
return_logits – If true, returns un-thresholded masks logits instead of a binary mask.
- Returns:
The output masks with shape (C, H, W), where C is the number of masks, and (H, W) is the original image size; an array of length C containing the model’s predictions for the quality of each mask; and an array of shape (C, H, W), where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input.
- Raises:
RuntimeError – If an image has not been set yet.