diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cropRoutines.py b/cropRoutines.py new file mode 100644 index 0000000..6f3e0c2 --- /dev/null +++ b/cropRoutines.py @@ -0,0 +1,104 @@ +# at times it may be beneficial to process a big block of data (volume, image) +# in little crops, one crop at a time +# since contex is important when using a convolutional neural network, +# we want to ensure that each element of output has been generated +# with enough context +# thus, when breaking down a large volume into little crops, we should +# generate crops that overlap, and only retain the output elements for each crop +# that have been generated with enough context +# the first crop XXXXXXXXXX +# the second crop OOOOOOOOOO +# the outputs combined XXXXXXXOOOOOOO + +import math + +def noCrops(inSize, cropSize, marginSize, startDim=0): + # inSize + # cropSize - can be shorter than inSize, if not all dims are cropped + # in this case startDim > 0 + # marginSize - same length as cropSize; stores size of a single margin; + # the resulting overlap between crops is 2*marginSize + # startDim - all dimensions starting from this one are cropped; + # for example, if dim 0 indexes batches and dim 1 indexes channels + # startDim would typically equal 2 + nCrops=1 + for dim in range(startDim, len(inSize)): + relDim=dim-startDim + nCropsPerDim=(inSize[dim]-2*marginSize[relDim])/ \ + (cropSize[relDim]-2*marginSize[relDim]) + if nCropsPerDim<=0: + nCropsPerDim=1 + nCrops*=math.ceil(nCropsPerDim) + return nCrops + +def noCropsPerDim(inSize,cropSize,marginSize,startDim=0): + # nCropsPerDim - number of crops per dimension, starting from startDim + # cumNCropsPerDim - number of crops for one index step along a dimension + # starting from startDim-1; i.e. it has one more element + # than nCropsPerDim, and is misaligned by a difference + # in index of 1 + nCropsPerDim=[] + cumNCropsPerDim=[1] + for dim in reversed(range(startDim,len(inSize))): + relDim=dim-startDim + nCrops=(inSize[dim]-2*marginSize[relDim])/ \ + (cropSize[relDim]-2*marginSize[relDim]) + if nCrops<=0: + nCrops=1 + nCrops=math.ceil(nCrops) + nCropsPerDim.append(nCrops) + cumNCropsPerDim.append(nCrops*cumNCropsPerDim[len(inSize)-dim-1]) + nCropsPerDim.reverse() + cumNCropsPerDim.reverse() + return nCropsPerDim, cumNCropsPerDim + +def cropInds(cropInd,cumNCropsPerDim): +# given a single index into the crops of a given data chunk +# this function returns indexes of the crop along all its dimensions + assert cropInd=cropSize + startind=cropInd*(cropSize-2*marg) #starting coord of the crop in the big vol + startValidInd=marg #starting coord of valid stuff in crop + endValidInd=cropSize-marg + if startind >= inSize-cropSize: + startValidInd=cropSize+startind-inSize+marg + startind=inSize-cropSize + endValidInd=cropSize + if cropInd==0: + startValidInd=0 + return slice(int(startind),int(startind+cropSize)), \ + slice(int(startValidInd),int(endValidInd)) + +def coords(cropInds,cropSizes,margs,inSizes,startDim): +# this function maps a table of crop indeces +# to the starting and end coordinates of the crop + cropCoords=[] + validCoords=[] + for i in range(startDim): + cropCoords. append(slice(0,inSizes[i])) + validCoords.append(slice(0,inSizes[i])) + for i in range(startDim,len(inSizes)): + reli=i-startDim + c,d=coord(cropInds[reli],cropSizes[reli],margs[reli],inSizes[i]) + cropCoords.append(c) + validCoords.append(d) + return cropCoords, validCoords + +def cropCoords(cropInd,cropSize,marg,inSize,startDim): +# a single index in, a table of crop coordinates out + nCropsPerDim,cumNCropsPerDim=noCropsPerDim(inSize,cropSize,marg,startDim) + cropIdx=cropInds(cropInd,cumNCropsPerDim) + cropCoords, validCoords=coords(cropIdx,cropSize,marg,inSize,startDim) + return cropCoords, validCoords + diff --git a/f1.py b/f1.py new file mode 100644 index 0000000..52e2554 --- /dev/null +++ b/f1.py @@ -0,0 +1,43 @@ +import torch + +def reverse(t): + idx = [i for i in range(t.size(0)-1, -1, -1)] + idx = torch.LongTensor(idx) + it = t.index_select(0, idx) + return it + +def PRFromHistograms(hPos,hNeg): + print("hPos",hPos,"hNeg",hNeg) + positives=hPos.sum() + negatives=hNeg.sum() + print("positives, negatives", positives, negatives) + truepositives =reverse(hPos.clone()).long().cumsum(dim=0) + falsepositives=reverse(hNeg.clone()).long().cumsum(dim=0) + predpositives=torch.add(truepositives,falsepositives) + #protect against zero division + predpositives[predpositives==0]=1 + precision=torch.Tensor.div_(truepositives.float(),predpositives.float()) + recall =torch.Tensor.div (truepositives.float(),positives.float()) + precision[precision<=0]=1e-12 + recall [recall <=0]=1e-12 + print("precision,recall",precision,recall) + return precision,recall + +def PRFromOutGt(outps,targs,nbins=10000): + hPos=torch.zeros(nbins) + hNeg=torch.zeros(nbins) + for o,t in zip(outps,targs): + pos=o[lbl==1] + neg=t[lbl==0] + hPos+=torch.from_numpy(pos.astype(np.float32)).histc(nbins,0,1) + hNeg+=torch.from_numpy(neg.astype(np.float32)).histc(nbins,0,1) + precision,recall=f1.PRFromHistograms(hPos,hNeg) + f1s=f1.F1FromPR(precision,recall) + f=f1s.max() + return f + +def F1FromPR(p,r): + suminv=torch.pow(p,-1)+torch.pow(r,-1) + f1s=torch.pow(suminv,-1).mul(2) + print("f1s",f1s) + return f1s diff --git a/loggerBasic.py b/loggerBasic.py new file mode 100644 index 0000000..49ee006 --- /dev/null +++ b/loggerBasic.py @@ -0,0 +1,34 @@ +import os +import torch + +class LoggerBasic: + def __init__(self, log_dir, name, saveNetEvery=500): + self.log_dir=log_dir + self.log_file=os.path.join(self.log_dir,"log"+name+".txt") + + text_file = open(self.log_file, "w") + text_file.close() + self.loss=0 + self.count=0 + self.saveNetEvery=saveNetEvery + self.epoch=0 + + def add(self,l,output,target): + self.loss+=l + self.count+=1 + + def logEpoch(self,net): + text_file = open(self.log_file, "a") + text_file.write(str(self.loss/self.count)) + text_file.write('\n') + text_file.close() + lastLoss=self.loss + self.loss=0 + self.count=0 + self.epoch+=1 + if self.epoch % self.saveNetEvery == 0: + torch.save({'epoch': self.epoch, + 'state_dict': net.state_dict()}, + os.path.join(self.log_dir, + 'net_last.pth')) + return lastLoss diff --git a/loggerComposit.py b/loggerComposit.py new file mode 100644 index 0000000..0e798ee --- /dev/null +++ b/loggerComposit.py @@ -0,0 +1,13 @@ +class LoggerComposit: + def __init__(self, loggers): + self.loggers=loggers + + def add(self,l,output,target): + for lgr in self.loggers: + lgr.add(l,output,target) + + def logEpoch(self,net): + lastLoss=self.loggers[0].logEpoch(net) + for k in range(1,len(self.loggers)): + self.loggers[k].logEpoch(net) + return lastLoss diff --git a/loggerF1.py b/loggerF1.py new file mode 100644 index 0000000..1829082 --- /dev/null +++ b/loggerF1.py @@ -0,0 +1,49 @@ +import os +import torch +import numpy as np +from .f1 import PRFromHistograms, F1FromPR + +def reverse(t): + idx = [i for i in range(t.size(0)-1, -1, -1)] + idx = torch.LongTensor(idx) + it = t.index_select(0, idx) + return it + +class LoggerF1: + + def __init__(self,logdir,fname,transform,nBins=10000,saveBest=False): + self.log_dir=logdir + self.name=fname + self.log_file=os.path.join(self.log_dir,"logF1"+self.name+".txt") + text_file = open(self.log_file, "w") + text_file.close() + self.preproc=transform + self.nBins=nBins + self.hPos=torch.zeros(self.nBins) + self.hNeg=torch.zeros(self.nBins) + self.bestF1=0 + self.saveBest=saveBest + + def add(self,l,output,target): + o,t=self.preproc(output.cpu().data,target.cpu().data) + pos=o[t==1] + neg=o[t==0] + self.hPos+=pos.histc(self.nBins,0,1) + self.hNeg+=neg.histc(self.nBins,0,1) + + def logEpoch(self,net): + precision,recall=PRFromHistograms(self.hPos,self.hNeg) + f1s=F1FromPR(precision,recall) + f=f1s.max() + text_file=open(self.log_file, "a") + text_file.write('{}\n'.format(f)) + text_file.close() + if self.saveBest and f > self.bestF1: + self.bestF1=f + torch.save({'state_dict': net.state_dict()}, + os.path.join(self.log_dir, + 'net_'+self.name+'_bestF1.pth')) + + + self.hPos.zero_() + self.hNeg.zero_() diff --git a/loggerIoU.py b/loggerIoU.py new file mode 100644 index 0000000..1c8f255 --- /dev/null +++ b/loggerIoU.py @@ -0,0 +1,52 @@ +import os +import torch +import numpy as np +import sklearn.metrics + +class LoggerIoU: + def __init__(self, log_dir, name, nClasses, ignoredIdx, saveBest=False, + preproc=lambda o,t: (o,t)): + self.log_dir=log_dir + self.name=name + self.log_file=os.path.join(self.log_dir,"logIou"+self.name+".txt") + text_file = open(self.log_file, "w") + text_file.close() + + self.nClasses=nClasses + self.confMat=np.zeros((nClasses,nClasses)) + self.ignoredIdx=ignoredIdx + self.saveBest=saveBest + self.bestIoU=0 + self.preproc=preproc + + def add(self,l,output,target): + output=output.cpu().data + output,target=self.preproc(output,target) + output=output.numpy() + + outputClass=np.argmax(output, axis=1) + oc=outputClass.flatten() + tc=target.cpu().data.numpy().flatten() + oc_valid=oc[tc!=self.ignoredIdx] + tc_valid=tc[tc!=self.ignoredIdx] + self.confMat+=sklearn.metrics.confusion_matrix(tc_valid,oc_valid,labels=np.array(range(self.nClasses))) + + def logEpoch(self,net): + sums1=np.sum(self.confMat,axis=0) + sums2=np.sum(self.confMat,axis=1) + dg=np.diag(self.confMat) + iou=np.zeros(dg.shape) + iou=np.divide(dg.astype(np.float64),(sums1+sums2-dg).astype(np.float64), out=iou, where=(dg!=0)) + text_file = open(self.log_file, "a") + for i in range(self.nClasses): + text_file.write('{}\t'.format(iou[i])) + mean_iou=np.mean(iou) + text_file.write('{}\n'.format(mean_iou)) + text_file.close() + self.confMat.fill(0) + if mean_iou > self.bestIoU: + self.bestIoU=mean_iou + if self.saveBest: + torch.save({'state_dict': net.state_dict()}, + os.path.join(self.log_dir, + 'net_'+self.name+'_best.pth')) diff --git a/tester.py b/tester.py new file mode 100644 index 0000000..c8d9d19 --- /dev/null +++ b/tester.py @@ -0,0 +1,35 @@ +import torch +from torch.autograd import Variable +import numpy as np +import sys +import time + +class tester: + + def __init__(self, test_loader, logger): + self.dataLoader=test_loader + self.logger=logger + + def test(self, net): + net.eval() + self.di=iter(self.dataLoader) + local_iter=0 + t0=time.time() + while True: + try: + data=next(self.di) + img, lbl = data + img, lbl = img.cuda(), lbl.long().cuda() + img, lbl = img, lbl + out= net.forward(img) + self.logger.add(0,out,lbl) + local_iter+=1 + t1=time.time() + if t1-t0>3: + sys.stdout.write('\rTest iter: %8d' % (local_iter)) + t0=t1 + except StopIteration: + self.logger.logEpoch(net) + break + net.train() + diff --git a/testerComposite.py b/testerComposite.py new file mode 100644 index 0000000..f83fb6c --- /dev/null +++ b/testerComposite.py @@ -0,0 +1,15 @@ +import torch +from torch.autograd import Variable +import numpy as np +import sys +import time + +class testerComposite: + + def __init__(self, testers): + self.testers=testers + + def test(self, net): + for t in self.testers: + t.test(net) + diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..4bb6328 --- /dev/null +++ b/trainer.py @@ -0,0 +1,57 @@ +import torch +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +import torch.optim as optim +import numpy as np +import sys +import time + +class trainer: + + def __init__(self, net, train_loader, optimizer, loss_function, logger, tester, test_every,lr_scheduler=None, + lrStepPer='batch'): + self.net=net + self.dataLoader=train_loader + self.optimizer=optimizer + self.crit=loss_function.cuda() + self.logger=logger + self.di=iter(self.dataLoader) + self.epoch=0 + self.tot_iter=0 + self.prev_iter=self.tot_iter + self.test_every=test_every + self.tester=tester + self.lr_scheduler=lr_scheduler + self.lrStepPer=lrStepPer + + def train(self, numiter): + self.net.train() + local_iter=0 + t0=time.time() + while local_iter3: + sys.stdout.write('\rIter: %8d\tEpoch: %6d\tTime/iter: %6f' % (self.tot_iter, self.epoch, (t1-t0)/(self.tot_iter-self.prev_iter))) + t0=t1 + self.prev_iter=self.tot_iter + except StopIteration: + lastLoss=self.logger.logEpoch(self.net) + self.epoch+=1 + self.di=iter(self.dataLoader) + if self.test_every and self.epoch%self.test_every==0: + self.tester.test(self.net) + if self.lr_scheduler and self.lrStepPer=='epoch': + self.lr_scheduler.step(lastLoss)