Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F102722465
main.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, Feb 23, 12:50
Size
4 KB
Mime Type
text/x-python
Expires
Tue, Feb 25, 12:50 (1 d, 14 h)
Engine
blob
Format
Raw Data
Handle
24404449
Attached To
R964 lineseg
main.py
View Options
from
u_net
import
get_unet
from
res_u_net
import
get_res_unet
from
reunet
import
get_unet
as
get_reunet
from
data
import
load_training_data
,
prepare_image
,
load_disparity_training_data
,
normalize
,
expand
import
os
import
datetime
import
argparse
import
keras
import
numpy
as
np
from
scipy.misc
import
imsave
from
augment
import
ImageDataGenerator
from
scipy.misc
import
imrotate
from
scipy.ndimage.morphology
import
distance_transform_edt
class
CurrentSegmentation
(
keras
.
callbacks
.
Callback
):
def
__init__
(
self
,
image
,
out
,
image_out_path
):
keras
.
callbacks
.
Callback
.
__init__
(
self
)
self
.
image_out_path
=
image_out_path
self
.
image
=
image
self
.
out
=
out
self
.
count
=
0
def
on_epoch_end
(
self
,
batch
,
logs
=
{}):
image
=
self
.
model
.
predict
(
prepare_image
(
self
.
image
))
if
image
.
shape
[
1
]
==
2
:
imsave
(
self
.
image_out_path
+
"_"
+
str
(
self
.
count
)
+
"_x.png"
,
np
.
squeeze
(
image
[
0
][
0
]))
imsave
(
self
.
image_out_path
+
"_"
+
str
(
self
.
count
)
+
"_y.png"
,
np
.
squeeze
(
image
[
0
][
1
]))
imsave
(
self
.
image_out_path
+
"_"
+
str
(
self
.
count
)
+
"_realx.png"
,
np
.
squeeze
(
self
.
out
[
0
]))
imsave
(
self
.
image_out_path
+
"_"
+
str
(
self
.
count
)
+
"_realy.png"
,
np
.
squeeze
(
self
.
out
[
1
]))
else
:
imsave
(
self
.
image_out_path
+
"_"
+
str
(
self
.
count
)
+
"_original.png"
,
np
.
squeeze
(
self
.
image
))
imsave
(
self
.
image_out_path
+
"_"
+
str
(
self
.
count
)
+
".png"
,
np
.
squeeze
(
image
))
imsave
(
self
.
image_out_path
+
"_"
+
str
(
self
.
count
)
+
"_truth.png"
,
np
.
squeeze
(
self
.
out
))
self
.
count
+=
1
def
add_all_rotations
(
x_train_in
,
y_train_in
):
x_train_out
=
[]
y_train_out
=
[]
x_train_out
.
extend
(
x_train_in
)
y_train_out
.
extend
(
y_train_in
)
f
=
lambda
v
,
angle
:
expand
(
imrotate
(
np
.
squeeze
(
v
),
angle
))
for
x
,
y
in
zip
(
x_train_in
,
y_train_in
):
x_train_out
.
append
(
f
(
x
,
90
))
x_train_out
.
append
(
f
(
x
,
180
))
x_train_out
.
append
(
f
(
x
,
-
90
))
y_train_out
.
append
(
f
(
y
,
90
)
>
0
)
y_train_out
.
append
(
f
(
y
,
180
)
>
0
)
y_train_out
.
append
(
f
(
y
,
-
90
)
>
0
)
# fig, axes = plt.subplots(2, 3)
# n = len(x_train_out)
# for i in range(3):
# axes[0, i].imshow(np.squeeze(x_train_out[n - i - 1]))
# # axes[1, i].imshow(np.squeeze(y_train_out[n - i - 1]))
# axes[1, i].imshow(
# np.exp(-0.1 * distance_transform_edt(np.squeeze(y < 1))) if np.sum(y.flatten()) > 0 else np.squeeze(y),
# cmap='gray')
# plt.show()
return
np
.
array
(
x_train_out
),
np
.
array
(
y_train_out
)
def
main
(
batch_size
=
1
):
parser
=
argparse
.
ArgumentParser
(
description
=
'train model.'
)
parser
.
add_argument
(
'n'
,
metavar
=
'n'
,
type
=
int
,
nargs
=
1
,
help
=
'number of images'
)
parser
.
add_argument
(
'f'
,
metavar
=
'f'
,
type
=
str
,
nargs
=
1
,
help
=
'input folder'
)
parser
.
add_argument
(
'epochs'
,
metavar
=
'e'
,
type
=
int
,
nargs
=
1
,
help
=
'epochs'
)
parser
.
add_argument
(
'o'
,
metavar
=
'o'
,
type
=
str
,
nargs
=
1
,
help
=
'output'
)
parser
.
add_argument
(
"-w"
,
"--weights"
,
help
=
"preload weights"
)
parser
.
add_argument
(
"--disparity"
,
action
=
"store_true"
,
help
=
"train for disparity map"
)
args
=
parser
.
parse_args
()
n
=
args
.
n
[
0
]
input_folder
=
args
.
f
[
0
]
out_path
=
args
.
o
[
0
]
epochs
=
args
.
epochs
[
0
]
weights
=
args
.
weights
disparity
=
args
.
disparity
if
disparity
:
x_train
,
y_train
=
load_disparity_training_data
(
input_folder
,
range
(
n
))
else
:
x_train
,
y_train
=
load_training_data
(
input_folder
,
range
(
n
))
x_train
=
np
.
array
([
normalize
(
x
.
astype
(
float
))
for
x
in
x_train
])
# x_train, y_train = add_all_rotations(x_train, y_train)
print
(
"Max: {}"
.
format
(
x_train
.
flatten
()
.
max
()))
print
(
"augmented size: {}"
.
format
(
x_train
.
shape
))
model
,
inputs
,
x
,
center
=
get_unet
(
x_train
.
shape
[
2
],
x_train
.
shape
[
3
],
classification
=
not
disparity
,
k
=
64
,
conv_per_level
=
4
,
batch_normalization
=
False
)
if
weights
is
not
None
:
model
.
load_weights
(
weights
)
now
=
datetime
.
datetime
.
now
()
idx
=
"{}_{}_{}_{}"
.
format
(
now
.
year
,
now
.
month
,
now
.
day
,
now
.
microsecond
)
model
.
fit
(
x_train
,
y_train
,
nb_epoch
=
epochs
,
batch_size
=
batch_size
,
validation_split
=
0.05
,
callbacks
=
[
CurrentSegmentation
(
x_train
[
0
],
y_train
[
0
],
os
.
path
.
join
(
out_path
,
idx
))]
)
with
open
(
os
.
path
.
join
(
out_path
,
"{}model.json"
.
format
(
idx
)),
"w"
)
as
f
:
f
.
write
(
model
.
to_json
())
model
.
save_weights
(
os
.
path
.
join
(
out_path
,
"{}weights.h5"
.
format
(
idx
)),
"w"
)
main
()
Event Timeline
Log In to Comment