Transformer Network in Pytorch from scratch

8 minute read

Published:

Step by step implementation of “Attention is all you need” with animated explanations.

This is a supplementary post to the medium article Transformers in Cheminformatics.

Code

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
# import seaborn
from IPython.display import Image
import plotly.express as px
# seaborn.set_context(context="talk")
%matplotlib inline

Single headed dot-scaled attention

\[Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt {d_k}})V\]
Image(filename='./Images/self_attn.png',width=1000)

png

def attention(query, key,value, mask = None,dropout = None):  #query:Q, key: K, value: V
    dk = key.shape[-1]
    score = torch.matmul(query,key.transpose(-1,-2)) #BxLxD
    scaled_score = score/math.sqrt(dk)
    #Masking (optional) 
    #Increase score to very large negative number for tokens that are masked.
    #Such large negative number will have 0 exponentiation and hence their softmax will be 0 as well. 
    if mask is not None:
        scaled_score.masked_fill(mask==0,-1e9)
    attention = F.softmax(scaled_score,dim=-1)
    #Optional: Dropout
    if dropout is not None:
        attention = nn.Dropout(attention,dropout)
    #Z = enriched embedding 
    Z = torch.matmul(attention,value)
    return Z, attention

    
Image(filename='./Images/multiheaded_attn.png',width=850)

png

SegmentLocal

class MultiheadAttention(nn.Module):
    def __init__(self,nheads,dmodel,dropout=0.1):
        super(MultiheadAttention,self).__init__()
        assert dmodel % nheads ==0 
        self.dk = dmodel//nheads
        self.nheads =  nheads
        
        #From the theory Wq linear layer should be (dmodel x dk)
        #But in implementation (we're using dmodel x dmodel) we will breakdown Wq into h heads later.
        #It can we shown that calculating 'nheads' small q_i's of BxLxdk dimension individually by feeding
        #key, query, value of dimension BxLxdk each is equivalent to 
        #calculating 1 big Wq of BxLxdmodel dimension and feeding in large X (BxLxdmodel) to get a large Q (BxLxdmodel)
        #then breaking Q into 'nheads' smaller q_i's of dimension BxLxdk each.
        self.Wq = nn.Linear(dmodel,dmodel)
        self.Wk = nn.Linear(dmodel,dmodel)
        self.Wv = nn.Linear(dmodel,dmodel)
        self.Wo = nn.Linear(dmodel,dmodel)
        
        self.dropout_value = dropout
        self.dropout = nn.Dropout(p= dropout)
        
    def forward(self,query,key,value,mask=None):
        if mask is not None:
            # Same mask applied to all of the nheads
            mask.unsqueeze(1)
       
        #Dim: q=k=v=x : (BxLxdmodel)
        key,query,value = self.Wk(key), self.Wq(query), self.Wv(value)  #k,q,v = (BxLxdmodel)
        
        #Break k,q,v into nheads k_i's, q_i's and v_i's of dim (BxLxdk)
        key = key.view(nbatches,-1,self.nheads,self.dk ) #(B,L,nheads,dk) (view -1: actual value for this dimension will be inferred so that the number of elements in the view matches the original number of elements.)
        query = query.view(nbatches,-1,self.nheads,self.dk)  
        value = value.view(nbatches,-1,self.nheads,self.dk)
        
        key = key.transpose(1,2) # (B,L,nheads,dk) --> (B,nheads,L,dk)
        query = query.transpose(1,2)
        value= value.transpose(1,2)
        
        #Calculate self attention and enriched embedding z_i's. 
        #All z_i's are channeled together in 1 large z matrix below
        z, self.attn = self_attention(query, key,value,mask,self.dropout_value)  #z : (B,nheads,L,dk), attn: (B,nheads,L,L)
        
        #Reshape z:(B,nheads,L,dk) -->z_concat (B,L,nheads*dk) to refelect the affect of concatenation as shown in figure
        z_concat = z.transpose(1,2) #z:(B,nheads,L,dk) --> z_concat: (B,L,nheads,dk)
        z_concat = z_concat.contiguous() #z_concat: (B,L,nheads,dk) --> z_concat: (1,B*L*nheads*dk)
        z_concat = z_concat.view(nbatches, -1, self.nheads * self.dk) #z_concat: (1,B*L*nheads*dk) --> z_concat (B,L,nheads*dk)
        
        #Project z_concat with linear layer (Wo) to get final enriched embedding z_enriched as shown in figure
        #z_concat (B,L,nheads*dk) --> z_enriched(B,L,dmodel)
        z_enriched = self.Wo(z_concat)
        return z_enriched

