pretrained.clip
Defines a simple API for using OpenAI’s pretrained CLIP model.
from pretrained.clip import pretrained_clip
full_model = pretrained_clip("RN50", mode="all")
visual_model = pretrained_clip("RN50", mode="visual")
linguistic_model = pretrained_clip("RN50", mode="linguistic")
image = PIL.Image.open(image_path)
image_tensorizer = visual_model.get_preprocess()
image_tensor = image_tensorizer(image) # (3, 224, 224)
tokenizer = linguistic_model.get_tokenizer()
token_tensor = tokenizer.tokenizer(["A photo of a cat", "A photo of a dog"])
visual_model.encode_image(imgs) # (N, C)
linguistic_model.encode_text(token_tensor) # (N, C)
The choices for the model key are:
RN50
: ResNet50 + TransformerRN101
: ResNet101 + TransformerRN50x4
: 4x ResNet50 + TransformerRN50x16
: 16x ResNet50 + TransformerRN50x64
: 64x ResNet50 + TransformerViT_B_32
: ViT-B/32 + TransformerViT_B_16
: ViT-B/16 + TransformerViT_L_14
: ViT-L/14 + TransformerViT_L_14_336px
: ViT-L/14 + Transformer (336px)
- pretrained.clip.cast_pretrained_clip_key(s: str) Literal['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT_B_32', 'ViT_B_16', 'ViT_L_14', 'ViT_L_14_336px'] [source]
- pretrained.clip.bytes_to_unicode() dict[int, str] [source]
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you’re at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the BPE code barfs on.
- Returns:
Mapping from UTF-8 byte to unicode string.
- class pretrained.clip.Bottleneck(inplanes: int, planes: int, stride: int = 1, *, device: device | None = None, dtype: dtype | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- expansion = 4
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.clip.AttentionPool2d(spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int | None = None, *, device: device | None = None, dtype: dtype | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.clip.ModifiedResNet(layers: tuple[int, int, int, int], output_dim: int, heads: int, input_resolution: int = 224, width: int = 64, *, device: device | None = None, dtype: dtype | None = None)[source]
Bases:
Module
ResNet class that is similar to TorchVision’s but with some changes.
There are now 3 “stem” convolutions as opposed to 1, with an average pool instead of a max pool.
Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
The final pooling layer is a QKV attention instead of an average pool
- Parameters:
layers – Layer counts for the four parts of the ResNet
output_dim – Number of final output dimensions
heads – Number of attention heads
input_resolution – Number of pixels in width and height directions
width – Hidden channel count
device – Default PyTorch device to use
dtype – Default PyTorch dtype to use
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.clip.QuickGELU(*args, **kwargs)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.clip.ResidualAttentionBlock(d_model: int, n_head: int, attn_mask: Tensor | None = None, *, device: device | None = None, dtype: dtype | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.clip.Transformer(width: int, layers: int, heads: int, attn_mask: Tensor | None = None, *, device: device | None = None, dtype: dtype | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.clip.VisionTransformer(input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, *, device: device | None = None, dtype: dtype | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.clip.TextModel(embed_dim: int, context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int, transformer_layers: int, *, device: device | None = None, dtype: dtype | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- get_tokenizer() ClipTokenizer [source]
- forward(text: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class pretrained.clip.Clip(embed_dim: int, image_resolution: int, vision_layers: tuple[int, int, int, int] | int, vision_width: int, vision_patch_size: int, context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int, transformer_layers: int, *, device: device | None = None, dtype: dtype | None = None)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(image: Tensor, text: Tensor) tuple[torch.Tensor, torch.Tensor] [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- predictor(*, device: base_device | None = None) ClipPredictor [source]
- class pretrained.clip.ClipPredictor(clip_model: Clip, *, device: base_device | None = None)[source]
Bases:
object
Provides an API for doing predictions with a CLIP model.
Note that this module is not an nn.Module, so you can use it in your module without worrying about storing all the weights on accident.
- Parameters:
clip_model – The CLIP model to use for predictions
device – The device to use for predictions. If None, will use the device returned by detect_device().
- pretrained.clip.convert_weights(model: Module) None [source]
Convert applicable model parameters to fp16.
- Parameters:
model – The model to convert
- pretrained.clip.pretrained_clip(key: Literal['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT_B_32', 'ViT_B_16', 'ViT_L_14', 'ViT_L_14_336px'] | Module, mode: Literal['visual'], *, device: device | None = None, dtype: dtype | None = None) ModifiedResNet | VisionTransformer [source]
- pretrained.clip.pretrained_clip(key: Literal['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT_B_32', 'ViT_B_16', 'ViT_L_14', 'ViT_L_14_336px'] | Module, mode: Literal['linguistic'], *, device: device | None = None, dtype: dtype | None = None) TextModel
- pretrained.clip.pretrained_clip(key: Literal['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT_B_32', 'ViT_B_16', 'ViT_L_14', 'ViT_L_14_336px'] | Module, mode: Literal['all'], *, device: device | None = None, dtype: dtype | None = None) Clip
Builds the CLIP model from a state dictionary.
- Parameters:
key – The model key to load, or another model to load weights from
mode – Default is to return all models, but can optionally return just the visual or linguistic part of the model
device – The device for the model
dtype – The dtype for the model
- Returns:
The constructed clip model, or just the visual or text branch