Source code for cellpose.train

import time
import os
import numpy as np
from cellpose import io, utils, models, dynamics
from cellpose.transforms import normalize_img, random_rotate_and_resize
from pathlib import Path
import torch
from torch import nn
from tqdm import trange

import logging

train_logger = logging.getLogger(__name__)

def _loss_fn_class(lbl, y, class_weights=None):
    """
    Calculates the loss function between true labels lbl and prediction y.

    Args:
        lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
        y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
        
    Returns:
        torch.Tensor: Loss value.

    """

    criterion3 = nn.CrossEntropyLoss(reduction="mean", weight=class_weights)
    loss3 = criterion3(y[:, :-3], lbl[:, 0].long())
    
    return loss3

def _loss_fn_seg(lbl, y, device):
    """
    Calculates the loss function between true labels lbl and prediction y.

    Args:
        lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
        y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
        device (torch.device): Device on which the tensors are located.

    Returns:
        torch.Tensor: Loss value.

    """
    criterion = nn.MSELoss(reduction="mean")
    criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
    veci = 5. * lbl[:, -2:]
    loss = criterion(y[:, -3:-1], veci)
    loss /= 2.
    loss2 = criterion2(y[:, -1], (lbl[:, -3] > 0.5).to(y.dtype))
    loss = loss + loss2
    return loss

def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}):
    """
    Reshapes and normalizes the input data.

    Args:
        data (list): List of input data, with channels axis first or last.
        normalize_params (dict, optional): Dictionary of normalization parameters. Defaults to {"normalize": False}.

    Returns:
        list: List of reshaped and normalized data.
    """
    if (np.array([td.ndim!=3 for td in data]).sum() > 0 or
        np.array([td.shape[0]!=3 for td in data]).sum() > 0):
        data_new = []
        for td in data:
            if td.ndim == 3:
                channel_axis0 = channel_axis if channel_axis is not None else np.array(td.shape).argmin()
                # put channel axis first 
                td = np.moveaxis(td, channel_axis0, 0)
                td = td[:3] # keep at most 3 channels
            if td.ndim == 2:
                td = np.stack((td, 0*td, 0*td), axis=0)
            elif td.ndim == 3 and td.shape[0] < 3:
                td = np.concatenate((td, np.zeros((3 - td.shape[0], *td.shape[1:]), dtype=td.dtype)), axis=0)
            data_new.append(td)
        data = data_new
    if normalize_params["normalize"]:
        data = [
            normalize_img(td, normalize=normalize_params, axis=0)
            for td in data
        ]
    return data

def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
               normalize_params={"normalize": False}):
    """
    Get a batch of images and labels.

    Args:
        inds (list): List of indices indicating which images and labels to retrieve.
        data (list or None): List of image data. If None, images will be loaded from files.
        labels (list or None): List of label data. If None, labels will be loaded from files.
        files (list or None): List of file paths for images.
        labels_files (list or None): List of file paths for labels.
        normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).

    Returns:
        tuple: A tuple containing two lists: the batch of images and the batch of labels.
    """
    if data is None:
        lbls = None
        imgs = [io.imread(files[i]) for i in inds]
        imgs = _reshape_norm(imgs, normalize_params=normalize_params)
        if labels_files is not None:
            lbls = [io.imread(labels_files[i])[1:] for i in inds]
    else:
        imgs = [data[i] for i in inds]
        lbls = [labels[i][1:] for i in inds]
    return imgs, lbls

def _reshape_norm_save(files, channels=None, channel_axis=None,
                       normalize_params={"normalize": False}):
    """ not currently used -- normalization happening on each batch if not load_files """
    files_new = []
    for f in trange(files):
        td = io.imread(f)
        if channels is not None:
            td = convert_image(td, channels=channels,
                                          channel_axis=channel_axis)
            td = td.transpose(2, 0, 1)
        if normalize_params["normalize"]:
            td = normalize_img(td, normalize=normalize_params, axis=0)
        fnew = os.path.splitext(str(f))[0] + "_cpnorm.tif"
        io.imsave(fnew, td)
        files_new.append(fnew)
    return files_new
    # else:
    #     train_files = reshape_norm_save(train_files, channels=channels,
    #                     channel_axis=channel_axis, normalize_params=normalize_params)
    # elif test_files is not None:
    #     test_files = reshape_norm_save(test_files, channels=channels,
    #                     channel_axis=channel_axis, normalize_params=normalize_params)


