Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F60909749
sample.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Fri, May 3, 07:54
Size
3 KB
Mime Type
text/x-python
Expires
Sun, May 5, 07:54 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
17438491
Attached To
R11149 PDM-Nicola-Oulu
sample.py
View Options
import
tensorflow
as
tf
import
model
def
top_k_logits
(
logits
,
k
):
if
k
==
0
:
# no truncation
return
logits
def
_top_k
():
values
,
_
=
tf
.
nn
.
top_k
(
logits
,
k
=
k
)
min_values
=
values
[:,
-
1
,
tf
.
newaxis
]
return
tf
.
where
(
logits
<
min_values
,
tf
.
ones_like
(
logits
,
dtype
=
logits
.
dtype
)
*
-
1e10
,
logits
,
)
return
tf
.
cond
(
tf
.
equal
(
k
,
0
),
lambda
:
logits
,
lambda
:
_top_k
(),
)
def
top_p_logits
(
logits
,
p
):
"""Nucleus sampling"""
batch
,
_
=
logits
.
shape
.
as_list
()
sorted_logits
=
tf
.
sort
(
logits
,
direction
=
'DESCENDING'
,
axis
=-
1
)
cumulative_probs
=
tf
.
cumsum
(
tf
.
nn
.
softmax
(
sorted_logits
,
axis
=-
1
),
axis
=-
1
)
indices
=
tf
.
stack
([
tf
.
range
(
0
,
batch
),
# number of indices to include
tf
.
maximum
(
tf
.
reduce_sum
(
tf
.
cast
(
cumulative_probs
<=
p
,
tf
.
int32
),
axis
=-
1
)
-
1
,
0
),
],
axis
=-
1
)
min_values
=
tf
.
gather_nd
(
sorted_logits
,
indices
)
return
tf
.
where
(
logits
<
min_values
,
tf
.
ones_like
(
logits
)
*
-
1e10
,
logits
,
)
def
sample_sequence
(
*
,
hparams
,
length
,
start_token
=
None
,
batch_size
=
None
,
context
=
None
,
temperature
=
1
,
top_k
=
0
,
top_p
=
1
):
if
start_token
is
None
:
assert
context
is
not
None
,
'Specify exactly one of start_token and context!'
else
:
assert
context
is
None
,
'Specify exactly one of start_token and context!'
context
=
tf
.
fill
([
batch_size
,
1
],
start_token
)
def
step
(
hparams
,
tokens
,
past
=
None
):
lm_output
=
model
.
model
(
hparams
=
hparams
,
X
=
tokens
,
past
=
past
,
reuse
=
tf
.
AUTO_REUSE
)
logits
=
lm_output
[
'logits'
][:,
:,
:
hparams
.
n_vocab
]
presents
=
lm_output
[
'present'
]
presents
.
set_shape
(
model
.
past_shape
(
hparams
=
hparams
,
batch_size
=
batch_size
))
return
{
'logits'
:
logits
,
'presents'
:
presents
,
}
with
tf
.
name_scope
(
'sample_sequence'
):
def
body
(
past
,
prev
,
output
):
next_outputs
=
step
(
hparams
,
prev
,
past
=
past
)
logits
=
next_outputs
[
'logits'
][:,
-
1
,
:]
/
tf
.
to_float
(
temperature
)
logits
=
top_k_logits
(
logits
,
k
=
top_k
)
logits
=
top_p_logits
(
logits
,
p
=
top_p
)
samples
=
tf
.
multinomial
(
logits
,
num_samples
=
1
,
output_dtype
=
tf
.
int32
)
return
[
next_outputs
[
'presents'
]
if
past
is
None
else
tf
.
concat
([
past
,
next_outputs
[
'presents'
]],
axis
=-
2
),
samples
,
tf
.
concat
([
output
,
samples
],
axis
=
1
)
]
past
,
prev
,
output
=
body
(
None
,
context
,
context
)
def
cond
(
*
args
):
return
True
_
,
_
,
tokens
=
tf
.
while_loop
(
cond
=
cond
,
body
=
body
,
maximum_iterations
=
length
-
1
,
loop_vars
=
[
past
,
prev
,
output
],
shape_invariants
=
[
tf
.
TensorShape
(
model
.
past_shape
(
hparams
=
hparams
,
batch_size
=
batch_size
)),
tf
.
TensorShape
([
batch_size
,
None
]),
tf
.
TensorShape
([
batch_size
,
None
]),
],
back_prop
=
False
,
)
return
tokens
Event Timeline
Log In to Comment