Page MenuHomec4science

segmentTestSet.py
No OneTemporary

File Metadata

Created
Mon, May 13, 09:57

segmentTestSet.py

import cv2
import IPython.display
import importlib
import skimage.io as imgio
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from net_v1 import UNet3d
import os
from networkTraining.forwardOnBigImages import processChunk
imgdir="/cvlabdata2/home/kozinski/experimentsTorch/bbp_neurons/data_npy/img/test/"
exec(open("testFiles.txt").read())
log_dir="log_v1"
net = UNet3d().cuda()
saved_net=torch.load(os.path.join(log_dir,"net_last.pth"))
net.load_state_dict(saved_net['state_dict'])
net.eval();
out_dir="test_last"
def process_output(o):
e=np.exp(o[0,1,:,:,:])
prob=e/(e+1)
return prob
outdir=os.path.join(log_dir,out_dir)
os.makedirs(outdir)
for f in testFiles:
img=np.load(os.path.join(imgdir,f[0])).astype(np.float32)
inp=img.reshape(1,1,img.shape[-3],img.shape[-2],img.shape[-1])
oup=processChunk(inp,(104,104,104),(22,22,22),2,net,outChannels=2)
prob=process_output(oup)
np.save(os.path.join(outdir,os.path.basename(f[0])),prob)

Event Timeline