Page MenuHomec4science

datasetCrops.py
No OneTemporary

File Metadata

Created
Fri, Apr 26, 06:24

datasetCrops.py

import torch
from torch.utils.data import Dataset
import os
from random import randint, random
import numpy as np
import torch
import math
import skimage.io as io
import networkTraining.cropRoutines as cropRoutines
import bisect
class NeuronsTestDataset(Dataset):
def __init__(self, imgDir, lblDir, fileList, cropSz, margSz, ignoreInd=255):
#fileList is a list of two-element lists,
#* the first one is the name of the input volume
#* the second elements of a fileList entry is also a 2-elem list,
# containing a pair of names of ground truth files
self.cropSz=cropSz
self.margSz=margSz
self.img=[]
self.lbl=[]
self.no_crops=[0]
self.ignoreInd=ignoreInd
for f in fileList:
img =np.load(os.path.join(imgDir,f[0]))
lbl =np.load(os.path.join(lblDir,f[1]))
self.img .append(img.astype(np.float32))
self.lbl .append(lbl)
tot_no_crops=cropRoutines.noCrops(lbl.shape,cropSz,margSz,0)+\
self.no_crops[-1]
self.no_crops.append(tot_no_crops)
def __len__(self):
return self.no_crops[-1]
def getCrop(self,idx):
ind=bisect.bisect_right(self.no_crops,idx)-1
lbl=self.lbl[ind]
img=self.img[ind]
cropInd=idx-self.no_crops[ind]
cc,vc=cropRoutines.cropCoords(cropInd,self.cropSz,self.margSz,lbl.shape,0)
cimg=img[tuple([slice(0,img.shape[0])]+cc)].copy()
clbl=lbl[tuple(cc)].copy()
# use vc to inpaint margins to ignore in lbl!
idx=[]
for i in range(len(vc)):
idx.append(slice(0,vc[i].start))
clbl[tuple(idx)]=self.ignoreInd
del idx[-1]
idx.append(slice(vc[i].stop,clbl.shape[i]))
clbl[tuple(idx)]=self.ignoreInd
del idx[-1]
idx.append(slice(0,clbl.shape[i]))
return cimg, clbl
def getAugmentedDataItem(self,cropSz,idx):
# random crop
img,lbl=self.getCrop(idx)
return img.copy(),lbl.copy()
def __getitem__(self, idx):
i,l=self.getAugmentedDataItem(self.cropSz,idx)
lbl=torch.from_numpy(l)
img=torch.from_numpy(i)
return (img, lbl)

Event Timeline