Source code for UM2N.model.transformer_model
import torch
import torch.nn as nn
from einops import rearrange
[docs]
class MLP_model(torch.nn.Module):
def __init__(
self,
input_channels,
output_channels,
list_hiddens=[128, 128],
hidden_act="LeakyReLU",
output_act="LeakyReLU",
input_norm=None,
dropout_prob=0.0,
):
"""
Note that list_hiddens should be a list of hidden channels per MLP layers
e.g. [64 128 64]
Args:
input_channels (_type_): _description_
list_hiddens (_type_): _description_
output_channels (_type_): _description_
hidden_act (str, optional): _description_. Defaults to "LeakyReLU".
output_act (str, optional): _description_. Defaults to "LeakyReLU".
dropout_prob (float, optional): _description_. Defaults to 0.0.
input_norm (float, optional): one of ["BatchNorm1d", "LayerNorm"]
"""
super(MLP_model, self).__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.hidden_channels = list_hiddens
self.hidden_act = getattr(nn, hidden_act)()
self.output_act = getattr(nn, output_act)()
list_in_channels = [input_channels] + list_hiddens
list_out_channels = list_hiddens + [output_channels]
list_layers = []
for i, (in_channels, out_channels) in enumerate(
zip(list_in_channels, list_out_channels)
):
list_layers.append(nn.Linear(in_channels, out_channels))
# output layer
if i == len(list_in_channels):
list_layers.append(self.output_act)
else:
list_layers.append(self.hidden_act)
self.layers = nn.ModuleList(list_layers)
if dropout_prob > 0.0:
self.dropout = nn.Dropout(dropout_prob)
if input_norm is not None:
if input_norm == "batch":
self.input_norm = nn.BatchNorm1d(input_channels)
elif input_norm == "layer":
self.input_norm = nn.LayerNorm(input_channels)
else:
raise NotImplementedError
[docs]
def forward(self, x):
if hasattr(self, "input_norm"):
x = self.input_norm(x)
for _, layer in enumerate(self.layers):
x = layer(x)
if hasattr(self, "dropout"):
x = self.dropout(x)
return x
[docs]
class TransformerBlock(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
dense_ratio=4,
list_dropout=[0.1, 0.1, 0.1],
activation="GELU",
) -> None:
super(TransformerBlock, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dense_dim = embed_dim * dense_ratio
self.pre_attn_norm = nn.LayerNorm(embed_dim)
self.attn_layer = nn.MultiheadAttention(
embed_dim,
num_heads,
dropout=list_dropout[0],
add_bias_kv=False,
batch_first=True,
)
self.post_attn_norm = nn.LayerNorm(embed_dim)
self.post_attn_dropout = nn.Dropout(list_dropout[1])
self.pre_dense_norm = nn.LayerNorm(embed_dim)
self.dense_1 = nn.Linear(embed_dim, self.dense_dim)
self.activation = getattr(nn, activation)()
self.act_dropout = nn.Dropout(list_dropout[2])
self.post_dense_norm = nn.LayerNorm(self.dense_dim)
self.dense_2 = nn.Linear(self.dense_dim, embed_dim)
self.c_attn = nn.Parameter(torch.ones(num_heads), requires_grad=True)
self.residual_weight = nn.Parameter(torch.ones(embed_dim), requires_grad=True)
[docs]
def forward(
self,
x,
k,
v,
x_cls=None,
key_padding_mask=None,
attn_mask=None,
return_attn=False,
):
# In pytorch nn.MultiheadAttention, key_padding_mask True indicates ignore the corresponding key value
# NOTE: check default True or False of key_padding_mask with nn.MultiheadAttention
# if key_padding_mask is not None:
# key_padding_mask = ~key_padding_mask
if x_cls is not None:
# to be implemented later
pass
else:
residual = x
x = self.pre_attn_norm(x)
# NOTE: Here we use batch first in nn.MultiheadAttention
# [batch_size, num_points, embed_dim]
x, attn_scores = self.attn_layer(
x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask
)
if self.c_attn is not None:
num_points = x.shape[1]
x = x.view(-1, num_points, self.num_heads, self.head_dim)
x = torch.einsum("b n h d, h -> b n d h", x, self.c_attn)
x = x.reshape(-1, num_points, self.embed_dim)
if self.post_attn_norm is not None:
x = self.post_attn_norm(x)
x = self.post_attn_dropout(x)
x = x + residual
residual = x
x = self.pre_dense_norm(x)
x = self.activation(self.dense_1(x))
x = self.act_dropout(x)
if self.post_dense_norm is not None:
x = self.post_dense_norm(x)
x = self.dense_2(x)
x = self.post_attn_dropout(x)
if self.residual_weight is not None:
residual = torch.mul(self.residual_weight, residual)
x = x + residual
if not return_attn:
return x
else:
return x, attn_scores
[docs]
class TransformerModel(nn.Module):
def __init__(
self, *, input_dim, embed_dim, output_dim, num_heads=4, num_layers=3
) -> None:
super(TransformerModel, self).__init__()
# save torch module kwargs - lightning ckpt too cumbersome to use
self.kwargs = {
"input_dim": input_dim,
"embed_dim": embed_dim,
"output_dim": output_dim,
"num_heads": num_heads,
"num_layers": num_layers,
}
self.num_heads = num_heads
list_attn_layers = []
for _ in range(num_layers):
list_attn_layers.append(
TransformerBlock(
embed_dim=embed_dim,
num_heads=num_heads,
list_dropout=[0.1, 0.1, 0.1],
)
)
self.attn_layers = nn.ModuleList(list_attn_layers)
self.mlp_in = MLP_model(
input_dim, embed_dim, [embed_dim], hidden_act="GELU", output_act="GELU"
)
self.mlp_out = MLP_model(
embed_dim, output_dim, [embed_dim], hidden_act="GELU", output_act="GELU"
)
[docs]
def forward(self, x, k, v, key_padding_mask=None, attention_mask=None):
x = self.mlp_in(x)
# k = self.mlp_in(k)
# v = self.mlp_in(v)
for _, layer in enumerate(self.attn_layers):
x = layer(
x, k, v, key_padding_mask=key_padding_mask, attn_mask=attention_mask
)
x = self.mlp_out(x)
return x
[docs]
def get_attention_scores(self, x, key_padding_mask=None, attn_mask=None):
list_attn_scores = []
x = self.mlp_in(x)
for _, layer in enumerate(self.attn_layers):
x, attn_scores = layer(
x,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
return_attn=True,
)
if key_padding_mask is not None:
mask_mat = rearrange(key_padding_mask, "b i -> b i 1") * rearrange(
key_padding_mask, "b j -> b 1 j"
)
num_points = key_padding_mask.sum().numpy()
attn_mat = (
attn_scores.detach()
.numpy()
.squeeze()[mask_mat.numpy().squeeze()]
.reshape(num_points, num_points)
)
else:
# The dim for torch squeeze can not be tuple with a version lower than 2.0
attn_mat = torch.squeeze(attn_scores, dim=0).detach().cpu().numpy()
list_attn_scores.append(attn_mat)
return list_attn_scores