# TODO: add forward through JAX/Flax when PR is merged
# This is currently here to make flake8 happy !
ifencoded_sequence_fastisNone:
raiseValueError("Cannot convert list to numpy tensor on encode_plus() (fast)")
ifbatch_encoded_sequence_fastisNone:
raiseValueError("Cannot convert list to numpy tensor on batch_encode_plus() (fast)")
@require_torch
deftest_prepare_seq2seq_batch(self):
ifnotself.test_seq2seq:
return
tokenizer=self.get_tokenizer()
# Longer text that will definitely require truncation.
src_text=[
" UN Chief Says There Is No Military Solution in Syria",
" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.",
]
tgt_text=[
"Şeful ONU declară că nu există o soluţie militară în Siria",
"Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei "
'pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu '
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
try:
batch=tokenizer.prepare_seq2seq_batch(
src_texts=src_text,
tgt_texts=tgt_text,
max_length=3,
max_target_length=10,
return_tensors="pt",
src_lang="en_XX",# this should be ignored (for all but mbart) but not cause an error
)
exceptNotImplementedError:
return
self.assertEqual(batch.input_ids.shape[1],3)
self.assertEqual(batch.labels.shape[1],10)
# max_target_length will default to max_length if not specified