"""
Let's get the relationships yo
"""
import numpy as np
import os
import sys
import torch
import torch.nn as nn
# Add project root to path for fasterRCNN imports
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
)
if project_root not in sys.path:
sys.path.insert(0, project_root)
from fasterRCNN.lib.model.roi_layers import ROIAlign, nms
from m3sgg.utils.fpn.box_utils import center_size
from m3sgg.utils.transformer import transformer
from m3sgg.utils.word_vectors import obj_edge_vectors, verb_edge_vectors
[docs]
class ObjectClassifier(nn.Module):
"""Module for computing object contexts and edge contexts for EASG.
EASG-specific implementation of object classification and contextual
feature extraction for efficient scene graph generation.
:param nn.Module: Base PyTorch module class
:type nn.Module: class
"""
[docs]
def __init__(self, mode="edgecls", obj_classes=None):
"""Initialize the EASG object classifier.
:param mode: Classification mode, defaults to "edgecls"
:type mode: str, optional
:param obj_classes: List of object class names, defaults to None
:type obj_classes: list, optional
:return: None
:rtype: None
"""
super(ObjectClassifier, self).__init__()
self.mode = mode
self.obj_classes = obj_classes
embed_vecs = obj_edge_vectors(
self.obj_classes, wv_type="glove.6B", wv_dir="data", wv_dim=200
)
self.obj_embed = nn.Embedding(len(self.obj_classes), 200)
self.obj_embed.weight.data = embed_vecs.clone()
# This probably doesn't help it much
self.pos_embed = nn.Sequential(
nn.BatchNorm1d(4, momentum=0.01 / 10.0),
nn.Linear(4, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
)
self.obj_dim = 2048
self.decoder_lin = nn.Sequential(
nn.Linear(self.obj_dim + 200 + 128, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, len(self.obj_classes)),
)
[docs]
def forward(self, entry):
if self.mode == "edgecls":
entry["pred_labels"] = entry["labels"]
entry["distribution"] = torch.nn.functional.one_hot(
entry["pred_labels"], len(self.obj_classes)
).to(torch.float32)
else:
obj_embed = entry["distribution"] @ self.obj_embed.weight
pos_embed = self.pos_embed(center_size(entry["boxes"][:, 1:]))
obj_features = torch.cat((entry["features"], obj_embed, pos_embed), 1)
if self.training:
entry["distribution"] = self.decoder_lin(obj_features)
entry["pred_labels"] = entry["labels"]
else:
entry["distribution"] = self.decoder_lin(obj_features)
entry["distribution"] = torch.softmax(entry["distribution"], dim=1)
entry["pred_labels"] = torch.max(entry["distribution"], dim=1)[1]
return entry
[docs]
class ActionClassifier(nn.Module):
[docs]
def __init__(self, mode="edgecls", verb_classes=None):
super(ActionClassifier, self).__init__()
self.mode = mode
self.verb_classes = verb_classes
# Verb features have dimension 2048
self.decoder_lin = nn.Sequential(
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, len(self.verb_classes)),
)
[docs]
def forward(self, entry):
if self.mode != "easgcls":
entry["pred_labels_verb"] = entry["labels_verb"]
entry["distribution_verb"] = torch.nn.functional.one_hot(
entry["pred_labels_verb"], len(self.verb_classes)
).to(torch.float32)
else:
if self.training:
entry["distribution_verb"] = self.decoder_lin(entry["features_verb"])
entry["pred_labels_verb"] = entry["labels_verb"]
else:
entry["distribution_verb"] = self.decoder_lin(entry["features_verb"])
entry["distribution_verb"] = torch.softmax(
entry["distribution_verb"], dim=1
)
entry["pred_labels_verb"] = torch.max(
entry["distribution_verb"], dim=1
)[1]
return entry
[docs]
class STTran(nn.Module):
[docs]
def __init__(
self,
mode="edgecls",
obj_classes=None,
verb_classes=None,
edge_class_num=None,
enc_layer_num=None,
dec_layer_num=None,
use_visual_features=False,
):
super(STTran, self).__init__()
self.obj_classes = [cls for cls in obj_classes if cls != "__background__"]
self.verb_classes = verb_classes
self.edge_class_num = edge_class_num
assert mode in ("easgcls", "sgcls", "edgecls")
self.mode = mode
self.object_classifier = ObjectClassifier(
mode=self.mode, obj_classes=self.obj_classes
)
self.action_classifier = ActionClassifier(
mode=self.mode, verb_classes=self.verb_classes
)
###################################
self.conv = nn.Sequential(
nn.Conv2d(2, 256 // 2, kernel_size=7, stride=2, padding=3, bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(256 // 2, momentum=0.01),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(256 // 2, 256, kernel_size=3, stride=1, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(256, momentum=0.01),
)
self.obj_fc = nn.Linear(2048, 512)
self.verb_fc = nn.Linear(2048, 512)
embed_vecs = obj_edge_vectors(
self.obj_classes, wv_type="glove.6B", wv_dir="data", wv_dim=200
)
self.obj_embed = nn.Embedding(len(self.obj_classes), 200)
self.obj_embed.weight.data = embed_vecs.clone()
embed_vecs = verb_edge_vectors(
self.verb_classes, wv_type="glove.6B", wv_dir="data", wv_dim=200
)
# embed_vecs = obj_edge_vectors(self.verb_classes, wv_type='glove.6B', wv_dir='data', wv_dim=200)
self.verb_embed = nn.Embedding(len(self.verb_classes), 200)
self.verb_embed.weight.data = embed_vecs.clone()
self.glocal_transformer = transformer(
enc_layer_num=enc_layer_num,
dec_layer_num=dec_layer_num,
embed_dim=1424,
nhead=8,
dim_feedforward=2048,
dropout=0.1,
mode="latter",
)
self.rel_compress = nn.Linear(1424, self.edge_class_num)
[docs]
def forward(self, entry):
entry = self.object_classifier(entry)
entry = self.action_classifier(entry)
# visual part
obj_rep = entry["features"]
obj_rep = self.obj_fc(obj_rep)
verb_rep = entry["features_verb"]
verb_rep = self.verb_fc(verb_rep)
x_visual = torch.cat((obj_rep, verb_rep), 1)
# semantic part
obj_class = entry["pred_labels"]
obj_emb = self.obj_embed(obj_class)
verb_class = entry["pred_labels_verb"]
verb_emb = self.verb_embed(verb_class)
x_semantic = torch.cat((obj_emb, verb_emb), 1)
rel_features = torch.cat((x_visual, x_semantic), dim=1)
# Spatial-Temporal Transformer
global_output, global_attention_weights, local_attention_weights = (
self.glocal_transformer(features=rel_features, im_idx=entry["im_idx"])
)
entry["edge_distribution"] = self.rel_compress(global_output)
entry["edge_distribution"] = torch.sigmoid(entry["edge_distribution"])
return entry