Page MenuHomec4science

main.py
No OneTemporary

File Metadata

Created
Sun, Feb 23, 12:50
from u_net import get_unet
from res_u_net import get_res_unet
from reunet import get_unet as get_reunet
from data import load_training_data, prepare_image, load_disparity_training_data, normalize, expand
import os
import datetime
import argparse
import keras
import numpy as np
from scipy.misc import imsave
from augment import ImageDataGenerator
from scipy.misc import imrotate
from scipy.ndimage.morphology import distance_transform_edt
class CurrentSegmentation(keras.callbacks.Callback):
def __init__(self, image, out, image_out_path):
keras.callbacks.Callback.__init__(self)
self.image_out_path = image_out_path
self.image = image
self.out = out
self.count = 0
def on_epoch_end(self, batch, logs={}):
image = self.model.predict(prepare_image(self.image))
if image.shape[1] == 2:
imsave(self.image_out_path + "_" + str(self.count) + "_x.png", np.squeeze(image[0][0]))
imsave(self.image_out_path + "_" + str(self.count) + "_y.png", np.squeeze(image[0][1]))
imsave(self.image_out_path + "_" + str(self.count) + "_realx.png", np.squeeze(self.out[0]))
imsave(self.image_out_path + "_" + str(self.count) + "_realy.png", np.squeeze(self.out[1]))
else:
imsave(self.image_out_path + "_" + str(self.count) + "_original.png", np.squeeze(self.image))
imsave(self.image_out_path + "_" + str(self.count) + ".png", np.squeeze(image))
imsave(self.image_out_path + "_" + str(self.count) + "_truth.png", np.squeeze(self.out))
self.count += 1
def add_all_rotations(x_train_in, y_train_in):
x_train_out = []
y_train_out = []
x_train_out.extend(x_train_in)
y_train_out.extend(y_train_in)
f = lambda v, angle: expand(imrotate(np.squeeze(v), angle))
for x, y in zip(x_train_in, y_train_in):
x_train_out.append(f(x, 90))
x_train_out.append(f(x, 180))
x_train_out.append(f(x, -90))
y_train_out.append(f(y, 90) > 0)
y_train_out.append(f(y, 180) > 0)
y_train_out.append(f(y, -90) > 0)
# fig, axes = plt.subplots(2, 3)
# n = len(x_train_out)
# for i in range(3):
# axes[0, i].imshow(np.squeeze(x_train_out[n - i - 1]))
# # axes[1, i].imshow(np.squeeze(y_train_out[n - i - 1]))
# axes[1, i].imshow(
# np.exp(-0.1 * distance_transform_edt(np.squeeze(y < 1))) if np.sum(y.flatten()) > 0 else np.squeeze(y),
# cmap='gray')
# plt.show()
return np.array(x_train_out), np.array(y_train_out)
def main(batch_size=1):
parser = argparse.ArgumentParser(description='train model.')
parser.add_argument('n', metavar='n', type=int, nargs=1,
help='number of images')
parser.add_argument('f', metavar='f', type=str, nargs=1,
help='input folder')
parser.add_argument('epochs', metavar='e', type=int, nargs=1, help='epochs')
parser.add_argument('o', metavar='o', type=str, nargs=1, help='output')
parser.add_argument("-w", "--weights",
help="preload weights")
parser.add_argument("--disparity", action="store_true", help="train for disparity map")
args = parser.parse_args()
n = args.n[0]
input_folder = args.f[0]
out_path = args.o[0]
epochs = args.epochs[0]
weights = args.weights
disparity = args.disparity
if disparity:
x_train, y_train = load_disparity_training_data(input_folder, range(n))
else:
x_train, y_train = load_training_data(input_folder, range(n))
x_train = np.array([normalize(x.astype(float)) for x in x_train])
# x_train, y_train = add_all_rotations(x_train, y_train)
print("Max: {}".format(x_train.flatten().max()))
print("augmented size: {}".format(x_train.shape))
model, inputs, x, center = get_unet(x_train.shape[2], x_train.shape[3], classification=not disparity, k=64,
conv_per_level=4, batch_normalization=False)
if weights is not None:
model.load_weights(weights)
now = datetime.datetime.now()
idx = "{}_{}_{}_{}".format(now.year, now.month, now.day, now.microsecond)
model.fit(x_train, y_train, nb_epoch=epochs, batch_size=batch_size, validation_split=0.05,
callbacks=[CurrentSegmentation(x_train[0], y_train[0], os.path.join(out_path, idx))]
)
with open(os.path.join(out_path, "{}model.json".format(idx)), "w") as f:
f.write(model.to_json())
model.save_weights(os.path.join(out_path, "{}weights.h5".format(idx)), "w")
main()

Event Timeline