import time
import os
import numpy as np
from cellpose import io, transforms, utils, models, dynamics, metrics, resnet_torch
from cellpose.transforms import normalize_img
from pathlib import Path
import torch
from torch import nn
from tqdm import trange
from numba import prange
import logging
train_logger = logging.getLogger(__name__)
def _loss_fn_seg(lbl, y, device):
"""
Calculates the loss function between true labels lbl and prediction y.
Args:
lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
device (torch.device): Device on which the tensors are located.
Returns:
torch.Tensor: Loss value.
"""
criterion = nn.MSELoss(reduction="mean")
criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
veci = 5. * torch.from_numpy(lbl[:, 1:]).to(device)
loss = criterion(y[:, :2], veci)
loss /= 2.
loss2 = criterion2(y[:, -1], torch.from_numpy(lbl[:, 0] > 0.5).to(device).float())
loss = loss + loss2
return loss
def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
channels=None, channel_axis=None, rgb=False,
normalize_params={"normalize": False}):
"""
Get a batch of images and labels.
Args:
inds (list): List of indices indicating which images and labels to retrieve.
data (list or None): List of image data. If None, images will be loaded from files.
labels (list or None): List of label data. If None, labels will be loaded from files.
files (list or None): List of file paths for images.
labels_files (list or None): List of file paths for labels.
channels (list or None): List of channel indices to extract from images.
channel_axis (int or None): Axis along which the channels are located.
normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).
Returns:
tuple: A tuple containing two lists: the batch of images and the batch of labels.
"""
if data is None:
lbls = None
imgs = [io.imread(files[i]) for i in inds]
imgs = _reshape_norm(imgs, channels=channels, channel_axis=channel_axis, rgb=rgb,
normalize_params=normalize_params)
if labels_files is not None:
lbls = [io.imread(labels_files[i])[1:] for i in inds]
else:
imgs = [data[i] for i in inds]
lbls = [labels[i][1:] for i in inds]
return imgs, lbls
def pad_to_rgb(img):
if img.ndim==2 or np.ptp(img[1]) < 1e-3:
if img.ndim==2:
img = img[np.newaxis,:,:]
img = np.tile(img[:1], (3,1,1))
elif img.shape[0] < 3:
nc, Ly, Lx = img.shape
# randomly flip channels
if np.random.rand() > 0.5:
img = img[::-1]
# randomly insert blank channel
ic = np.random.randint(3)
img = np.insert(img, ic, np.zeros((3-nc, Ly, Lx), dtype=img.dtype), axis=0)
return img
def convert_to_rgb(img):
if img.ndim==2:
img = img[np.newaxis,:,:]
img = np.tile(img, (3,1,1))
elif img.shape[0] < 3:
img = img.mean(axis=0, keepdims=True)
img = transforms.normalize99(img)
img = np.tile(img, (3,1,1))
return img
def _reshape_norm(data, channels=None, channel_axis=None, rgb=False,
normalize_params={"normalize": False}):
"""
Reshapes and normalizes the input data.
Args:
data (list): List of input data.
channels (int or list, optional): Number of channels or list of channel indices to keep. Defaults to None.
channel_axis (int, optional): Axis along which the channels are located. Defaults to None.
normalize_params (dict, optional): Dictionary of normalization parameters. Defaults to {"normalize": False}.
Returns:
list: List of reshaped and normalized data.
"""
if channels is not None or channel_axis is not None:
data = [
transforms.convert_image(td, channels=channels, channel_axis=channel_axis)
for td in data
]
data = [td.transpose(2, 0, 1) for td in data]
if normalize_params["normalize"]:
data = [
transforms.normalize_img(td, normalize=normalize_params, axis=0)
for td in data
]
if rgb:
data = [pad_to_rgb(td) for td in data]
return data
def _reshape_norm_save(files, channels=None, channel_axis=None,
normalize_params={"normalize": False}):
""" not currently used -- normalization happening on each batch if not load_files """
files_new = []
for f in trange(files):
td = io.imread(f)
if channels is not None:
td = transforms.convert_image(td, channels=channels,
channel_axis=channel_axis)
td = td.transpose(2, 0, 1)
if normalize_params["normalize"]:
td = transforms.normalize_img(td, normalize=normalize_params, axis=0)
fnew = os.path.splitext(str(f))[0] + "_cpnorm.tif"
io.imsave(fnew, td)
files_new.append(fnew)
return files_new
# else:
# train_files = reshape_norm_save(train_files, channels=channels,
# channel_axis=channel_axis, normalize_params=normalize_params)
# elif test_files is not None:
# test_files = reshape_norm_save(test_files, channels=channels,
# channel_axis=channel_axis, normalize_params=normalize_params)
def _process_train_test(train_data=None, train_labels=None, train_files=None,
train_labels_files=None, train_probs=None, test_data=None,
test_labels=None, test_files=None, test_labels_files=None,
test_probs=None, load_files=True, min_train_masks=5,
compute_flows=False, channels=None, channel_axis=None,
rgb=False, normalize_params={"normalize": False},
device=torch.device("cuda")):
"""
Process train and test data.
Args:
train_data (list or None): List of training data arrays.
train_labels (list or None): List of training label arrays.
train_files (list or None): List of training file paths.
#train_labels_files (list or None): List of training label file paths.
train_probs (ndarray or None): Array of training probabilities.
test_data (list or None): List of test data arrays.
test_labels (list or None): List of test label arrays.
test_files (list or None): List of test file paths.
#test_labels_files (list or None): List of test label file paths.
test_probs (ndarray or None): Array of test probabilities.
load_files (bool): Whether to load data from files.
min_train_masks (int): Minimum number of masks required for training images.
compute_flows (bool): Whether to compute flows.
channels (list or None): List of channel indices to use.
channel_axis (int or None): Axis of channel dimension.
rgb (bool): Convert training/testing images to RGB.
normalize_params (dict): Dictionary of normalization parameters.
device (torch.device): Device to use for computation.
Returns:
tuple: A tuple containing the processed train and test data and sampling probabilities and diameters.
"""
if train_data is not None and train_labels is not None:
# if data is loaded
nimg = len(train_data)
nimg_test = len(test_data) if test_data is not None else None
else:
# otherwise use files
nimg = len(train_files)
if train_labels_files is None:
train_labels_files = [
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files
]
train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)]
if (test_data is not None or test_files is not None) and test_labels_files is None:
test_labels_files = [
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files
]
test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)]
if not load_files:
train_logger.info(">>> using files instead of loading dataset")
else:
# load all images
train_logger.info(">>> loading images and labels")
train_data = [io.imread(train_files[i]) for i in trange(nimg)]
train_labels = [io.imread(train_labels_files[i]) for i in trange(nimg)]
nimg_test = len(test_files) if test_files is not None else None
if load_files and nimg_test:
test_data = [io.imread(test_files[i]) for i in trange(nimg_test)]
test_labels = [io.imread(test_labels_files[i]) for i in trange(nimg_test)]
### check that arrays are correct size
if ((train_labels is not None and nimg != len(train_labels)) or
(train_labels_files is not None and nimg != len(train_labels_files))):
error_message = "train data and labels not same length"
train_logger.critical(error_message)
raise ValueError(error_message)
if ((test_labels is not None and nimg_test != len(test_labels)) or
(test_labels_files is not None and nimg_test != len(test_labels_files))):
train_logger.warning("test data and labels not same length, not using")
test_data, test_files = None, None
if train_labels is not None:
if train_labels[0].ndim < 2 or train_data[0].ndim < 2:
error_message = "training data or labels are not at least two-dimensional"
train_logger.critical(error_message)
raise ValueError(error_message)
if train_data[0].ndim > 3:
error_message = "training data is more than three-dimensional (should be 2D or 3D array)"
train_logger.critical(error_message)
raise ValueError(error_message)
### check that flows are computed
if train_labels is not None:
train_labels = dynamics.labels_to_flows(train_labels, files=train_files,
device=device)
if test_labels is not None:
test_labels = dynamics.labels_to_flows(test_labels, files=test_files,
device=device)
elif compute_flows:
for k in trange(nimg):
tl = dynamics.labels_to_flows(io.imread(train_labels_files),
files=train_files, device=device)
if test_files is not None:
for k in trange(nimg_test):
tl = dynamics.labels_to_flows(io.imread(test_labels_files),
files=test_files, device=device)
### compute diameters
nmasks = np.zeros(nimg)
diam_train = np.zeros(nimg)
train_logger.info(">>> computing diameters")
for k in trange(nimg):
tl = (train_labels[k][0]
if train_labels is not None else io.imread(train_labels_files[k])[0])
diam_train[k], dall = utils.diameters(tl)
nmasks[k] = len(dall)
diam_train[diam_train < 5] = 5.
if test_data is not None:
diam_test = np.array(
[utils.diameters(test_labels[k][0])[0] for k in trange(len(test_labels))])
diam_test[diam_test < 5] = 5.
elif test_labels_files is not None:
diam_test = np.array([
utils.diameters(io.imread(test_labels_files[k])[0])[0]
for k in trange(len(test_labels_files))
])
diam_test[diam_test < 5] = 5.
else:
diam_test = None
### check to remove training images with too few masks
if min_train_masks > 0:
nremove = (nmasks < min_train_masks).sum()
if nremove > 0:
train_logger.warning(
f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set"
)
ikeep = np.nonzero(nmasks >= min_train_masks)[0]
if train_data is not None:
train_data = [train_data[i] for i in ikeep]
train_labels = [train_labels[i] for i in ikeep]
if train_files is not None:
train_files = [train_files[i] for i in ikeep]
if train_labels_files is not None:
train_labels_files = [train_labels_files[i] for i in ikeep]
if train_probs is not None:
train_probs = train_probs[ikeep]
diam_train = diam_train[ikeep]
### normalize probabilities
train_probs = 1. / nimg * np.ones(nimg,
"float64") if train_probs is None else train_probs
train_probs /= train_probs.sum()
if test_files is not None or test_data is not None:
test_probs = 1. / nimg_test * np.ones(
nimg_test, "float64") if test_probs is None else test_probs
test_probs /= test_probs.sum()
### reshape and normalize train / test data
normed = False
if channels is not None or normalize_params["normalize"]:
if channels:
train_logger.info(f">>> using channels {channels}")
if normalize_params["normalize"]:
train_logger.info(f">>> normalizing {normalize_params}")
if train_data is not None:
train_data = _reshape_norm(train_data, channels=channels,
channel_axis=channel_axis, rgb=rgb,
normalize_params=normalize_params)
normed = True
if test_data is not None:
test_data = _reshape_norm(test_data, channels=channels,
channel_axis=channel_axis, rgb=rgb,
normalize_params=normalize_params)
return (train_data, train_labels, train_files, train_labels_files, train_probs,
diam_train, test_data, test_labels, test_files, test_labels_files,
test_probs, diam_test, normed)
[docs]def train_seg(net, train_data=None, train_labels=None, train_files=None,
train_labels_files=None, train_probs=None, test_data=None,
test_labels=None, test_files=None, test_labels_files=None,
test_probs=None, load_files=True, batch_size=8, learning_rate=0.005,
n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None,
channel_axis=None, rgb=False, normalize=True,
compute_flows=False, save_path=None,
save_every=100, nimg_per_epoch=None, nimg_test_per_epoch=None,
rescale=True, scale_range=None, bsize=224,
min_train_masks=5, model_name=None):
"""
Train the network with images for segmentation.
Args:
net (object): The network model to train.
train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None.
train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None.
train_labels_files (list or None): List of training label file paths. Defaults to None.
train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None.
test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None.
test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None.
test_labels_files (list or None): List of test label file paths. Defaults to None.
test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None.
load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True.
batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8.
learning_rate (float or List[float], optional): Float or list/np.ndarray - learning rate for training. Defaults to 0.005.
n_epochs (int, optional): Integer - number of times to go through the whole training set during training. Defaults to 2000.
weight_decay (float, optional): Float - weight decay for the optimizer. Defaults to 1e-5.
momentum (float, optional): Float - momentum for the optimizer. Defaults to 0.9.
SGD (bool, optional): Boolean - whether to use SGD as optimization instead of RAdam. Defaults to False.
channels (List[int], optional): List of ints - channels to use for training. Defaults to None.
channel_axis (int, optional): Integer - axis of the channel dimension in the input data. Defaults to None.
normalize (bool or dict, optional): Boolean or dictionary - whether to normalize the data. Defaults to True.
compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False.
save_path (str, optional): String - where to save the trained model. Defaults to None.
save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100.
nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None.
nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None.
rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True.
min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5.
model_name (str, optional): String - name of the network. Defaults to None.
Returns:
Path: path to saved model weights
"""
device = net.device
scale_range0 = 0.5 if rescale else 1.0
scale_range = scale_range if scale_range is not None else scale_range0
if isinstance(normalize, dict):
normalize_params = {**models.normalize_default, **normalize}
elif not isinstance(normalize, bool):
raise ValueError("normalize parameter must be a bool or a dict")
else:
normalize_params = models.normalize_default
normalize_params["normalize"] = normalize
out = _process_train_test(
train_data=train_data, train_labels=train_labels, train_files=train_files,
train_probs=train_probs,
test_data=test_data, test_labels=test_labels, test_files=test_files,
test_probs=test_probs,
load_files=load_files, min_train_masks=min_train_masks,
compute_flows=compute_flows, channels=channels, channel_axis=channel_axis,
rgb=rgb, normalize_params=normalize_params, device=net.device)
(train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
test_data, test_labels, test_files, test_labels_files, test_probs, diam_test, normed) = out
# already normalized, do not normalize during training
if normed:
kwargs = {}
else:
kwargs = {"normalize_params": normalize_params, "channels": channels,
"channel_axis": channel_axis, "rgb": rgb}
net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)
nimg = len(train_data) if train_data is not None else len(train_files)
nimg_test = len(test_data) if test_data is not None else None
nimg_test = len(test_files) if test_files is not None else nimg_test
nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
# learning rate schedule
LR = np.linspace(0, learning_rate, 10)
LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10)))
if n_epochs > 300:
LR = LR[:-100]
for i in range(10):
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
elif n_epochs > 100:
LR = LR[:-50]
for i in range(10):
LR = np.append(LR, LR[-1] / 2 * np.ones(5))
LR = LR
train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}")
if not SGD:
train_logger.info(
f">>> AdamW, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}"
)
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate,
weight_decay=weight_decay)
else:
train_logger.info(
f">>> SGD, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}, momentum={momentum:0.3f}"
)
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate,
weight_decay=weight_decay, momentum=momentum)
t0 = time.time()
model_name = f"cellpose_{t0}" if model_name is None else model_name
save_path = Path.cwd() if save_path is None else Path(save_path)
model_path = save_path / "models" / model_name
(save_path / "models").mkdir(exist_ok=True)
train_logger.info(f">>> saving model to {model_path}")
lavg, nsum = 0, 0
for iepoch in range(n_epochs):
np.random.seed(iepoch)
if nimg != nimg_per_epoch:
rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
p=train_probs)
else:
rperm = np.random.permutation(np.arange(0, nimg))
for param_group in optimizer.param_groups:
param_group["lr"] = LR[iepoch]
net.train()
for k in range(0, nimg_per_epoch, batch_size):
kend = min(k + batch_size, nimg)
inds = rperm[k:kend]
imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels,
files=train_files, labels_files=train_labels_files,
**kwargs)
diams = np.array([diam_train[i] for i in inds])
rsc = diams / net.diam_mean.item() if rescale else np.ones(len(diams), "float32")
# augmentations
imgi, lbl = transforms.random_rotate_and_resize(imgs, Y=lbls, rescale=rsc,
scale_range=scale_range, xy=(bsize, bsize))[:2]
X = torch.from_numpy(imgi).to(device)
y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss = loss.item()
train_loss *= len(imgi)
lavg += train_loss
nsum += len(imgi)
if iepoch == 5 or iepoch % 10 == 0:
lavgt = 0.
if test_data is not None or test_files is not None:
np.random.seed(42)
if nimg_test != nimg_test_per_epoch:
rperm = np.random.choice(np.arange(0, nimg_test),
size=(nimg_test_per_epoch,), p=test_probs)
else:
rperm = np.random.permutation(np.arange(0, nimg_test))
for ibatch in range(0, len(rperm), batch_size):
with torch.no_grad():
net.eval()
inds = rperm[ibatch:ibatch + batch_size]
imgs, lbls = _get_batch(inds, data=test_data,
labels=test_labels, files=test_files,
labels_files=test_labels_files,
**kwargs)
diams = np.array([diam_test[i] for i in inds])
rsc = diams / net.diam_mean.item() if rescale else np.ones(len(diams), "float32")
imgi, lbl = transforms.random_rotate_and_resize(
imgs, Y=lbls, rescale=rsc, scale_range=scale_range,
xy=(bsize, bsize))[:2]
X = torch.from_numpy(imgi).to(device)
y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
test_loss = loss.item()
test_loss *= len(imgi)
lavgt += test_loss
lavgt /= len(rperm)
lavg /= nsum
train_logger.info(
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.4f}, time {time.time()-t0:.2f}s"
)
lavg, nsum = 0, 0
if iepoch > 0 and iepoch % save_every == 0:
net.save_model(model_path)
net.save_model(model_path)
return model_path
[docs]def train_size(net, pretrained_model, train_data=None, train_labels=None,
train_files=None, train_labels_files=None, train_probs=None,
test_data=None, test_labels=None, test_files=None,
test_labels_files=None, test_probs=None, load_files=True,
min_train_masks=5, channels=None, channel_axis=None, rgb=False,
normalize=True, nimg_per_epoch=None, nimg_test_per_epoch=None,
batch_size=64, scale_range=1.0, bsize=512,
l2_regularization=1.0, n_epochs=10):
"""Train the size model.
Args:
net (object): The neural network model.
pretrained_model (str): The path to the pretrained model.
train_data (numpy.ndarray, optional): The training data. Defaults to None.
train_labels (numpy.ndarray, optional): The training labels. Defaults to None.
train_files (list, optional): The training file paths. Defaults to None.
train_labels_files (list, optional): The training label file paths. Defaults to None.
train_probs (numpy.ndarray, optional): The training probabilities. Defaults to None.
test_data (numpy.ndarray, optional): The test data. Defaults to None.
test_labels (numpy.ndarray, optional): The test labels. Defaults to None.
test_files (list, optional): The test file paths. Defaults to None.
test_labels_files (list, optional): The test label file paths. Defaults to None.
test_probs (numpy.ndarray, optional): The test probabilities. Defaults to None.
load_files (bool, optional): Whether to load files. Defaults to True.
min_train_masks (int, optional): The minimum number of training masks. Defaults to 5.
channels (list, optional): The channels. Defaults to None.
channel_axis (int, optional): The channel axis. Defaults to None.
normalize (bool or dict, optional): Whether to normalize the data. Defaults to True.
nimg_per_epoch (int, optional): The number of images per epoch. Defaults to None.
nimg_test_per_epoch (int, optional): The number of test images per epoch. Defaults to None.
batch_size (int, optional): The batch size. Defaults to 64.
l2_regularization (float, optional): The L2 regularization factor. Defaults to 1.0.
n_epochs (int, optional): The number of epochs. Defaults to 10.
Returns:
dict: The trained size model parameters.
"""
if isinstance(normalize, dict):
normalize_params = {**models.normalize_default, **normalize}
elif not isinstance(normalize, bool):
raise ValueError("normalize parameter must be a bool or a dict")
else:
normalize_params = models.normalize_default
normalize_params["normalize"] = normalize
out = _process_train_test(
train_data=train_data, train_labels=train_labels, train_files=train_files,
train_labels_files=train_labels_files, train_probs=train_probs,
test_data=test_data, test_labels=test_labels, test_files=test_files,
test_labels_files=test_labels_files, test_probs=test_probs,
load_files=load_files, min_train_masks=min_train_masks, compute_flows=False,
channels=channels, channel_axis=channel_axis, normalize_params=normalize_params,
device=net.device)
(train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
test_data, test_labels, test_files, test_labels_files, test_probs, diam_test, normed) = out
# already normalized, do not normalize during training
if normed:
kwargs = {}
else:
kwargs = {"normalize_params": normalize_params, "channels": channels,
"channel_axis": channel_axis, "rgb": rgb}
nimg = len(train_data) if train_data is not None else len(train_files)
nimg_test = len(test_data) if test_data is not None else None
nimg_test = len(test_files) if test_files is not None else nimg_test
nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
diam_mean = net.diam_mean.item()
device = net.device
net.eval()
styles = np.zeros((n_epochs * nimg_per_epoch, 256), np.float32)
diams = np.zeros((n_epochs * nimg_per_epoch,), np.float32)
tic = time.time()
for iepoch in range(n_epochs):
np.random.seed(iepoch)
if nimg != nimg_per_epoch:
rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
p=train_probs)
else:
rperm = np.random.permutation(np.arange(0, nimg))
for ibatch in range(0, nimg_per_epoch, batch_size):
inds_batch = np.arange(ibatch, min(nimg_per_epoch, ibatch + batch_size))
inds = rperm[inds_batch]
imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels,
files=train_files,
**kwargs)
diami = diam_train[inds].copy()
imgi, lbl, scale = transforms.random_rotate_and_resize(
imgs, scale_range=scale_range, xy=(bsize, bsize))
imgi = torch.from_numpy(imgi).to(device)
with torch.no_grad():
feat = net(imgi)[1]
indsi = inds_batch + nimg_per_epoch * iepoch
styles[indsi] = feat.cpu().numpy()
diams[indsi] = np.log(diami) - np.log(diam_mean) + np.log(scale)
del feat
train_logger.info("ran %d epochs in %0.3f sec" %
(iepoch + 1, time.time() - tic))
l2_regularization = 1.
# create model
smean = styles.copy().mean(axis=0)
X = ((styles.copy() - smean).T).copy()
ymean = diams.copy().mean()
y = diams.copy() - ymean
A = np.linalg.solve(X @ X.T + l2_regularization * np.eye(X.shape[0]), X @ y)
ypred = A @ X
train_logger.info("train correlation: %0.4f" % np.corrcoef(y, ypred)[0, 1])
if nimg_test:
np.random.seed(0)
styles_test = np.zeros((nimg_test_per_epoch, 256), np.float32)
diams_test = np.zeros((nimg_test_per_epoch,), np.float32)
diams_test0 = np.zeros((nimg_test_per_epoch,), np.float32)
if nimg_test != nimg_test_per_epoch:
rperm = np.random.choice(np.arange(0, nimg_test),
size=(nimg_test_per_epoch,), p=test_probs)
else:
rperm = np.random.permutation(np.arange(0, nimg_test))
for ibatch in range(0, nimg_test_per_epoch, batch_size):
inds_batch = np.arange(ibatch, min(nimg_test_per_epoch,
ibatch + batch_size))
inds = rperm[inds_batch]
imgs, lbls = _get_batch(inds, data=test_data, labels=test_labels,
files=test_files, labels_files=test_labels_files,
**kwargs)
diami = diam_test[inds].copy()
imgi, lbl, scale = transforms.random_rotate_and_resize(
imgs, Y=lbls, scale_range=scale_range, xy=(bsize, bsize))
imgi = torch.from_numpy(imgi).to(device)
diamt = np.array([utils.diameters(lbl0[0])[0] for lbl0 in lbl])
diamt = np.maximum(5., diamt)
with torch.no_grad():
feat = net(imgi)[1]
styles_test[inds_batch] = feat.cpu().numpy()
diams_test[inds_batch] = np.log(diami) - np.log(diam_mean) + np.log(scale)
diams_test0[inds_batch] = diamt
diam_test_pred = np.exp(A @ (styles_test - smean).T + np.log(diam_mean) + ymean)
diam_test_pred = np.maximum(5., diam_test_pred)
train_logger.info("test correlation: %0.4f" %
np.corrcoef(diams_test0, diam_test_pred)[0, 1])
pretrained_size = str(pretrained_model) + "_size.npy"
params = {"A": A, "smean": smean, "diam_mean": diam_mean, "ymean": ymean}
np.save(pretrained_size, params)
train_logger.info("model saved to " + pretrained_size)
return params