===================
== Nathan Matare ==
===================
🌎

Stabilizing transformers for deep reinforcement learning

Being the forever resourceful hacker, this code is heavily inspired by the Ray, Google Brain, and Deepmind (Sonnet) implementations.

So if you haven’t heard, the world of AI is converging onto one architecture to rule them all: the transformer.

What’s interesting is that there’s a bit of discussion among neuroscientists proposing that evolution has converged onto an optimal “architecture”: the neocortex. And so, it would seem plausible that synthetic architectures might also converge onto an optimal structure. Given that our intuitions about the learning mechanism, the mind, and artificial intelligence get eviscerated every year, I wouldn’t bet my horse on this proposition. But it is interesting. Whatever the case, we can all enjoy the zoo of transformers taking over every field for now.

Fun aside: Tesla’s latest self driving system appears to feed the post-processed encoded feature maps into a variant of the transformer. Right around 13:04

Anyway, for my purposes, I needed an implementation of the stabilized transformer for reinforcement learning Parisotto et al 2019.


So a normal transformer block is comprised of the following components:

Transformer-XL Block Dai 2019

One stabilized for reinforcement learning, however, swaps the residual connections for gated, moves the layer normalization to the top of the block, and adds some additional magic (e.g., defaults and activation functions).

Gated-Transformer-XL Block Parisotto 2019

The original paper has all the details. Also if you haven’t read Jay’s explanation, I highly recommend.

OK, I’m going to reuse a ton of pieces from the original implementation, and simply swap the order of some of the components to make it all shine and twinkle. Here, I’ll build the transformer block, and I’ll test it out on a few DL benchmarks in a follow-up post…

The stabilized transformer block:


import tensorflow as tf
from official.nlp.modeling.layers import relative_attention, position_embedding

