Core Library
The core library contains the main functionality including configuration, models, detectors, evaluation, and training components.
Configuration
Unified configuration interface for M3SGG.
This module provides a unified interface that can work with both legacy and modern configuration systems, allowing for gradual migration.
- author:
M3SGG Team
- version:
0.1.0
- class m3sgg.core.config.unified.UnifiedConfig(config_path: str | None = None, model_type: str | None = None, use_modern: bool = True, cli_args: List[str] | None = None, overrides: Dict[str, Any] | None = None)[source]
Bases:
object
Unified configuration interface supporting both legacy and modern systems.
This class provides a unified interface that can work with both the legacy argparse-based configuration system and the modern OmegaConf-based system. It automatically detects which system to use based on the configuration and provides a consistent interface.
- Parameters:
config_path (Optional[str]) – Path to configuration file
model_type (Optional[str]) – Model type for structured configuration
use_modern (bool) – Force use of modern configuration system
cli_args (Optional[List[str]]) – Command-line arguments
overrides (Optional[Dict[str, Any]]) – Configuration overrides
- __init__(config_path: str | None = None, model_type: str | None = None, use_modern: bool = True, cli_args: List[str] | None = None, overrides: Dict[str, Any] | None = None)[source]
Initialize the unified configuration.
- Parameters:
config_path (Optional[str]) – Path to configuration file
model_type (Optional[str]) – Model type for structured configuration
use_modern (bool) – Force use of modern configuration system
cli_args (Optional[List[str]]) – Command-line arguments
overrides (Optional[Dict[str, Any]]) – Configuration overrides
- get(key: str, default: Any | None = None) Any [source]
Get a configuration value by key.
- Parameters:
key (str) – Configuration key (supports dot notation for modern)
default (Any) – Default value if key not found
- Returns:
Configuration value
- Return type:
Any
- set(key: str, value: Any)[source]
Set a configuration value by key.
- Parameters:
key (str) – Configuration key
value (Any) – Value to set
- save(path: str)[source]
Save configuration to a file.
- Parameters:
path (str) – Path to save the configuration
- __getattr__(name: str) Any [source]
Allow attribute-style access to configuration values.
- Parameters:
name (str) – Attribute name
- Returns:
Configuration value
- Return type:
Any
- __getitem__(key: str) Any [source]
Allow dictionary-style access to configuration values.
- Parameters:
key (str) – Configuration key
- Returns:
Configuration value
- Return type:
Any
- __setitem__(key: str, value: Any)[source]
Allow dictionary-style setting of configuration values.
- Parameters:
key (str) – Configuration key
value (Any) – Value to set
- __repr__() str [source]
String representation of the configuration.
- Returns:
String representation
- Return type:
- __str__() str [source]
String representation of the configuration.
- Returns:
String representation
- Return type:
- m3sgg.core.config.unified.create_config(model_type: str, config_path: str | None = None, **overrides) UnifiedConfig [source]
Create a configuration for a specific model type.
- Parameters:
- Returns:
Unified configuration instance
- Return type:
- m3sgg.core.config.unified.load_config_from_file(config_path: str) UnifiedConfig [source]
Load configuration from a file.
- Parameters:
config_path (str) – Path to configuration file
- Returns:
Unified configuration instance
- Return type:
- m3sgg.core.config.unified.Config
alias of
UnifiedConfig
Modern configuration system using OmegaConf.
This module provides a modern configuration management system using OmegaConf with support for YAML files, command-line overrides, interpolation, and structured validation.
- author:
M3SGG Team
- version:
0.1.0
- class m3sgg.core.config.modern.ConfigManager(config_path: str | None = None, model_type: str | None = None, cli_args: List[str] | None = None, overrides: Dict[str, Any] | None = None)[source]
Bases:
object
Modern configuration manager using OmegaConf.
This class provides a unified interface for configuration management that supports YAML files, command-line arguments, interpolation, and structured validation.
- Parameters:
- __init__(config_path: str | None = None, model_type: str | None = None, cli_args: List[str] | None = None, overrides: Dict[str, Any] | None = None)[source]
Initialize the configuration manager.
- get(key: str, default: Any | None = None) Any [source]
Get a configuration value by key.
- Parameters:
key (str) – Configuration key (supports dot notation)
default (Any) – Default value if key not found
- Returns:
Configuration value
- Return type:
Any
- set(key: str, value: Any)[source]
Set a configuration value by key.
- Parameters:
key (str) – Configuration key (supports dot notation)
value (Any) – Value to set
- save(path: str)[source]
Save configuration to a YAML file.
- Parameters:
path (str) – Path to save the configuration
- __getattr__(name: str) Any [source]
Allow attribute-style access to configuration values.
- Parameters:
name (str) – Attribute name
- Returns:
Configuration value
- Return type:
Any
- __getitem__(key: str) Any [source]
Allow dictionary-style access to configuration values.
- Parameters:
key (str) – Configuration key
- Returns:
Configuration value
- Return type:
Any
- __setitem__(key: str, value: Any)[source]
Allow dictionary-style setting of configuration values.
- Parameters:
key (str) – Configuration key
value (Any) – Value to set
- m3sgg.core.config.modern.create_config(model_type: str, config_path: str | None = None, **overrides) ConfigManager [source]
Create a configuration for a specific model type.
- Parameters:
- Returns:
Configuration manager instance
- Return type:
- m3sgg.core.config.modern.load_config_from_file(config_path: str) ConfigManager [source]
Load configuration from a YAML file.
- Parameters:
config_path (str) – Path to configuration file
- Returns:
Configuration manager instance
- Return type:
- m3sgg.core.config.modern.merge_configs(*configs: ConfigManager) ConfigManager [source]
Merge multiple configuration managers.
- Parameters:
configs (ConfigManager) – Configuration managers to merge
- Returns:
Merged configuration manager
- Return type:
Legacy configuration system for backward compatibility.
This module provides the original argparse-based configuration system for backward compatibility with existing code that depends on the old Config class interface.
- author:
M3SGG Team
- version:
0.1.0
- class m3sgg.core.config.legacy.LegacyConfig(config_path: str | None = None, overrides: Dict[str, Any] | None = None)[source]
Bases:
Config
Legacy configuration class for backward compatibility.
This class wraps the original Config class to provide backward compatibility while allowing for future migration to the modern configuration system.
- Parameters:
- __init__(config_path: str | None = None, overrides: Dict[str, Any] | None = None)[source]
Initialize the legacy configuration.
- to_dict() Dict[str, Any] [source]
Convert configuration to dictionary.
- Returns:
Configuration as dictionary
- Return type:
Dict[str, Any]
- update(**kwargs)[source]
Update configuration values.
- Parameters:
kwargs (Any) – Configuration updates
Structured Configuration
Base structured configuration classes for M3SGG models.
This module provides the base configuration classes and common structures used across all model types in the M3SGG framework.
- author:
M3SGG Team
- version:
0.1.0
- class m3sgg.core.config.structured.base.BaseConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'sttran', use_matcher: bool = False, eval: bool = False)[source]
Bases:
object
Base configuration class with common parameters for all models.
This class defines the common configuration parameters that are shared across all model types in the M3SGG framework.
- Parameters:
mode (str) – Training mode (predcls/sgcls/sgdet)
save_path (str) – Path to save model outputs
model_path (str) – Path to model weights
dataset (str) – Dataset name (action_genome/EASG)
data_path (str) – Path to dataset
datasize (str) – Dataset size (mini/large)
fraction (int) – Fraction of dataset to use (1=all, 2=half, etc.)
ckpt (Optional[str]) – Checkpoint path
optimizer (str) – Optimizer type (adamw/adam/sgd)
lr (float) – Learning rate
nepoch (float) – Number of epochs
niter (Optional[int]) – Number of iterations for iterative training
eval_frequency (int) – Evaluation frequency for iterative training
enc_layer (int) – Number of encoder layers
dec_layer (int) – Number of decoder layers
bce_loss (bool) – Use BCE loss instead of multi-label margin loss
device (str) – Torch device string (e.g., cuda:0, cpu)
seed (int) – Global random seed
num_workers (int) – Number of DataLoader workers
model_type (str) – Model type identifier
use_matcher (bool) – Use Hungarian matcher (for DSG-DETR)
eval (bool) – Evaluation mode
- __init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'sttran', use_matcher: bool = False, eval: bool = False) None
- class m3sgg.core.config.structured.base.TrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1)[source]
Bases:
object
Training-specific configuration parameters.
- Parameters:
batch_size (int) – Training batch size
val_batch_size (int) – Validation batch size
max_grad_norm (float) – Maximum gradient norm for clipping
warmup_epochs (int) – Number of warmup epochs
early_stopping_patience (int) – Early stopping patience
save_frequency (int) – Model saving frequency (epochs)
log_frequency (int) – Logging frequency (iterations)
val_frequency (int) – Validation frequency (epochs)
- class m3sgg.core.config.structured.base.DataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False)[source]
Bases:
object
Data-specific configuration parameters.
- Parameters:
- class m3sgg.core.config.structured.base.LoggingConfig(log_level: str = 'INFO', log_file: str | None = None, tensorboard_dir: str | None = None, wandb_project: str | None = None, wandb_entity: str | None = None, log_gradients: bool = False, log_weights: bool = False)[source]
Bases:
object
Logging configuration parameters.
- Parameters:
log_level (str) – Logging level (DEBUG, INFO, WARNING, ERROR)
log_file (Optional[str]) – Log file path
tensorboard_dir (Optional[str]) – TensorBoard log directory
wandb_project (Optional[str]) – Weights & Biases project name
wandb_entity (Optional[str]) – Weights & Biases entity name
log_gradients (bool) – Log gradient norms
log_weights (bool) – Log weight histograms
- class m3sgg.core.config.structured.base.CheckpointConfig(save_dir: str = 'checkpoints', save_best: bool = True, save_last: bool = True, save_frequency: int = 1, max_checkpoints: int = 5, checkpoint_metric: str = 'recall@20', checkpoint_mode: str = 'max')[source]
Bases:
object
Checkpoint configuration parameters.
- Parameters:
save_dir (str) – Directory to save checkpoints
save_best (bool) – Save best model only
save_last (bool) – Save last model
save_frequency (int) – Save frequency (epochs)
max_checkpoints (int) – Maximum number of checkpoints to keep
checkpoint_metric (str) – Metric to use for best model selection
checkpoint_mode (str) – Mode for metric comparison (min/max)
- class m3sgg.core.config.structured.base.EvaluationConfig(iou_threshold: float = 0.5, constraint: str = 'with', save_predictions: bool = True, predictions_dir: str = 'predictions', eval_metrics: ~typing.List[str] = <factory>, eval_frequency: int = 1)[source]
Bases:
object
Evaluation configuration parameters.
- Parameters:
iou_threshold (float) – IoU threshold for evaluation
constraint (str) – Constraint type for evaluation
save_predictions (bool) – Save predictions to file
predictions_dir (str) – Directory to save predictions
eval_metrics (List[str]) – List of metrics to compute
eval_frequency (int) – Evaluation frequency (epochs)
- class m3sgg.core.config.structured.base.ModelConfig(hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True)[source]
Bases:
object
Model-specific configuration parameters.
- Parameters:
- class m3sgg.core.config.structured.base.LossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None)[source]
Bases:
object
Loss function configuration parameters.
- Parameters:
loss_weights (Dict[str, float]) – Weights for different loss components
label_smoothing (float) – Label smoothing factor
focal_alpha (float) – Focal loss alpha parameter
focal_gamma (float) – Focal loss gamma parameter
class_weights (Optional[Dict[str, float]]) – Class weights for imbalanced datasets
- m3sgg.core.config.structured.base.get_config_class(model_type: str)[source]
Get the appropriate configuration class for a model type.
- Parameters:
model_type (str) – The model type identifier
- Returns:
The configuration class for the model type
- Return type:
- Raises:
ValueError – If the model type is not supported
STTRAN model configuration classes.
This module provides structured configuration classes specifically for the STTRAN (Spatial-Temporal Transformer) model.
- author:
M3SGG Team
- version:
0.1.0
- class m3sgg.core.config.structured.sttran.STTRANConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 2e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'sttran', use_matcher: bool = False, eval: bool = False, hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, use_spatial_encoding: bool = True, use_temporal_encoding: bool = True, max_seq_len: int = 1000, spatial_dim: int = 64, temporal_dim: int = 64, obj_feat_dim: int = 2048, rel_feat_dim: int = 256, num_obj_classes: int = 35, num_rel_classes: int = 132, use_bbox_encoding: bool = True, bbox_encoding_dim: int = 128, use_attention_weights: bool = False, attention_dropout: float = 0.1, ffn_dim: int = 1024, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True, gradient_checkpointing: bool = False, use_memory_efficient_attention: bool = False)[source]
Bases:
BaseConfig
Configuration for STTRAN model.
STTRAN (Spatial-Temporal Transformer) is a transformer-based model for video scene graph generation that processes spatial and temporal information through attention mechanisms.
- Parameters:
model_type (str) – Model type identifier
lr (float) – Learning rate (overridden for STTRAN)
hidden_dim (int) – Hidden dimension size
num_heads (int) – Number of attention heads
num_layers (int) – Number of transformer layers
dropout (float) – Dropout rate
use_spatial_encoding (bool) – Use spatial position encoding
use_temporal_encoding (bool) – Use temporal position encoding
max_seq_len (int) – Maximum sequence length
spatial_dim (int) – Spatial dimension size
temporal_dim (int) – Temporal dimension size
obj_feat_dim (int) – Object feature dimension
rel_feat_dim (int) – Relation feature dimension
num_obj_classes (int) – Number of object classes
num_rel_classes (int) – Number of relation classes
use_bbox_encoding (bool) – Use bounding box encoding
bbox_encoding_dim (int) – Bounding box encoding dimension
use_attention_weights (bool) – Use attention weights for visualization
attention_dropout (float) – Attention dropout rate
ffn_dim (int) – Feed-forward network dimension
activation (str) – Activation function
norm_type (str) – Normalization type
use_bias (bool) – Use bias in linear layers
gradient_checkpointing (bool) – Use gradient checkpointing
use_memory_efficient_attention (bool) – Use memory efficient attention
- __init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 2e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'sttran', use_matcher: bool = False, eval: bool = False, hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, use_spatial_encoding: bool = True, use_temporal_encoding: bool = True, max_seq_len: int = 1000, spatial_dim: int = 64, temporal_dim: int = 64, obj_feat_dim: int = 2048, rel_feat_dim: int = 256, num_obj_classes: int = 35, num_rel_classes: int = 132, use_bbox_encoding: bool = True, bbox_encoding_dim: int = 128, use_attention_weights: bool = False, attention_dropout: float = 0.1, ffn_dim: int = 1024, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True, gradient_checkpointing: bool = False, use_memory_efficient_attention: bool = False) None
- class m3sgg.core.config.structured.sttran.STTRANTrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 2, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, scheduler_type: str = 'reduce_on_plateau', scheduler_patience: int = 3, scheduler_factor: float = 0.5, weight_decay: float = 0.0001, clip_grad_norm: float = 1.0, use_amp: bool = False, accumulation_steps: int = 1)[source]
Bases:
TrainingConfig
STTRAN-specific training configuration.
- Parameters:
warmup_epochs (int) – Number of warmup epochs (STTRAN specific)
scheduler_type (str) – Learning rate scheduler type
scheduler_patience (int) – Scheduler patience
scheduler_factor (float) – Scheduler reduction factor
weight_decay (float) – Weight decay for regularization
clip_grad_norm (float) – Gradient clipping norm
use_amp (bool) – Use automatic mixed precision
accumulation_steps (int) – Gradient accumulation steps
- __init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 2, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, scheduler_type: str = 'reduce_on_plateau', scheduler_patience: int = 3, scheduler_factor: float = 0.5, weight_decay: float = 0.0001, clip_grad_norm: float = 1.0, use_amp: bool = False, accumulation_steps: int = 1) None
- class m3sgg.core.config.structured.sttran.STTRANLossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 2.0, bbox_loss_weight: float = 2.5, giou_loss_weight: float = 1.0, use_focal_loss: bool = False, use_class_weights: bool = False, obj_class_weights: dict | None = None, rel_class_weights: dict | None = None)[source]
Bases:
LossConfig
STTRAN-specific loss configuration.
- Parameters:
obj_loss_weight (float) – Weight for object classification loss
rel_loss_weight (float) – Weight for relation classification loss
bbox_loss_weight (float) – Weight for bounding box regression loss
giou_loss_weight (float) – Weight for GIoU loss
use_focal_loss (bool) – Use focal loss for classification
focal_alpha (float) – Focal loss alpha parameter
focal_gamma (float) – Focal loss gamma parameter
label_smoothing (float) – Label smoothing factor
use_class_weights (bool) – Use class weights for imbalanced data
obj_class_weights (Optional[dict]) – Object class weights
rel_class_weights (Optional[dict]) – Relation class weights
- __init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 2.0, bbox_loss_weight: float = 2.5, giou_loss_weight: float = 1.0, use_focal_loss: bool = False, use_class_weights: bool = False, obj_class_weights: dict | None = None, rel_class_weights: dict | None = None) None
- class m3sgg.core.config.structured.sttran.STTRANDataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_augmentation: bool = False, augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5)[source]
Bases:
DataConfig
STTRAN-specific data configuration.
- Parameters:
max_objects (int) – Maximum number of objects per frame
max_relations (int) – Maximum number of relations per frame
max_frames (int) – Maximum number of frames per video
use_temporal_sampling (bool) – Use temporal sampling
temporal_stride (int) – Temporal stride for sampling
use_augmentation (bool) – Use data augmentation
augmentation_prob (float) – Probability of applying augmentation
use_negative_sampling (bool) – Use negative sampling for relations
negative_ratio (float) – Ratio of negative samples
- __init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_augmentation: bool = False, augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5) None
Tempura model configuration classes.
This module provides structured configuration classes specifically for the Tempura model, which includes memory mechanisms and GMM heads.
- author:
M3SGG Team
- version:
0.1.0
- class m3sgg.core.config.structured.tempura.TempuraConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'tempura', use_matcher: bool = False, eval: bool = False, obj_head: str = 'gmm', rel_head: str = 'gmm', K: int = 4, gmm_components: int = 4, gmm_covariance_type: str = 'full', gmm_reg_covar: float = 1e-06, rel_mem_compute: str | None = None, obj_mem_compute: bool = False, take_obj_mem_feat: bool = False, obj_mem_weight_type: str = 'simple', rel_mem_weight_type: str = 'simple', mem_feat_selection: str = 'manual', mem_fusion: str = 'early', mem_feat_lambda: float | None = None, mem_size: int = 1000, mem_dim: int = 256, mem_update_rate: float = 0.1, mem_temperature: float = 1.0, use_memory_attention: bool = True, memory_dropout: float = 0.1, pseudo_thresh: int = 7, obj_unc: bool = False, rel_unc: bool = False, obj_loss_weighting: str | None = None, rel_loss_weighting: str | None = None, mlm: bool = False, eos_coef: float = 1, obj_con_loss: str | None = None, lambda_con: float = 1, tracking: bool = True, use_prior: bool = False, prior_weight: float = 0.1)[source]
Bases:
BaseConfig
Configuration for Tempura model.
Tempura is a memory-enhanced model for video scene graph generation that uses Gaussian Mixture Model (GMM) heads and memory mechanisms for improved performance.
- Parameters:
model_type (str) – Model type identifier
obj_head (str) – Object classification head type
rel_head (str) – Relation classification head type
K (int) – Number of mixture models
rel_mem_compute (Optional[str]) – Relation memory computation type
obj_mem_compute (bool) – Object memory computation
take_obj_mem_feat (bool) – Take object memory features
obj_mem_weight_type (str) – Object memory weight type
rel_mem_weight_type (str) – Relation memory weight type
mem_feat_selection (str) – Memory feature selection method
mem_fusion (str) – Memory fusion method
mem_feat_lambda (Optional[float]) – Memory feature lambda
pseudo_thresh (int) – Pseudo label threshold
obj_unc (bool) – Object uncertainty
rel_unc (bool) – Relation uncertainty
obj_loss_weighting (Optional[str]) – Object loss weighting
rel_loss_weighting (Optional[str]) – Relation loss weighting
mlm (bool) – Masked language modeling
eos_coef (float) – End-of-sequence coefficient
obj_con_loss (Optional[str]) – Object consistency loss
lambda_con (float) – Consistency loss coefficient
tracking (bool) – Enable tracking
mem_size (int) – Memory size
mem_dim (int) – Memory dimension
mem_update_rate (float) – Memory update rate
mem_temperature (float) – Memory temperature
use_memory_attention (bool) – Use memory attention
memory_dropout (float) – Memory dropout rate
gmm_components (int) – Number of GMM components
gmm_covariance_type (str) – GMM covariance type
gmm_reg_covar (float) – GMM regularization covariance
use_prior (bool) – Use prior knowledge
prior_weight (float) – Prior weight
- __init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'tempura', use_matcher: bool = False, eval: bool = False, obj_head: str = 'gmm', rel_head: str = 'gmm', K: int = 4, gmm_components: int = 4, gmm_covariance_type: str = 'full', gmm_reg_covar: float = 1e-06, rel_mem_compute: str | None = None, obj_mem_compute: bool = False, take_obj_mem_feat: bool = False, obj_mem_weight_type: str = 'simple', rel_mem_weight_type: str = 'simple', mem_feat_selection: str = 'manual', mem_fusion: str = 'early', mem_feat_lambda: float | None = None, mem_size: int = 1000, mem_dim: int = 256, mem_update_rate: float = 0.1, mem_temperature: float = 1.0, use_memory_attention: bool = True, memory_dropout: float = 0.1, pseudo_thresh: int = 7, obj_unc: bool = False, rel_unc: bool = False, obj_loss_weighting: str | None = None, rel_loss_weighting: str | None = None, mlm: bool = False, eos_coef: float = 1, obj_con_loss: str | None = None, lambda_con: float = 1, tracking: bool = True, use_prior: bool = False, prior_weight: float = 0.1) None
- class m3sgg.core.config.structured.tempura.TempuraTrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, memory_warmup_epochs: int = 5, memory_lr: float = 0.0001, gmm_lr: float = 0.001, use_memory_scheduler: bool = True, memory_decay: float = 0.99, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_memory_regularization: bool = True, memory_reg_weight: float = 0.01)[source]
Bases:
TrainingConfig
Tempura-specific training configuration.
- Parameters:
memory_warmup_epochs (int) – Number of epochs for memory warmup
memory_lr (float) – Learning rate for memory parameters
gmm_lr (float) – Learning rate for GMM parameters
use_memory_scheduler (bool) – Use separate scheduler for memory
memory_decay (float) – Memory decay rate
use_curriculum_learning (bool) – Use curriculum learning
curriculum_epochs (int) – Number of epochs for curriculum
use_memory_regularization (bool) – Use memory regularization
memory_reg_weight (float) – Memory regularization weight
- __init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, memory_warmup_epochs: int = 5, memory_lr: float = 0.0001, gmm_lr: float = 0.001, use_memory_scheduler: bool = True, memory_decay: float = 0.99, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_memory_regularization: bool = True, memory_reg_weight: float = 0.01) None
- class m3sgg.core.config.structured.tempura.TempuraLossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 1.0, memory_loss_weight: float = 0.1, gmm_loss_weight: float = 0.1, consistency_loss_weight: float = 0.1, uncertainty_loss_weight: float = 0.1, use_memory_loss: bool = True, use_gmm_loss: bool = True, use_consistency_loss: bool = True, use_uncertainty_loss: bool = True, memory_loss_type: str = 'mse', gmm_loss_type: str = 'nll', consistency_loss_type: str = 'mse', uncertainty_loss_type: str = 'kl')[source]
Bases:
LossConfig
Tempura-specific loss configuration.
- Parameters:
obj_loss_weight (float) – Weight for object classification loss
rel_loss_weight (float) – Weight for relation classification loss
memory_loss_weight (float) – Weight for memory loss
gmm_loss_weight (float) – Weight for GMM loss
consistency_loss_weight (float) – Weight for consistency loss
uncertainty_loss_weight (float) – Weight for uncertainty loss
use_memory_loss (bool) – Use memory loss
use_gmm_loss (bool) – Use GMM loss
use_consistency_loss (bool) – Use consistency loss
use_uncertainty_loss (bool) – Use uncertainty loss
memory_loss_type (str) – Type of memory loss
gmm_loss_type (str) – Type of GMM loss
consistency_loss_type (str) – Type of consistency loss
uncertainty_loss_type (str) – Type of uncertainty loss
- __init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 1.0, memory_loss_weight: float = 0.1, gmm_loss_weight: float = 0.1, consistency_loss_weight: float = 0.1, uncertainty_loss_weight: float = 0.1, use_memory_loss: bool = True, use_gmm_loss: bool = True, use_consistency_loss: bool = True, use_uncertainty_loss: bool = True, memory_loss_type: str = 'mse', gmm_loss_type: str = 'nll', consistency_loss_type: str = 'mse', uncertainty_loss_type: str = 'kl') None
- class m3sgg.core.config.structured.tempura.TempuraDataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_memory_sampling: bool = True, memory_sampling_ratio: float = 0.1, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_pseudo_labels: bool = False, pseudo_label_threshold: float = 0.7, use_uncertainty_sampling: bool = False, uncertainty_threshold: float = 0.5)[source]
Bases:
DataConfig
Tempura-specific data configuration.
- Parameters:
max_objects (int) – Maximum number of objects per frame
max_relations (int) – Maximum number of relations per frame
max_frames (int) – Maximum number of frames per video
use_temporal_sampling (bool) – Use temporal sampling
temporal_stride (int) – Temporal stride for sampling
use_memory_sampling (bool) – Use memory sampling
memory_sampling_ratio (float) – Memory sampling ratio
use_negative_sampling (bool) – Use negative sampling
negative_ratio (float) – Ratio of negative samples
use_pseudo_labels (bool) – Use pseudo labels
pseudo_label_threshold (float) – Pseudo label threshold
use_uncertainty_sampling (bool) – Use uncertainty sampling
uncertainty_threshold (float) – Uncertainty threshold
- __init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_memory_sampling: bool = True, memory_sampling_ratio: float = 0.1, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_pseudo_labels: bool = False, pseudo_label_threshold: float = 0.7, use_uncertainty_sampling: bool = False, uncertainty_threshold: float = 0.5) None
SceneLLM model configuration classes.
This module provides structured configuration classes specifically for the SceneLLM model, which combines VQ-VAE with language models.
- author:
M3SGG Team
- version:
0.1.0
- class m3sgg.core.config.structured.scenellm.SceneLLMConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'scenellm', use_matcher: bool = False, eval: bool = False, embed_dim: int = 1024, codebook_size: int = 8192, commitment_cost: float = 0.25, vqvae_hidden_dim: int = 512, vqvae_num_layers: int = 4, vqvae_num_resblocks: int = 2, vqvae_dropout: float = 0.1, vqvae_use_attention: bool = True, vqvae_attention_heads: int = 8, llm_name: str = 'google/gemma-2-2b', llm_max_length: int = 512, llm_temperature: float = 0.7, llm_top_p: float = 0.9, llm_top_k: int = 50, llm_repetition_penalty: float = 1.1, llm_do_sample: bool = True, llm_pad_token_id: int = 0, llm_eos_token_id: int = 1, llm_bos_token_id: int = 2, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, ot_step: int = 512, vqvae_epochs: int = 5, stage1_iterations: int = 30000, stage2_iterations: int = 50000, alpha_obj: float = 1.0, alpha_rel: float = 1.0, scenellm_training_stage: str = 'vqvae', disable_checkpoint_saving: bool = False, use_peft: bool = True, peft_config: dict | None = None, use_gradient_checkpointing: bool = True, use_flash_attention: bool = False, use_8bit_optimizer: bool = False, use_4bit_quantization: bool = False)[source]
Bases:
BaseConfig
Configuration for SceneLLM model.
SceneLLM combines VQ-VAE (Vector Quantized Variational AutoEncoder) with language models for scene graph generation and text summarization.
- Parameters:
model_type (str) – Model type identifier
embed_dim (int) – Embedding dimension for VQ-VAE
codebook_size (int) – Size of VQ-VAE codebook
commitment_cost (float) – Commitment cost for VQ-VAE
llm_name (str) – LLM model name
lora_r (int) – LoRA rank
lora_alpha (int) – LoRA alpha
lora_dropout (float) – LoRA dropout
ot_step (int) – Step size for optimal transport codebook update
vqvae_epochs (int) – Epochs for VQ-VAE pretraining
stage1_iterations (int) – Iterations for stage 1 training
stage2_iterations (int) – Iterations for stage 2 training
alpha_obj (float) – Weight for object loss in SceneLLM
alpha_rel (float) – Weight for relation loss in SceneLLM
scenellm_training_stage (str) – SceneLLM training stage
disable_checkpoint_saving (bool) – Disable checkpoint saving
vqvae_hidden_dim (int) – VQ-VAE hidden dimension
vqvae_num_layers (int) – VQ-VAE number of layers
vqvae_num_resblocks (int) – VQ-VAE number of residual blocks
vqvae_dropout (float) – VQ-VAE dropout rate
vqvae_use_attention (bool) – VQ-VAE use attention
vqvae_attention_heads (int) – VQ-VAE attention heads
llm_max_length (int) – LLM maximum sequence length
llm_temperature (float) – LLM temperature for generation
llm_top_p (float) – LLM top-p sampling
llm_top_k (int) – LLM top-k sampling
llm_repetition_penalty (float) – LLM repetition penalty
llm_do_sample (bool) – LLM do sampling
llm_pad_token_id (int) – LLM pad token ID
llm_eos_token_id (int) – LLM EOS token ID
llm_bos_token_id (int) – LLM BOS token ID
use_peft (bool) – Use PEFT (Parameter Efficient Fine-Tuning)
peft_config (Optional[dict]) – PEFT configuration
use_gradient_checkpointing (bool) – Use gradient checkpointing
use_flash_attention (bool) – Use flash attention
use_8bit_optimizer (bool) – Use 8-bit optimizer
use_4bit_quantization (bool) – Use 4-bit quantization
- __init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'scenellm', use_matcher: bool = False, eval: bool = False, embed_dim: int = 1024, codebook_size: int = 8192, commitment_cost: float = 0.25, vqvae_hidden_dim: int = 512, vqvae_num_layers: int = 4, vqvae_num_resblocks: int = 2, vqvae_dropout: float = 0.1, vqvae_use_attention: bool = True, vqvae_attention_heads: int = 8, llm_name: str = 'google/gemma-2-2b', llm_max_length: int = 512, llm_temperature: float = 0.7, llm_top_p: float = 0.9, llm_top_k: int = 50, llm_repetition_penalty: float = 1.1, llm_do_sample: bool = True, llm_pad_token_id: int = 0, llm_eos_token_id: int = 1, llm_bos_token_id: int = 2, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, ot_step: int = 512, vqvae_epochs: int = 5, stage1_iterations: int = 30000, stage2_iterations: int = 50000, alpha_obj: float = 1.0, alpha_rel: float = 1.0, scenellm_training_stage: str = 'vqvae', disable_checkpoint_saving: bool = False, use_peft: bool = True, peft_config: dict | None = None, use_gradient_checkpointing: bool = True, use_flash_attention: bool = False, use_8bit_optimizer: bool = False, use_4bit_quantization: bool = False) None
- class m3sgg.core.config.structured.scenellm.SceneLLMTrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 5, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, vqvae_lr: float = 0.0001, llm_lr: float = 2e-05, vqvae_weight_decay: float = 0.0001, llm_weight_decay: float = 0.01, use_warmup: bool = True, warmup_steps: int = 1000, use_cosine_schedule: bool = True, cosine_min_lr: float = 1e-07, use_linear_schedule: bool = False, linear_min_lr: float = 1e-07, use_adafactor: bool = False, adafactor_scale_parameter: bool = True, adafactor_relative_step_size: bool = True, adafactor_warmup_init: bool = False, use_dataloader_pin_memory: bool = True, dataloader_num_workers: int = 4, dataloader_prefetch_factor: int = 2, use_mixed_precision: bool = True, mixed_precision_backend: str = 'apex', mixed_precision_loss_scale: str = 'dynamic', use_gradient_accumulation: bool = True, gradient_accumulation_steps: int = 4, use_gradient_clipping: bool = True, use_ema: bool = False, ema_decay: float = 0.999, use_swa: bool = False, swa_lr: float = 1e-05, swa_epochs: int = 5, use_early_stopping: bool = True, early_stopping_min_delta: float = 0.0001, early_stopping_monitor: str = 'val_loss', early_stopping_mode: str = 'min')[source]
Bases:
TrainingConfig
SceneLLM-specific training configuration.
- Parameters:
vqvae_lr (float) – Learning rate for VQ-VAE
llm_lr (float) – Learning rate for LLM
vqvae_weight_decay (float) – Weight decay for VQ-VAE
llm_weight_decay (float) – Weight decay for LLM
use_warmup (bool) – Use learning rate warmup
warmup_steps (int) – Number of warmup steps
use_cosine_schedule (bool) – Use cosine learning rate schedule
cosine_min_lr (float) – Minimum learning rate for cosine schedule
use_linear_schedule (bool) – Use linear learning rate schedule
linear_min_lr (float) – Minimum learning rate for linear schedule
use_adafactor (bool) – Use Adafactor optimizer
adafactor_scale_parameter (bool) – Adafactor scale parameter
adafactor_relative_step_size (bool) – Adafactor relative step size
adafactor_warmup_init (bool) – Adafactor warmup init
use_dataloader_pin_memory (bool) – Use DataLoader pin memory
dataloader_num_workers (int) – DataLoader number of workers
dataloader_prefetch_factor (int) – DataLoader prefetch factor
use_mixed_precision (bool) – Use mixed precision training
mixed_precision_backend (str) – Mixed precision backend
mixed_precision_loss_scale (str) – Mixed precision loss scale
use_gradient_accumulation (bool) – Use gradient accumulation
gradient_accumulation_steps (int) – Gradient accumulation steps
use_gradient_clipping (bool) – Use gradient clipping
max_grad_norm (float) – Maximum gradient norm
use_ema (bool) – Use exponential moving average
ema_decay (float) – EMA decay rate
use_swa (bool) – Use stochastic weight averaging
swa_lr (float) – SWA learning rate
swa_epochs (int) – SWA epochs
use_early_stopping (bool) – Use early stopping
early_stopping_patience (int) – Early stopping patience
early_stopping_min_delta (float) – Early stopping minimum delta
early_stopping_monitor (str) – Early stopping monitor metric
early_stopping_mode (str) – Early stopping mode
- __init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 5, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, vqvae_lr: float = 0.0001, llm_lr: float = 2e-05, vqvae_weight_decay: float = 0.0001, llm_weight_decay: float = 0.01, use_warmup: bool = True, warmup_steps: int = 1000, use_cosine_schedule: bool = True, cosine_min_lr: float = 1e-07, use_linear_schedule: bool = False, linear_min_lr: float = 1e-07, use_adafactor: bool = False, adafactor_scale_parameter: bool = True, adafactor_relative_step_size: bool = True, adafactor_warmup_init: bool = False, use_dataloader_pin_memory: bool = True, dataloader_num_workers: int = 4, dataloader_prefetch_factor: int = 2, use_mixed_precision: bool = True, mixed_precision_backend: str = 'apex', mixed_precision_loss_scale: str = 'dynamic', use_gradient_accumulation: bool = True, gradient_accumulation_steps: int = 4, use_gradient_clipping: bool = True, use_ema: bool = False, ema_decay: float = 0.999, use_swa: bool = False, swa_lr: float = 1e-05, swa_epochs: int = 5, use_early_stopping: bool = True, early_stopping_min_delta: float = 0.0001, early_stopping_monitor: str = 'val_loss', early_stopping_mode: str = 'min') None
- class m3sgg.core.config.structured.scenellm.SceneLLMLossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, vqvae_loss_weight: float = 1.0, llm_loss_weight: float = 1.0, obj_loss_weight: float = 1.0, rel_loss_weight: float = 1.0, commitment_loss_weight: float = 0.25, perceptual_loss_weight: float = 0.1, use_commitment_loss: bool = True, use_perceptual_loss: bool = False, use_kl_loss: bool = False, kl_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01, use_auxiliary_loss: bool = False, auxiliary_loss_weight: float = 0.1)[source]
Bases:
LossConfig
SceneLLM-specific loss configuration.
- Parameters:
vqvae_loss_weight (float) – Weight for VQ-VAE loss
llm_loss_weight (float) – Weight for LLM loss
obj_loss_weight (float) – Weight for object loss
rel_loss_weight (float) – Weight for relation loss
commitment_loss_weight (float) – Weight for commitment loss
perceptual_loss_weight (float) – Weight for perceptual loss
use_commitment_loss (bool) – Use commitment loss
use_perceptual_loss (bool) – Use perceptual loss
use_kl_loss (bool) – Use KL divergence loss
kl_loss_weight (float) – Weight for KL divergence loss
use_contrastive_loss (bool) – Use contrastive loss
contrastive_loss_weight (float) – Weight for contrastive loss
contrastive_temperature (float) – Contrastive loss temperature
use_consistency_loss (bool) – Use consistency loss
consistency_loss_weight (float) – Weight for consistency loss
use_regularization_loss (bool) – Use regularization loss
regularization_loss_weight (float) – Weight for regularization loss
use_auxiliary_loss (bool) – Use auxiliary loss
auxiliary_loss_weight (float) – Weight for auxiliary loss
- __init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, vqvae_loss_weight: float = 1.0, llm_loss_weight: float = 1.0, obj_loss_weight: float = 1.0, rel_loss_weight: float = 1.0, commitment_loss_weight: float = 0.25, perceptual_loss_weight: float = 0.1, use_commitment_loss: bool = True, use_perceptual_loss: bool = False, use_kl_loss: bool = False, kl_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01, use_auxiliary_loss: bool = False, auxiliary_loss_weight: float = 0.1) None
- class m3sgg.core.config.structured.scenellm.SceneLLMDataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, max_text_length: int = 512, use_text_augmentation: bool = False, text_augmentation_prob: float = 0.5, use_image_augmentation: bool = False, image_augmentation_prob: float = 0.5, use_temporal_augmentation: bool = False, temporal_augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_hard_negative_mining: bool = False, hard_negative_ratio: float = 0.2, use_contrastive_sampling: bool = False, contrastive_ratio: float = 0.3, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_dynamic_sampling: bool = False, dynamic_sampling_alpha: float = 0.5, use_balanced_sampling: bool = False, balanced_sampling_alpha: float = 0.5)[source]
Bases:
DataConfig
SceneLLM-specific data configuration.
- Parameters:
max_objects (int) – Maximum number of objects per frame
max_relations (int) – Maximum number of relations per frame
max_frames (int) – Maximum number of frames per video
max_text_length (int) – Maximum text length
use_text_augmentation (bool) – Use text augmentation
text_augmentation_prob (float) – Text augmentation probability
use_image_augmentation (bool) – Use image augmentation
image_augmentation_prob (float) – Image augmentation probability
use_temporal_augmentation (bool) – Use temporal augmentation
temporal_augmentation_prob (float) – Temporal augmentation probability
use_negative_sampling (bool) – Use negative sampling
negative_ratio (float) – Ratio of negative samples
use_hard_negative_mining (bool) – Use hard negative mining
hard_negative_ratio (float) – Ratio of hard negative samples
use_contrastive_sampling (bool) – Use contrastive sampling
contrastive_ratio (float) – Ratio of contrastive samples
use_curriculum_learning (bool) – Use curriculum learning
curriculum_epochs (int) – Number of epochs for curriculum
use_dynamic_sampling (bool) – Use dynamic sampling
dynamic_sampling_alpha (float) – Dynamic sampling alpha
use_balanced_sampling (bool) – Use balanced sampling
balanced_sampling_alpha (float) – Balanced sampling alpha
- __init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, max_text_length: int = 512, use_text_augmentation: bool = False, text_augmentation_prob: float = 0.5, use_image_augmentation: bool = False, image_augmentation_prob: float = 0.5, use_temporal_augmentation: bool = False, temporal_augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_hard_negative_mining: bool = False, hard_negative_ratio: float = 0.2, use_contrastive_sampling: bool = False, contrastive_ratio: float = 0.3, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_dynamic_sampling: bool = False, dynamic_sampling_alpha: float = 0.5, use_balanced_sampling: bool = False, balanced_sampling_alpha: float = 0.5) None
OED model configuration classes.
This module provides structured configuration classes specifically for the OED (Object-Event Detection) model.
- author:
M3SGG Team
- version:
0.1.0
- class m3sgg.core.config.structured.oed.OEDConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'oed', use_matcher: bool = False, eval: bool = False, num_queries: int = 100, dec_layers_hopd: int = 6, dec_layers_interaction: int = 6, num_attn_classes: int = 3, num_spatial_classes: int = 6, num_contacting_classes: int = 17, alpha: float = 0.5, oed_use_matching: bool = False, bbox_loss_coef: float = 2.5, giou_loss_coef: float = 1.0, obj_loss_coef: float = 1.0, rel_loss_coef: float = 2.0, oed_eos_coef: float = 0.1, interval1: int = 4, interval2: int = 4, num_ref_frames: int = 2, oed_variant: str = 'multi', fuse_semantic_pos: bool = False, query_temporal_interaction: bool = False, hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, use_bbox_encoding: bool = True, bbox_encoding_dim: int = 128, use_positional_encoding: bool = True, positional_encoding_dim: int = 128, use_temporal_encoding: bool = True, temporal_encoding_dim: int = 128, use_attention_weights: bool = False, attention_dropout: float = 0.1, ffn_dim: int = 1024, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True, gradient_checkpointing: bool = False, use_memory_efficient_attention: bool = False, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01)[source]
Bases:
BaseConfig
Configuration for OED model.
OED (Object-Event Detection) is a transformer-based model for video scene graph generation that uses object queries and attention mechanisms for detection.
- Parameters:
model_type (str) – Model type identifier
num_queries (int) – Number of query slots for OED
dec_layers_hopd (int) – Number of hopd decoding layers in OED transformer
dec_layers_interaction (int) – Number of interaction decoding layers in OED transformer
num_attn_classes (int) – Number of attention classes
num_spatial_classes (int) – Number of spatial classes
num_contacting_classes (int) – Number of contacting classes
alpha (float) – Focal loss alpha for OED
oed_use_matching (bool) – Use obj/sub matching 2class loss in OED decoder
bbox_loss_coef (float) – L1 box coefficient
giou_loss_coef (float) – GIoU box coefficient
obj_loss_coef (float) – Object classification coefficient
rel_loss_coef (float) – Relation classification coefficient
oed_eos_coef (float) – Relative classification weight of no-object class for OED
interval1 (int) – Interval for training frame selection
interval2 (int) – Interval for test frame selection
num_ref_frames (int) – Number of reference frames
oed_variant (str) – OED variant (single/multi)
fuse_semantic_pos (bool) – Fuse semantic and positional embeddings
query_temporal_interaction (bool) – Enable query temporal interaction
hidden_dim (int) – Hidden dimension size
num_heads (int) – Number of attention heads
num_layers (int) – Number of transformer layers
dropout (float) – Dropout rate
use_bbox_encoding (bool) – Use bounding box encoding
bbox_encoding_dim (int) – Bounding box encoding dimension
use_positional_encoding (bool) – Use positional encoding
positional_encoding_dim (int) – Positional encoding dimension
use_temporal_encoding (bool) – Use temporal encoding
temporal_encoding_dim (int) – Temporal encoding dimension
use_attention_weights (bool) – Use attention weights for visualization
attention_dropout (float) – Attention dropout rate
ffn_dim (int) – Feed-forward network dimension
activation (str) – Activation function
norm_type (str) – Normalization type
use_bias (bool) – Use bias in linear layers
gradient_checkpointing (bool) – Use gradient checkpointing
use_memory_efficient_attention (bool) – Use memory efficient attention
use_auxiliary_loss (bool) – Use auxiliary loss
aux_loss_weight (float) – Auxiliary loss weight
use_contrastive_loss (bool) – Use contrastive loss
contrastive_loss_weight (float) – Contrastive loss weight
contrastive_temperature (float) – Contrastive loss temperature
use_consistency_loss (bool) – Use consistency loss
consistency_loss_weight (float) – Consistency loss weight
use_regularization_loss (bool) – Use regularization loss
regularization_loss_weight (float) – Regularization loss weight
- __init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'oed', use_matcher: bool = False, eval: bool = False, num_queries: int = 100, dec_layers_hopd: int = 6, dec_layers_interaction: int = 6, num_attn_classes: int = 3, num_spatial_classes: int = 6, num_contacting_classes: int = 17, alpha: float = 0.5, oed_use_matching: bool = False, bbox_loss_coef: float = 2.5, giou_loss_coef: float = 1.0, obj_loss_coef: float = 1.0, rel_loss_coef: float = 2.0, oed_eos_coef: float = 0.1, interval1: int = 4, interval2: int = 4, num_ref_frames: int = 2, oed_variant: str = 'multi', fuse_semantic_pos: bool = False, query_temporal_interaction: bool = False, hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, use_bbox_encoding: bool = True, bbox_encoding_dim: int = 128, use_positional_encoding: bool = True, positional_encoding_dim: int = 128, use_temporal_encoding: bool = True, temporal_encoding_dim: int = 128, use_attention_weights: bool = False, attention_dropout: float = 0.1, ffn_dim: int = 1024, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True, gradient_checkpointing: bool = False, use_memory_efficient_attention: bool = False, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01) None
- class m3sgg.core.config.structured.oed.OEDTrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 2, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, scheduler_type: str = 'reduce_on_plateau', scheduler_patience: int = 3, scheduler_factor: float = 0.5, weight_decay: float = 0.0001, clip_grad_norm: float = 1.0, use_amp: bool = False, accumulation_steps: int = 1, use_ema: bool = False, ema_decay: float = 0.999, use_swa: bool = False, swa_lr: float = 1e-05, swa_epochs: int = 5, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01)[source]
Bases:
TrainingConfig
OED-specific training configuration.
- Parameters:
warmup_epochs (int) – Number of warmup epochs
scheduler_type (str) – Learning rate scheduler type
scheduler_patience (int) – Scheduler patience
scheduler_factor (float) – Scheduler reduction factor
weight_decay (float) – Weight decay for regularization
clip_grad_norm (float) – Gradient clipping norm
use_amp (bool) – Use automatic mixed precision
accumulation_steps (int) – Gradient accumulation steps
use_ema (bool) – Use exponential moving average
ema_decay (float) – EMA decay rate
use_swa (bool) – Use stochastic weight averaging
swa_lr (float) – SWA learning rate
swa_epochs (int) – SWA epochs
use_curriculum_learning (bool) – Use curriculum learning
curriculum_epochs (int) – Number of epochs for curriculum
use_auxiliary_loss (bool) – Use auxiliary loss
aux_loss_weight (float) – Auxiliary loss weight
use_contrastive_loss (bool) – Use contrastive loss
contrastive_loss_weight (float) – Contrastive loss weight
contrastive_temperature (float) – Contrastive loss temperature
use_consistency_loss (bool) – Use consistency loss
consistency_loss_weight (float) – Consistency loss weight
use_regularization_loss (bool) – Use regularization loss
regularization_loss_weight (float) – Regularization loss weight
- __init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 2, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, scheduler_type: str = 'reduce_on_plateau', scheduler_patience: int = 3, scheduler_factor: float = 0.5, weight_decay: float = 0.0001, clip_grad_norm: float = 1.0, use_amp: bool = False, accumulation_steps: int = 1, use_ema: bool = False, ema_decay: float = 0.999, use_swa: bool = False, swa_lr: float = 1e-05, swa_epochs: int = 5, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01) None
- class m3sgg.core.config.structured.oed.OEDLossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 2.0, bbox_loss_weight: float = 2.5, giou_loss_weight: float = 1.0, use_focal_loss: bool = True, use_class_weights: bool = False, obj_class_weights: dict | None = None, rel_class_weights: dict | None = None, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01)[source]
Bases:
LossConfig
OED-specific loss configuration.
- Parameters:
obj_loss_weight (float) – Weight for object classification loss
rel_loss_weight (float) – Weight for relation classification loss
bbox_loss_weight (float) – Weight for bounding box regression loss
giou_loss_weight (float) – Weight for GIoU loss
use_focal_loss (bool) – Use focal loss for classification
focal_alpha (float) – Focal loss alpha parameter
focal_gamma (float) – Focal loss gamma parameter
label_smoothing (float) – Label smoothing factor
use_class_weights (bool) – Use class weights for imbalanced data
obj_class_weights (Optional[dict]) – Object class weights
rel_class_weights (Optional[dict]) – Relation class weights
use_auxiliary_loss (bool) – Use auxiliary loss
aux_loss_weight (float) – Auxiliary loss weight
use_contrastive_loss (bool) – Use contrastive loss
contrastive_loss_weight (float) – Contrastive loss weight
contrastive_temperature (float) – Contrastive loss temperature
use_consistency_loss (bool) – Use consistency loss
consistency_loss_weight (float) – Consistency loss weight
use_regularization_loss (bool) – Use regularization loss
regularization_loss_weight (float) – Regularization loss weight
- __init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 2.0, bbox_loss_weight: float = 2.5, giou_loss_weight: float = 1.0, use_focal_loss: bool = True, use_class_weights: bool = False, obj_class_weights: dict | None = None, rel_class_weights: dict | None = None, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01) None
- class m3sgg.core.config.structured.oed.OEDDataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_augmentation: bool = False, augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_hard_negative_mining: bool = False, hard_negative_ratio: float = 0.2, use_contrastive_sampling: bool = False, contrastive_ratio: float = 0.3, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_dynamic_sampling: bool = False, dynamic_sampling_alpha: float = 0.5, use_balanced_sampling: bool = False, balanced_sampling_alpha: float = 0.5)[source]
Bases:
DataConfig
OED-specific data configuration.
- Parameters:
max_objects (int) – Maximum number of objects per frame
max_relations (int) – Maximum number of relations per frame
max_frames (int) – Maximum number of frames per video
use_temporal_sampling (bool) – Use temporal sampling
temporal_stride (int) – Temporal stride for sampling
use_augmentation (bool) – Use data augmentation
augmentation_prob (float) – Probability of applying augmentation
use_negative_sampling (bool) – Use negative sampling for relations
negative_ratio (float) – Ratio of negative samples
use_hard_negative_mining (bool) – Use hard negative mining
hard_negative_ratio (float) – Ratio of hard negative samples
use_contrastive_sampling (bool) – Use contrastive sampling
contrastive_ratio (float) – Ratio of contrastive samples
use_curriculum_learning (bool) – Use curriculum learning
curriculum_epochs (int) – Number of epochs for curriculum
use_dynamic_sampling (bool) – Use dynamic sampling
dynamic_sampling_alpha (float) – Dynamic sampling alpha
use_balanced_sampling (bool) – Use balanced sampling
balanced_sampling_alpha (float) – Balanced sampling alpha
- __init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_augmentation: bool = False, augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_hard_negative_mining: bool = False, hard_negative_ratio: float = 0.2, use_contrastive_sampling: bool = False, contrastive_ratio: float = 0.3, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_dynamic_sampling: bool = False, dynamic_sampling_alpha: float = 0.5, use_balanced_sampling: bool = False, balanced_sampling_alpha: float = 0.5) None
Detectors
- class m3sgg.core.detectors.faster_rcnn.detector(train, object_classes, use_SUPPLY, mode='predcls')[source]
Bases:
Module
Object detector module for scene graph generation.
Implements object detection functionality using Faster R-CNN backbone for scene graph generation tasks including predcls, sgcls, and sgdet modes.
- Parameters:
nn.Module (class) – Base PyTorch module class
- __init__(train, object_classes, use_SUPPLY, mode='predcls')[source]
Initialize the object detector.
- forward(im_data, im_info, gt_boxes, num_boxes, gt_annotation, im_all)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Detector factory for creating different object detection models.
This module provides a unified interface for creating various object detection models including Faster R-CNN, ViT-based detectors, DETR, YOLO, and ResNet variants.
- class m3sgg.core.detectors.factory.DetectorFactory(config, dataset_train, device, logger: Logger | None = None)[source]
Bases:
object
Factory class for creating object detection models based on configuration.
This factory provides a unified interface for creating different object detection models including Faster R-CNN, ViT-based detectors, DETR, YOLO, and ResNet variants. It handles detector instantiation logic and provides a consistent interface across different detector types for improved modularity and maintainability.
- Parameters:
config (Config) – Configuration object containing detector parameters
dataset_train (Dataset) – Training dataset for extracting class information
device (torch.device) – Device to place the detector on
logger (Optional[logging.Logger]) – Optional logger instance
- __init__(config, dataset_train, device, logger: Logger | None = None)[source]
Initialize the detector factory.
- Parameters:
config (Config) – Configuration object containing detector parameters
dataset_train (Dataset) – Training dataset for extracting class information
device (torch.device) – Device to place the detector on
logger (Optional[logging.Logger]) – Optional logger instance
- create_detector() Module [source]
Create a detector instance based on the configuration.
- Returns:
Instantiated detector
- Return type:
- Raises:
ValueError – If detector type is not supported
- class m3sgg.core.detectors.factory.BaseDetector[source]
Bases:
ABC
Abstract base class for all detector implementations.
This abstract base class ensures consistent interface across different detector types. All detector implementations must inherit from this class and implement the required abstract methods to maintain compatibility with the detector factory and training pipeline.
- Parameters:
ABC (ABC) – Abstract base class from abc module
- abstract forward(*args, **kwargs) Dict[str, Any] [source]
Forward pass of the detector.
This method performs the forward pass of the detector and returns a dictionary containing detection results including bounding boxes, class predictions, and confidence scores.
- Parameters:
args (Any) – Variable length argument list
kwargs (Any) – Arbitrary keyword arguments
- Returns:
Dictionary containing detection results
- Return type:
Dict[str, Any]
- abstract get_feature_extractor() Module [source]
Get the feature extraction backbone.
Returns the feature extraction backbone module used by the detector. This is typically the CNN or transformer backbone that extracts features from input images.
- Returns:
Feature extraction module
- Return type:
EASG Detectors
- class m3sgg.core.detectors.easg.object_detector_EASG.detector(train, object_classes, use_SUPPLY, mode='edgecls')[source]
Bases:
Module
Object detector module for EASG (Efficient and Accurate Scene Graph) generation.
Implements object detection functionality specifically designed for EASG scene graph generation with video-based detection capabilities.
- Parameters:
nn.Module (class) – Base PyTorch module class
- __init__(train, object_classes, use_SUPPLY, mode='edgecls')[source]
Initialize the EASG object detector.
- forward(im_data, im_info, gt_grounding, im_all)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Let’s get the relationships yo
- class m3sgg.core.detectors.easg.sttran_EASG.ObjectClassifier(mode='edgecls', obj_classes=None)[source]
Bases:
Module
Module for computing object contexts and edge contexts for EASG.
EASG-specific implementation of object classification and contextual feature extraction for efficient scene graph generation.
- Parameters:
nn.Module (class) – Base PyTorch module class
- __init__(mode='edgecls', obj_classes=None)[source]
Initialize the EASG object classifier.
- forward(entry)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class m3sgg.core.detectors.easg.sttran_EASG.ActionClassifier(mode='edgecls', verb_classes=None)[source]
Bases:
Module
- __init__(mode='edgecls', verb_classes=None)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(entry)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class m3sgg.core.detectors.easg.sttran_EASG.STTran(mode='edgecls', obj_classes=None, verb_classes=None, edge_class_num=None, enc_layer_num=None, dec_layer_num=None, use_visual_features=False)[source]
Bases:
Module
- __init__(mode='edgecls', obj_classes=None, verb_classes=None, edge_class_num=None, enc_layer_num=None, dec_layer_num=None, use_visual_features=False)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(entry)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Evaluation
- class m3sgg.core.evaluation.metrics.BasicSceneGraphEvaluator(mode, AG_object_classes, AG_all_predicates, AG_attention_predicates, AG_spatial_predicates, AG_contacting_predicates, iou_threshold=0.5, constraint=False, semithreshold=None, logger=None)[source]
Bases:
object
Evaluator for scene graph generation tasks.
Computes recall metrics for scene graph generation across different tasks (predcls, sgcls, sgdet) and handles constraint evaluation modes.
- Parameters:
object (class) – Base object class
- __init__(mode, AG_object_classes, AG_all_predicates, AG_attention_predicates, AG_spatial_predicates, AG_contacting_predicates, iou_threshold=0.5, constraint=False, semithreshold=None, logger=None)[source]
Initialize the scene graph evaluator.
- Parameters:
mode (str) – Evaluation mode (‘predcls’, ‘sgcls’, or ‘sgdet’)
AG_object_classes (list) – List of object class names
AG_all_predicates (list) – List of all predicate names
AG_attention_predicates (list) – List of attention predicate names
AG_spatial_predicates (list) – List of spatial predicate names
AG_contacting_predicates (list) – List of contacting predicate names
iou_threshold (float, optional) – IoU threshold for evaluation, defaults to 0.5
constraint (bool, optional) – Whether to use constraint evaluation, defaults to False
semithreshold (float, optional) – Semi-constraint threshold, defaults to None
logger (logging.Logger, optional) – Logger instance, defaults to None
- Returns:
None
- Return type:
None
- m3sgg.core.evaluation.metrics.evaluate_from_dict(gt_entry, pred_entry, mode, result_dict, method=None, threshold=0.9, **kwargs)[source]
Shortcut to doing evaluate_recall from dict :param gt_entry: Dictionary containing gt_relations, gt_boxes, gt_classes :param pred_entry: Dictionary containing pred_rels, pred_boxes (if detection), pred_classes :param result_dict: :param kwargs: :return:
- m3sgg.core.evaluation.metrics.evaluate_recall(gt_rels, gt_boxes, gt_classes, pred_rels, pred_boxes, pred_classes, rel_scores=None, cls_scores=None, iou_thresh=0.5, phrdet=False)[source]
Evaluates recall metrics for scene graph generation.
Computes recall by matching predicted relations to ground truth relations based on object detection IoU and relation class matching.
- Parameters:
gt_rels (numpy.ndarray) – Ground truth relations array of shape [#gt_rel, 3]
gt_boxes (numpy.ndarray) – Ground truth bounding boxes of shape [#gt_box, 4]
gt_classes (numpy.ndarray) – Ground truth object classes of shape [#gt_box]
pred_rels (numpy.ndarray) – Predicted relations array of shape [#pred_rel, 3] (id0, id1, rel)
pred_boxes (numpy.ndarray) – Predicted bounding boxes of shape [#pred_box, 4]
pred_classes (numpy.ndarray) – Predicted object classes of shape [#pred_box]
rel_scores (numpy.ndarray, optional) – Relation scores, defaults to None
cls_scores (numpy.ndarray, optional) – Classification scores, defaults to None
iou_thresh (float, optional) – IoU threshold for matching, defaults to 0.5
phrdet (bool, optional) – Whether to use phrase detection mode, defaults to False
- Returns:
Tuple containing predicate-to-GT matching, predicted 5-tuples, and relation scores
- Return type:
Training
Trainer class for modularized training loop.
- class m3sgg.core.training.trainer.Trainer(model: Module, config: Config, dataloader_train: DataLoader, dataloader_test: DataLoader, optimizer: Optimizer, scheduler: ReduceLROnPlateau, logger: Logger, object_detector: Module | None = None, object_detector_EASG: Module | None = None, matcher: Any | None = None, evaluator: BasicSceneGraphEvaluator | None = None, evaluator2: BasicSceneGraphEvaluator | None = None, dataset_train: Any | None = None, dataset_test: Any | None = None)[source]
Bases:
object
Main trainer class for scene graph generation models.
This class encapsulates the training loop, epoch management, and step execution for various scene graph generation models.
- Parameters:
model (torch.nn.Module) – The model to train
config (Config) – Configuration object containing training parameters
dataloader_train (torch.utils.data.DataLoader) – Training data loader
dataloader_test (torch.utils.data.DataLoader) – Test data loader
optimizer (Optimizer) – Optimizer for training
scheduler (ReduceLROnPlateau) – Learning rate scheduler
logger (logging.Logger) – Logger instance
- __init__(model: Module, config: Config, dataloader_train: DataLoader, dataloader_test: DataLoader, optimizer: Optimizer, scheduler: ReduceLROnPlateau, logger: Logger, object_detector: Module | None = None, object_detector_EASG: Module | None = None, matcher: Any | None = None, evaluator: BasicSceneGraphEvaluator | None = None, evaluator2: BasicSceneGraphEvaluator | None = None, dataset_train: Any | None = None, dataset_test: Any | None = None)[source]
Initialize the Trainer with all necessary components.
- Parameters:
model – The model to train
config – Configuration object
dataloader_train – Training data loader
dataloader_test – Test data loader
optimizer – Optimizer
scheduler – Learning rate scheduler
logger – Logger instance
object_detector – Object detector for Action Genome dataset
object_detector_EASG – Object detector for EASG dataset
matcher – Hungarian matcher for DSG-DETR
evaluator – Scene graph evaluator
evaluator2 – Secondary evaluator without constraints
dataset_train – Training dataset
dataset_test – Test dataset
- train_loop() None [source]
Main training loop that orchestrates the entire training process.
This method runs the complete training process including all epochs, evaluation, and checkpoint saving.
- train_epoch(epoch: int) None [source]
Train the model for one epoch.
- Parameters:
epoch (int) – Current epoch number
- train_step(batch_idx: int, train_iter: iter, unc_vals: Any | None = None) Dict[str, Tensor] [source]
Execute one training step.
- Parameters:
batch_idx (int) – Current batch index
train_iter (iter) – Training data iterator
unc_vals (Optional[Any]) – Uncertainty values for Tempura model
- Returns:
Dictionary of loss components
- Return type:
Dict[str, torch.Tensor]
- train_iter(max_iterations: int | None = None) Iterator[Dict[str, Any]] [source]
Iterative training function that yields training progress.
This function provides an iterator interface for training, allowing for more granular control over the training process and real-time monitoring.
Evaluation class for modularized evaluation loop.
This module provides a clean separation of evaluation logic from the main training script.
- class m3sgg.core.training.evaluation.Evaluator(evaluator: BasicSceneGraphEvaluator, evaluator2: BasicSceneGraphEvaluator | None = None, logger: Logger | None = None)[source]
Bases:
object
Evaluation class for scene graph generation models.
This class encapsulates the evaluation loop and metrics computation for various scene graph generation models.
- Parameters:
evaluator (BasicSceneGraphEvaluator) – Primary scene graph evaluator
evaluator2 (BasicSceneGraphEvaluator) – Secondary evaluator without constraints
logger (logging.Logger) – Logger instance
- __init__(evaluator: BasicSceneGraphEvaluator, evaluator2: BasicSceneGraphEvaluator | None = None, logger: Logger | None = None)[source]
Initialize the Evaluator with necessary components.
- Parameters:
evaluator – Primary scene graph evaluator
evaluator2 – Secondary evaluator without constraints
logger – Logger instance
- eval_loop(model: Module, dataloader_test: DataLoader, config: Any, object_detector: Module | None = None, object_detector_EASG: Module | None = None, matcher: Any | None = None, dataset_test: Any | None = None) Tuple[float, float] [source]
Run the complete evaluation loop.
- Parameters:
model (torch.nn.Module) – Model to evaluate
dataloader_test (torch.utils.data.DataLoader) – Test data loader
config (Any) – Configuration object
object_detector (Optional[torch.nn.Module]) – Object detector for Action Genome dataset
object_detector_EASG (Optional[torch.nn.Module]) – Object detector for EASG dataset
matcher (Optional[Any]) – Hungarian matcher for DSG-DETR
dataset_test (Optional[Any]) – Test dataset
- Returns:
Tuple of (score, mrecall)
- Return type:
Model factory for creating different model types based on configuration.
This module contains the model instantiation logic that was extracted from the monolithic training.py script to improve modularity and maintainability.
- class m3sgg.core.training.model_factory.ModelFactory(config, dataset_train, device, logger: Logger | None = None)[source]
Bases:
object
Factory class for creating model instances based on configuration.
This factory provides a unified interface for creating different scene graph generation models including STTran, DSG-DETR, STKET, TEMPURA, SceneLLM, OED, and VLM models. It handles model instantiation logic that was extracted from the monolithic training script to improve modularity and maintainability.
- Parameters:
config (Config) – Configuration object containing model parameters
dataset_train (Dataset) – Training dataset for extracting class information
device (torch.device) – Device to place the model on
logger (Optional[logging.Logger]) – Optional logger instance
- __init__(config, dataset_train, device, logger: Logger | None = None)[source]
Initialize the model factory.
- Parameters:
config (Config) – Configuration object containing model parameters
dataset_train (Dataset) – Training dataset for extracting class information
device (torch.device) – Device to place the model on
logger (Optional[logging.Logger]) – Optional logger instance
- create_model() Module [source]
Create a model instance based on the configuration.
- Returns:
Instantiated model
- Return type:
- Raises:
ValueError – If dataset or model type is not supported
Loss factory for creating different loss functions based on model type and configuration.
This module contains the loss function setup logic that was extracted from the monolithic training.py script to improve modularity and maintainability.
- class m3sgg.core.training.loss_factory.LossFactory(config, model, device, logger: Logger | None = None)[source]
Bases:
object
Factory class for creating loss functions based on configuration.
This factory provides a unified interface for creating different loss functions based on model type and configuration. It handles loss function setup logic that was extracted from the monolithic training script to improve modularity and maintainability. Supports basic losses for all models and model-specific losses for TEMPURA models.
- Parameters:
config (Config) – Configuration object containing loss parameters
model (torch.nn.Module) – Model instance for extracting class information
device (torch.device) – Device to place loss functions on
logger (Optional[logging.Logger]) – Optional logger instance
- __init__(config, model, device, logger: Logger | None = None)[source]
Initialize the loss factory.
- Parameters:
config (Config) – Configuration object containing loss parameters
model (torch.nn.Module) – Model instance for extracting class information
device (torch.device) – Device to place loss functions on
logger (Optional[logging.Logger]) – Optional logger instance