Pointwise Feedforward Neural Network

$FFN(x) = max(0,xW_1+b_1)W_2 + b_2$

SegmentLocal

class PointwiseFeedForward(nn.Module):
    def __init__(self,dmodel, d_linear,dropout = 0.2):
        super(PointwiseFeedForward,self).__init__()
        self.W1 = nn.Linear(dmodel,d_linear)
        self.W2 = nn.Linear(d_linear, dmodel)
        self.dropout = nn.Dropout(dropout)
        
    def forward(x):
        relu = F.relu(self.W1(x))
        output = self.dropout(self.W2(x))
        return output

LayerNorm

\(X=\{x_1,x_2......x_m\} \\ % Y =\{y_1,y_2......y_m\} \\ x_i = \{x_{i,1},x_{i,2},.........,x_{i,K}\}\)

LayerNorm normalizes each $x_i$ across all its (K) features such that each sample $x_i$ has 0 mean and unit variance.

\[\mu_i = \frac{1}{K}\sum_{k=1}^K x_{i,k} \\ \sigma_i^2 = \frac{1}{K}\sum_{k=1}^K (x_{i,k}-\mu_i)^2\] \[\hat{x}_{i,k}= \frac{x_{i,k}-\mu_i}{\sqrt{\sigma_i^2 + \epsilon}}\]

And finally the output of a Layernorm layer $y_i = LN_{\beta,\gamma}(x_i)$

\[y_i = \gamma \hat{x_i} + \beta\]
class LayerNorm(nn.Module):
    def __init__(self,features,epsilon = 1e-9):
        'features = number of features along which to normalize \
        in the given input vector/matrix = dmodel'
        super(LayerNorm,self).__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.epsilon = epsilon
        
    def forward(x):
        #calculate mean and std across the last dimension.
        #this will enforce that mean and std are calculated across
        #all features of a fed in example.
        mean = x.mean(-1)
        std = x.std(-1)
        x_hat = x-mean/(std+self.epsilon) #for numerical stability, we skip sqrt in denominator
        output = self.gamma*x_hat + self.beta
        return output

Residual Connection (Add & Norm)

Residual Connection followed by layerNorm

\[Add\_and\_Norm(Sublayer(x)) = LayerNorm(x+Dropout(Sublayer(x)))\]

With the Residual connection and LayerNorm, the complete Transformer block looks as following SegmentLocal

class AddandNorm(nn.Module):
    def __init__(self,features,dropout=0.2,epsilon = 1e-9):
        super(AddandNorm,self).__init__()
        self.layernorm = LayerNorm(features,epsilon)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self,x,sublayer_output):
        return self.layernorm(x+self.dropout(sublayer_output))

Positional Embedding

Input representation: $\hspace{23mm}\textbf{X_input} \in R^{L \times dmodel}$
Positional Embedding: $\hspace{23mm}\textbf{P} \in R^{L \times dmodel}$
Updated Embedding: $\hspace{23mm}\textbf{X} = \textbf{X_input} + \textbf{P}$
Elements in $\textbf{P}$ are calculated as following \(p_{i,2j}= sin\bigg(\frac{i}{10000^{\frac{2j}{dmodel}}}\bigg) = sin(\theta)\)

\[p_{i,2j+1}= cos\bigg(\frac{i}{10000^{\frac{2j}{dmodel}}}\bigg) = cos(\theta)\]

