"""
Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""
import os, sys, time, shutil, tempfile, datetime, pathlib, subprocess
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import datetime
from . import transforms, io, dynamics, utils
sz = 3
def convbatchrelu(in_channels, out_channels, sz):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
nn.BatchNorm2d(out_channels, eps=1e-5),
nn.ReLU(inplace=True),
)
def batchconv(in_channels, out_channels, sz):
return nn.Sequential(
nn.BatchNorm2d(in_channels, eps=1e-5),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
)
def batchconv0(in_channels, out_channels, sz):
return nn.Sequential(
nn.BatchNorm2d(in_channels, eps=1e-5),
nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
)
[docs]class resdown(nn.Module):
def __init__(self, in_channels, out_channels, sz):
super().__init__()
self.conv = nn.Sequential()
self.proj = batchconv0(in_channels, out_channels, 1)
for t in range(4):
if t==0:
self.conv.add_module('conv_%d'%t, batchconv(in_channels, out_channels, sz))
else:
self.conv.add_module('conv_%d'%t, batchconv(out_channels, out_channels, sz))
def forward(self, x):
x = self.proj(x) + self.conv[1](self.conv[0](x))
x = x + self.conv[3](self.conv[2](x))
return x
[docs]class convdown(nn.Module):
def __init__(self, in_channels, out_channels, sz):
super().__init__()
self.conv = nn.Sequential()
for t in range(2):
if t==0:
self.conv.add_module('conv_%d'%t, batchconv(in_channels, out_channels, sz))
else:
self.conv.add_module('conv_%d'%t, batchconv(out_channels, out_channels, sz))
def forward(self, x):
x = self.conv[0](x)
x = self.conv[1](x)
return x
[docs]class downsample(nn.Module):
def __init__(self, nbase, sz, residual_on=True):
super().__init__()
self.down = nn.Sequential()
self.maxpool = nn.MaxPool2d(2, 2)
for n in range(len(nbase)-1):
if residual_on:
self.down.add_module('res_down_%d'%n, resdown(nbase[n], nbase[n+1], sz))
else:
self.down.add_module('conv_down_%d'%n, convdown(nbase[n], nbase[n+1], sz))
def forward(self, x):
xd = []
for n in range(len(self.down)):
if n>0:
y = self.maxpool(xd[n-1])
else:
y = x
xd.append(self.down[n](y))
return xd
[docs]class batchconvstyle(nn.Module):
def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False):
super().__init__()
self.concatenation = concatenation
if concatenation:
self.conv = batchconv(in_channels*2, out_channels, sz)
self.full = nn.Linear(style_channels, out_channels*2)
else:
self.conv = batchconv(in_channels, out_channels, sz)
self.full = nn.Linear(style_channels, out_channels)
def forward(self, style, x, mkldnn=False, y=None):
if y is not None:
if self.concatenation:
x = torch.cat((y, x), dim=1)
else:
x = x + y
feat = self.full(style)
if mkldnn:
x = x.to_dense()
y = (x + feat.unsqueeze(-1).unsqueeze(-1)).to_mkldnn()
else:
y = x + feat.unsqueeze(-1).unsqueeze(-1)
y = self.conv(y)
return y
[docs]class resup(nn.Module):
def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False):
super().__init__()
self.conv = nn.Sequential()
self.conv.add_module('conv_0', batchconv(in_channels, out_channels, sz))
self.conv.add_module('conv_1', batchconvstyle(out_channels, out_channels, style_channels, sz, concatenation=concatenation))
self.conv.add_module('conv_2', batchconvstyle(out_channels, out_channels, style_channels, sz))
self.conv.add_module('conv_3', batchconvstyle(out_channels, out_channels, style_channels, sz))
self.proj = batchconv0(in_channels, out_channels, 1)
def forward(self, x, y, style, mkldnn=False):
x = self.proj(x) + self.conv[1](style, self.conv[0](x), y=y, mkldnn=mkldnn)
x = x + self.conv[3](style, self.conv[2](style, x, mkldnn=mkldnn), mkldnn=mkldnn)
return x
[docs]class convup(nn.Module):
def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False):
super().__init__()
self.conv = nn.Sequential()
self.conv.add_module('conv_0', batchconv(in_channels, out_channels, sz))
self.conv.add_module('conv_1', batchconvstyle(out_channels, out_channels, style_channels, sz, concatenation=concatenation))
def forward(self, x, y, style, mkldnn=False):
x = self.conv[1](style, self.conv[0](x), y=y)
return x
[docs]class make_style(nn.Module):
def __init__(self):
super().__init__()
#self.pool_all = nn.AvgPool2d(28)
self.flatten = nn.Flatten()
def forward(self, x0):
#style = self.pool_all(x0)
style = F.avg_pool2d(x0, kernel_size=(x0.shape[-2],x0.shape[-1]))
style = self.flatten(style)
style = style / torch.sum(style**2, axis=1, keepdim=True)**.5
return style
[docs]class upsample(nn.Module):
def __init__(self, nbase, sz, residual_on=True, concatenation=False):
super().__init__()
self.upsampling = nn.Upsample(scale_factor=2, mode='nearest')
self.up = nn.Sequential()
for n in range(1,len(nbase)):
if residual_on:
self.up.add_module('res_up_%d'%(n-1),
resup(nbase[n], nbase[n-1], nbase[-1], sz, concatenation))
else:
self.up.add_module('conv_up_%d'%(n-1),
convup(nbase[n], nbase[n-1], nbase[-1], sz, concatenation))
def forward(self, style, xd, mkldnn=False):
x = self.up[-1](xd[-1], xd[-1], style, mkldnn=mkldnn)
for n in range(len(self.up)-2,-1,-1):
if mkldnn:
x = self.upsampling(x.to_dense()).to_mkldnn()
else:
x = self.upsampling(x)
x = self.up[n](x, xd[n], style, mkldnn=mkldnn)
return x
[docs]class CPnet(nn.Module):
def __init__(self, nbase, nout, sz,
residual_on=True, style_on=True,
concatenation=False, mkldnn=False,
diam_mean=30.):
super(CPnet, self).__init__()
self.nbase = nbase
self.nout = nout
self.sz = sz
self.residual_on = residual_on
self.style_on = style_on
self.concatenation = concatenation
self.mkldnn = mkldnn if mkldnn is not None else False
self.downsample = downsample(nbase, sz, residual_on=residual_on)
nbaseup = nbase[1:]
nbaseup.append(nbaseup[-1])
self.upsample = upsample(nbaseup, sz, residual_on=residual_on, concatenation=concatenation)
self.make_style = make_style()
self.output = batchconv(nbaseup[0], nout, 1)
self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean, requires_grad=False)
self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean, requires_grad=False)
self.style_on = style_on
def forward(self, data):
if self.mkldnn:
data = data.to_mkldnn()
T0 = self.downsample(data)
if self.mkldnn:
style = self.make_style(T0[-1].to_dense())
else:
style = self.make_style(T0[-1])
style0 = style
if not self.style_on:
style = style * 0
T0 = self.upsample(style, T0, self.mkldnn)
T0 = self.output(T0)
if self.mkldnn:
T0 = T0.to_dense()
#T1 = T1.to_dense()
return T0, style0
def save_model(self, filename):
torch.save(self.state_dict(), filename)
def load_model(self, filename, device=None):
if (device is not None) and (device.type != 'cpu'):
state_dict = torch.load(filename, map_location=device)
else:
self.__init__(self.nbase,
self.nout,
self.sz,
self.residual_on,
self.style_on,
self.concatenation,
self.mkldnn,
self.diam_mean)
state_dict = torch.load(filename, map_location=torch.device('cpu'))
self.load_state_dict(dict([(name, param) for name, param in state_dict.items()]), strict=False)