Source code for m3sgg.core.models.oed.oed_multi

"""
OED Multi-frame Model Implementation

This module implements the OED architecture for multi-frame dynamic scene graph generation
with the Progressively Refined Module (PRM) for temporal context aggregation.
"""

import copy
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from .criterion import SetCriterionOED
from .postprocess import PostProcessOED
from .transformer import build_transformer

# Set up logger
logger = logging.getLogger(__name__)


[docs] class OEDMulti(nn.Module): """OED Multi-frame model for dynamic scene graph generation. Implements the one-stage end-to-end framework with cascaded decoders and Progressively Refined Module (PRM) for temporal context aggregation. :param nn.Module: Base PyTorch module class :type nn.Module: class """
[docs] def __init__(self, conf, dataset): """Initialize the OED Multi-frame model. :param conf: Configuration object containing model parameters :type conf: Config :param dataset: Dataset information for model setup :type dataset: object :return: None :rtype: None """ super().__init__() # Store configuration and dataset self.conf = conf self.dataset = dataset # Model parameters self.num_queries = conf.num_queries self.hidden_dim = 256 # Default hidden dimension self.num_obj_classes = len(dataset.object_classes) self.num_attn_classes = conf.num_attn_classes self.num_spatial_classes = conf.num_spatial_classes self.num_contacting_classes = conf.num_contacting_classes # Build transformer (remove backbone dependency) # self.backbone = build_backbone(conf) # Remove backbone dependency self.transformer = build_transformer(conf) # Query embedding self.query_embed = nn.Embedding(self.num_queries, self.hidden_dim) # Classification heads self.obj_class_embed = nn.Linear(self.hidden_dim, self.num_obj_classes + 1) self.attn_class_embed = nn.Linear(self.hidden_dim, self.num_attn_classes) self.spatial_class_embed = nn.Linear(self.hidden_dim, self.num_spatial_classes) self.contacting_class_embed = nn.Linear( self.hidden_dim, self.num_contacting_classes ) # Bounding box heads self.sub_bbox_embed = MLP(self.hidden_dim, self.hidden_dim, 4, 3) self.obj_bbox_embed = MLP(self.hidden_dim, self.hidden_dim, 4, 3) # Add projection layer for object features self.feature_proj = nn.Linear( 2048, self.hidden_dim ) # Project from object detector features to hidden dim # Hyperparameters self.aux_loss = getattr(conf, "aux_loss", False) self.dec_layers_hopd = conf.dec_layers_hopd self.dec_layers_interaction = conf.dec_layers_interaction # Semantic and positional fusion self.fuse_semantic_pos = conf.fuse_semantic_pos if self.fuse_semantic_pos: # TODO: Initialize semantic embeddings pass # Temporal interaction heads self.ins_temporal_embed_head = conf.query_temporal_interaction self.rel_temporal_embed_head = conf.query_temporal_interaction if self.ins_temporal_embed_head: self.temporal_obj_class_embed = copy.deepcopy(self.obj_class_embed) self.temporal_sub_bbox_embed = copy.deepcopy(self.sub_bbox_embed) self.temporal_obj_bbox_embed = copy.deepcopy(self.obj_bbox_embed) if self.rel_temporal_embed_head: self.temporal_attn_class_embed = copy.deepcopy(self.attn_class_embed) self.temporal_spatial_class_embed = copy.deepcopy(self.spatial_class_embed) self.temporal_contacting_class_embed = copy.deepcopy( self.contacting_class_embed )
[docs] def forward(self, entry, targets=None): """Forward pass through the OED model. :param entry: Input data containing images and features :type entry: dict :param targets: Ground truth targets for training, defaults to None :type targets: dict, optional :return: Model predictions and outputs :rtype: dict """ # Extract features from the object detector entry if "features" not in entry: raise ValueError("Entry must contain 'features' from object detector") # Use the features from the object detector features = entry["features"] # Shape: (num_objects, feature_dim) # Project features to the hidden dimension features = self.feature_proj(features) # (num_objects, hidden_dim) # For now, create a simple transformer input # TODO: Implement proper feature processing for OED batch_size = 1 # Assume single batch for now num_objects = features.size(0) hidden_dim = self.hidden_dim # Create inputs for transformer src = features.unsqueeze(0) # Add batch dimension: (1, num_objects, hidden_dim) mask = torch.zeros( batch_size, num_objects, dtype=torch.bool, device=features.device ) pos = torch.zeros(batch_size, num_objects, hidden_dim, device=features.device) # Prepare embedding dictionary for transformer embed_dict = { "obj_class_embed": self.obj_class_embed, "attn_class_embed": self.attn_class_embed, "spatial_class_embed": self.spatial_class_embed, "contacting_class_embed": self.contacting_class_embed, "sub_bbox_embed": self.sub_bbox_embed, "obj_bbox_embed": self.obj_bbox_embed, } # Process through transformer hopd_out, interaction_decoder_out = self.transformer( src, mask, self.query_embed.weight, pos, embed_dict, targets, cur_idx=0 )[:2] # Transformer outputs have shape (N, B, C) where N is num_queries, B is batch_size, C is hidden_dim # We need to transpose to (B, N, C) for the classification heads hopd_out = hopd_out.transpose(0, 1) # (B, N, C) interaction_decoder_out = interaction_decoder_out.transpose(0, 1) # (B, N, C) # Generate outputs based on temporal heads if self.ins_temporal_embed_head: outputs_sub_coord = self.temporal_sub_bbox_embed(hopd_out).sigmoid() outputs_obj_coord = self.temporal_obj_bbox_embed(hopd_out).sigmoid() outputs_obj_class = self.temporal_obj_class_embed(hopd_out) else: outputs_sub_coord = self.sub_bbox_embed(hopd_out).sigmoid() outputs_obj_coord = self.obj_bbox_embed(hopd_out).sigmoid() outputs_obj_class = self.obj_class_embed(hopd_out) # Semantic and positional fusion if self.fuse_semantic_pos: # TODO: Implement semantic fusion pass # Generate relation outputs if self.rel_temporal_embed_head: outputs_attn_class = self.temporal_attn_class_embed(interaction_decoder_out) outputs_spatial_class = self.temporal_spatial_class_embed( interaction_decoder_out ) outputs_contacting_class = self.temporal_contacting_class_embed( interaction_decoder_out ) else: outputs_attn_class = self.attn_class_embed(interaction_decoder_out) outputs_spatial_class = self.spatial_class_embed(interaction_decoder_out) outputs_contacting_class = self.contacting_class_embed( interaction_decoder_out ) # Now the outputs have shape (B, N, C) where B is batch_size, N is num_queries # We need to reshape to (B*N, C) for the loss computation batch_size, num_queries, hidden_dim = outputs_obj_class.shape outputs_obj_class = outputs_obj_class.view(-1, hidden_dim) # (B*N, C) outputs_sub_coord = outputs_sub_coord.view(-1, 4) # (B*N, 4) outputs_obj_coord = outputs_obj_coord.view(-1, 4) # (B*N, 4) outputs_attn_class = outputs_attn_class.view( -1, outputs_attn_class.size(-1) ) # (B*N, num_attn_classes) outputs_spatial_class = outputs_spatial_class.view( -1, outputs_spatial_class.size(-1) ) # (B*N, num_spatial_classes) outputs_contacting_class = outputs_contacting_class.view( -1, outputs_contacting_class.size(-1) ) # (B*N, num_contacting_classes) # Extract ground truth from entry if available attention_gt = entry.get("attention_gt", []) spatial_gt = entry.get("spatial_gt", []) contact_gt = entry.get("contact_gt", []) labels = entry.get("labels", []) # Convert to tensors if they're lists if isinstance(attention_gt, list) and len(attention_gt) > 0: attention_gt = torch.tensor( attention_gt, dtype=torch.long, device=outputs_attn_class.device ) if isinstance(spatial_gt, list) and len(spatial_gt) > 0: spatial_gt = [ torch.tensor(sg, dtype=torch.long, device=outputs_spatial_class.device) for sg in spatial_gt ] if isinstance(contact_gt, list) and len(contact_gt) > 0: contact_gt = [ torch.tensor( cg, dtype=torch.long, device=outputs_contacting_class.device ) for cg in contact_gt ] if isinstance(labels, list) and len(labels) > 0: labels = torch.tensor( labels, dtype=torch.long, device=outputs_obj_class.device ) # Get the actual batch size from ground truth if isinstance(attention_gt, torch.Tensor): actual_batch_size = attention_gt.size(0) elif isinstance(spatial_gt, list) and len(spatial_gt) > 0: actual_batch_size = len(spatial_gt) elif isinstance(contact_gt, list) and len(contact_gt) > 0: actual_batch_size = len(contact_gt) elif isinstance(labels, torch.Tensor): actual_batch_size = labels.size(0) else: # Fallback: use the number of objects from features actual_batch_size = features.size(0) # Handle batch size mismatch by aligning outputs and ground truth model_output_size = outputs_obj_class.size(0) if actual_batch_size > model_output_size: # Ground truth has more samples than model outputs # Truncate ground truth to match model outputs logger.warning( f"Ground truth batch size ({actual_batch_size}) exceeds model outputs ({model_output_size}). Truncating ground truth." ) if isinstance(attention_gt, torch.Tensor): attention_gt = attention_gt[:model_output_size] if isinstance(spatial_gt, list): spatial_gt = spatial_gt[:model_output_size] if isinstance(contact_gt, list): contact_gt = contact_gt[:model_output_size] if isinstance(labels, torch.Tensor): labels = labels[:model_output_size] actual_batch_size = model_output_size elif actual_batch_size < model_output_size: # Model outputs have more samples than ground truth # Truncate model outputs to match ground truth outputs_obj_class = outputs_obj_class[:actual_batch_size] outputs_sub_coord = outputs_sub_coord[:actual_batch_size] outputs_obj_coord = outputs_obj_coord[:actual_batch_size] outputs_attn_class = outputs_attn_class[:actual_batch_size] outputs_spatial_class = outputs_spatial_class[:actual_batch_size] outputs_contacting_class = outputs_contacting_class[:actual_batch_size] # CRITICAL: Create pair indices that match the actual batch size # In predcls mode, we need to ensure pair indices are valid for our predictions # The evaluator expects pair indices to reference valid objects in the prediction # Create pair indices that reference our actual predictions # For simplicity, create pairs between first object (human) and all others num_pairs = actual_batch_size pair_idx = torch.stack( [ torch.zeros(num_pairs, dtype=torch.long), torch.arange(num_pairs, dtype=torch.long), ], dim=1, ).to(device=outputs_obj_class.device) # Create image indices for all pairs if "im_idx" in entry and len(entry["im_idx"]) >= actual_batch_size: # Use the provided image indices, but truncate to our batch size im_idx = entry["im_idx"][:actual_batch_size] else: # Default to image 0 for all pairs im_idx = torch.zeros( actual_batch_size, dtype=torch.long, device=outputs_obj_class.device ) # Prepare output dictionary in the format expected by the training loop and evaluator out = { # Standard format expected by training loop "distribution": outputs_obj_class, "attention_distribution": torch.softmax( outputs_attn_class, dim=1 ), # Evaluator expects softmax "spatial_distribution": torch.softmax( outputs_spatial_class, dim=1 ), # Evaluator expects softmax "contact_distribution": torch.softmax( outputs_contacting_class, dim=1 ), # Evaluator expects softmax # Ground truth labels "attention_gt": attention_gt, "spatial_gt": spatial_gt, "contact_gt": contact_gt, "labels": labels, # Critical fields required by Action Genome evaluator "pair_idx": pair_idx, "im_idx": im_idx, # OED format for compatibility "pred_obj_logits": outputs_obj_class, "pred_sub_boxes": outputs_sub_coord, "pred_obj_boxes": outputs_obj_coord, "pred_attn_logits": outputs_attn_class, "pred_spatial_logits": outputs_spatial_class, "pred_contacting_logits": outputs_contacting_class, # Additional fields for evaluation "pred_scores": torch.ones( actual_batch_size, device=outputs_obj_class.device ), # Default confidence scores "scores": torch.ones( actual_batch_size, device=outputs_obj_class.device ), # Object confidence scores "pred_labels": torch.argmax( outputs_obj_class, dim=1 ), # Predicted object labels } # CRITICAL: The evaluator expects individual object boxes in the correct format # The boxes field should contain individual object bounding boxes with batch indices # Format: [batch_idx, x1, y1, x2, y2] for each object if "boxes" in entry: # Use the original object boxes from the detector (this is the correct approach for predcls) out["boxes"] = entry["boxes"] # Keep the full format with batch indices else: # Fallback: create boxes from OED predictions (should not happen in predcls mode) # Add batch index column to the predicted boxes batch_indices = torch.zeros( actual_batch_size, 1, device=outputs_obj_coord.device ) out["boxes"] = torch.cat([batch_indices, outputs_obj_coord], dim=1) # OED is a relation prediction model that works with the evaluator in predcls mode # The evaluator has all the information it needs: # 1. Individual object boxes (from the detector input via entry["boxes"]) # 2. Relationship predictions (from OED outputs) # 3. Proper frame organization (via pair_idx and im_idx) return out
@torch.jit.unused def _set_aux_loss( self, outputs_obj_class, outputs_sub_coord, outputs_obj_coord, outputs_attn_class, outputs_spatial_class, outputs_contacting_class, ): """Set auxiliary outputs for loss computation. :param outputs_obj_class: Object classification outputs :type outputs_obj_class: torch.Tensor :param outputs_sub_coord: Subject bounding box outputs :type outputs_sub_coord: torch.Tensor :param outputs_obj_coord: Object bounding box outputs :type outputs_obj_coord: torch.Tensor :param outputs_attn_class: Attention classification outputs :type outputs_attn_class: torch.Tensor :param outputs_spatial_class: Spatial classification outputs :type outputs_spatial_class: torch.Tensor :param outputs_contacting_class: Contacting classification outputs :type outputs_spatial_class: torch.Tensor :return: List of auxiliary outputs :rtype: list """ min_dec_layers_num = min(self.dec_layers_hopd, self.dec_layers_interaction) return [ { "pred_obj_logits": a, "pred_sub_boxes": b, "pred_obj_boxes": c, "pred_attn_logits": d, "pred_spatial_logits": e, "pred_contacting_logits": f, } for a, b, c, d, e, f in zip( outputs_obj_class[-min_dec_layers_num:-1], outputs_sub_coord[-min_dec_layers_num:-1], outputs_obj_coord[-min_dec_layers_num:-1], outputs_attn_class[-min_dec_layers_num:-1], outputs_spatial_class[-min_dec_layers_num:-1], outputs_contacting_class[-min_dec_layers_num:-1], ) ]
[docs] class MLP(nn.Module): """Multi-layer perceptron for bounding box prediction. :param nn.Module: Base PyTorch module class :type nn.Module: class """
[docs] def __init__(self, input_dim, hidden_dim, output_dim, num_layers): """Initialize the MLP. :param input_dim: Input dimension :type input_dim: int :param hidden_dim: Hidden dimension :type hidden_dim: int :param output_dim: Output dimension :type output_dim: int :param num_layers: Number of layers :type num_layers: int :return: None :rtype: None """ super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) )
[docs] def forward(self, x): """Forward pass through the MLP. :param x: Input tensor :type x: torch.Tensor :return: Output tensor :rtype: torch.Tensor """ for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x
[docs] def build_oed_multi(conf, dataset): """Build OED multi-frame model. :param conf: Configuration object :type conf: Config :param dataset: Dataset object :type dataset: object :return: Tuple of (model, criterion, postprocessors) :rtype: tuple """ model = OEDMulti(conf, dataset) # Build criterion criterion = SetCriterionOED( num_obj_classes=len(dataset.object_classes), num_queries=conf.num_queries, matcher=None, # Will be set during training weight_dict={}, # Will be set during training eos_coef=conf.oed_eos_coef, losses=["obj_labels", "relation_labels", "sub_obj_boxes"], conf=conf, ) # Build postprocessors postprocessors = {"oed": PostProcessOED(conf)} return model, criterion, postprocessors