class PositionalEmbedding(nn.Module):
    def __init__(self,dmodel,device,maxlen=10000,dropout=0.2):
        super(PositionalEmbeddingitionalEmbedding,self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.device = device
        
        #i is a max_len dimensional vector, so that we can store a positional embedding
        #value corresponding to each token in sequence (Character in SMILES)
        theta_numerator = torch.arange(max_len,dtype = torch.float32)
        theta_denominator = torch.pow(10000,torch.arange(0,dmodel,2,dtype=torch.float32))/dmodel
        theta = theta_numerator/theta_denominator
        
        #Create a large P tensor to hold position embedding value for each token in the sequence
        self.P = torch.zeros((maxlen,dmodel))
        #Update even column ids in P with sin(theta) and odd column ids with cos(theta)
        self.P[:,0::2] = sin(theta)
        self.P[:,1::2] = cos(theta)
        
    def forward(self,x):
        # x.shape[1] gives the length of input sequence
        x = x+self.P[:,x.shape[1],:]  
        return self.dropout(x)
        
        

Encoder Layer

We are ready to build a single encoder layer by combining Residual connection,
Layernorm, multiheaded attention block and position-wise feedforward network

class EncoderLayer(nn.Module):
    def __init__(self,mask=None,nhead=8,dmodel=512,dlinear=1024,dropout=0.2):
        super(EncoderLayer,self).__init__()
        self.multihead_attn = MultiheadAttention(nheads,dmodel,dropout)
        self.add_norm1 = AddandNorm(dmodel,dropout)
        self.pw_ffn = PointwiseFeedForward(dmodel,dlinear,dropout)
        self.add_norm2 = AddandNorm(dmodel,dropout)
        
    def forward(self,x,mask=None):
        'The input to the encoderlayer is either the embedding for first encoder layer \
         or representions from previous layer. We use key=query=value = input (x) to feed \
         into the multiheaded attention block within encoder layer'
        multihead_attn_output = self.multihead_attn(x,x,x,mask)
        addnorm1_out = self.add_norm1(x,multihead_attn_output)
        pwffn_outputs = self.pw_ffn(addnorm1_out)
        addnorm2_out = self.add_norm2(addnorm1_out,pwffn_outputs)
        return addnorm2_out
        

Encoder (Stack of encoder layers)

Encoder stack is a sequential model with N encoder layers. Vaswani et al. used N=6 in their original paper.

class Encoder(nn.Module):
    def __init__(self,device,src_vocab,nlayers=6,nhead=8,dmodel=512,
                 dlinear=1024,dropout=0.2):
        super(Encoder,self).__init__()
        self.dmodel = dmodel
        self.encoder_stack = nn.Sequential()
        self.embed = nn.Embedding(len(src_vocab),dmodel)
        self.pos_embed = PositionalEmbedding(dmodel,device)
        for i in range(nlayers):
            self.encoder_stack.add_module("Encoder_Layer_"+str(i),
                                          EncoderLayer(nhead,dmodel,dlinear,dropout))
            
    def forward(self,x):
        embedding = self.embed(x)
        pos_embedding = self.pos_embed(embedding*math.sqrt(self.dmodel))
        self.attention_weights = []
        for layer in self.encoder_stack:
            x = layer(x,mask)
            self.attention_weights.append(layer.multihead_attn.attn)
        return x
                            
            
        

Decoder Layer

View 1

View 2

Autoregression

Mask out subsequent positions for decoder use

def autoregression_mask(nbatch,size):
    "Mask out subsequent positions."
    attn_shape = (nbatch, size, size)
    autoregression_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(autoregression_mask) == 0

print(autoregression_mask(1,20).shape)
plt.imshow(autoregression_mask(1,20).squeeze(0))
torch.Size([1, 20, 20])
<matplotlib.image.AxesImage at 0x7f2eba48b208>

png

Decoder layer

class DecoderLayer(nn.Module):
    def __init__(self,i,nheads=8,dmodel=512,dropout=0.2,dlinear=1024):
        super(DecoderLayer,self).__init__()
        self.i = i #identifier to distinguish decoder layers in stacked decoder
        self.multihead_attn1 = MultiheadAttention(nheads,dmodel,dropout)
        self.add_norm1 = AddandNorm(dmodel,dropout)
        
        self.multihead_en_de = MultiheadAttention(nheads,dmodel,dropout)
        self.add_norm2 = AddandNorm(dmodel,dropout)
        
        self.pw_ffn = PointwiseFeedForward(dmodel,dlinear,dropout)
        self.add_norm3 = AddandNorm(dmodel,dropout)
        
    def forward(self,x,encoder_output,encoder_mask,decoder_state,time_step=0):
        nbatch = x.shape[0]
        
        if self.training:
            decoder_mask = autoregression_mask(nbatch,self.dmodel) #(B,L,L)
        else:
            decoder_mask = None
        
        if self.training:
            #when training, decoder has access to the entire output sequence
            #hence can take entire x (target sequence) as input. Obviously proper masking
            #of target sequence is needed to stop the decoder from seeing future tokens
            mulithead_attn1_output = self.multihead_attn1(x,x,x,decoder_mask)
        else:
            #when validating, decoder hasn't produced the word beyond the time step it is 
            #processing. Decoding happens one word at a time during validation.
            #at t=0, input to multihead_attn1 
            #q,k,v = <bos> token, 
            #at t=1 and beyond, input to ith decoder block (from prev. decoder side) is whatever
            #was predicted at prev. time step and the ith decoder block's state at t-1 timestep. 
            #See figure above
            if time_step ==0:
                mulithead_attn1_output = self.multihead_attn1(x,x,x,decoder_mask)
                #update decoder state with current time step's state
                decoder_state[self.i]= x
            else:
                decoder_query = x
                decoder_key_value = torch.cat((decoder_state[self.i],x),dim=1)
                mulithead_attn1_output = self.multihead_attn1(x,decoder_key_value,
                                                              decoder_key_value,decoder_mask)
                #update decoder state with current time step's state
                decoder_state[self.i]= decoder_key_value
            
            
        addnorm1_out = self.add_norm1(x,mulithead_attn1_output)
    
        key,value,query = encoder_output,encoder_output,addnorm1_out
        multihead_en_de_output = self.multihead_en_de(query,key,value,encoder_mask)
        addnorm2_out = self.add_norm2(addnorm1_out,multihead_en_de_output)
        
        pwffn_outputs = self.pw_ffn(addnorm2_out)
        addnorm3_out = self.add_norm2(addnorm2_out,pwffn_outputs)
        return addnorm3_out,decoder_state
        

Decoder

Decoder stack is also a sequential model like encoder with N layers. Vaswani et al. used N=6 in their original paper.

class Decoder(nn.Module):
    def __init__(self,device,tgt_vocab,nlayers=6,nhead=8,dmodel=512,
                 dlinear=1024,dropout=0.2):
        super(Decoder,self).__init__()
        self.dmodel = dmodel
        self.nlayers = nlayers
        self.embed = nn.Embedding(len(tgt_vocab),dmodel)
        self.pos_embed = PositionalEmbedding(dmodel,device)
        
        self.decoder_stack = nn.Sequential()
        
        for i in range(nlayers):
            self.decoder_stack.add_module("Decoder_Layer_"+str(i),
                                          DecoderLayer(i,nhead,dmodel,dropout,dlinear))
        self.dense = nn.Linear(dmodel,len(tgt_vocab))
            
    def forward(self,x,encoder_output,encoder_mask,decoder_state,time_step=0):
        embedding = self.embed(x)
        pos_embedding = self.pos_embed(embedding*math.sqrt(self.dmodel))
        #To save attention weights from both multiheaded attention layers
        #for later visualization
        self.att_wts_de,self.att_wts_en_de = [],[]
        for layer in self.encoder_stack:
            x,decoder_state = layer(x,encoder_output,encoder_mask,decoder_state,time_step)
            self.att_wts_de.append(layer.multihead_attn1.attn)
            self.att_wts_en_de.append(layer.multihead_en_de.attn)
        return dense(x),decoder_state
                            
            
        

Transformer Network

class Transformer(nn.Module):
    def __init__(self,encoder,decoder):
        super(Transformer,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    def forward(self,src_input,src_mask,tgt_input,decoder_state,time_step):
        encoder_output = encoder(src_input)
        decoder_output,decoder_state = decoder(tgt_input,encoder_output,
                                               encoder_mask,decoder_state,time_step)
        
encoder = Encoder(device,src_vocab,nlayers=6,nhead=8,dmodel=512,
                 dlinear=1024,dropout=0.2)
decoder = Decoder(device,tgt_vocab,nlayers=6,nhead=8,dmodel=512,
                 dlinear=1024,dropout=0.2)

trfm_network = Transformer(encoder,decoder)

Leave a Comment