class GatedTransformerXLBlock(tf.keras.layers.Layer):
    r""" Gated Transformer XL block

    We use the Google "production" implementation as a baseline with changes
    as discussed in Stabilizing Transformers for RL and a few other minor
    changes discussed in the model code.

    ref: https://github.com/kimiyoung/transformer-xl
    ref: https://github.com/tensorflow/models/blob/master/official/nlp/
    modeling/layers/transformer_xl.py
    ref: https://arxiv.org/pdf/1910.06764.pdf

    ...
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        head_size: int,
        inner_size: int,
        attention_dropout_rate: float = 0.20,
        dropout_rate: float = 0.20,
        norm_epsilon: float = 1e-12,
        inner_activation: str = 'relu',
        kernel_initializer: str = 'variance_scaling',
        gru_bias_initializer: float = 2.,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Pretty much the same as the official implementation but adds default
        # options for the GRU gating mechanism.

        self._num_heads = num_attention_heads
        self._head_size = head_size
        self._hidden_size = hidden_size
        self._inner_size = inner_size
        self._dropout_rate = dropout_rate
        self._attention_dropout_rate = attention_dropout_rate
        self._post_attention_activation = post_attention_activation
        self._inner_activation = inner_activation
        self._norm_epsilon = norm_epsilon
        self._kernel_initializer = kernel_initializer
        self._gru_bias_initializer = gru_bias_initializer
        self._attention_layer_type = (
            relative_attention.MultiHeadRelativeAttention
        )

    def build(self, input_shape: tf.TensorShape):
        input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape

        input_tensor_shape = tf.TensorShape(input_tensor)

        if len(input_tensor_shape.as_list()) != 3:
            raise ValueError(
                'TransformerLayer expects a three-dimensional input of '
                'shape [batch, sequence, width].'
            )

        batch_size, sequence_length, hidden_size = input_tensor_shape

        if len(input_shape) == 2:
            mask_tensor_shape = tf.TensorShape(input_shape[1])
            expected_mask_tensor_shape = tf.TensorShape(
                [batch_size, sequence_length, sequence_length,]
            )

        if not expected_mask_tensor_shape.is_compatible_with(
            mask_tensor_shape,
        ):
            raise ValueError(
                'When passing a mask tensor to TransformerXLBlock, '
                'the mask tensor must be of shape [batch, '
                'sequence_length, sequence_length] (here %s). Got a '
                'mask tensor of shape %s. '
                % (expected_mask_tensor_shape, mask_tensor_shape)
            )

        if hidden_size % self._num_heads != 0:
            raise ValueError(
                'The input size (%d) is not a multiple of the number '
                'of attention heads (%d)' % (hidden_size, self._num_heads)
            )

        # Now following Parisotto et al, we stabilize the transformer by moving
        # the layer normalization and adding gated layers.

        self._attention_layer_norm = tf.keras.layers.LayerNormalization(
            name='layer_norm_content_stream',
            axis=-1,
            epsilon=self._norm_epsilon,
            dtype=tf.float32,
        )

        self._attention_layer = self._attention_layer_type(
            num_heads=self._num_heads,
            key_dim=self._head_size,
            value_dim=self._head_size,
            use_bias=False,
            kernel_initializer=self._kernel_initializer,
            dropout=self._attention_dropout_rate,
            name='RMHA',
        )

        # [*D]
        # The official Google implementation adds a second dropout here
        # while the two research models (XL and XLNet) do not. We do not apply
        # post RMHA dropout.

        # [*D]
        # Unlike the TrXL paper, the RL version applies an activation here,
        # "Because the layer norm reordering causes a path where two linear
        # layers are applied in sequence, we apply a ReLU activation to each
        # sub-module output before the residual connection
        self._attention_activation = tf.keras.layers.Activation(
            self._post_attention_activation
        )

        self._attention_gating = GRUGating(
            bias_initializer=self._gru_bias_initializer
        )

        # Position wise NLP
        self._output_layer_norm = tf.keras.layers.LayerNormalization(
            name='layer_norm_position_wise',
            axis=-1,
            epsilon=self._norm_epsilon,
        )

        self._position_wise_ff_1 = tf.keras.layers.Dense(
            self._inner_size,
            activation=self._inner_activation,
            kernel_initializer=self._kernel_initializer,
        )

        self._position_wise_ff_1_dropout = tf.keras.layers.Dropout(
            rate=self._inner_dropout
        )

        self._position_wise_ff_2 = tf.keras.layers.Dense(
            self._hidden_size,
            activation=self._inner_activation,
        )
        self._position_wise_ff_2_dropout = tf.keras.layers.Dropout(
            rate=self._inner_dropout
        )

        self._position_wise_ff_gru_gating = GRUGating(
            bias_initializer=self._gru_bias_initializer
        )

        super().build(input_shape)

We’ll need to build the GRU gating mechanism in order to gate the input post the attention mechanism and post the positional encoding.


class GRUGating(tf.keras.layers.Layer):
    r""" Gating mechanism from Stabilizing Transformers
    """

    def __init__(self, bias_initializer: float = 2.0, **kwargs):
        self._bias_initializer = bias_initializer
        super().__init__()

    def build(self, input_shape):
        h_shape, x_shape = input_shape

        if x_shape[-1] != h_shape[-1]:
            raise ValueError(
                'Both inputs to GRUGate must have equal size in last axis!'
            )

        dim = int(h_shape[-1])

        self._w_r = self.add_weight(shape=(dim, dim))
        self._u_r = self.add_weight(shape=(dim, dim))

        self._w_z = self.add_weight(shape=(dim, dim))
        self._u_z = self.add_weight(shape=(dim, dim))

        self._w_h = self.add_weight(shape=(dim, dim))
        self._u_h = self.add_weight(shape=(dim, dim))

        def bias_initializer(shape, dtype):
            return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype))

        self._bias_z = self.add_weight(
            shape=(dim,), initializer=bias_initializer,
        )

        def call(self, inputs, **kwargs):
            h, X = inputs

            r = tf.tensordot(X, self._w_r, axes=1) + tf.tensordot(
                h, self._u_r, axes=1
            )
            r = tf.nn.sigmoid(r)

            z = (
                tf.tensordot(X, self._w_z, axes=1)
                + tf.tensordot(h, self._u_z, axes=1)
                - self._bias_z
            )
            z = tf.nn.sigmoid(z)

            h_next = tf.tensordot(X, self._w_h, axes=1) + tf.tensordot(
                (h * r), self._u_h, axes=1
            )
            h_next = tf.nn.tanh(h_next)

            return (1 - z) * h + z * h_next

Let me start a unit test here, and I’ll follow-up in another post testing this against the Deepmind benchmarks.

@pytest.mark.skip()
def test_transformer_xl():

    inputs = np.random.random((20, 10, 50)).astype(np.float32)
    state_0 = np.zeros((20, 8, 2)).astype(np.float32)

    # See here for hyperparameters: We implement the "thin" version
    # ref: https://arxiv.org/pdf/1910.06764.pdf
    default_params = dict(
        num_layers=12,
        num_attention_heads=4,
        head_size=64,  # head dim
        inner_size=128,  # hidden dim
        hidden_size=128,  # aka "thin version"
        dropout_rate=0.25,
        attention_dropout_rate=0.20,
        memory_length=512,  # memory size
        reuse_length=10,
        initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
    )

    test_params = default_params.copy()
    test_params['num_layers'] = 2

    model = transformers.GatedTransformerXL(**test_params)
    assert model
    # TODO: test input shapes, ensure masking works, and gated is working

To be continued…