# Source code for cellpose.dynamics

```"""
Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""

import time, os
from scipy.ndimage import maximum_filter1d, find_objects
import torch
import numpy as np
import tifffile
from tqdm import trange
from numba import njit, float32, int32, vectorize
import cv2
import fastremap

import logging
dynamics_logger = logging.getLogger(__name__)

from . import utils, metrics, transforms

import torch
from torch import optim, nn
from . import resnet_torch
TORCH_ENABLED = True
torch_GPU = torch.device('cuda')
torch_CPU = torch.device('cpu')

@njit('(float64[:], int32[:], int32[:], int32, int32, int32, int32)', nogil=True)
def _extend_centers(T,y,x,ymed,xmed,Lx, niter):
""" run diffusion from center of mask (ymed, xmed) on mask pixels (y, x)
Parameters
--------------
T: float64, array
_ x Lx array that diffusion is run in
y: int32, array
x: int32, array
ymed: int32
xmed: int32
Lx: int32
niter: int32
number of iterations to run diffusion
Returns
---------------
T: float64, array
amount of diffused particles at each pixel
"""

for t in range(niter):
T[ymed*Lx + xmed] += 1
T[y*Lx + x] = 1/9. * (T[y*Lx + x] + T[(y-1)*Lx + x]   + T[(y+1)*Lx + x] +
T[y*Lx + x-1]     + T[y*Lx + x+1] +
T[(y-1)*Lx + x-1] + T[(y-1)*Lx + x+1] +
T[(y+1)*Lx + x-1] + T[(y+1)*Lx + x+1])
return T

def _extend_centers_gpu(neighbors, centers, isneighbor, Ly, Lx, n_iter=200, device=torch.device('cuda')):
""" runs diffusion on GPU to generate flows for training images or quality control

neighbors is 9 x pixels in masks,
isneighbor is valid neighbor boolean 9 x pixels

"""
if device is not None:
device = device
nimg = neighbors.shape // 9
pt = torch.from_numpy(neighbors).to(device)

T = torch.zeros((nimg,Ly,Lx), dtype=torch.double, device=device)
meds = torch.from_numpy(centers.astype(int)).to(device).long()
isneigh = torch.from_numpy(isneighbor).to(device)
for i in range(n_iter):
T[:, meds[:,0], meds[:,1]] +=1
Tneigh = T[:, pt[:,:,0], pt[:,:,1]]
Tneigh *= isneigh
T[:, pt[0,:,0], pt[0,:,1]] = Tneigh.mean(axis=1)
del meds, isneigh, Tneigh
T = torch.log(1.+ T)
del pt
mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
return mu_torch

""" convert masks to flows using diffusion from center pixel
Center of masks where diffusion starts is defined using COM
Parameters
-------------
masks: int, 2D or 3D array
Returns
-------------
mu: float, 3D or 4D array
flows in Y = mu[-2], flows in X = mu[-1].
if masks are 3D, flows in Z = mu.
mu_c: float, 2D or 3D array
for each pixel, the distance to the center of the mask
in which it resides
"""
if device is None:
device = torch.device('cuda')

Ly, Lx = Ly0+2, Lx0+2

neighborsY = np.stack((y, y-1, y+1,
y, y, y-1,
y-1, y+1, y+1), axis=0)
neighborsX = np.stack((x, x, x,
x-1, x+1, x-1,
x+1, x-1, x+1), axis=0)
neighbors = np.stack((neighborsY, neighborsX), axis=-1)

for i,si in enumerate(slices):
if si is not None:
sr,sc = si
ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1
yi,xi = np.nonzero(masks[sr, sc] == (i+1))
ymed = np.median(yi)
xmed = np.median(xi)
imin = np.argmin((xi-xmed)**2 + (yi-ymed)**2)
xmed = xi[imin]
ymed = yi[imin]
centers[i,0] = ymed + sr.start
centers[i,1] = xmed + sc.start

# get neighbor validator (not all neighbors are in same mask)
ext = np.array([[sr.stop - sr.start + 1, sc.stop - sc.start + 1] for sr, sc in slices])
n_iter = 2 * (ext.sum(axis=1)).max()
# run diffusion
mu = _extend_centers_gpu(neighbors, centers, isneighbor, Ly, Lx,
n_iter=n_iter, device=device)

# normalize
mu /= (1e-20 + (mu**2).sum(axis=0)**0.5)

# put into original image
mu0 = np.zeros((2, Ly0, Lx0))
mu0[:, y-1, x-1] = mu
mu_c = np.zeros_like(mu0)
return mu0, mu_c

""" convert masks to flows using diffusion from center pixel
Center of masks where diffusion starts is defined to be the
closest pixel to the median of all pixels that is inside the
mask. Result of diffusion is converted into flows by computing
the gradients of the diffusion density map.
Parameters
-------------
Returns
-------------
mu: float, 3D array
flows in Y = mu[-2], flows in X = mu[-1].
if masks are 3D, flows in Z = mu.
mu_c: float, 2D array
for each pixel, the distance to the center of the mask
in which it resides
"""

mu = np.zeros((2, Ly, Lx), np.float64)
mu_c = np.zeros((Ly, Lx), np.float64)

s2 = (.15 * dia)**2
for i,si in enumerate(slices):
if si is not None:
sr,sc = si
ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1
y,x = np.nonzero(masks[sr, sc] == (i+1))
y = y.astype(np.int32) + 1
x = x.astype(np.int32) + 1
ymed = np.median(y)
xmed = np.median(x)
imin = np.argmin((x-xmed)**2 + (y-ymed)**2)
xmed = x[imin]
ymed = y[imin]

d2 = (x-xmed)**2 + (y-ymed)**2
mu_c[sr.start+y-1, sc.start+x-1] = np.exp(-d2/s2)

niter = 2*np.int32(np.ptp(x) + np.ptp(y))
T = np.zeros((ly+2)*(lx+2), np.float64)
T = _extend_centers(T, y, x, ymed, xmed, np.int32(lx), np.int32(niter))
T[(y+1)*lx + x+1] = np.log(1.+T[(y+1)*lx + x+1])

dy = T[(y+1)*lx + x] - T[(y-1)*lx + x]
dx = T[y*lx + x+1] - T[y*lx + x-1]
mu[:, sr.start+y-1, sc.start+x-1] = np.stack((dy,dx))

mu /= (1e-20 + (mu**2).sum(axis=0)**0.5)

return mu, mu_c

""" convert masks to flows using diffusion from center pixel

Center of masks where diffusion starts is defined to be the
closest pixel to the median of all pixels that is inside the
mask. Result of diffusion is converted into flows by computing
the gradients of the diffusion density map.

Parameters
-------------

masks: int, 2D or 3D array

Returns
-------------

mu: float, 3D or 4D array
flows in Y = mu[-2], flows in X = mu[-1].
if masks are 3D, flows in Z = mu.

mu_c: float, 2D or 3D array
for each pixel, the distance to the center of the mask
in which it resides

"""

if use_gpu:
if use_gpu and device is None:
device = torch_GPU
elif device is None:
device = torch_CPU
else:

mu = np.zeros((3, Lz, Ly, Lx), np.float32)
for z in range(Lz):
mu[[1,2], z] += mu0
for y in range(Ly):
mu[[0,2], :, y] += mu0
for x in range(Lx):
mu[[0,1], :, :, x] += mu0
return mu
return mu

else:
raise ValueError('masks_to_flows only takes 2D or 3D arrays')

[docs]def labels_to_flows(labels, files=None, use_gpu=False, device=None, redo_flows=False):
""" convert labels (list of masks or flows) to flows for training model

if files is not None, flows are saved to files to be reused

Parameters
--------------

labels: list of ND-arrays
labels[k] can be 2D or 3D, if [3 x Ly x Lx] then it is assumed that flows were precomputed.
Otherwise labels[k] or labels[k] (if 2D) is used to create flows and cell probabilities.

Returns
--------------

flows: list of [4 x Ly x Lx] arrays
flows[k] is labels[k], flows[k] is cell distance transform, flows[k] is Y flow,
flows[k] is X flow, and flows[k] is heat distribution

"""
nimg = len(labels)
if labels.ndim < 3:
labels = [labels[n][np.newaxis,:,:] for n in range(nimg)]

if labels.shape == 1 or labels.ndim < 3 or redo_flows: # flows need to be recomputed

dynamics_logger.info('computing flows for labels')

# compute flows; labels are fixed here to be unique, so they need to be passed back
# make sure labels are unique!
labels = [fastremap.renumber(label, in_place=True) for label in labels]
veci = [masks_to_flows(labels[n],use_gpu=use_gpu, device=device) for n in trange(nimg)]

# concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
flows = [np.concatenate((labels[n], labels[n]>0.5, veci[n]), axis=0).astype(np.float32)
for n in range(nimg)]
if files is not None:
for flow, file in zip(flows, files):
file_name = os.path.splitext(file)
tifffile.imwrite(file_name+'_flows.tif', flow)
else:
dynamics_logger.info('flows precomputed')
flows = [labels[n].astype(np.float32) for n in range(nimg)]
return flows

@njit(['(int16[:,:,:], float32[:], float32[:], float32[:,:])',
'(float32[:,:,:], float32[:], float32[:], float32[:,:])'], cache=True)
def map_coordinates(I, yc, xc, Y):
"""
bilinear interpolation of image 'I' in-place with ycoordinates yc and xcoordinates xc to Y

Parameters
-------------
I : C x Ly x Lx
yc : ni
new y coordinates
xc : ni
new x coordinates
Y : C x ni
I sampled at (yc,xc)
"""
C,Ly,Lx = I.shape
yc_floor = yc.astype(np.int32)
xc_floor = xc.astype(np.int32)
yc = yc - yc_floor
xc = xc - xc_floor
for i in range(yc_floor.shape):
yf = min(Ly-1, max(0, yc_floor[i]))
xf = min(Lx-1, max(0, xc_floor[i]))
yf1= min(Ly-1, yf+1)
xf1= min(Lx-1, xf+1)
y = yc[i]
x = xc[i]
for c in range(C):
Y[c,i] = (np.float32(I[c, yf, xf]) * (1 - y) * (1 - x) +
np.float32(I[c, yf, xf1]) * (1 - y) * x +
np.float32(I[c, yf1, xf]) * y * (1 - x) +
np.float32(I[c, yf1, xf1]) * y * x )

def steps2D_interp(p, dP, niter, use_gpu=False, device=None):
shape = dP.shape[1:]
if use_gpu:
if device is None:
device = torch_GPU
shape = np.array(shape)[[1,0]].astype('float')-1  # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1
pt = torch.from_numpy(p[[1,0]].T).float().to(device).unsqueeze(0).unsqueeze(0) # p is n_points by 2, so pt is [1 1 2 n_points]
im = torch.from_numpy(dP[[1,0]]).float().to(device).unsqueeze(0) #covert flow numpy array to tensor on GPU, add dimension
# normalize pt between  0 and  1, normalize the flow
for k in range(2):
im[:,k,:,:] *= 2./shape[k]
pt[:,:,:,k] /= shape[k]

# normalize to between -1 and 1
pt = pt*2-1

#here is where the stepping happens
for t in range(niter):
# align_corners default is False, just added to suppress warning
dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False)

for k in range(2): #clamp the final pixel locations
pt[:,:,:,k] = torch.clamp(pt[:,:,:,k] + dPt[:,k,:,:], -1., 1.)

#undo the normalization from before, reverse order of operations
pt = (pt+1)*0.5
for k in range(2):
pt[:,:,:,k] *= shape[k]

p =  pt[:,:,:,[1,0]].cpu().numpy().squeeze().T
return p

else:
dPt = np.zeros(p.shape, np.float32)

for t in range(niter):
map_coordinates(dP.astype(np.float32), p, p, dPt)
for k in range(len(p)):
p[k] = np.minimum(shape[k]-1, np.maximum(0, p[k] + dPt[k]))
return p

@njit('(float32[:,:,:,:],float32[:,:,:,:], int32[:,:], int32)', nogil=True)
def steps3D(p, dP, inds, niter):
""" run dynamics of pixels to recover masks in 3D

Euler integration of dynamics dP for niter steps

Parameters
----------------

p: float32, 4D array
pixel locations [axis x Lz x Ly x Lx] (start at initial meshgrid)

dP: float32, 4D array
flows [axis x Lz x Ly x Lx]

inds: int32, 2D array
non-zero pixels to run dynamics on [npixels x 3]

niter: int32
number of iterations of dynamics to run

Returns
---------------

p: float32, 4D array
final locations of each pixel after dynamics

"""
shape = p.shape[1:]
for t in range(niter):
#pi = p.astype(np.int32)
for j in range(inds.shape):
z = inds[j,0]
y = inds[j,1]
x = inds[j,2]
p0, p1, p2 = int(p[0,z,y,x]), int(p[1,z,y,x]), int(p[2,z,y,x])
p[0,z,y,x] = min(shape-1, max(0, p[0,z,y,x] + dP[0,p0,p1,p2]))
p[1,z,y,x] = min(shape-1, max(0, p[1,z,y,x] + dP[1,p0,p1,p2]))
p[2,z,y,x] = min(shape-1, max(0, p[2,z,y,x] + dP[2,p0,p1,p2]))
return p

@njit('(float32[:,:,:], float32[:,:,:], int32[:,:], int32)', nogil=True)
def steps2D(p, dP, inds, niter):
""" run dynamics of pixels to recover masks in 2D

Euler integration of dynamics dP for niter steps

Parameters
----------------

p: float32, 3D array
pixel locations [axis x Ly x Lx] (start at initial meshgrid)

dP: float32, 3D array
flows [axis x Ly x Lx]

inds: int32, 2D array
non-zero pixels to run dynamics on [npixels x 2]

niter: int32
number of iterations of dynamics to run

Returns
---------------

p: float32, 3D array
final locations of each pixel after dynamics

"""
shape = p.shape[1:]
for t in range(niter):
for j in range(inds.shape):
# starting coordinates
y = inds[j,0]
x = inds[j,1]
p0, p1 = int(p[0,y,x]), int(p[1,y,x])
step = dP[:,p0,p1]
for k in range(p.shape):
p[k,y,x] = min(shape[k]-1, max(0, p[k,y,x] + step[k]))
return p

[docs]def follow_flows(dP, mask=None, niter=200, interp=True, use_gpu=True, device=None):
""" define pixels and run dynamics to recover masks in 2D

Pixels are meshgrid. Only pixels with non-zero cell-probability
are used (as defined by inds)

Parameters
----------------

dP: float32, 3D or 4D array
flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]

niter: int (optional, default 200)
number of iterations of dynamics to run

interp: bool (optional, default True)
interpolate during 2D dynamics (not available in 3D)
(in previous versions + paper it was False)

use_gpu: bool (optional, default False)
use GPU to run interpolated dynamics (faster than CPU)

Returns
---------------

p: float32, 3D or 4D array
final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx]

inds: int32, 3D or 4D array
indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx]

"""
shape = np.array(dP.shape[1:]).astype(np.int32)
niter = np.uint32(niter)
if len(shape)>2:
p = np.meshgrid(np.arange(shape), np.arange(shape),
np.arange(shape), indexing='ij')
p = np.array(p).astype(np.float32)
# run dynamics on subset of pixels
inds = np.array(np.nonzero(np.abs(dP)>1e-3)).astype(np.int32).T
p = steps3D(p, dP, inds, niter)
else:
p = np.meshgrid(np.arange(shape), np.arange(shape), indexing='ij')
p = np.array(p).astype(np.float32)

inds = np.array(np.nonzero(np.abs(dP)>1e-3)).astype(np.int32).T

if inds.ndim < 2 or inds.shape < 5:
return p, None

if not interp:
p = steps2D(p, dP.astype(np.float32), inds, niter)

else:
p_interp = steps2D_interp(p[:,inds[:,0], inds[:,1]], dP, niter, use_gpu=use_gpu, device=device)
p[:,inds[:,0],inds[:,1]] = p_interp
return p, inds

""" remove masks which have inconsistent flows

Uses metrics.flow_error to compute flows from predicted masks
and compare flows to predicted flows from network. Discards
masks with flow errors greater than the threshold.

Parameters
----------------

masks: int, 2D or 3D array
size [Ly x Lx] or [Lz x Ly x Lx]

flows: float, 3D or 4D array
flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]

threshold: float (optional, default 0.4)

Returns
---------------

masks: int, 2D or 3D array
size [Ly x Lx] or [Lz x Ly x Lx]

"""
if masks.size > 10000*10000 and use_gpu:

major_version, minor_version, _ = torch.__version__.split(".")

if major_version == "1" and int(minor_version) < 10:
# for PyTorch version lower than 1.10
def mem_info():
total_mem = torch.cuda.get_device_properties(0).total_memory
used_mem = torch.cuda.memory_allocated()
else:
# for PyTorch version 1.10 and above
def mem_info():
total_mem, used_mem = torch.cuda.mem_get_info()

if masks.size * 20 > mem_info():
dynamics_logger.warning('WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold')
dynamics_logger.info('turn off QC step with flow_threshold=0 if too slow')
use_gpu = False

merrors, _ = metrics.flow_error(masks, flows, use_gpu, device)

""" create masks using pixel convergence after running dynamics

Makes a histogram of final pixel locations p, initializes masks
at peaks of histogram and extends the masks from the peaks so that
they include all pixels with more than 2 final pixels p. Discards
masks with flow errors greater than the threshold.
Parameters
----------------
p: float32, 3D or 4D array
final locations of each pixel after dynamics,
size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
iscell: bool, 2D or 3D array
if iscell is not None, set pixels that are
iscell False to stay in their original location.
threshold: float (optional, default 0.4)
(if flows is not None)
flows: float, 3D or 4D array (optional, default None)
flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. If flows
is not None, then masks with inconsistent flows are removed using
Returns
---------------
M0: int, 2D or 3D array
size [Ly x Lx] or [Lz x Ly x Lx]

"""

pflows = []
edges = []
shape0 = p.shape[1:]
dims = len(p)
if iscell is not None:
if dims==3:
inds = np.meshgrid(np.arange(shape0), np.arange(shape0),
np.arange(shape0), indexing='ij')
elif dims==2:
inds = np.meshgrid(np.arange(shape0), np.arange(shape0),
indexing='ij')
for i in range(dims):
p[i, ~iscell] = inds[i][~iscell]

for i in range(dims):
pflows.append(p[i].flatten().astype('int32'))

h,_ = np.histogramdd(tuple(pflows), bins=edges)
hmax = h.copy()
for i in range(dims):
hmax = maximum_filter1d(hmax, 5, axis=i)

seeds = np.nonzero(np.logical_and(h-hmax>-1e-6, h>10))
Nmax = h[seeds]
isort = np.argsort(Nmax)[::-1]
for s in seeds:
s = s[isort]

pix = list(np.array(seeds).T)

shape = h.shape
if dims==3:
expand = np.nonzero(np.ones((3,3,3)))
else:
expand = np.nonzero(np.ones((3,3)))
for e in expand:
e = np.expand_dims(e,1)

for iter in range(5):
for k in range(len(pix)):
if iter==0:
pix[k] = list(pix[k])
newpix = []
iin = []
for i,e in enumerate(expand):
epix = e[:,np.newaxis] + np.expand_dims(pix[k][i], 0) - 1
epix = epix.flatten()
iin.append(np.logical_and(epix>=0, epix<shape[i]))
newpix.append(epix)
iin = np.all(tuple(iin), axis=0)
for p in newpix:
p = p[iin]
newpix = tuple(newpix)
igood = h[newpix]>2
for i in range(dims):
pix[k][i] = newpix[i][igood]
if iter==4:
pix[k] = tuple(pix[k])

M = np.zeros(h.shape, np.uint32)
for k in range(len(pix)):
M[pix[k]] = 1+k

for i in range(dims):
M0 = M[tuple(pflows)]

uniq, counts = fastremap.unique(M0, return_counts=True)
big = np.prod(shape0) * 0.4
bigc = uniq[counts > big]
if len(bigc) > 0 and (len(bigc)>1 or bigc!=0):
fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
M0 = np.reshape(M0, shape0)
return M0

cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False,
min_size=15, resize=None,
use_gpu=False,device=None):
""" compute masks using dynamics from dP, cellprob, and boundary """

if np.any(cp_mask): #mask at this point is a cell cluster binary map, not labels
if p is None:
p, inds = follow_flows(dP * cp_mask / 5., niter=niter, interp=interp,
use_gpu=use_gpu, device=device)
if inds is None:
dynamics_logger.info('No cell pixels found.')
shape = resize if resize is not None else cellprob.shape
p = np.zeros((len(shape), *shape), np.uint16)

# flow thresholding factored out of get_masks
if not do_3D:
shape0 = p.shape[1:]
if mask.max()>0 and flow_threshold is not None and flow_threshold > 0:
# make sure labels are unique at output of get_masks

if resize is not None:
#if verbose:
#    dynamics_logger.info(f'resizing output with resize = {resize}')
recast = True
else:
recast = False
if recast:

else: # nothing to compute, just make it compatible
dynamics_logger.info('No cell pixels found.')
shape = resize if resize is not None else cellprob.shape
p = np.zeros((len(shape), *shape), np.uint16)