Page MenuHomec4science

dataset.py
No OneTemporary

File Metadata

Created
Thu, Apr 25, 22:54

dataset.py

import torch
from torch.utils.data import Dataset
import os
from random import randint, random
import numpy as np
import torch
import math
def remapIntensity(inp,max_z):
# based on: full-resolution residual networks, appendix A
z_over_sqrt2=(random()-0.5)*max_z/math.sqrt(2)
gamma=math.log(z_over_sqrt2+0.5)/math.log(-z_over_sqrt2+0.5)
return np.power(inp,gamma)
class NeuronsDataset(Dataset):
def __init__(self, imgDir, lblDir, fileList, cropSz, inteVal):
#fileList is a list of two-element lists,
# the first element is the name of the input file,
# the second element is the name of the label file
self.cropSz=cropSz
self.img=[]
self.lbl=[]
self.inteVal=inteVal # coefficient for image intensity augmentation
self.count=np.array([0,0])
for f in fileList:
img =np.load(os.path.join(imgDir,f[0]))
lbl =np.load(os.path.join(lblDir,f[1]))
lbl -=1 #in the label files the "ignore" label is zero
#so subtracting one gives 255 for unsigned byte
lbl[lbl==255]=0 # treat margins around centerlines as background
self.img.append(img.astype(np.float32))
self.lbl.append(lbl)
self.count[0]+=np.equal(lbl,0).sum()
self.count[1]+=np.equal(lbl,1).sum()
def __len__(self):
return len(self.lbl)
def getAugmentedDataItem(self,cropSz,idx):
# random crop
maxstartind1=self.lbl[idx].shape[0]-cropSz[0]
maxstartind2=self.lbl[idx].shape[1]-cropSz[1]
maxstartind3=self.lbl[idx].shape[2]-cropSz[2]
startind1=randint(0,maxstartind1)
startind2=randint(0,maxstartind2)
startind3=randint(0,maxstartind3)
img =self.img[idx][:,
startind1:startind1+cropSz[0],
startind2:startind2+cropSz[1],
startind3:startind3+cropSz[2]]
lbl =self.lbl[idx][startind1:startind1+cropSz[0],
startind2:startind2+cropSz[1],
startind3:startind3+cropSz[2]]
# flip
if random()>0.5 :
img=np.flip(img,1)
lbl=np.flip(lbl,0)
if random()>0.5 :
img=np.flip(img,2)
lbl=np.flip(lbl,1)
if random()>0.5 :
img=np.flip(img,3)
lbl=np.flip(lbl,2)
# intensity augmentation
img =remapIntensity(img.copy(),self.inteVal)
return img,lbl.copy()
def __getitem__(self, idx):
i,l=self.getAugmentedDataItem(self.cropSz,idx)
it=torch.from_numpy(i)
lt=torch.from_numpy(l.astype(np.long))
return it,lt

Event Timeline