Page MenuHomec4science

domain_adaption.py
No OneTemporary

File Metadata

Created
Sun, Feb 23, 07:18

domain_adaption.py

from data import load_domain_adaption_data, load_training_data
import os
import datetime
import argparse
import keras
from misc import *
from scipy.misc import imsave
from data import normalize
class CurrentSegmentation(keras.callbacks.Callback):
def __init__(self, images, image_out_path, interval=50):
keras.callbacks.Callback.__init__(self)
self.image_out_path = image_out_path
self.images = images
self.count = 0
self.interval = 100
self.batch_idx = 0
def on_epoch_end(self, batch, logs={}):
self.count += 1
def on_batch_end(self, batch, logs={}):
if self.batch_idx % self.interval == 0:
segmentations = np.squeeze(self.model.predict([expand(self.images[0]), expand(self.images[1])]))
imsave(self.image_out_path + "_" + str(self.count) + "_source_seg.png", np.squeeze(segmentations[0]))
imsave(self.image_out_path + "_" + str(self.count) + "_target_seg.png", np.squeeze(segmentations[1]))
imsave(self.image_out_path + "_" + str(self.count) + "_source.png", np.squeeze(self.images[0]))
imsave(self.image_out_path + "_" + str(self.count) + "_target.png", np.squeeze(self.images[1]))
self.batch_idx += 1
def main():
parser = argparse.ArgumentParser(description='train model.')
parser.add_argument('n_s', metavar='n_s', type=int, nargs=1,
help='number of images source')
parser.add_argument('n_t', metavar='n_t', type=int, nargs=1,
help='number of images target')
parser.add_argument('s', metavar='s', type=str, nargs=1,
help='input folder source')
parser.add_argument('t', metavar='t', type=str, nargs=1,
help='input folder target')
parser.add_argument('epochs', metavar='e', type=int, nargs=1, help='epochs')
parser.add_argument('o', metavar='o', type=str, nargs=1, help='output')
parser.add_argument("-w", "--weights",
help="preload weights")
args = parser.parse_args()
n_s = args.n_s[0]
n_t = args.n_t[0]
source_path = args.s[0]
target_path = args.t[0]
out_path = args.o[0]
epochs = args.epochs[0]
weights = args.weights
source_x_train, target_x_train, source_y_train, target_y_train = load_domain_adaption_data(source_path, target_path,
range(n_s), range(n_t))
source_x_train = np.array([normalize(x.astype(float)) for x in source_x_train])
target_x_train = np.array([normalize(x.astype(float)) for x in target_x_train])
# y_train = np.array([[y_source[0], y_target[0]] for y_source, y_target in zip(source_y_train, target_y_train)])
image_size = source_x_train.shape[2], source_x_train.shape[3]
print("weights: {}".format(weights))
model, source_model, target_model = create_joint_network(image_size, weights)
now = datetime.datetime.now()
idx = "{}_{}_{}_{}".format(now.year, now.month, now.day, now.microsecond)
with open(os.path.join(out_path, "{}model.json".format(idx)), "w") as f:
f.write(model.to_json())
image_generator = CurrentSegmentation([source_x_train[10], target_x_train[4]], "./")
# model.fit([source_x_train, target_x_train], y_train, nb_epoch=epochs, batch_size=1,
# callbacks=[image_generator])
model.fit([source_x_train, target_x_train], [source_y_train, target_y_train], nb_epoch=epochs, batch_size=10,
callbacks=[image_generator])
model.save_weights(os.path.join(out_path, "{}weights.h5".format(idx)), "w")
main()

Event Timeline