diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..cc7d977 --- /dev/null +++ b/dataset.py @@ -0,0 +1,75 @@ +import torch +from torch.utils.data import Dataset +import os +from random import randint, random +import numpy as np +import torch +import math + +def remapIntensity(inp,max_z): + # based on: full-resolution residual networks, appendix A + z_over_sqrt2=(random()-0.5)*max_z/math.sqrt(2) + gamma=math.log(z_over_sqrt2+0.5)/math.log(-z_over_sqrt2+0.5) + return np.power(inp,gamma) + +class NeuronsDataset(Dataset): + + def __init__(self, imgDir, lblDir, fileList, cropSz, inteVal): + #fileList is a list of two-element lists, + # the first element is the name of the input file, + # the second element is the name of the label file + self.cropSz=cropSz + self.img=[] + self.lbl=[] + self.inteVal=inteVal # coefficient for image intensity augmentation + self.count=np.array([0,0]) + for f in fileList: + img =np.load(os.path.join(imgDir,f[0])) + lbl =np.load(os.path.join(lblDir,f[1])) + lbl -=1 #in the label files the "ignore" label is zero + #so subtracting one gives 255 for unsigned byte + lbl[lbl==255]=0 # treat margins around centerlines as background + self.img.append(img.astype(np.float32)) + self.lbl.append(lbl) + self.count[0]+=np.equal(lbl,0).sum() + self.count[1]+=np.equal(lbl,1).sum() + + def __len__(self): + return len(self.lbl) + + def getAugmentedDataItem(self,cropSz,idx): + # random crop + maxstartind1=self.lbl[idx].shape[0]-cropSz[0] + maxstartind2=self.lbl[idx].shape[1]-cropSz[1] + maxstartind3=self.lbl[idx].shape[2]-cropSz[2] + startind1=randint(0,maxstartind1) + startind2=randint(0,maxstartind2) + startind3=randint(0,maxstartind3) + img =self.img[idx][:, + startind1:startind1+cropSz[0], + startind2:startind2+cropSz[1], + startind3:startind3+cropSz[2]] + lbl =self.lbl[idx][startind1:startind1+cropSz[0], + startind2:startind2+cropSz[1], + startind3:startind3+cropSz[2]] + + # flip + if random()>0.5 : + img=np.flip(img,1) + lbl=np.flip(lbl,0) + if random()>0.5 : + img=np.flip(img,2) + lbl=np.flip(lbl,1) + if random()>0.5 : + img=np.flip(img,3) + lbl=np.flip(lbl,2) + # intensity augmentation + img =remapIntensity(img.copy(),self.inteVal) + return img,lbl.copy() + + def __getitem__(self, idx): + i,l=self.getAugmentedDataItem(self.cropSz,idx) + it=torch.from_numpy(i) + lt=torch.from_numpy(l.astype(np.long)) + return it,lt + diff --git a/datasetCrops.py b/datasetCrops.py new file mode 100644 index 0000000..9551914 --- /dev/null +++ b/datasetCrops.py @@ -0,0 +1,69 @@ +import torch +from torch.utils.data import Dataset +import os +from random import randint, random +import numpy as np +import torch +import math +import skimage.io as io +import networkTraining.cropRoutines as cropRoutines +import bisect + +class NeuronsTestDataset(Dataset): + + def __init__(self, imgDir, lblDir, fileList, cropSz, margSz, ignoreInd=255): + #fileList is a list of two-element lists, + #* the first one is the name of the input volume + #* the second elements of a fileList entry is also a 2-elem list, + # containing a pair of names of ground truth files + self.cropSz=cropSz + self.margSz=margSz + self.img=[] + self.lbl=[] + self.no_crops=[0] + self.ignoreInd=ignoreInd + + for f in fileList: + img =np.load(os.path.join(imgDir,f[0])) + lbl =np.load(os.path.join(lblDir,f[1])) + self.img .append(img.astype(np.float32)) + self.lbl .append(lbl) + tot_no_crops=cropRoutines.noCrops(lbl.shape,cropSz,margSz,0)+\ + self.no_crops[-1] + self.no_crops.append(tot_no_crops) + + def __len__(self): + return self.no_crops[-1] + + def getCrop(self,idx): + ind=bisect.bisect_right(self.no_crops,idx)-1 + lbl=self.lbl[ind] + img=self.img[ind] + cropInd=idx-self.no_crops[ind] + cc,vc=cropRoutines.cropCoords(cropInd,self.cropSz,self.margSz,lbl.shape,0) + cimg=img[tuple([slice(0,img.shape[0])]+cc)].copy() + clbl=lbl[tuple(cc)].copy() + # use vc to inpaint margins to ignore in lbl! + idx=[] + for i in range(len(vc)): + idx.append(slice(0,vc[i].start)) + clbl[tuple(idx)]=self.ignoreInd + del idx[-1] + idx.append(slice(vc[i].stop,clbl.shape[i])) + clbl[tuple(idx)]=self.ignoreInd + del idx[-1] + idx.append(slice(0,clbl.shape[i])) + + return cimg, clbl + + def getAugmentedDataItem(self,cropSz,idx): + # random crop + img,lbl=self.getCrop(idx) + return img.copy(),lbl.copy() + + def __getitem__(self, idx): + i,l=self.getAugmentedDataItem(self.cropSz,idx) + lbl=torch.from_numpy(l) + img=torch.from_numpy(i) + return (img, lbl) + diff --git a/net_v1.py b/net_v1.py new file mode 100644 index 0000000..5816a81 --- /dev/null +++ b/net_v1.py @@ -0,0 +1,103 @@ +# todo: +# multi-scale for better receptive field (decoder only, bilinear input subsampling) +import torch +import torch.nn as nn +import torch.nn.functional as F + +def conv_bn_relu(in_channels, out_channels, kernel_size=3): + stride=1 + padding=kernel_size//2 + return nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True), + nn.Dropout3d(p=0.1), + ) + +def upSkipConvolution(s): + c=nn.Conv3d(s,2*s, kernel_size=1, stride=1, padding=0, bias=False) + c.weight.data.zero_() + for i in range(0,s): + c.weight.data[2*i][i]=1 + c.weight.data[2*i+1][i]=1 + return c + +def dwSkipConvolution(s): + c=nn.Conv3d(s,s//2, kernel_size=1, stride=1, padding=0, bias=False) + c.weight.data.zero_() + for i in range(0,s//2): + c.weight.data[i][2*i]=0.5 + c.weight.data[i][2*i+1]=0.5 + return c + +class upResBlock(nn.Module): + def __init__(self, in_channels, kernel_size=3): + super().__init__() + self.cbr=conv_bn_relu(in_channels,2*in_channels,kernel_size) + self.uc=upSkipConvolution(in_channels) + def forward(self, x): + return self.cbr(x)+self.uc(x) + +class dwResBlock(nn.Module): + def __init__(self, in_channels, kernel_size=3): + super().__init__() + self.cbr=conv_bn_relu(in_channels,in_channels//2,kernel_size) + self.dc=dwSkipConvolution(in_channels) + def forward(self, x): + return self.cbr(x)+self.dc(x) + +def down_pooling(): + return nn.MaxPool3d(2) + +def up_pooling(in_channels, out_channels, kernel_size=2, stride=2): + return nn.Sequential( + nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True) + ) + +class UNet3d(nn.Module): + def __init__(self): + super().__init__() + input_channels = 1 + nclasses = 1 + # go down + self.conv0 = conv_bn_relu(input_channels, 64) + self.conv1 = upResBlock(64) + self.conv2 = upResBlock(128) + + self.down_pooling = nn.MaxPool3d(2) + + # go up + self.up_pool6 = up_pooling(256, 128) + self.conv7 = dwResBlock(256) + self.up_pool8 = up_pooling(128, 64) + self.conv9 = dwResBlock(128) + + self.conv10 = nn.Conv3d(64, 1, 1) + + def forward(self, x): + # normalize input data + # go down + x0 = self.conv0(x) + p0 = self.down_pooling(x0) + x1 = self.conv1(p0) + p1 = self.down_pooling(x1) + x2 = self.conv2(p1) + + x6 = x2 + + # go up + p7 = self.up_pool6(x6) + x7 = torch.cat([p7, x1], dim=1) + x7 = self.conv7(x7) + + p8 = self.up_pool8(x7) + x8 = torch.cat([p8, x0], dim=1) + x8 = self.conv9(x8) + + output = F.pad(self.conv10(x8), [0,0, 0,0, 0,0, 1,0]) + return output diff --git a/run_v1.py b/run_v1.py new file mode 100644 index 0000000..e4342f6 --- /dev/null +++ b/run_v1.py @@ -0,0 +1,82 @@ +import torch +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +import torch.optim as optim +import numpy as np +from net_v1 import UNet3d +from networkTraining.loggerBasic import LoggerBasic +from networkTraining.loggerIoU import LoggerIoU +from networkTraining.loggerF1 import LoggerF1 +from networkTraining.loggerComposit import LoggerComposit +from dataset import NeuronsDataset +from datasetCrops import NeuronsTestDataset +import os +import os.path +import torch.nn.functional as F +import sys +from shutil import copyfile +from networkTraining.trainer import trainer +from networkTraining.tester import tester + +log_dir="log_v1" + +def calcClassWeights(count): + freq=count.astype(np.double)/count.sum() + #freq+=1.02 + #lfreq=np.log(freq) + w=np.power(freq,-1) + w=w/w.sum() + return torch.Tensor(w) + +imgdir="/cvlabdata2/home/kozinski/experimentsTorch/bbp_neurons/data_npy/img/" +lbldir="/cvlabdata2/home/kozinski/experimentsTorch/bbp_neurons/data_npy/lbl/" +trainimgdir=os.path.join(imgdir,"train") +trainlbldir=os.path.join(lbldir,"train") +testimgdir=os.path.join(imgdir,"test") +testlbldir=os.path.join(lbldir,"test") +exec(open("trainFiles.txt").read()) +exec(open("testFiles.txt").read()) + +#prev_log_dir="log_v2" +os.makedirs(log_dir) +copyfile(__file__,os.path.join(log_dir,"setup.txt")) + +net = UNet3d().cuda() +#saved_net=torch.load(os.path.join(prev_log_dir,"net_last.pth")) +#net.load_state_dict(saved_net['state_dict']) +net.train() + +optimizer = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-4) +logger= LoggerBasic(log_dir,"Basic",100) +train_dataset = NeuronsDataset( + trainimgdir,trainlbldir,trainFiles,np.array([64,64,64]),0.0) +train_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=6, shuffle=True, + num_workers=6, drop_last=True) +loss = torch.nn.CrossEntropyLoss(weight=calcClassWeights(train_dataset.count), + ignore_index=255) +print("loss.weight",loss.weight) + +test_dataset = NeuronsTestDataset( + testimgdir,testlbldir,testFiles,np.array([80,80,80]),[22,22,22]) + +test_loader = torch.utils.data.DataLoader(test_dataset, + batch_size=1, shuffle=False, + num_workers=1, drop_last=False) +def test_preproc(o,t): + idx=torch.Tensor.mul_(t<=2, t>=1).reshape(t.numel()) + o=o[:,1,:,:,:] + oo=o.reshape(t.numel())[idx] + tt=t.reshape(t.numel())[idx] + o=torch.pow(torch.exp(-oo)+1,-1) + return o,tt-1 + +logger_test=LoggerF1(log_dir,"Test",test_preproc, nBins=10000, saveBest=True) +tstr=tester(test_loader,logger_test) + +lr_lambda=lambda e: 1/(1+e*1e-5) +lr_scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) +trn=trainer(net, train_loader, optimizer, loss, logger, tstr, 100,lr_scheduler=lr_scheduler) + +if __name__ == '__main__': + trn.train(50000) diff --git a/segmentTestSet.py b/segmentTestSet.py new file mode 100644 index 0000000..6920841 --- /dev/null +++ b/segmentTestSet.py @@ -0,0 +1,37 @@ +import cv2 +import IPython.display +import importlib +import skimage.io as imgio +import numpy as np +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from net_v1 import UNet3d +import os +from networkTraining.forwardOnBigImages import processChunk + +imgdir="/cvlabdata2/home/kozinski/experimentsTorch/bbp_neurons/data_npy/img/test/" +exec(open("testFiles.txt").read()) + +log_dir="log_v1" +net = UNet3d().cuda() +saved_net=torch.load(os.path.join(log_dir,"net_last.pth")) +net.load_state_dict(saved_net['state_dict']) +net.eval(); + +out_dir="test_last" + +def process_output(o): + e=np.exp(o[0,1,:,:,:]) + prob=e/(e+1) + return prob + +outdir=os.path.join(log_dir,out_dir) +os.makedirs(outdir) + +for f in testFiles: + img=np.load(os.path.join(imgdir,f[0])).astype(np.float32) + inp=img.reshape(1,1,img.shape[-3],img.shape[-2],img.shape[-1]) + oup=processChunk(inp,(104,104,104),(22,22,22),2,net,outChannels=2) + prob=process_output(oup) + np.save(os.path.join(outdir,os.path.basename(f[0])),prob) diff --git a/testFiles.txt b/testFiles.txt new file mode 100644 index 0000000..ef1437b --- /dev/null +++ b/testFiles.txt @@ -0,0 +1,6 @@ +testFiles=[ +["10.t7.npy", "10.t7.npy",], +["11.t7.npy", "11.t7.npy",], +["16.t7.npy", "16.t7.npy",], +["4.t7.npy", "4.t7.npy",], +] diff --git a/trainFiles.txt b/trainFiles.txt new file mode 100644 index 0000000..47d9fca --- /dev/null +++ b/trainFiles.txt @@ -0,0 +1,12 @@ +trainFiles=[ +["0.t7.npy", "0.t7.npy"], +["12.t7.npy", "12.t7.npy"], +["13.t7.npy", "13.t7.npy"], +["14.t7.npy", "14.t7.npy"], +["17.t7.npy", "17.t7.npy"], +["1.t7.npy", "1.t7.npy"], +["3.t7.npy", "3.t7.npy"], +["5.t7.npy", "5.t7.npy"], +["6.t7.npy", "6.t7.npy"], +["8.t7.npy", "8.t7.npy"], +] diff --git a/viewSegmentations.ipynb b/viewSegmentations.ipynb new file mode 100644 index 0000000..babf078 --- /dev/null +++ b/viewSegmentations.ipynb @@ -0,0 +1,122 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import cv2\n", + "import IPython.display \n", + "import importlib\n", + "import skimage.io as imgio\n", + "import numpy as np\n", + "import os\n", + "import torch\n", + "\n", + "def imshow(img):\n", + " _,ret = cv2.imencode('.jpg', img) \n", + " i = IPython.display.Image(data=ret)\n", + " IPython.display.display(i)\n", + " \n", + "def showCube(vol):\n", + " v1=np.amax(vol,axis=0)\n", + " v2=np.amax(vol,axis=1)\n", + " v3=np.amax(vol,axis=2)\n", + " imshow(v1*255)\n", + " imshow(v2*255)\n", + " imshow(v3*255)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "imgdir=\"/cvlabdata2/home/kozinski/experimentsTorch/bbp_neurons/data_npy/img/test\"\n", + "lbldir=\"/cvlabdata2/home/kozinski/experimentsTorch/bbp_neurons/data_npy/lbl/test\"\n", + "segmdir=\"log_v1/test_last\"\n", + "exec(open(\"testFiles.txt\").read())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "\n", + "for f in testFiles:\n", + " img=np.load(os.path.join(imgdir,f[0])).astype(np.float32)\n", + " lbl=np.load(os.path.join(lbldir,f[1])).astype(np.float32)\n", + " segm=np.load(os.path.join(segmdir,os.path.basename(f[0])))\n", + " showCube(img[0]*100)\n", + " showCube(lbl/2)\n", + " showCube(segm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lbl.min()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "segm.min()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}