"""
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""
import torch
from segment_anything import sam_model_registry
from torch import nn
import torch.nn.functional as F
from pathlib import Path
torch.backends.cuda.matmul.allow_tf32 = True
import logging
vit_logger = logging.getLogger(__name__)
try:
from dinov3.hub.backbones import dinov3_vitl16, dinov3_vitb16
except:
vit_logger.warning("Could not import CPDINO, run `pip install git+https://github.com/facebookresearch/dinov3` to use CPDINO model")
[docs]
class BaseModel(nn.Module):
def __init__(self, dtype=torch.float32):
super().__init__()
self._dtype = dtype
# average diameter of ROIs from training images from fine-tuning
self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
# average diameter of ROIs during main training
self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
def load_model(self, PATH, device, strict = False):
state_dict = torch.load(PATH, map_location = device, weights_only=True)
keys = [k for k in state_dict.keys()]
# loudly fail on attempt to load not cp4 model:
w2_data = state_dict.get('W2', None)
if w2_data == None:
raise ValueError('This model does not appear to be a CP4 model. CP3 models are not compatible with CP4.')
# models are always saved as float32
if keys[0][:7] == "module.":
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
new_state_dict[name] = v
self.load_state_dict(new_state_dict, strict = strict)
else:
self.load_state_dict(state_dict, strict = strict)
if self.dtype != torch.float32:
self = self.to(self.dtype)
@property
def dtype(self):
"""
Get the data type of the model.
Returns:
torch.dtype: The data type of the model.
"""
return self._dtype
@dtype.setter
def dtype(self, value):
"""
Set the data type of the model.
Args:
value (torch.dtype): The data type to set for the model.
"""
if self._dtype != value:
self.to(value)
self._dtype = value
@property
def device(self):
"""
Get the device of the model.
Returns:
torch.device: The device of the model.
"""
return next(self.parameters()).device
[docs]
def save_model(self, filename):
"""
Save the model to a file.
Args:
filename (str): The path to the file where the model will be saved.
"""
torch.save(self.state_dict(), filename)
[docs]
class CPSAM(BaseModel):
def __init__(self, ps=8, nout=3, bsize=256, rdrop=0.4,
dtype=torch.float32):
super().__init__(dtype=dtype)
self.rdrop = rdrop
self.backbone = "sam_vitl"
self.encoder = sam_model_registry["vit_l"](None).image_encoder.to(self._dtype)
w = self.encoder.patch_embed.proj.weight.detach()
nchan = w.shape[0]
# change token size to ps x ps
self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
# adjust position embeddings for new bsize and new token size
ds = (1024 // 16) // (bsize // ps)
self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
# set attention to global in every layer
for blk in self.encoder.blocks:
blk.window_size = 0
feat_dim = self.encoder.neck[-2].weight.shape[0]
# readout weights for nout output channels
self.out = nn.Conv2d(feat_dim, nout * ps**2, kernel_size=1)
# W2 reshapes token space to pixel space, not trainable
self.W2 = nn.Parameter(torch.eye(nout * ps**2).reshape(nout*ps**2, nout, ps, ps),
requires_grad=False)
self.nout = nout
self.ps = ps
def forward(self, x):
x = F.conv2d(x, self.encoder.patch_embed.proj.weight.data[:, :x.shape[1]],
bias=self.encoder.patch_embed.proj.bias.data, stride=self.ps)
x = x.permute(0, 2, 3, 1)
if self.encoder.pos_embed is not None:
x = x + self.encoder.pos_embed
if self.training and self.rdrop > 0:
nlay = len(self.encoder.blocks)
rdrop = (torch.rand((len(x), nlay), device=x.device) <
torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
for i, blk in enumerate(self.encoder.blocks):
mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = x * mask + blk(x) * (1-mask)
else:
for blk in self.encoder.blocks:
x = blk(x)
x = self.encoder.neck(x.permute(0, 3, 1, 2))
x1 = self.out(x)
x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
# maintain the second output of feature size 256 for backwards compatibility
return x1, torch.zeros((x.shape[0], 256), device=x.device)
def load_pretrained(self, root=None):
root = Path(__file__).parent / "models" if root is None else root
self.encoder = sam_model_registry["vit_l"](root / "sam_vit_l_0b3195.pth").image_encoder.to(self.device, dtype=self.dtype)
[docs]
class CPDINO(BaseModel):
def __init__(self, model_name="vitl", ps=8, nout=3, bsize=256, rdrop=0.4,
dtype=torch.float32):
super().__init__(dtype=dtype)
self.rdrop = rdrop
self.backbone = "dino_" + model_name
self.model_name = model_name
vit_model = dinov3_vitl16 if self.model_name == 'vitl' else dinov3_vitb16
self.encoder = vit_model(pretrained=False).to(self._dtype)
# decrease stride and add padding to increase resolution
self.encoder.patch_embed.proj.stride = (ps, ps)
self.encoder.patch_embed.proj.padding = (ps//2, ps//2) if ps < 16 else (0,0)
feat_dim = self.encoder.patch_embed.proj.weight.shape[0]
# readout weights for nout output channels
self.out = nn.Linear(feat_dim, nout * ps**2, dtype=self._dtype)
# W2 reshapes token space to pixel space, not trainable
self.W2 = nn.Parameter(torch.eye(nout * ps**2, dtype=self._dtype).reshape(nout*ps**2, nout, ps, ps), requires_grad=False)
self.nout = nout
self.ps = ps
def load_pretrained(self, root=None):
root = Path(__file__).parent / 'models' if root is None else root
if self.model_name == 'vitl':
self.encoder.load_state_dict(torch.load(root / 'dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth', map_location=self.device), strict=True)
else:
self.encoder.load_state_dict(torch.load(root / 'dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth', map_location=self.device), strict=True)
def forward(self, x):
x = F.conv2d(x, self.encoder.patch_embed.proj.weight.data[:, :x.shape[1]],
bias=self.encoder.patch_embed.proj.bias.data,
stride=self.encoder.patch_embed.proj.stride,
padding=self.encoder.patch_embed.proj.padding)
hw_tuple = (x.shape[-2], x.shape[-1])
x = x.flatten(2).transpose(1, 2)
x = self.encoder.patch_embed.norm(x)
B = x.shape[0]
x = torch.cat([self.encoder.cls_token.expand(B, -1, -1),
self.encoder.storage_tokens.expand(B, -1, -1), x], axis=1)
if self.training and self.rdrop > 0:
nlay = len(self.encoder.blocks)
rdrop = (torch.rand((len(x), nlay), device=x.device) <
torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
for i, blk in enumerate(self.encoder.blocks):
mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1)
rope_sincos = self.encoder.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
x = x * mask + blk(x, rope_sincos) * (1 - mask)
else:
for _, blk in enumerate(self.encoder.blocks):
rope_sincos = self.encoder.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
x = blk(x, rope_sincos)
x = self.encoder.norm(x)
x = x[:,5:] # remove cls and storage tokens
x1 = self.out(x)
x1 = x1.reshape((x.shape[0], hw_tuple[0], hw_tuple[1], self.nout * self.ps**2))
x1 = x1.permute(0, 3, 1, 2)
x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
# maintain the second output of feature size 256 for backwards compatibility
return x1, torch.randn((x.shape[0], 256), device=x.device, dtype=self._dtype)
class CPnetBioImageIO(CPSAM):
"""
A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
"""
def forward(self, x):
"""
Perform a forward pass of the CPnet model and return unpacked tensors.
Args:
x (torch.Tensor): Input tensor.
Returns:
tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
"""
output_tensor, style_tensor, downsampled_tensors = super().forward(x)
return output_tensor, style_tensor, *downsampled_tensors
def load_model(self, filename, device=None):
"""
Load the model from a file.
Args:
filename (str): The path to the file where the model is saved.
device (torch.device, optional): The device to load the model on. Defaults to None.
"""
if (device is not None) and (device.type != "cpu"):
state_dict = torch.load(filename, map_location=device, weights_only=True)
else:
self.__init__(self.nout)
state_dict = torch.load(filename, map_location=torch.device("cpu"),
weights_only=True)
self.load_state_dict(state_dict)
def load_state_dict(self, state_dict):
"""
Load the state dictionary into the model.
This method overrides the default `load_state_dict` to handle Cellpose's custom
loading mechanism and ensures compatibility with BioImage.IO Core.
Args:
state_dict (Mapping[str, Any]): A state dictionary to load into the model
"""
if state_dict["output.2.weight"].shape[0] != self.nout:
for name in self.state_dict():
if "output" not in name:
self.state_dict()[name].copy_(state_dict[name])
else:
super().load_state_dict(
{name: param for name, param in state_dict.items()},
strict=False)