9.1. Gated Recurrent Units (GRU) — Dive Into Deep Learning 0.17.5 Documentation

10.2.5. Concise Implementation¶

In high-level APIs, we can directly instantiate a GRU model. This encapsulates all the configuration detail that we made explicit above.

pytorchmxnetjaxtensorflowclass GRU(d2l.RNN): def __init__(self, num_inputs, num_hiddens): d2l.Module.__init__(self) self.save_hyperparameters() self.rnn = nn.GRU(num_inputs, num_hiddens) class GRU(d2l.RNN): def __init__(self, num_inputs, num_hiddens): d2l.Module.__init__(self) self.save_hyperparameters() self.rnn = rnn.GRU(num_hiddens) class GRU(d2l.RNN): num_hiddens: int @nn.compact def __call__(self, inputs, H=None, training=False): if H is None: batch_size = inputs.shape[1] H = nn.GRUCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), self.num_hiddens) GRU = nn.scan(nn.GRUCell, variable_broadcast="params", in_axes=0, out_axes=0, split_rngs={"params": False}) H, outputs = GRU()(H, inputs) return outputs, H class GRU(d2l.RNN): def __init__(self, num_inputs, num_hiddens): d2l.Module.__init__(self) self.save_hyperparameters() self.rnn = tf.keras.layers.GRU(num_hiddens, return_sequences=True, return_state=True)

The code is significantly faster in training as it uses compiled operators rather than Python.

pytorchmxnetjaxtensorflowgru = GRU(num_inputs=len(data.vocab), num_hiddens=32) model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4) trainer.fit(model, data) ../_images/output_gru_b77a34_78_0.svg gru = GRU(num_inputs=len(data.vocab), num_hiddens=32) model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4) trainer.fit(model, data) ../_images/output_gru_b77a34_81_0.svg gru = GRU(num_hiddens=32) model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4) trainer.fit(model, data) ../_images/output_gru_b77a34_84_0.svg gru = GRU(num_inputs=len(data.vocab), num_hiddens=32) with d2l.try_gpu(): model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4) trainer.fit(model, data) ../_images/output_gru_b77a34_87_0.svg

After training, we print out the perplexity on the training set and the predicted sequence following the provided prefix.

pytorchmxnetjaxtensorflowmodel.predict('it has', 20, data.vocab, d2l.try_gpu()) 'it has so it and the time ' model.predict('it has', 20, data.vocab, d2l.try_gpu()) 'it has i have the time tra' model.predict('it has', 20, data.vocab, trainer.state.params) 'it has is a move and a mov' model.predict('it has', 20, data.vocab) 'it has t t t t t t t t t t'

Tag » What Is Gated Recurrent Unit