import os, sys, time, shutil, tempfile, datetime, pathlib, subprocess
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import datetime
from . import transforms, io, dynamics, utils
sz = 3
def convbatchrelu(in_channels, out_channels, sz):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
nn.BatchNorm2d(out_channels, eps=1e-5),
nn.ReLU(inplace=True),
)
def batchconv(in_channels, out_channels, sz):
return nn.Sequential(
nn.BatchNorm2d(in_channels, eps=1e-5),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
)
def batchconv0(in_channels, out_channels, sz):
return nn.Sequential(
nn.BatchNorm2d(in_channels, eps=1e-5),
nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
)
[docs]class resdown(nn.Module):
def __init__(self, in_channels, out_channels, sz):
super().__init__()
self.conv = nn.Sequential()
self.proj = batchconv0(in_channels, out_channels, 1)
for t in range(4):
if t==0:
self.conv.add_module('conv_%d'%t, batchconv(in_channels, out_channels, sz))
else:
self.conv.add_module('conv_%d'%t, batchconv(out_channels, out_channels, sz))
[docs] def forward(self, x):
x = self.proj(x) + self.conv[1](self.conv[0](x))
x = x + self.conv[3](self.conv[2](x))
return x
[docs]class convdown(nn.Module):
def __init__(self, in_channels, out_channels, sz):
super().__init__()
self.conv = nn.Sequential()
for t in range(2):
if t==0:
self.conv.add_module('conv_%d'%t, batchconv(in_channels, out_channels, sz))
else:
self.conv.add_module('conv_%d'%t, batchconv(out_channels, out_channels, sz))
[docs] def forward(self, x):
x = self.conv[0](x)
x = self.conv[1](x)
return x
[docs]class downsample(nn.Module):
def __init__(self, nbase, sz, residual_on=True):
super().__init__()
self.down = nn.Sequential()
self.maxpool = nn.MaxPool2d(2, 2)
for n in range(len(nbase)-1):
if residual_on:
self.down.add_module('res_down_%d'%n, resdown(nbase[n], nbase[n+1], sz))
else:
self.down.add_module('conv_down_%d'%n, convdown(nbase[n], nbase[n+1], sz))
[docs] def forward(self, x):
xd = []
for n in range(len(self.down)):
if n>0:
y = self.maxpool(xd[n-1])
else:
y = x
xd.append(self.down[n](y))
return xd
[docs]class batchconvstyle(nn.Module):
def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False):
super().__init__()
self.concatenation = concatenation
if concatenation:
self.conv = batchconv(in_channels*2, out_channels, sz)
self.full = nn.Linear(style_channels, out_channels*2)
else:
self.conv = batchconv(in_channels, out_channels, sz)
self.full = nn.Linear(style_channels, out_channels)
[docs] def forward(self, style, x, mkldnn=False, y=None):
if y is not None:
if self.concatenation:
x = torch.cat((y, x), dim=1)
else:
x = x + y
feat = self.full(style)
if mkldnn:
x = x.to_dense()
y = (x + feat.unsqueeze(-1).unsqueeze(-1)).to_mkldnn()
else:
y = x + feat.unsqueeze(-1).unsqueeze(-1)
y = self.conv(y)
return y
[docs]class resup(nn.Module):
def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False):
super().__init__()
self.conv = nn.Sequential()
self.conv.add_module('conv_0', batchconv(in_channels, out_channels, sz))
self.conv.add_module('conv_1', batchconvstyle(out_channels, out_channels, style_channels, sz, concatenation=concatenation))
self.conv.add_module('conv_2', batchconvstyle(out_channels, out_channels, style_channels, sz))
self.conv.add_module('conv_3', batchconvstyle(out_channels, out_channels, style_channels, sz))
self.proj = batchconv0(in_channels, out_channels, 1)
[docs] def forward(self, x, y, style, mkldnn=False):
x = self.proj(x) + self.conv[1](style, self.conv[0](x), y=y, mkldnn=mkldnn)
x = x + self.conv[3](style, self.conv[2](style, x, mkldnn=mkldnn), mkldnn=mkldnn)
return x
[docs]class convup(nn.Module):
def __init__(self, in_channels, out_channels, style_channels, sz, concatenation=False):
super().__init__()
self.conv = nn.Sequential()
self.conv.add_module('conv_0', batchconv(in_channels, out_channels, sz))
self.conv.add_module('conv_1', batchconvstyle(out_channels, out_channels, style_channels, sz, concatenation=concatenation))
[docs] def forward(self, x, y, style, mkldnn=False):
x = self.conv[1](style, self.conv[0](x), y=y)
return x
[docs]class make_style(nn.Module):
def __init__(self):
super().__init__()
#self.pool_all = nn.AvgPool2d(28)
self.flatten = nn.Flatten()
[docs] def forward(self, x0):
#style = self.pool_all(x0)
style = F.avg_pool2d(x0, kernel_size=(x0.shape[-2],x0.shape[-1]))
style = self.flatten(style)
style = style / torch.sum(style**2, axis=1, keepdim=True)**.5
return style
[docs]class upsample(nn.Module):
def __init__(self, nbase, sz, residual_on=True, concatenation=False):
super().__init__()
self.upsampling = nn.Upsample(scale_factor=2, mode='nearest')
self.up = nn.Sequential()
for n in range(1,len(nbase)):
if residual_on:
self.up.add_module('res_up_%d'%(n-1),
resup(nbase[n], nbase[n-1], nbase[-1], sz, concatenation))
else:
self.up.add_module('conv_up_%d'%(n-1),
convup(nbase[n], nbase[n-1], nbase[-1], sz, concatenation))
[docs] def forward(self, style, xd, mkldnn=False):
x = self.up[-1](xd[-1], xd[-1], style, mkldnn=mkldnn)
for n in range(len(self.up)-2,-1,-1):
if mkldnn:
x = self.upsampling(x.to_dense()).to_mkldnn()
else:
x = self.upsampling(x)
x = self.up[n](x, xd[n], style, mkldnn=mkldnn)
return x
[docs]class CPnet(nn.Module):
def __init__(self, nbase, nout, sz,
residual_on=True, style_on=True,
concatenation=False, mkldnn=False,
diam_mean=30.):
super(CPnet, self).__init__()
self.nbase = nbase
self.nout = nout
self.sz = sz
self.residual_on = residual_on
self.style_on = style_on
self.concatenation = concatenation
self.mkldnn = mkldnn if mkldnn is not None else False
self.downsample = downsample(nbase, sz, residual_on=residual_on)
nbaseup = nbase[1:]
nbaseup.append(nbaseup[-1])
self.upsample = upsample(nbaseup, sz, residual_on=residual_on, concatenation=concatenation)
self.make_style = make_style()
self.output = batchconv(nbaseup[0], nout, 1)
self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean, requires_grad=False)
self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean, requires_grad=False)
self.style_on = style_on
[docs] def forward(self, data):
if self.mkldnn:
data = data.to_mkldnn()
T0 = self.downsample(data)
if self.mkldnn:
style = self.make_style(T0[-1].to_dense())
else:
style = self.make_style(T0[-1])
style0 = style
if not self.style_on:
style = style * 0
T0 = self.upsample(style, T0, self.mkldnn)
T0 = self.output(T0)
if self.mkldnn:
T0 = T0.to_dense()
#T1 = T1.to_dense()
return T0, style0
def save_model(self, filename):
torch.save(self.state_dict(), filename)
def load_model(self, filename, device=None):
if (device is not None) and (device.type != 'cpu'):
state_dict = torch.load(filename, map_location=device)
else:
self.__init__(self.nbase,
self.nout,
self.sz,
self.residual_on,
self.style_on,
self.concatenation,
self.mkldnn,
self.diam_mean)
state_dict = torch.load(filename, map_location=torch.device('cpu'))
self.load_state_dict(dict([(name, param) for name, param in state_dict.items()]), strict=False)