# Source code for cellpose.transforms

```"""
Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""

import numpy as np
import warnings
import cv2
import torch

import logging
transforms_logger = logging.getLogger(__name__)

from . import dynamics, utils

bsize = max(224, max(ly, lx))
xm = np.arange(bsize)
xm = np.abs(xm - xm.mean())
mask = 1/(1 + np.exp((xm - (bsize/2-20)) / sig))
bsize//2-lx//2 : bsize//2+lx//2+lx%2]

[docs]def unaugment_tiles(y, unet=False):
""" reverse test-time augmentations for averaging

Parameters
----------

y: float32
array that's ntiles_y x ntiles_x x chan x Ly x Lx where chan = (dY, dX, cell prob)

unet: bool (optional, False)
whether or not unet output or cellpose output

Returns
-------

y: float32

"""
for j in range(y.shape):
for i in range(y.shape):
if j%2==0 and i%2==1:
y[j,i] = y[j,i, :,::-1, :]
if not unet:
y[j,i,0] *= -1
elif j%2==1 and i%2==0:
y[j,i] = y[j,i, :,:, ::-1]
if not unet:
y[j,i,1] *= -1
elif j%2==1 and i%2==1:
y[j,i] = y[j,i, :,::-1, ::-1]
if not unet:
y[j,i,0] *= -1
y[j,i,1] *= -1
return y

[docs]def average_tiles(y, ysub, xsub, Ly, Lx):
""" average results of network over tiles

Parameters
-------------

y: float, [ntiles x nclasses x bsize x bsize]
output of cellpose network for each tile

ysub : list
list of arrays with start and end of tiles in Y of length ntiles

xsub : list
list of arrays with start and end of tiles in X of length ntiles

Ly : int
size of pre-tiled image in Y (may be larger than original image if
image size is less than bsize)

Lx : int
size of pre-tiled image in X (may be larger than original image if
image size is less than bsize)

Returns
-------------

yf: float32, [nclasses x Ly x Lx]
network output averaged over tiles

"""
Navg = np.zeros((Ly,Lx))
yf = np.zeros((y.shape, Ly, Lx), np.float32)
# taper edges of tiles
for j in range(len(ysub)):
yf[:, ysub[j]:ysub[j],  xsub[j]:xsub[j]] += y[j] * mask
yf /= Navg
return yf

[docs]def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1):
""" make tiles of image to run at test-time

if augmented, tiles are flipped and tile_overlap=2.
* original
* flipped vertically
* flipped horizontally
* flipped vertically and horizontally

Parameters
----------
imgi : float32
array that's nchan x Ly x Lx

bsize : float (optional, default 224)
size of tiles

augment : bool (optional, default False)
flip tiles and set tile_overlap=2.

tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles

Returns
-------
IMG : float32
array that's ntiles x nchan x bsize x bsize

ysub : list
list of arrays with start and end of tiles in Y of length ntiles

xsub : list
list of arrays with start and end of tiles in X of length ntiles

"""

nchan, Ly, Lx = imgi.shape
if augment:
bsize = np.int32(bsize)
# pad if image smaller than bsize
if Ly<bsize:
imgi = np.concatenate((imgi, np.zeros((nchan, bsize-Ly, Lx))), axis=1)
Ly = bsize
if Lx<bsize:
imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize-Lx))), axis=2)
Ly, Lx = imgi.shape[-2:]
# tiles overlap by half of tile size
ny = max(2, int(np.ceil(2. * Ly / bsize)))
nx = max(2, int(np.ceil(2. * Lx / bsize)))
ystart = np.linspace(0, Ly-bsize, ny).astype(int)
xstart = np.linspace(0, Lx-bsize, nx).astype(int)

ysub = []
xsub = []

# flip tiles so that overlapping segments are processed in rotation
IMG = np.zeros((len(ystart), len(xstart), nchan,  bsize, bsize), np.float32)
for j in range(len(ystart)):
for i in range(len(xstart)):
ysub.append([ystart[j], ystart[j]+bsize])
xsub.append([xstart[i], xstart[i]+bsize])
IMG[j, i] = imgi[:, ysub[-1]:ysub[-1],  xsub[-1]:xsub[-1]]
# flip tiles to allow for augmentation of overlapping segments
if j%2==0 and i%2==1:
IMG[j,i] = IMG[j,i, :,::-1, :]
elif j%2==1 and i%2==0:
IMG[j,i] = IMG[j,i, :,:, ::-1]
elif j%2==1 and i%2==1:
IMG[j,i] = IMG[j,i,:, ::-1, ::-1]
else:
tile_overlap = min(0.5, max(0.05, tile_overlap))
bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx)
bsizeY = np.int32(bsizeY)
bsizeX = np.int32(bsizeX)
# tiles overlap by 10% tile size
ny = 1 if Ly<=bsize else int(np.ceil((1.+2*tile_overlap) * Ly / bsize))
nx = 1 if Lx<=bsize else int(np.ceil((1.+2*tile_overlap) * Lx / bsize))
ystart = np.linspace(0, Ly-bsizeY, ny).astype(int)
xstart = np.linspace(0, Lx-bsizeX, nx).astype(int)

ysub = []
xsub = []
IMG = np.zeros((len(ystart), len(xstart), nchan,  bsizeY, bsizeX), np.float32)
for j in range(len(ystart)):
for i in range(len(xstart)):
ysub.append([ystart[j], ystart[j]+bsizeY])
xsub.append([xstart[i], xstart[i]+bsizeX])
IMG[j, i] = imgi[:, ysub[-1]:ysub[-1],  xsub[-1]:xsub[-1]]

return IMG, ysub, xsub, Ly, Lx

[docs]def normalize99(Y, lower=1,upper=99):
""" normalize image so 0.0 is 1st percentile and 1.0 is 99th percentile """
X = Y.copy()
x01 = np.percentile(X, lower)
x99 = np.percentile(X, upper)
X = (X - x01) / (x99 - x01)
return X

[docs]def move_axis(img, m_axis=-1, first=True):
""" move axis m_axis to first or last position """
if m_axis==-1:
m_axis = img.ndim-1
m_axis = min(img.ndim-1, m_axis)
axes = np.arange(0, img.ndim)
if first:
axes[1:m_axis+1] = axes[:m_axis]
axes = m_axis
else:
axes[m_axis:-1] = axes[m_axis+1:]
axes[-1] = m_axis
img = img.transpose(tuple(axes))
return img

# This was edited to fix a bug where single-channel images of shape (y,x) would be
# transposed to (x,y) if x<y, making the labels no longer correspond to the data.
[docs]def move_min_dim(img, force=False):
""" move minimum dimension last as channels if < 10, or force==True """
if len(img.shape) > 2: #only makese sense to do this if channel axis is already present
min_dim = min(img.shape)
if min_dim < 10 or force:
if img.shape[-1]==min_dim:
channel_axis = -1
else:
channel_axis = (img.shape).index(min_dim)
img = move_axis(img, m_axis=channel_axis, first=False)
return img

def update_axis(m_axis, to_squeeze, ndim):
if m_axis==-1:
m_axis = ndim-1
if (to_squeeze==m_axis).sum() == 1:
m_axis = None
else:
inds = np.ones(ndim, bool)
inds[to_squeeze] = False
m_axis = np.nonzero(np.arange(0, ndim)[inds]==m_axis)
if len(m_axis) > 0:
m_axis = m_axis
else:
m_axis = None
return m_axis

[docs]def convert_image(x, channels, channel_axis=None, z_axis=None,
do_3D=False, normalize=True, invert=False,
nchan=2):
""" return image with z first, channels last and normalized intensities """

# check if image is a torch array instead of numpy array
# converts torch to numpy
if torch.is_tensor(x):
transforms_logger.warning('torch array used as input, converting to numpy')
x = x.cpu().numpy()

# squeeze image, and if channel_axis or z_axis given, transpose image
if x.ndim > 3:
to_squeeze = np.array([int(isq) for isq,s in enumerate(x.shape) if s==1])
# remove channel axis if number of channels is 1
if len(to_squeeze) > 0:
channel_axis = update_axis(channel_axis, to_squeeze, x.ndim) if channel_axis is not None else channel_axis
z_axis = update_axis(z_axis, to_squeeze, x.ndim) if z_axis is not None else z_axis
x = x.squeeze()

# put z axis first
if z_axis is not None and x.ndim > 2:
x = move_axis(x, m_axis=z_axis, first=True)
if channel_axis is not None:
channel_axis += 1
if x.ndim==3:
x = x[...,np.newaxis]

# put channel axis last
if channel_axis is not None and x.ndim > 2:
x = move_axis(x, m_axis=channel_axis, first=False)
elif x.ndim == 2:
x = x[:,:,np.newaxis]

if do_3D :
if x.ndim < 3:
transforms_logger.critical('ERROR: cannot process 2D images in 3D mode')
raise ValueError('ERROR: cannot process 2D images in 3D mode')
elif x.ndim<4:
x = x[...,np.newaxis]

if channel_axis is None:
x = move_min_dim(x)

if x.ndim > 3:
transforms_logger.info('multi-stack tiff read in as having %d planes %d channels'%
(x.shape, x.shape[-1]))

if channels is not None:
channels = channels if len(channels)==1 else channels
if len(channels) < 2:
transforms_logger.critical('ERROR: two channels not specified')
raise ValueError('ERROR: two channels not specified')
x = reshape(x, channels=channels)

else:
# code above put channels last
if x.shape[-1] > nchan:
transforms_logger.warning('WARNING: more than %d channels given, use "channels" input for specifying channels - just using first %d channels to run processing'%(nchan,nchan))
x = x[...,:nchan]

if not do_3D and x.ndim>3:
transforms_logger.critical('ERROR: cannot process 4D images in 2D mode')
raise ValueError('ERROR: cannot process 4D images in 2D mode')

if x.shape[-1] < nchan:
x = np.concatenate((x,
np.tile(np.zeros_like(x), (1,1,nchan-1))),
axis=-1)

if normalize or invert:
x = normalize_img(x, invert=invert)

return x

[docs]def reshape(data, channels=[0,0], chan_first=False):
""" reshape data using channels

Parameters
----------

data : numpy array that's (Z x ) Ly x Lx x nchan
if data.ndim==8 and data.shape<8, assumed to be nchan x Ly x Lx

channels : list of int of length 2 (optional, default [0,0])
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
For instance, to train on grayscale images, input [0,0]. To train on images with cells
in green and nuclei in blue, input [2,3].

invert : bool
invert intensities

Returns
-------
data : numpy array that's (Z x ) Ly x Lx x nchan (if chan_first==False)

"""
data = data.astype(np.float32)
if data.ndim < 3:
data = data[:,:,np.newaxis]
elif data.shape<8 and data.ndim==3:
data = np.transpose(data, (1,2,0))

# use grayscale image
if data.shape[-1]==1:
data = np.concatenate((data, np.zeros_like(data)), axis=-1)
else:
if channels==0:
data = data.mean(axis=-1, keepdims=True)
data = np.concatenate((data, np.zeros_like(data)), axis=-1)
else:
chanid = [channels-1]
if channels > 0:
chanid.append(channels-1)
data = data[...,chanid]
for i in range(data.shape[-1]):
if np.ptp(data[...,i]) == 0.0:
if i==0:
warnings.warn("chan to seg' has value range of ZERO")
else:
warnings.warn("'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0")
if data.shape[-1]==1:
data = np.concatenate((data, np.zeros_like(data)), axis=-1)
if chan_first:
if data.ndim==4:
data = np.transpose(data, (3,0,1,2))
else:
data = np.transpose(data, (2,0,1))
return data

[docs]def normalize_img(img, axis=-1, invert=False):
""" normalize each channel of the image so that so that 0.0=1st percentile
and 1.0=99th percentile of image intensities

optional inversion

Parameters
------------

img: ND-array (at least 3 dimensions)

axis: channel axis to loop over for normalization

invert: invert image (useful if cells are dark instead of bright)

Returns
---------------

img: ND-array, float32
normalized image of same size

"""
if img.ndim<3:
error_message = 'Image needs to have at least 3 dimensions'
transforms_logger.critical(error_message)
raise ValueError(error_message)

img = img.astype(np.float32)
img = np.moveaxis(img, axis, 0)
for k in range(img.shape):
# ptp can still give nan's with weird images
i99 = np.percentile(img[k],99)
i1 = np.percentile(img[k],1)
if i99 - i1 > +1e-3: #np.ptp(img[k]) > 1e-3:
img[k] = normalize99(img[k])
if invert:
img[k] = -1*img[k] + 1
else:
img[k] = 0
img = np.moveaxis(img, 0, axis)
return img

[docs]def reshape_train_test(train_data, train_labels, test_data, test_labels, channels, normalize=True):
""" check sizes and reshape train and test data for training """
nimg = len(train_data)
# check that arrays are correct size
if nimg != len(train_labels):
error_message = 'train data and labels not same length'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return
if train_labels.ndim < 2 or train_data.ndim < 2:
error_message = 'training data or labels are not at least two-dimensional'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return

if train_data.ndim > 3:
error_message = 'training data is more than three-dimensional (should be 2D or 3D array)'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return

# check if test_data correct length
if not (test_data is not None and test_labels is not None and
len(test_data) > 0 and len(test_data)==len(test_labels)):
test_data = None

# make data correct shape and normalize it so that 0 and 1 are 1st and 99th percentile of data
train_data, test_data, run_test = reshape_and_normalize_data(train_data, test_data=test_data,
channels=channels, normalize=normalize)

if train_data is None:
error_message = 'training data do not all have the same number of channels'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return

if not run_test:
test_data, test_labels = None, None

return train_data, train_labels, test_data, test_labels, run_test

[docs]def reshape_and_normalize_data(train_data, test_data=None, channels=None, normalize=True):
""" inputs converted to correct shapes for *training* and rescaled so that 0.0=1st percentile
and 1.0=99th percentile of image intensities in each channel

Parameters
--------------

train_data: list of ND-arrays, float
list of training images of size [Ly x Lx], [nchan x Ly x Lx], or [Ly x Lx x nchan]

test_data: list of ND-arrays, float (optional, default None)
list of testing images of size [Ly x Lx], [nchan x Ly x Lx], or [Ly x Lx x nchan]

channels: list of int of length 2 (optional, default None)
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
For instance, to train on grayscale images, input [0,0]. To train on images with cells
in green and nuclei in blue, input [2,3].

normalize: bool (optional, True)
normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel

Returns
-------------

train_data: list of ND-arrays, float
list of training images of size [2 x Ly x Lx]

test_data: list of ND-arrays, float (optional, default None)
list of testing images of size [2 x Ly x Lx]

run_test: bool
whether or not test_data was correct size and is useable during training

"""

# if training data is less than 2D
run_test = False
for test, data in enumerate([train_data, test_data]):
if data is None:
return train_data, test_data, run_test
nimg = len(data)
for i in range(nimg):
if channels is not None:
data[i] = move_min_dim(data[i], force=True)
data[i] = reshape(data[i], channels=channels, chan_first=True)
if data[i].ndim < 3:
data[i] = data[i][np.newaxis,:,:]
if normalize:
data[i] = normalize_img(data[i], axis=0)

nchan = [data[i].shape for i in range(nimg)]
run_test = True
return train_data, test_data, run_test

[docs]def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR, no_channels=False):
""" resize image for computing flows / unresize for computing dynamics

Parameters
-------------

img0: ND-array
image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X]

Ly: int, optional

Lx: int, optional

rsz: float, optional
resize coefficient(s) for image; if Ly is None then rsz is used

interpolation: cv2 interp method (optional, default cv2.INTER_LINEAR)

Returns
--------------

imgs: ND-array
image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]

"""
if Ly is None and rsz is None:
error_message = 'must give size to resize to or factor to use for resizing'
transforms_logger.critical(error_message)
raise ValueError(error_message)

if Ly is None:
# determine Ly and Lx using rsz
if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
rsz = [rsz, rsz]
if no_channels:
Ly = int(img0.shape[-2] * rsz[-2])
Lx = int(img0.shape[-1] * rsz[-1])
else:
Ly = int(img0.shape[-3] * rsz[-2])
Lx = int(img0.shape[-2] * rsz[-1])

# no_channels useful for z-stacks, sot he third dimension is not treated as a channel
# but if this is called for grayscale images, they first become [Ly,Lx,2] so ndim=3 but
if (img0.ndim>2 and no_channels) or (img0.ndim==4 and not no_channels):
if Ly==0 or Lx==0:
raise ValueError('anisotropy too high / low -- not enough pixels to resize to ratio')
if no_channels:
imgs = np.zeros((img0.shape, Ly, Lx), np.float32)
else:
imgs = np.zeros((img0.shape, Ly, Lx, img0.shape[-1]), np.float32)
for i,img in enumerate(img0):
imgs[i] = cv2.resize(img, (Lx, Ly), interpolation=interpolation)
else:
imgs = cv2.resize(img0, (Lx, Ly), interpolation=interpolation)
return imgs

[docs]def pad_image_ND(img0, div=16, extra = 1):
""" pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D)

Parameters
-------------

img0: ND-array
image of size [nchan (x Lz) x Ly x Lx]

div: int (optional, default 16)

Returns
--------------

I: ND-array

ysub: array, int
yrange of pixels in I corresponding to img0

xsub: array, int
xrange of pixels in I corresponding to img0

"""
Lpad = int(div * np.ceil(img0.shape[-2]/div) - img0.shape[-2])
Lpad = int(div * np.ceil(img0.shape[-1]/div) - img0.shape[-1])

if img0.ndim>3:
else:

Ly, Lx = img0.shape[-2:]
return I, ysub, xsub

def normalize_field(mu):
mu /= (1e-20 + (mu**2).sum(axis=0)**0.5)
return mu

def _X2zoom(img, X2=1):
""" zoom in image

Parameters
----------
img : numpy array that's Ly x Lx

Returns
-------
img : numpy array that's Ly x Lx

"""
ny,nx = img.shape[:2]
img = cv2.resize(img, (int(nx * (2**X2)), int(ny * (2**X2))))
return img

def _image_resizer(img, resize=512, to_uint8=False):
""" resize image

Parameters
----------
img : numpy array that's Ly x Lx

resize : int
max size of image returned

to_uint8 : bool
convert image to uint8

Returns
-------
img : numpy array that's Ly x Lx, Ly,Lx<resize

"""
ny,nx = img.shape[:2]
if to_uint8:
if img.max()<=255 and img.min()>=0 and img.max()>1:
img = img.astype(np.uint8)
else:
img = img.astype(np.float32)
img -= img.min()
img /= img.max()
img *= 255
img = img.astype(np.uint8)
if np.array(img.shape).max() > resize:
if ny>nx:
nx = int(nx/ny * resize)
ny = resize
else:
ny = int(ny/nx * resize)
nx = resize
shape = (nx,ny)
img = cv2.resize(img, shape)
img = img.astype(np.uint8)
return img

[docs]def random_rotate_and_resize(X, Y=None, scale_range=1., xy = (224,224),
do_flip=True, rescale=None, unet=False, random_per_image=True):
""" augmentation by random rotation and resizing
X and Y are lists or arrays of length nimg, with dims channels x Ly x Lx (channels optional)
Parameters
----------
X: LIST of ND-arrays, float
list of image arrays of size [nchan x Ly x Lx] or [Ly x Lx]
Y: LIST of ND-arrays, float (optional, default None)
list of image labels of size [nlabels x Ly x Lx] or [Ly x Lx]. The 1st channel
of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation).
If Y.shape==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow].
If unet, second channel is dist_to_bound.
scale_range: float (optional, default 1.0)
Range of resizing of images for augmentation. Images are resized by
(1-scale_range/2) + scale_range * np.random.rand()
xy: tuple, int (optional, default (224,224))
size of transformed images to return
do_flip: bool (optional, default True)
whether or not to flip images horizontally
rescale: array, float (optional, default None)
how much to resize images by before performing augmentations
unet: bool (optional, default False)
random_per_image: bool (optional, default True)
different random rotate and resize per image
Returns
-------
imgi: ND-array, float
transformed images in array [nimg x nchan x xy x xy]
lbl: ND-array, float
transformed labels in array [nimg x nchan x xy x xy]
scale: array, float
amount each image was resized by
"""
scale_range = max(0, min(2, float(scale_range)))
nimg = len(X)
if X.ndim>2:
nchan = X.shape
else:
nchan = 1
imgi  = np.zeros((nimg, nchan, xy, xy), np.float32)

lbl = []
if Y is not None:
if Y.ndim>2:
nt = Y.shape
else:
nt = 1
lbl = np.zeros((nimg, nt, xy, xy), np.float32)

scale = np.ones(nimg, np.float32)

for n in range(nimg):
Ly, Lx = X[n].shape[-2:]

if random_per_image or n==0:
# generate random augmentation parameters
flip = np.random.rand()>.5
theta = np.random.rand() * np.pi * 2
scale[n] = (1-scale_range/2) + scale_range * np.random.rand()
if rescale is not None:
scale[n] *= 1. / rescale[n]
dxy = np.maximum(0, np.array([Lx*scale[n]-xy,Ly*scale[n]-xy]))
dxy = (np.random.rand(2,) - .5) * dxy

# create affine transform
cc = np.array([Lx/2, Ly/2])
cc1 = cc - np.array([Lx-xy, Ly-xy])/2 + dxy
pts1 = np.float32([cc,cc + np.array([1,0]), cc + np.array([0,1])])
pts2 = np.float32([cc1,
cc1 + scale[n]*np.array([np.cos(theta), np.sin(theta)]),
cc1 + scale[n]*np.array([np.cos(np.pi/2+theta), np.sin(np.pi/2+theta)])])
M = cv2.getAffineTransform(pts1,pts2)

img = X[n].copy()
if Y is not None:
labels = Y[n].copy()
if labels.ndim<3:
labels = labels[np.newaxis,:,:]

if flip and do_flip:
img = img[..., ::-1]
if Y is not None:
labels = labels[..., ::-1]
if nt > 1 and not unet:
labels = -labels

for k in range(nchan):
I = cv2.warpAffine(img[k], M, (xy,xy), flags=cv2.INTER_LINEAR)
imgi[n,k] = I

if Y is not None:
for k in range(nt):
if k==0:
lbl[n,k] = cv2.warpAffine(labels[k], M, (xy,xy), flags=cv2.INTER_NEAREST)
else:
lbl[n,k] = cv2.warpAffine(labels[k], M, (xy,xy), flags=cv2.INTER_LINEAR)

if nt > 1 and not unet:
v1 = lbl[n,2].copy()
v2 = lbl[n,1].copy()
lbl[n,1] = (-v1 * np.sin(-theta) + v2*np.cos(-theta))
lbl[n,2] = (v1 * np.cos(-theta) + v2*np.sin(-theta))

return imgi, lbl, scale
```