Page MenuHomec4science

display_layer.py
No OneTemporary

File Metadata

Created
Sun, Feb 23, 11:33

display_layer.py

import os.path as path
import numpy as np
import argparse
from scipy.misc import imread, imsave
from u_net import get_unet
import theano
import math
import matplotlib.pyplot as plt
def expand(x):
return np.expand_dims(x, axis=0)
def normalize(x):
x_out = x - x.min()
x_out /= x.max()
return x_out
def make_collage(model, image):
images = []
for layer in model.layers[1:]:
print(layer)
get_activations = theano.function([model.layers[0].input], layer.output,
allow_input_downcast=True)
activations = get_activations(image) # same result as above
act = activations[0]
shape = act.shape
rr = int(math.ceil(math.sqrt(shape[0])))
cc = int(np.round(math.sqrt(shape[0])))
activation_map = np.zeros((rr * shape[1], cc * shape[2]))
for r in range(rr):
for c in range(cc):
if r * cc + c < shape[0]:
activation_map[r * shape[1]:(r * shape[1] + shape[1]),
c * shape[2]:(c * shape[2] + shape[2])] = act[r * cc + c, :, :]
images.append(activation_map)
return images
def main():
parser = argparse.ArgumentParser(description='Test model on image.')
parser.add_argument('w', metavar='w', type=str, nargs=1,
help='weights')
parser.add_argument('i', metavar='i', type=str, nargs=1,
help='image')
# parser.add_argument('j', metavar='j', type=str, nargs=1)
args = parser.parse_args()
weights_path = args.w[0]
image_path = args.i[0]
# second_path = args.j[0]
second_path = None
width = 128
assert path.isfile(image_path)
image = expand(expand(np.array(imread(image_path))))
model, inputs, x, center = get_unet(width, width)
model.load_weights(weights_path)
out_path = path.join("./",
path.basename(image_path) + path.basename(weights_path) + "{}.png")
for idx, activation_map in enumerate(make_collage(model, image)):
cmap = plt.cm.inferno
norm = plt.Normalize(vmin=activation_map.min(), vmax=activation_map.max())
# map the normalized data to colors
# image is now RGBA (512x512x4)
image = cmap(norm(activation_map))
# save the image
plt.imsave(out_path.format(idx), image)
main()

Event Timeline