提交 0b93f490 编写于 作者: G guosheng

Add random input for seq2seq to test.

上级 21f50136
......@@ -223,7 +223,8 @@ class Seq2Seq(Model):
# encoder
encoder_output, encoder_final_state = self.encoder(src, src_length)
# decoder initial states
# decoder initial states: use input_feed and the structure is
# [[h,c] * num_layers, input_feed]
decoder_initial_states = [
encoder_final_state,
self.decoder.lstm_attention.cell.get_initial_states(
......
......@@ -80,6 +80,37 @@ def do_train(args):
Input([None, None, 1], "int64", name="label"),
]
model = Seq2Seq(args.src_vocab_size, args.trg_vocab_size, args.embed_dim,
args.hidden_size, args.num_layers, args.dropout)
model.prepare(fluid.optimizer.Adam(learning_rate=args.learning_rate,
parameter_list=model.parameters()),
CrossEntropyCriterion(),
inputs=inputs,
labels=labels)
batch_size = 32
src_seq_len = 10
trg_seq_len = 12
iter_num = 10
def random_generator():
for i in range(iter_num):
src = np.random.randint(2, args.src_vocab_size,
(batch_size, src_seq_len)).astype("int64")
src_length = np.random.randint(
1, src_seq_len, (batch_size, )).astype("int64")
trg = np.random.randint(2, args.trg_vocab_size,
(batch_size, trg_seq_len)).astype("int64")
trg_length = np.random.randint(1, trg_seq_len,
(batch_size, )).astype("int64")
label = np.random.randint(1, trg_seq_len,
(batch_size, trg_seq_len, 1)).astype("int64")
yield src, src_length, trg, trg_length, label
model.fit(train_data=random_generator, log_freq=1)
exit(0)
dataset = Seq2SeqDataset(fpattern=args.training_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
......@@ -107,15 +138,6 @@ def do_train(args):
num_workers=0,
return_list=True)
model = Seq2Seq(args.src_vocab_size, args.trg_vocab_size, args.embed_dim,
args.hidden_size, args.num_layers, args.dropout)
model.prepare(fluid.optimizer.Adam(learning_rate=args.learning_rate,
parameter_list=model.parameters()),
CrossEntropyCriterion(),
inputs=inputs,
labels=labels)
model.fit(train_data=train_loader,
eval_data=None,
epochs=1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册