Source code for m3sgg.core.models.tempura.transformer_tempura
import copy
import torch
import torch.nn as nn
[docs]
class TransformerEncoderLayer(nn.Module):
[docs]
def __init__(self, embed_dim=1936, nhead=4, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
self.linear1 = nn.Linear(embed_dim, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
[docs]
def forward(self, src, input_key_padding_mask):
# local attention
src2, local_attention_weights = self.self_attn(
src, src, src, key_padding_mask=input_key_padding_mask
)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(nn.functional.relu(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src, local_attention_weights
[docs]
class TransformerDecoderLayer(nn.Module):
[docs]
def __init__(self, embed_dim=1936, nhead=4, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.multihead2 = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
self.linear1 = nn.Linear(embed_dim, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, embed_dim)
self.norm3 = nn.LayerNorm(embed_dim)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
[docs]
def forward(self, global_input, input_key_padding_mask, position_embed):
tgt2, global_attention_weights = self.multihead2(
query=global_input + position_embed,
key=global_input + position_embed,
value=global_input,
key_padding_mask=input_key_padding_mask,
)
tgt = global_input + self.dropout2(tgt2)
tgt = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(nn.functional.relu(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
return tgt, global_attention_weights
[docs]
class TransformerEncoder(nn.Module):
[docs]
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
[docs]
def forward(self, input, input_key_padding_mask):
output = input
weights = torch.zeros(
[self.num_layers, output.shape[1], output.shape[0], output.shape[0]]
).to(output.device)
for i, layer in enumerate(self.layers):
output, local_attention_weights = layer(output, input_key_padding_mask)
weights[i] = local_attention_weights
if self.num_layers > 0:
return output, weights
else:
return output, None
[docs]
class TransformerDecoder(nn.Module):
[docs]
def __init__(self, decoder_layer, num_layers, embed_dim):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
[docs]
def forward(self, global_input, input_key_padding_mask, position_embed):
output = global_input
weights = torch.zeros(
[self.num_layers, output.shape[1], output.shape[0], output.shape[0]]
).to(output.device)
for i, layer in enumerate(self.layers):
output, global_attention_weights = layer(
output, input_key_padding_mask, position_embed
)
weights[i] = global_attention_weights
if self.num_layers > 0:
return output, weights
else:
return output, None
[docs]
class transformer(nn.Module):
"""Spatial Temporal Transformer.
:param local_attention: spatial encoder
:type local_attention: object
:param global_attention: temporal decoder
:type global_attention: object
:param position_embedding: frame encoding (window_size*dim)
:type position_embedding: object
:param mode: both--use the features from both frames in the window, latter--use the features from the latter frame in the window
:type mode: str
"""
[docs]
def __init__(
self,
enc_layer_num=1,
dec_layer_num=3,
embed_dim=1936,
nhead=8,
dim_feedforward=2048,
dropout=0.1,
mode=None,
mem_compute=True,
mem_fusion=None,
selection=None,
selection_lambda=0.5,
):
super(transformer, self).__init__()
self.mode = mode
self.mem_fusion = mem_fusion
self.mem_compute = mem_compute
self.selection = selection
encoder_layer = TransformerEncoderLayer(
embed_dim=embed_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
self.local_attention = TransformerEncoder(encoder_layer, enc_layer_num)
if mem_compute:
if mem_compute == "seperate":
self.mem_attention = nn.ModuleDict()
for rel in ["attention", "contacting", "spatial"]:
self.mem_attention.update(
{rel: nn.MultiheadAttention(embed_dim, 1, 0.0, bias=False)}
)
else:
self.mem_attention = nn.MultiheadAttention(
embed_dim, 1, 0.0, bias=False
)
if selection == "manual":
self.selector = selection_lambda
else:
self.selector = nn.Linear(embed_dim, 1)
decoder_layer = TransformerDecoderLayer(
embed_dim=embed_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
self.global_attention = TransformerDecoder(
decoder_layer, dec_layer_num, embed_dim
)
self.position_embedding = nn.Embedding(2, embed_dim) # present and next frame
nn.init.uniform_(self.position_embedding.weight)
[docs]
def memory_hallucinator(self, memory, feat):
if len(memory) != 0:
e = self.mem_selection(feat)
q = feat.unsqueeze(1)
if self.mem_compute == "seperate":
mem_features = {}
for rel in ["attention", "contacting", "spatial"]:
k = v = memory[rel].unsqueeze(1)
mem_features[rel], _ = self.mem_attention[rel](q, k, v)
mem_features = torch.cat([v for k, v in mem_features.items()], dim=1)
mem_features = mem_features.mean(dim=1)
else:
memory = torch.cat([v for k, v in memory.items()], dim=0)
k = v = memory.unsqueeze(1)
mem_features, _ = self.mem_attention(q, k, v)
if e is not None:
mem_encoded_features = e * feat + (1 - e) * mem_features.squeeze(1)
else:
mem_encoded_features = feat + mem_features.squeeze(1)
# mem_encoded_features = feat + e*mem_features.squeeze(1)
else:
mem_encoded_features = feat
return mem_encoded_features
[docs]
def mem_selection(self, feat):
if self.selection == "manual":
return self.selector
else:
return self.selector(feat).sigmoid()
[docs]
def forward(self, features, im_idx, memory=[]):
rel_idx = torch.arange(im_idx.shape[0]).to(
features.device
) # Ensure rel_idx is on the same device
l = torch.sum(
im_idx == torch.mode(im_idx)[0]
) # the highest box number in the single frame
b = int(im_idx[-1] + 1)
rel_input = torch.zeros([l, b, features.shape[1]]).to(features.device)
# masks = torch.zeros([b, l], dtype=torch.uint8).to(features.device)
masks = torch.zeros([b, l], dtype=torch.bool).to(features.device)
# TODO Padding/Mask maybe don't need for-loop
for i in range(b):
rel_input[: torch.sum(im_idx == i), i, :] = features[im_idx == i]
masks[i, torch.sum(im_idx == i) :] = 1
# spatial encoder
local_output, local_attention_weights = self.local_attention(rel_input, masks)
local_output = (
(local_output.permute(1, 0, 2))
.contiguous()
.view(-1, features.shape[1])[masks.view(-1) == 0]
)
if self.mem_compute and self.mem_fusion == "early":
mem_encoder_features = self.memory_hallucinator(
memory=memory, feat=local_output
)
else:
mem_encoder_features = local_output
global_input = torch.zeros([l * 2, b - 1, features.shape[1]]).to(
features.device
)
position_embed = torch.zeros([l * 2, b - 1, features.shape[1]]).to(
features.device
)
idx = -torch.ones([l * 2, b - 1]).to(features.device)
idx_plus = -torch.ones([l * 2, b - 1], dtype=torch.long).to(
features.device
) # TODO
# sliding window size = 2
for j in range(b - 1):
global_input[: torch.sum((im_idx == j) + (im_idx == j + 1)), j, :] = (
mem_encoder_features[(im_idx == j) + (im_idx == j + 1)]
)
idx[: torch.sum((im_idx == j) + (im_idx == j + 1)), j] = im_idx[
(im_idx == j) + (im_idx == j + 1)
]
idx_plus[: torch.sum((im_idx == j) + (im_idx == j + 1)), j] = rel_idx[
(im_idx == j) + (im_idx == j + 1)
] # TODO
position_embed[: torch.sum(im_idx == j), j, :] = (
self.position_embedding.weight[0]
)
position_embed[
torch.sum(im_idx == j) : torch.sum(im_idx == j)
+ torch.sum(im_idx == j + 1),
j,
:,
] = self.position_embedding.weight[1]
global_masks = (
(torch.sum(global_input.view(-1, features.shape[1]), dim=1) == 0)
.view(l * 2, b - 1)
.permute(1, 0)
)
# temporal decoder
global_output, global_attention_weights = self.global_attention(
global_input, global_masks, position_embed
)
# print(global_output.shape)
output = torch.zeros_like(features)
if self.mode == "both":
# both
for j in range(b - 1):
if j == 0:
output[im_idx == j] = global_output[:, j][idx[:, j] == j]
if j == b - 2:
output[im_idx == j + 1] = global_output[:, j][idx[:, j] == j + 1]
else:
output[im_idx == j + 1] = (
global_output[:, j][idx[:, j] == j + 1]
+ global_output[:, j + 1][idx[:, j + 1] == j + 1]
) / 2
elif self.mode == "latter":
# later
for j in range(b - 1):
if j == 0:
output[im_idx == j] = global_output[:, j][idx[:, j] == j]
output[im_idx == j + 1] = global_output[:, j][idx[:, j] == j + 1]
if self.mem_compute and self.mem_fusion == "late":
local_output = output
output = self.memory_hallucinator(memory=memory, feat=output)
mem_encoder_features = output
return (
output,
local_output,
mem_encoder_features,
global_attention_weights,
local_attention_weights,
)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])