Page MenuHomec4science

run_v1.py
No OneTemporary

File Metadata

Created
Wed, Apr 17, 00:04

run_v1.py

import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.optim as optim
import numpy as np
from net_v1 import UNet3d
from networkTraining.loggerBasic import LoggerBasic
from networkTraining.loggerIoU import LoggerIoU
from networkTraining.loggerF1 import LoggerF1
from networkTraining.loggerComposit import LoggerComposit
from dataset import NeuronsDataset
from datasetCrops import NeuronsTestDataset
import os
import os.path
import torch.nn.functional as F
import sys
from shutil import copyfile
from networkTraining.trainer import trainer
from networkTraining.tester import tester
log_dir="log_v1"
def calcClassWeights(count):
freq=count.astype(np.double)/count.sum()
freq+=1.02
lfreq=np.log(freq)
w=np.power(freq,-1)
w=w/w.sum()
return torch.Tensor(w)
imgdir="/cvlabdata2/home/kozinski/experimentsTorch/bbp_neurons/data_npy/img/"
lbldir="/cvlabdata2/home/kozinski/experimentsTorch/bbp_neurons/data_npy/lbl/"
trainimgdir=os.path.join(imgdir,"train")
trainlbldir=os.path.join(lbldir,"train")
testimgdir=os.path.join(imgdir,"test")
testlbldir=os.path.join(lbldir,"test")
exec(open("trainFiles.txt").read())
exec(open("testFiles.txt").read())
#prev_log_dir="log_v2"
os.makedirs(log_dir)
copyfile(__file__,os.path.join(log_dir,"setup.txt"))
net = UNet3d().cuda()
#saved_net=torch.load(os.path.join(prev_log_dir,"net_last.pth"))
#net.load_state_dict(saved_net['state_dict'])
net.train()
optimizer = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-4)
logger= LoggerBasic(log_dir,"Basic",100)
train_dataset = NeuronsDataset(
trainimgdir,trainlbldir,trainFiles,np.array([64,64,64]),0.0)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=6, shuffle=True,
num_workers=6, drop_last=True)
loss = torch.nn.CrossEntropyLoss(weight=calcClassWeights(train_dataset.count),
ignore_index=255)
print("loss.weight",loss.weight)
test_dataset = NeuronsTestDataset(
testimgdir,testlbldir,testFiles,np.array([80,80,80]),[22,22,22])
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=1, shuffle=False,
num_workers=1, drop_last=False)
def test_preproc(o,t):
idx=torch.Tensor.mul_(t<=2, t>=1).reshape(t.numel())
o=o[:,1,:,:,:]
oo=o.reshape(t.numel())[idx]
tt=t.reshape(t.numel())[idx]
o=torch.pow(torch.exp(-oo)+1,-1)
return o,tt-1
logger_test=LoggerF1(log_dir,"Test",test_preproc, nBins=10000, saveBest=True)
tstr=tester(test_loader,logger_test)
lr_lambda=lambda e: 1/(1+e*1e-5)
lr_scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
trn=trainer(net, train_loader, optimizer, loss, logger, tstr, 100,lr_scheduler=lr_scheduler)
if __name__ == '__main__':
trn.train(50000)

Event Timeline