def _process_train_test(train_data=None, train_labels=None, train_files=None,
                        train_labels_files=None, train_probs=None, test_data=None,
                        test_labels=None, test_files=None, test_labels_files=None,
                        test_probs=None, load_files=True, min_train_masks=5,
                        compute_flows=False, normalize_params={"normalize": False}, 
                        channel_axis=None, device=None):
    """
    Process train and test data.

    Args:
        train_data (list or None): List of training data arrays.
        train_labels (list or None): List of training label arrays.
        train_files (list or None): List of training file paths.
        train_labels_files (list or None): List of training label file paths.
        train_probs (ndarray or None): Array of training probabilities.
        test_data (list or None): List of test data arrays.
        test_labels (list or None): List of test label arrays.
        test_files (list or None): List of test file paths.
        test_labels_files (list or None): List of test label file paths.
        test_probs (ndarray or None): Array of test probabilities.
        load_files (bool): Whether to load data from files.
        min_train_masks (int): Minimum number of masks required for training images.
        compute_flows (bool): Whether to compute flows.
        channels (list or None): List of channel indices to use.
        channel_axis (int or None): Axis of channel dimension.
        rgb (bool): Convert training/testing images to RGB.
        normalize_params (dict): Dictionary of normalization parameters.
        device (torch.device): Device to use for computation.

    Returns:
        tuple: A tuple containing the processed train and test data and sampling probabilities and diameters.
    """
    if device == None:
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
    
    if train_data is not None and train_labels is not None:
        # if data is loaded
        nimg = len(train_data)
        nimg_test = len(test_data) if test_data is not None else None
    else:
        # otherwise use files
        nimg = len(train_files)
        if train_labels_files is None:
            train_labels_files = [
                os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files
            ]
            train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)]
        if (test_data is not None or
                test_files is not None) and test_labels_files is None:
            test_labels_files = [
                os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files
            ]
            test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)]
        if not load_files:
            train_logger.info(">>> using files instead of loading dataset")
        else:
            # load all images
            train_logger.info(">>> loading images and labels")
            train_data = [io.imread(train_files[i]) for i in trange(nimg)]
            train_labels = [io.imread(train_labels_files[i]) for i in trange(nimg)]
        nimg_test = len(test_files) if test_files is not None else None
        if load_files and nimg_test:
            test_data = [io.imread(test_files[i]) for i in trange(nimg_test)]
            test_labels = [io.imread(test_labels_files[i]) for i in trange(nimg_test)]

    ### check that arrays are correct size
    if ((train_labels is not None and nimg != len(train_labels)) or
        (train_labels_files is not None and nimg != len(train_labels_files))):
        error_message = "train data and labels not same length"
        train_logger.critical(error_message)
        raise ValueError(error_message)
    if ((test_labels is not None and nimg_test != len(test_labels)) or
        (test_labels_files is not None and nimg_test != len(test_labels_files))):
        train_logger.warning("test data and labels not same length, not using")
        test_data, test_files = None, None
    if train_labels is not None:
        if train_labels[0].ndim < 2 or train_data[0].ndim < 2:
            error_message = "training data or labels are not at least two-dimensional"
            train_logger.critical(error_message)
            raise ValueError(error_message)
        if train_data[0].ndim > 3:
            error_message = "training data is more than three-dimensional (should be 2D or 3D array)"
            train_logger.critical(error_message)
            raise ValueError(error_message)

    ### check that flows are computed
    if train_labels is not None:
        train_labels = dynamics.labels_to_flows(train_labels, files=train_files,
                                                device=device)
        if test_labels is not None:
            test_labels = dynamics.labels_to_flows(test_labels, files=test_files,
                                                   device=device)
    elif compute_flows:
        for k in trange(nimg):
            tl = dynamics.labels_to_flows(io.imread(train_labels_files),
                                          files=train_files, device=device)
        if test_files is not None:
            for k in trange(nimg_test):
                tl = dynamics.labels_to_flows(io.imread(test_labels_files),
                                              files=test_files, device=device)

    ### compute diameters
    nmasks = np.zeros(nimg)
    diam_train = np.zeros(nimg)
    train_logger.info(">>> computing diameters")
    for k in trange(nimg):
        tl = (train_labels[k][0]
              if train_labels is not None else io.imread(train_labels_files[k])[0])
        diam_train[k], dall = utils.diameters(tl)
        nmasks[k] = len(dall)
    diam_train[diam_train < 5] = 5.
    if test_data is not None:
        diam_test = np.array(
            [utils.diameters(test_labels[k][0])[0] for k in trange(len(test_labels))])
        diam_test[diam_test < 5] = 5.
    elif test_labels_files is not None:
        diam_test = np.array([
            utils.diameters(io.imread(test_labels_files[k])[0])[0]
            for k in trange(len(test_labels_files))
        ])
        diam_test[diam_test < 5] = 5.
    else:
        diam_test = None

    ### check to remove training images with too few masks
    if min_train_masks > 0:
        nremove = (nmasks < min_train_masks).sum()
        if nremove > 0:
            train_logger.warning(
                f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set"
            )
            ikeep = np.nonzero(nmasks >= min_train_masks)[0]
            if train_data is not None:
                train_data = [train_data[i] for i in ikeep]
                train_labels = [train_labels[i] for i in ikeep]
            if train_files is not None:
                train_files = [train_files[i] for i in ikeep]
            if train_labels_files is not None:
                train_labels_files = [train_labels_files[i] for i in ikeep]
            if train_probs is not None:
                train_probs = train_probs[ikeep]
            diam_train = diam_train[ikeep]
            nimg = len(train_data)

    ### normalize probabilities
    train_probs = 1. / nimg * np.ones(nimg,
                                      "float64") if train_probs is None else train_probs
    train_probs /= train_probs.sum()
    if test_files is not None or test_data is not None:
        test_probs = 1. / nimg_test * np.ones(
            nimg_test, "float64") if test_probs is None else test_probs
        test_probs /= test_probs.sum()

    ### reshape and normalize train / test data
    normed = False
    if normalize_params["normalize"]:
        train_logger.info(f">>> normalizing {normalize_params}")
    if train_data is not None:
        train_data = _reshape_norm(train_data, channel_axis=channel_axis, 
                                   normalize_params=normalize_params)
        normed = True
    if test_data is not None:
        test_data = _reshape_norm(test_data, channel_axis=channel_axis,
                                  normalize_params=normalize_params)

    return (train_data, train_labels, train_files, train_labels_files, train_probs,
            diam_train, test_data, test_labels, test_files, test_labels_files,
            test_probs, diam_test, normed)


