Source code for cellpose.resnet_torch


import os, sys, time, shutil, tempfile, datetime, pathlib, subprocess
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import datetime


from . import transforms, io, dynamics, utils

sz = 3

def convbatchrelu(in_channels, out_channels, sz):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
        nn.BatchNorm2d(out_channels, eps=1e-5),
        nn.ReLU(inplace=True),
    )  

def batchconv(in_channels, out_channels, sz):
    return nn.Sequential(
        nn.BatchNorm2d(in_channels, eps=1e-5),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
    )  

def batchconv0(in_channels, out_channels, sz):
    return nn.Sequential(
        nn.BatchNorm2d(in_channels, eps=1e-5),
        nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
    )  

[docs]class resdown(nn.Module): def __init__(self, in_channels, out_channels, sz): super().__init__() self.conv = nn.Sequential() self.proj = batchconv0(in_channels, out_channels, 1) for t in range(4): if t==0: self.conv.add_module('conv_%d'%t, batchconv(in_channels, out_channels, sz)) else: self.conv.add_module('conv_%d'%t, batchconv(out_channels, out_channels, sz))
[docs] def forward(self, x): x = self.proj(x) + self.conv[1](self.conv[0](x)) x = x + self.conv[3](self.conv[2](x)) return x
[docs]class convdown(nn.Module): def __init__(self, in_channels, out_channels, sz): super().__init__() self.conv = nn.Sequential() for t in range(2): if t==0: self.conv.add_module('conv_%d'%t, batchconv(in_channels, out_channels, sz)) else: self.conv.add_module('conv_%d'%t, batchconv(out_channels, out_channels, sz))
[docs] def forward(self, x): x = self.conv[0](x) x = self.conv[1](x) return x
[docs]class downsample(nn.Module): def __init__(self, nbase, sz, residual_on=True): super().__init__() self.down = nn.Sequential() self.maxpool = nn.MaxPool2d(2, 2) for n in range(len(nbase)-1): if residual_on: self.down.add_module('res_down_%d'%n, resdown(nbase[n], nbase[n+1], sz)) else: self.down.add_module('conv_down_%d'%n, convdown(nbase[n], nbase[n+1], sz))
[docs] def forward(self, x): xd = [] for n in range(len(self.down)): if n>0: y = self.maxpool(xd[n-1]) else: y = x xd.append(self.down[n](y)) return xd
[docs]class batchconvstyle(nn.Module): def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False): super().__init__() self.concatenation = concatenation if concatenation: self.conv = batchconv(in_channels*2, out_channels, sz) self.full = nn.Linear(style_channels, out_channels*2) else: self.conv = batchconv(in_channels, out_channels, sz) self.full = nn.Linear(style_channels, out_channels)
[docs] def forward(self, style, x, mkldnn=False, y=None): if y is not None: if self.concatenation: x = torch.cat((y, x), dim=1) else: x = x + y feat = self.full(style) if mkldnn: x = x.to_dense() y = (x + feat.unsqueeze(-1).unsqueeze(-1)).to_mkldnn() else: y = x + feat.unsqueeze(-1).unsqueeze(-1) y = self.conv(y) return y
[docs]class resup(nn.Module): def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False): super().__init__() self.conv = nn.Sequential() self.conv.add_module('conv_0', batchconv(in_channels, out_channels, sz)) self.conv.add_module('conv_1', batchconvstyle(out_channels, out_channels, style_channels, sz, concatenation=concatenation)) self.conv.add_module('conv_2', batchconvstyle(out_channels, out_channels, style_channels, sz)) self.conv.add_module('conv_3', batchconvstyle(out_channels, out_channels, style_channels, sz)) self.proj = batchconv0(in_channels, out_channels, 1)
[docs] def forward(self, x, y, style, mkldnn=False): x = self.proj(x) + self.conv[1](style, self.conv[0](x), y=y, mkldnn=mkldnn) x = x + self.conv[3](style, self.conv[2](style, x, mkldnn=mkldnn), mkldnn=mkldnn) return x
[docs]class convup(nn.Module): def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False): super().__init__() self.conv = nn.Sequential() self.conv.add_module('conv_0', batchconv(in_channels, out_channels, sz)) self.conv.add_module('conv_1', batchconvstyle(out_channels, out_channels, style_channels, sz, concatenation=concatenation))
[docs] def forward(self, x, y, style, mkldnn=False): x = self.conv[1](style, self.conv[0](x), y=y) return x
[docs]class make_style(nn.Module): def __init__(self): super().__init__() #self.pool_all = nn.AvgPool2d(28) self.flatten = nn.Flatten()
[docs] def forward(self, x0): #style = self.pool_all(x0) style = F.avg_pool2d(x0, kernel_size=(x0.shape[-2],x0.shape[-1])) style = self.flatten(style) style = style / torch.sum(style**2, axis=1, keepdim=True)**.5 return style
[docs]class upsample(nn.Module): def __init__(self, nbase, sz, residual_on=True, concatenation=False): super().__init__() self.upsampling = nn.Upsample(scale_factor=2, mode='nearest') self.up = nn.Sequential() for n in range(1,len(nbase)): if residual_on: self.up.add_module('res_up_%d'%(n-1), resup(nbase[n], nbase[n-1], nbase[-1], sz, concatenation)) else: self.up.add_module('conv_up_%d'%(n-1), convup(nbase[n], nbase[n-1], nbase[-1], sz, concatenation))
[docs] def forward(self, style, xd, mkldnn=False): x = self.up[-1](xd[-1], xd[-1], style, mkldnn=mkldnn) for n in range(len(self.up)-2,-1,-1): if mkldnn: x = self.upsampling(x.to_dense()).to_mkldnn() else: x = self.upsampling(x) x = self.up[n](x, xd[n], style, mkldnn=mkldnn) return x
[docs]class CPnet(nn.Module): def __init__(self, nbase, nout, sz, residual_on=True, style_on=True, concatenation=False, mkldnn=False, diam_mean=30.): super(CPnet, self).__init__() self.nbase = nbase self.nout = nout self.sz = sz self.residual_on = residual_on self.style_on = style_on self.concatenation = concatenation self.mkldnn = mkldnn if mkldnn is not None else False self.downsample = downsample(nbase, sz, residual_on=residual_on) nbaseup = nbase[1:] nbaseup.append(nbaseup[-1]) self.upsample = upsample(nbaseup, sz, residual_on=residual_on, concatenation=concatenation) self.make_style = make_style() self.output = batchconv(nbaseup[0], nout, 1) self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean, requires_grad=False) self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean, requires_grad=False) self.style_on = style_on
[docs] def forward(self, data): if self.mkldnn: data = data.to_mkldnn() T0 = self.downsample(data) if self.mkldnn: style = self.make_style(T0[-1].to_dense()) else: style = self.make_style(T0[-1]) style0 = style if not self.style_on: style = style * 0 T0 = self.upsample(style, T0, self.mkldnn) T0 = self.output(T0) if self.mkldnn: T0 = T0.to_dense() #T1 = T1.to_dense() return T0, style0
def save_model(self, filename): torch.save(self.state_dict(), filename) def load_model(self, filename, device=None): if (device is not None) and (device.type != 'cpu'): state_dict = torch.load(filename, map_location=device) else: self.__init__(self.nbase, self.nout, self.sz, self.residual_on, self.style_on, self.concatenation, self.mkldnn, self.diam_mean) state_dict = torch.load(filename, map_location=torch.device('cpu')) self.load_state_dict(dict([(name, param) for name, param in state_dict.items()]), strict=False)