[docs] def train_seg(net, train_data=None, train_labels=None, train_files=None, train_labels_files=None, train_probs=None, test_data=None, test_labels=None, test_files=None, test_labels_files=None, test_probs=None, channel_axis=None, load_files=True, batch_size=1, learning_rate=1e-5, SGD=False, n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False, save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=None, min_train_masks=5, model_name=None, class_weights=None): """ Train the network with images for segmentation. Args: net (object): The network model to train. If `net` is a bfloat16 model it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned as the original dtype for consistency. train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None. train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None. train_labels_files (list or None): List of training label file paths. Defaults to None. train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None. test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None. test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None. test_labels_files (list or None): List of test label file paths. Defaults to None. test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None. load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True. batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 1. learning_rate (float or List[float], optional): Float or list/np.ndarray - learning rate for training. Defaults to 1e-5. n_epochs (int, optional): Integer - number of times to go through the whole training set during training. Defaults to 100. weight_decay (float, optional): Float - weight decay for the optimizer. Defaults to 0.1. SGD (bool, optional): Deprecated in v4.0.1+ - AdamW always used. normalize (bool or dict, optional): Boolean or dictionary - whether to normalize the data. Defaults to True. compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False. save_path (str, optional): String - where to save the trained model. Defaults to None. save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100. save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False. nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None. nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None. rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to False. min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5. bsize (int, optional): Integer - image size to use for training, must be 256 for cpsam. Defaults to None. model_name (str, optional): String - name of the network. Defaults to None. Returns: tuple: A tuple containing the path to the saved model weights, training losses, and test losses. """ if SGD: train_logger.warning("SGD is deprecated, using AdamW instead") device = net.device if bsize is not None and net.backbone == "sam_vitl" and bsize != 256: raise ValueError("bsize != 256 is not supported for cpsam, please set bsize to 256") bsize = 256 if bsize is None and net.backbone == "sam_vitl" else 384 if bsize is None else bsize original_net_dtype = net.dtype if net.dtype == torch.bfloat16: # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \ train_logger.info(">>> converting bfloat16 network to float32 for training") net.dtype = torch.float32 scale_range = 0.5 if scale_range is None else scale_range if isinstance(normalize, dict): normalize_params = {**models.normalize_default, **normalize} elif not isinstance(normalize, bool): raise ValueError("normalize parameter must be a bool or a dict") else: normalize_params = models.normalize_default normalize_params["normalize"] = normalize out = _process_train_test(train_data=train_data, train_labels=train_labels, train_files=train_files, train_labels_files=train_labels_files, train_probs=train_probs, test_data=test_data, test_labels=test_labels, test_files=test_files, test_labels_files=test_labels_files, test_probs=test_probs, load_files=load_files, min_train_masks=min_train_masks, compute_flows=compute_flows, channel_axis=channel_axis, normalize_params=normalize_params, device=net.device) (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train, test_data, test_labels, test_files, test_labels_files, test_probs, diam_test, normed) = out # already normalized, do not normalize during training if normed: kwargs = {} else: kwargs = {"normalize_params": normalize_params, "channel_axis": channel_axis} net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device) if class_weights is not None and isinstance(class_weights, (list, np.ndarray, tuple)): class_weights = torch.from_numpy(class_weights).to(device).float() print(class_weights) nimg = len(train_data) if train_data is not None else len(train_files) nimg_test = len(test_data) if test_data is not None else None nimg_test = len(test_files) if test_files is not None else nimg_test nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch # learning rate schedule LR = np.linspace(0, learning_rate, 10) LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10))) if n_epochs > 300: LR = LR[:-100] for i in range(10): LR = np.append(LR, LR[-1] / 2 * np.ones(10)) elif n_epochs > 99: LR = LR[:-50] for i in range(10): LR = np.append(LR, LR[-1] / 2 * np.ones(5)) train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}, bsize={bsize}, batch_size={batch_size}, nimg_per_epoch={nimg_per_epoch}") train_logger.info( f">>> AdamW, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}" ) optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=weight_decay) t0 = time.time() model_name = f"cellpose_{t0}" if model_name is None else model_name save_path = Path.cwd() if save_path is None else Path(save_path) filename = save_path / "models" / model_name (save_path / "models").mkdir(exist_ok=True) train_logger.info(f">>> saving model to {filename}") lavg, nsum = 0, 0 train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs) for iepoch in range(n_epochs): np.random.seed(iepoch) if nimg != nimg_per_epoch: # choose random images for epoch with probability train_probs rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), p=train_probs) else: # otherwise use all images rperm = np.random.permutation(np.arange(0, nimg)) for param_group in optimizer.param_groups: param_group["lr"] = LR[iepoch] # set learning rate net.train() for k in range(0, nimg_per_epoch, batch_size): kend = min(k + batch_size, nimg_per_epoch) inds = rperm[k:kend] imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels, files=train_files, labels_files=train_labels_files, **kwargs) diams = np.array([diam_train[i] for i in inds]) rsc = diams / net.diam_mean.item() if rescale else np.ones( len(diams), "float32") # augmentations X, lbl = random_rotate_and_resize(imgs, lbls=lbls, rescale=rsc, bsize=bsize, scale_range=scale_range, device=device)[:2] with torch.autocast(device_type=device.type, dtype=net.dtype): y = net(X)[0] loss = _loss_fn_seg(lbl, y, device) if y.shape[1] > 3: loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) loss += loss3 optimizer.zero_grad() loss.backward() optimizer.step() train_loss = loss.item() train_loss *= len(X) # keep track of average training loss across epochs lavg += train_loss nsum += len(X) # per epoch training loss train_losses[iepoch] += train_loss train_losses[iepoch] /= nimg_per_epoch if iepoch == 5 or iepoch % 10 == 0: lavgt = 0. if test_data is not None or test_files is not None: np.random.seed(42) if nimg_test != nimg_test_per_epoch: rperm = np.random.choice(np.arange(0, nimg_test), size=(nimg_test_per_epoch,), p=test_probs) else: rperm = np.random.permutation(np.arange(0, nimg_test)) for ibatch in range(0, len(rperm), batch_size): with torch.no_grad(): net.eval() inds = rperm[ibatch:ibatch + batch_size] imgs, lbls = _get_batch(inds, data=test_data, labels=test_labels, files=test_files, labels_files=test_labels_files, **kwargs) diams = np.array([diam_test[i] for i in inds]) rsc = diams / net.diam_mean.item() if rescale else np.ones( len(diams), "float32") X, lbl = random_rotate_and_resize( imgs, lbls=lbls, rescale=rsc, scale_range=scale_range, bsize=bsize, device=device)[:2] with torch.autocast(device_type=device.type, dtype=net.dtype): y = net(X)[0] loss = _loss_fn_seg(lbl, y, device) if y.shape[1] > 3: loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) loss += loss3 test_loss = loss.item() test_loss *= len(X) lavgt += test_loss lavgt /= len(rperm) test_losses[iepoch] = lavgt lavg /= nsum train_logger.info( f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" ) lavg, nsum = 0, 0 if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): if save_each and iepoch != n_epochs - 1: #separate files as model progresses filename0 = str(filename) + f"_epoch_{iepoch:04d}" else: filename0 = filename train_logger.info(f"saving network parameters to {filename0}") net.save_model(filename0) net.save_model(filename) if original_net_dtype != torch.float32: train_logger.info(f">>> converting network back to {original_net_dtype} after training") net.dtype = original_net_dtype return filename, train_losses, test_losses