Core Library

The core library contains the main functionality including configuration, models, detectors, evaluation, and training components.

Configuration

Unified configuration interface for M3SGG.

This module provides a unified interface that can work with both legacy and modern configuration systems, allowing for gradual migration.

author:

M3SGG Team

version:

0.1.0

class m3sgg.core.config.unified.UnifiedConfig(config_path: str | None = None, model_type: str | None = None, use_modern: bool = True, cli_args: List[str] | None = None, overrides: Dict[str, Any] | None = None)[source]

Bases: object

Unified configuration interface supporting both legacy and modern systems.

This class provides a unified interface that can work with both the legacy argparse-based configuration system and the modern OmegaConf-based system. It automatically detects which system to use based on the configuration and provides a consistent interface.

Parameters:
  • config_path (Optional[str]) – Path to configuration file

  • model_type (Optional[str]) – Model type for structured configuration

  • use_modern (bool) – Force use of modern configuration system

  • cli_args (Optional[List[str]]) – Command-line arguments

  • overrides (Optional[Dict[str, Any]]) – Configuration overrides

__init__(config_path: str | None = None, model_type: str | None = None, use_modern: bool = True, cli_args: List[str] | None = None, overrides: Dict[str, Any] | None = None)[source]

Initialize the unified configuration.

Parameters:
  • config_path (Optional[str]) – Path to configuration file

  • model_type (Optional[str]) – Model type for structured configuration

  • use_modern (bool) – Force use of modern configuration system

  • cli_args (Optional[List[str]]) – Command-line arguments

  • overrides (Optional[Dict[str, Any]]) – Configuration overrides

get(key: str, default: Any | None = None) Any[source]

Get a configuration value by key.

Parameters:
  • key (str) – Configuration key (supports dot notation for modern)

  • default (Any) – Default value if key not found

Returns:

Configuration value

Return type:

Any

set(key: str, value: Any)[source]

Set a configuration value by key.

Parameters:
  • key (str) – Configuration key

  • value (Any) – Value to set

update(key: str, value: Any, merge: bool = True)[source]

Update a configuration value by key.

Parameters:
  • key (str) – Configuration key

  • value (Any) – Value to update

  • merge (bool) – Whether to merge dictionaries/lists (modern only)

save(path: str)[source]

Save configuration to a file.

Parameters:

path (str) – Path to save the configuration

to_dict(resolve: bool = True) Dict[str, Any][source]

Convert configuration to a dictionary.

Parameters:

resolve (bool) – Whether to resolve interpolations (modern only)

Returns:

Configuration as dictionary

Return type:

Dict[str, Any]

__getattr__(name: str) Any[source]

Allow attribute-style access to configuration values.

Parameters:

name (str) – Attribute name

Returns:

Configuration value

Return type:

Any

__getitem__(key: str) Any[source]

Allow dictionary-style access to configuration values.

Parameters:

key (str) – Configuration key

Returns:

Configuration value

Return type:

Any

__setitem__(key: str, value: Any)[source]

Allow dictionary-style setting of configuration values.

Parameters:
  • key (str) – Configuration key

  • value (Any) – Value to set

__contains__(key: str) bool[source]

Check if configuration contains a key.

Parameters:

key (str) – Configuration key

Returns:

True if key exists

Return type:

bool

__repr__() str[source]

String representation of the configuration.

Returns:

String representation

Return type:

str

__str__() str[source]

String representation of the configuration.

Returns:

String representation

Return type:

str

property is_modern: bool

Check if using modern configuration system.

Returns:

True if using modern system

Return type:

bool

property is_legacy: bool

Check if using legacy configuration system.

Returns:

True if using legacy system

Return type:

bool

m3sgg.core.config.unified.create_config(model_type: str, config_path: str | None = None, **overrides) UnifiedConfig[source]

Create a configuration for a specific model type.

Parameters:
  • model_type (str) – Model type identifier

  • config_path (Optional[str]) – Path to configuration file

  • overrides (dict) – Configuration overrides

Returns:

Unified configuration instance

Return type:

UnifiedConfig

m3sgg.core.config.unified.load_config_from_file(config_path: str) UnifiedConfig[source]

Load configuration from a file.

Parameters:

config_path (str) – Path to configuration file

Returns:

Unified configuration instance

Return type:

UnifiedConfig

m3sgg.core.config.unified.Config

alias of UnifiedConfig

Modern configuration system using OmegaConf.

This module provides a modern configuration management system using OmegaConf with support for YAML files, command-line overrides, interpolation, and structured validation.

author:

M3SGG Team

version:

0.1.0

class m3sgg.core.config.modern.ConfigManager(config_path: str | None = None, model_type: str | None = None, cli_args: List[str] | None = None, overrides: Dict[str, Any] | None = None)[source]

Bases: object

Modern configuration manager using OmegaConf.

This class provides a unified interface for configuration management that supports YAML files, command-line arguments, interpolation, and structured validation.

Parameters:
  • config_path (Optional[str]) – Path to configuration file

  • model_type (Optional[str]) – Model type for structured configuration

  • cli_args (Optional[List[str]]) – Command-line arguments to override config

  • overrides (Optional[Dict[str, Any]]) – Additional configuration overrides

__init__(config_path: str | None = None, model_type: str | None = None, cli_args: List[str] | None = None, overrides: Dict[str, Any] | None = None)[source]

Initialize the configuration manager.

Parameters:
  • config_path (Optional[str]) – Path to configuration file

  • model_type (Optional[str]) – Model type for structured configuration

  • cli_args (Optional[List[str]]) – Command-line arguments to override config

  • overrides (Optional[Dict[str, Any]]) – Additional configuration overrides

get(key: str, default: Any | None = None) Any[source]

Get a configuration value by key.

Parameters:
  • key (str) – Configuration key (supports dot notation)

  • default (Any) – Default value if key not found

Returns:

Configuration value

Return type:

Any

set(key: str, value: Any)[source]

Set a configuration value by key.

Parameters:
  • key (str) – Configuration key (supports dot notation)

  • value (Any) – Value to set

update(key: str, value: Any, merge: bool = True)[source]

Update a configuration value by key.

Parameters:
  • key (str) – Configuration key (supports dot notation)

  • value (Any) – Value to update

  • merge (bool) – Whether to merge dictionaries/lists

save(path: str)[source]

Save configuration to a YAML file.

Parameters:

path (str) – Path to save the configuration

to_dict(resolve: bool = True) Dict[str, Any][source]

Convert configuration to a dictionary.

Parameters:

resolve (bool) – Whether to resolve interpolations

Returns:

Configuration as dictionary

Return type:

Dict[str, Any]

__getattr__(name: str) Any[source]

Allow attribute-style access to configuration values.

Parameters:

name (str) – Attribute name

Returns:

Configuration value

Return type:

Any

__getitem__(key: str) Any[source]

Allow dictionary-style access to configuration values.

Parameters:

key (str) – Configuration key

Returns:

Configuration value

Return type:

Any

__setitem__(key: str, value: Any)[source]

Allow dictionary-style setting of configuration values.

Parameters:
  • key (str) – Configuration key

  • value (Any) – Value to set

__contains__(key: str) bool[source]

Check if configuration contains a key.

Parameters:

key (str) – Configuration key

Returns:

True if key exists

Return type:

bool

__repr__() str[source]

String representation of the configuration.

Returns:

String representation

Return type:

str

__str__() str[source]

String representation of the configuration.

Returns:

YAML representation of configuration

Return type:

str

m3sgg.core.config.modern.create_config(model_type: str, config_path: str | None = None, **overrides) ConfigManager[source]

Create a configuration for a specific model type.

Parameters:
  • model_type (str) – Model type identifier

  • config_path (Optional[str]) – Path to configuration file

  • overrides (dict) – Configuration overrides

Returns:

Configuration manager instance

Return type:

ConfigManager

m3sgg.core.config.modern.load_config_from_file(config_path: str) ConfigManager[source]

Load configuration from a YAML file.

Parameters:

config_path (str) – Path to configuration file

Returns:

Configuration manager instance

Return type:

ConfigManager

m3sgg.core.config.modern.merge_configs(*configs: ConfigManager) ConfigManager[source]

Merge multiple configuration managers.

Parameters:

configs (ConfigManager) – Configuration managers to merge

Returns:

Merged configuration manager

Return type:

ConfigManager

Legacy configuration system for backward compatibility.

This module provides the original argparse-based configuration system for backward compatibility with existing code that depends on the old Config class interface.

author:

M3SGG Team

version:

0.1.0

class m3sgg.core.config.legacy.LegacyConfig(config_path: str | None = None, overrides: Dict[str, Any] | None = None)[source]

Bases: Config

Legacy configuration class for backward compatibility.

This class wraps the original Config class to provide backward compatibility while allowing for future migration to the modern configuration system.

Parameters:
  • config_path (Optional[str]) – Path to configuration file

  • overrides (Optional[Dict[str, Any]]) – Configuration overrides

__init__(config_path: str | None = None, overrides: Dict[str, Any] | None = None)[source]

Initialize the legacy configuration.

Parameters:
  • config_path (Optional[str]) – Path to configuration file

  • overrides (Optional[Dict[str, Any]]) – Configuration overrides

to_dict() Dict[str, Any][source]

Convert configuration to dictionary.

Returns:

Configuration as dictionary

Return type:

Dict[str, Any]

update(**kwargs)[source]

Update configuration values.

Parameters:

kwargs (Any) – Configuration updates

get(key: str, default: Any | None = None) Any[source]

Get configuration value with default.

Parameters:
  • key (str) – Configuration key

  • default (Any) – Default value if key not found

Returns:

Configuration value

Return type:

Any

set(key: str, value: Any)[source]

Set configuration value.

Parameters:
  • key (str) – Configuration key

  • value (Any) – Value to set

Structured Configuration

Base structured configuration classes for M3SGG models.

This module provides the base configuration classes and common structures used across all model types in the M3SGG framework.

author:

M3SGG Team

version:

0.1.0

class m3sgg.core.config.structured.base.BaseConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'sttran', use_matcher: bool = False, eval: bool = False)[source]

Bases: object

Base configuration class with common parameters for all models.

This class defines the common configuration parameters that are shared across all model types in the M3SGG framework.

Parameters:
  • mode (str) – Training mode (predcls/sgcls/sgdet)

  • save_path (str) – Path to save model outputs

  • model_path (str) – Path to model weights

  • dataset (str) – Dataset name (action_genome/EASG)

  • data_path (str) – Path to dataset

  • datasize (str) – Dataset size (mini/large)

  • fraction (int) – Fraction of dataset to use (1=all, 2=half, etc.)

  • ckpt (Optional[str]) – Checkpoint path

  • optimizer (str) – Optimizer type (adamw/adam/sgd)

  • lr (float) – Learning rate

  • nepoch (float) – Number of epochs

  • niter (Optional[int]) – Number of iterations for iterative training

  • eval_frequency (int) – Evaluation frequency for iterative training

  • enc_layer (int) – Number of encoder layers

  • dec_layer (int) – Number of decoder layers

  • bce_loss (bool) – Use BCE loss instead of multi-label margin loss

  • device (str) – Torch device string (e.g., cuda:0, cpu)

  • seed (int) – Global random seed

  • num_workers (int) – Number of DataLoader workers

  • model_type (str) – Model type identifier

  • use_matcher (bool) – Use Hungarian matcher (for DSG-DETR)

  • eval (bool) – Evaluation mode

mode: str = 'predcls'
save_path: str = 'output'
model_path: str = 'weights/predcls.tar'
dataset: str = 'action_genome'
data_path: str = 'data/action_genome'
datasize: str = 'large'
fraction: int = 1
ckpt: str | None = None
optimizer: str = 'adamw'
lr: float = 1e-05
nepoch: float = 10
niter: int | None = None
eval_frequency: int = 50
enc_layer: int = 1
dec_layer: int = 3
bce_loss: bool = False
device: str = 'cuda:0'
seed: int = 42
num_workers: int = 0
model_type: str = 'sttran'
use_matcher: bool = False
eval: bool = False
__init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'sttran', use_matcher: bool = False, eval: bool = False) None
class m3sgg.core.config.structured.base.TrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1)[source]

Bases: object

Training-specific configuration parameters.

Parameters:
  • batch_size (int) – Training batch size

  • val_batch_size (int) – Validation batch size

  • max_grad_norm (float) – Maximum gradient norm for clipping

  • warmup_epochs (int) – Number of warmup epochs

  • early_stopping_patience (int) – Early stopping patience

  • save_frequency (int) – Model saving frequency (epochs)

  • log_frequency (int) – Logging frequency (iterations)

  • val_frequency (int) – Validation frequency (epochs)

batch_size: int = 1
val_batch_size: int = 1
max_grad_norm: float = 1.0
warmup_epochs: int = 0
early_stopping_patience: int = 10
save_frequency: int = 1
log_frequency: int = 10
val_frequency: int = 1
__init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1) None
class m3sgg.core.config.structured.base.DataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False)[source]

Bases: object

Data-specific configuration parameters.

Parameters:
  • cache_dir (str) – Directory for caching data

  • pin_memory (bool) – Pin memory for DataLoader

  • shuffle (bool) – Shuffle training data

  • drop_last (bool) – Drop last incomplete batch

  • prefetch_factor (int) – DataLoader prefetch factor

  • persistent_workers (bool) – Use persistent workers

cache_dir: str = 'data/cache'
pin_memory: bool = False
shuffle: bool = True
drop_last: bool = False
prefetch_factor: int = 2
persistent_workers: bool = False
__init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False) None
class m3sgg.core.config.structured.base.LoggingConfig(log_level: str = 'INFO', log_file: str | None = None, tensorboard_dir: str | None = None, wandb_project: str | None = None, wandb_entity: str | None = None, log_gradients: bool = False, log_weights: bool = False)[source]

Bases: object

Logging configuration parameters.

Parameters:
  • log_level (str) – Logging level (DEBUG, INFO, WARNING, ERROR)

  • log_file (Optional[str]) – Log file path

  • tensorboard_dir (Optional[str]) – TensorBoard log directory

  • wandb_project (Optional[str]) – Weights & Biases project name

  • wandb_entity (Optional[str]) – Weights & Biases entity name

  • log_gradients (bool) – Log gradient norms

  • log_weights (bool) – Log weight histograms

log_level: str = 'INFO'
log_file: str | None = None
tensorboard_dir: str | None = None
wandb_project: str | None = None
wandb_entity: str | None = None
log_gradients: bool = False
log_weights: bool = False
__init__(log_level: str = 'INFO', log_file: str | None = None, tensorboard_dir: str | None = None, wandb_project: str | None = None, wandb_entity: str | None = None, log_gradients: bool = False, log_weights: bool = False) None
class m3sgg.core.config.structured.base.CheckpointConfig(save_dir: str = 'checkpoints', save_best: bool = True, save_last: bool = True, save_frequency: int = 1, max_checkpoints: int = 5, checkpoint_metric: str = 'recall@20', checkpoint_mode: str = 'max')[source]

Bases: object

Checkpoint configuration parameters.

Parameters:
  • save_dir (str) – Directory to save checkpoints

  • save_best (bool) – Save best model only

  • save_last (bool) – Save last model

  • save_frequency (int) – Save frequency (epochs)

  • max_checkpoints (int) – Maximum number of checkpoints to keep

  • checkpoint_metric (str) – Metric to use for best model selection

  • checkpoint_mode (str) – Mode for metric comparison (min/max)

save_dir: str = 'checkpoints'
save_best: bool = True
save_last: bool = True
save_frequency: int = 1
max_checkpoints: int = 5
checkpoint_metric: str = 'recall@20'
checkpoint_mode: str = 'max'
__init__(save_dir: str = 'checkpoints', save_best: bool = True, save_last: bool = True, save_frequency: int = 1, max_checkpoints: int = 5, checkpoint_metric: str = 'recall@20', checkpoint_mode: str = 'max') None
class m3sgg.core.config.structured.base.EvaluationConfig(iou_threshold: float = 0.5, constraint: str = 'with', save_predictions: bool = True, predictions_dir: str = 'predictions', eval_metrics: ~typing.List[str] = <factory>, eval_frequency: int = 1)[source]

Bases: object

Evaluation configuration parameters.

Parameters:
  • iou_threshold (float) – IoU threshold for evaluation

  • constraint (str) – Constraint type for evaluation

  • save_predictions (bool) – Save predictions to file

  • predictions_dir (str) – Directory to save predictions

  • eval_metrics (List[str]) – List of metrics to compute

  • eval_frequency (int) – Evaluation frequency (epochs)

iou_threshold: float = 0.5
constraint: str = 'with'
save_predictions: bool = True
predictions_dir: str = 'predictions'
eval_metrics: List[str]
eval_frequency: int = 1
__init__(iou_threshold: float = 0.5, constraint: str = 'with', save_predictions: bool = True, predictions_dir: str = 'predictions', eval_metrics: ~typing.List[str] = <factory>, eval_frequency: int = 1) None
class m3sgg.core.config.structured.base.ModelConfig(hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True)[source]

Bases: object

Model-specific configuration parameters.

Parameters:
  • hidden_dim (int) – Hidden dimension size

  • num_heads (int) – Number of attention heads

  • num_layers (int) – Number of transformer layers

  • dropout (float) – Dropout rate

  • activation (str) – Activation function

  • norm_type (str) – Normalization type

  • use_bias (bool) – Use bias in linear layers

hidden_dim: int = 256
num_heads: int = 8
num_layers: int = 6
dropout: float = 0.1
activation: str = 'relu'
norm_type: str = 'layer_norm'
use_bias: bool = True
__init__(hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True) None
class m3sgg.core.config.structured.base.LossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None)[source]

Bases: object

Loss function configuration parameters.

Parameters:
  • loss_weights (Dict[str, float]) – Weights for different loss components

  • label_smoothing (float) – Label smoothing factor

  • focal_alpha (float) – Focal loss alpha parameter

  • focal_gamma (float) – Focal loss gamma parameter

  • class_weights (Optional[Dict[str, float]]) – Class weights for imbalanced datasets

loss_weights: Dict[str, float]
label_smoothing: float = 0.0
focal_alpha: float = 0.25
focal_gamma: float = 2.0
class_weights: Dict[str, float] | None = None
__init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None) None
m3sgg.core.config.structured.base.get_config_class(model_type: str)[source]

Get the appropriate configuration class for a model type.

Parameters:

model_type (str) – The model type identifier

Returns:

The configuration class for the model type

Return type:

type

Raises:

ValueError – If the model type is not supported

STTRAN model configuration classes.

This module provides structured configuration classes specifically for the STTRAN (Spatial-Temporal Transformer) model.

author:

M3SGG Team

version:

0.1.0

class m3sgg.core.config.structured.sttran.STTRANConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 2e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'sttran', use_matcher: bool = False, eval: bool = False, hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, use_spatial_encoding: bool = True, use_temporal_encoding: bool = True, max_seq_len: int = 1000, spatial_dim: int = 64, temporal_dim: int = 64, obj_feat_dim: int = 2048, rel_feat_dim: int = 256, num_obj_classes: int = 35, num_rel_classes: int = 132, use_bbox_encoding: bool = True, bbox_encoding_dim: int = 128, use_attention_weights: bool = False, attention_dropout: float = 0.1, ffn_dim: int = 1024, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True, gradient_checkpointing: bool = False, use_memory_efficient_attention: bool = False)[source]

Bases: BaseConfig

Configuration for STTRAN model.

STTRAN (Spatial-Temporal Transformer) is a transformer-based model for video scene graph generation that processes spatial and temporal information through attention mechanisms.

Parameters:
  • model_type (str) – Model type identifier

  • lr (float) – Learning rate (overridden for STTRAN)

  • hidden_dim (int) – Hidden dimension size

  • num_heads (int) – Number of attention heads

  • num_layers (int) – Number of transformer layers

  • dropout (float) – Dropout rate

  • use_spatial_encoding (bool) – Use spatial position encoding

  • use_temporal_encoding (bool) – Use temporal position encoding

  • max_seq_len (int) – Maximum sequence length

  • spatial_dim (int) – Spatial dimension size

  • temporal_dim (int) – Temporal dimension size

  • obj_feat_dim (int) – Object feature dimension

  • rel_feat_dim (int) – Relation feature dimension

  • num_obj_classes (int) – Number of object classes

  • num_rel_classes (int) – Number of relation classes

  • use_bbox_encoding (bool) – Use bounding box encoding

  • bbox_encoding_dim (int) – Bounding box encoding dimension

  • use_attention_weights (bool) – Use attention weights for visualization

  • attention_dropout (float) – Attention dropout rate

  • ffn_dim (int) – Feed-forward network dimension

  • activation (str) – Activation function

  • norm_type (str) – Normalization type

  • use_bias (bool) – Use bias in linear layers

  • gradient_checkpointing (bool) – Use gradient checkpointing

  • use_memory_efficient_attention (bool) – Use memory efficient attention

model_type: str = 'sttran'
lr: float = 2e-05
hidden_dim: int = 256
num_heads: int = 8
num_layers: int = 6
dropout: float = 0.1
use_spatial_encoding: bool = True
use_temporal_encoding: bool = True
max_seq_len: int = 1000
spatial_dim: int = 64
temporal_dim: int = 64
obj_feat_dim: int = 2048
rel_feat_dim: int = 256
num_obj_classes: int = 35
num_rel_classes: int = 132
use_bbox_encoding: bool = True
bbox_encoding_dim: int = 128
use_attention_weights: bool = False
attention_dropout: float = 0.1
ffn_dim: int = 1024
activation: str = 'relu'
norm_type: str = 'layer_norm'
use_bias: bool = True
gradient_checkpointing: bool = False
use_memory_efficient_attention: bool = False
__init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 2e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'sttran', use_matcher: bool = False, eval: bool = False, hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, use_spatial_encoding: bool = True, use_temporal_encoding: bool = True, max_seq_len: int = 1000, spatial_dim: int = 64, temporal_dim: int = 64, obj_feat_dim: int = 2048, rel_feat_dim: int = 256, num_obj_classes: int = 35, num_rel_classes: int = 132, use_bbox_encoding: bool = True, bbox_encoding_dim: int = 128, use_attention_weights: bool = False, attention_dropout: float = 0.1, ffn_dim: int = 1024, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True, gradient_checkpointing: bool = False, use_memory_efficient_attention: bool = False) None
class m3sgg.core.config.structured.sttran.STTRANTrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 2, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, scheduler_type: str = 'reduce_on_plateau', scheduler_patience: int = 3, scheduler_factor: float = 0.5, weight_decay: float = 0.0001, clip_grad_norm: float = 1.0, use_amp: bool = False, accumulation_steps: int = 1)[source]

Bases: TrainingConfig

STTRAN-specific training configuration.

Parameters:
  • warmup_epochs (int) – Number of warmup epochs (STTRAN specific)

  • scheduler_type (str) – Learning rate scheduler type

  • scheduler_patience (int) – Scheduler patience

  • scheduler_factor (float) – Scheduler reduction factor

  • weight_decay (float) – Weight decay for regularization

  • clip_grad_norm (float) – Gradient clipping norm

  • use_amp (bool) – Use automatic mixed precision

  • accumulation_steps (int) – Gradient accumulation steps

warmup_epochs: int = 2
scheduler_type: str = 'reduce_on_plateau'
scheduler_patience: int = 3
scheduler_factor: float = 0.5
weight_decay: float = 0.0001
clip_grad_norm: float = 1.0
use_amp: bool = False
accumulation_steps: int = 1
__init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 2, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, scheduler_type: str = 'reduce_on_plateau', scheduler_patience: int = 3, scheduler_factor: float = 0.5, weight_decay: float = 0.0001, clip_grad_norm: float = 1.0, use_amp: bool = False, accumulation_steps: int = 1) None
class m3sgg.core.config.structured.sttran.STTRANLossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 2.0, bbox_loss_weight: float = 2.5, giou_loss_weight: float = 1.0, use_focal_loss: bool = False, use_class_weights: bool = False, obj_class_weights: dict | None = None, rel_class_weights: dict | None = None)[source]

Bases: LossConfig

STTRAN-specific loss configuration.

Parameters:
  • obj_loss_weight (float) – Weight for object classification loss

  • rel_loss_weight (float) – Weight for relation classification loss

  • bbox_loss_weight (float) – Weight for bounding box regression loss

  • giou_loss_weight (float) – Weight for GIoU loss

  • use_focal_loss (bool) – Use focal loss for classification

  • focal_alpha (float) – Focal loss alpha parameter

  • focal_gamma (float) – Focal loss gamma parameter

  • label_smoothing (float) – Label smoothing factor

  • use_class_weights (bool) – Use class weights for imbalanced data

  • obj_class_weights (Optional[dict]) – Object class weights

  • rel_class_weights (Optional[dict]) – Relation class weights

obj_loss_weight: float = 1.0
rel_loss_weight: float = 2.0
bbox_loss_weight: float = 2.5
giou_loss_weight: float = 1.0
use_focal_loss: bool = False
focal_alpha: float = 0.25
focal_gamma: float = 2.0
label_smoothing: float = 0.0
use_class_weights: bool = False
obj_class_weights: dict | None = None
rel_class_weights: dict | None = None
__init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 2.0, bbox_loss_weight: float = 2.5, giou_loss_weight: float = 1.0, use_focal_loss: bool = False, use_class_weights: bool = False, obj_class_weights: dict | None = None, rel_class_weights: dict | None = None) None
class m3sgg.core.config.structured.sttran.STTRANDataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_augmentation: bool = False, augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5)[source]

Bases: DataConfig

STTRAN-specific data configuration.

Parameters:
  • max_objects (int) – Maximum number of objects per frame

  • max_relations (int) – Maximum number of relations per frame

  • max_frames (int) – Maximum number of frames per video

  • use_temporal_sampling (bool) – Use temporal sampling

  • temporal_stride (int) – Temporal stride for sampling

  • use_augmentation (bool) – Use data augmentation

  • augmentation_prob (float) – Probability of applying augmentation

  • use_negative_sampling (bool) – Use negative sampling for relations

  • negative_ratio (float) – Ratio of negative samples

max_objects: int = 50
max_relations: int = 100
max_frames: int = 30
use_temporal_sampling: bool = True
temporal_stride: int = 1
use_augmentation: bool = False
augmentation_prob: float = 0.5
use_negative_sampling: bool = True
negative_ratio: float = 0.5
__init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_augmentation: bool = False, augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5) None

Tempura model configuration classes.

This module provides structured configuration classes specifically for the Tempura model, which includes memory mechanisms and GMM heads.

author:

M3SGG Team

version:

0.1.0

class m3sgg.core.config.structured.tempura.TempuraConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'tempura', use_matcher: bool = False, eval: bool = False, obj_head: str = 'gmm', rel_head: str = 'gmm', K: int = 4, gmm_components: int = 4, gmm_covariance_type: str = 'full', gmm_reg_covar: float = 1e-06, rel_mem_compute: str | None = None, obj_mem_compute: bool = False, take_obj_mem_feat: bool = False, obj_mem_weight_type: str = 'simple', rel_mem_weight_type: str = 'simple', mem_feat_selection: str = 'manual', mem_fusion: str = 'early', mem_feat_lambda: float | None = None, mem_size: int = 1000, mem_dim: int = 256, mem_update_rate: float = 0.1, mem_temperature: float = 1.0, use_memory_attention: bool = True, memory_dropout: float = 0.1, pseudo_thresh: int = 7, obj_unc: bool = False, rel_unc: bool = False, obj_loss_weighting: str | None = None, rel_loss_weighting: str | None = None, mlm: bool = False, eos_coef: float = 1, obj_con_loss: str | None = None, lambda_con: float = 1, tracking: bool = True, use_prior: bool = False, prior_weight: float = 0.1)[source]

Bases: BaseConfig

Configuration for Tempura model.

Tempura is a memory-enhanced model for video scene graph generation that uses Gaussian Mixture Model (GMM) heads and memory mechanisms for improved performance.

Parameters:
  • model_type (str) – Model type identifier

  • obj_head (str) – Object classification head type

  • rel_head (str) – Relation classification head type

  • K (int) – Number of mixture models

  • rel_mem_compute (Optional[str]) – Relation memory computation type

  • obj_mem_compute (bool) – Object memory computation

  • take_obj_mem_feat (bool) – Take object memory features

  • obj_mem_weight_type (str) – Object memory weight type

  • rel_mem_weight_type (str) – Relation memory weight type

  • mem_feat_selection (str) – Memory feature selection method

  • mem_fusion (str) – Memory fusion method

  • mem_feat_lambda (Optional[float]) – Memory feature lambda

  • pseudo_thresh (int) – Pseudo label threshold

  • obj_unc (bool) – Object uncertainty

  • rel_unc (bool) – Relation uncertainty

  • obj_loss_weighting (Optional[str]) – Object loss weighting

  • rel_loss_weighting (Optional[str]) – Relation loss weighting

  • mlm (bool) – Masked language modeling

  • eos_coef (float) – End-of-sequence coefficient

  • obj_con_loss (Optional[str]) – Object consistency loss

  • lambda_con (float) – Consistency loss coefficient

  • tracking (bool) – Enable tracking

  • mem_size (int) – Memory size

  • mem_dim (int) – Memory dimension

  • mem_update_rate (float) – Memory update rate

  • mem_temperature (float) – Memory temperature

  • use_memory_attention (bool) – Use memory attention

  • memory_dropout (float) – Memory dropout rate

  • gmm_components (int) – Number of GMM components

  • gmm_covariance_type (str) – GMM covariance type

  • gmm_reg_covar (float) – GMM regularization covariance

  • use_prior (bool) – Use prior knowledge

  • prior_weight (float) – Prior weight

model_type: str = 'tempura'
obj_head: str = 'gmm'
rel_head: str = 'gmm'
K: int = 4
gmm_components: int = 4
gmm_covariance_type: str = 'full'
gmm_reg_covar: float = 1e-06
rel_mem_compute: str | None = None
obj_mem_compute: bool = False
take_obj_mem_feat: bool = False
obj_mem_weight_type: str = 'simple'
rel_mem_weight_type: str = 'simple'
mem_feat_selection: str = 'manual'
mem_fusion: str = 'early'
mem_feat_lambda: float | None = None
mem_size: int = 1000
mem_dim: int = 256
mem_update_rate: float = 0.1
mem_temperature: float = 1.0
use_memory_attention: bool = True
memory_dropout: float = 0.1
pseudo_thresh: int = 7
obj_unc: bool = False
rel_unc: bool = False
obj_loss_weighting: str | None = None
rel_loss_weighting: str | None = None
mlm: bool = False
eos_coef: float = 1
obj_con_loss: str | None = None
lambda_con: float = 1
tracking: bool = True
use_prior: bool = False
prior_weight: float = 0.1
__init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'tempura', use_matcher: bool = False, eval: bool = False, obj_head: str = 'gmm', rel_head: str = 'gmm', K: int = 4, gmm_components: int = 4, gmm_covariance_type: str = 'full', gmm_reg_covar: float = 1e-06, rel_mem_compute: str | None = None, obj_mem_compute: bool = False, take_obj_mem_feat: bool = False, obj_mem_weight_type: str = 'simple', rel_mem_weight_type: str = 'simple', mem_feat_selection: str = 'manual', mem_fusion: str = 'early', mem_feat_lambda: float | None = None, mem_size: int = 1000, mem_dim: int = 256, mem_update_rate: float = 0.1, mem_temperature: float = 1.0, use_memory_attention: bool = True, memory_dropout: float = 0.1, pseudo_thresh: int = 7, obj_unc: bool = False, rel_unc: bool = False, obj_loss_weighting: str | None = None, rel_loss_weighting: str | None = None, mlm: bool = False, eos_coef: float = 1, obj_con_loss: str | None = None, lambda_con: float = 1, tracking: bool = True, use_prior: bool = False, prior_weight: float = 0.1) None
class m3sgg.core.config.structured.tempura.TempuraTrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, memory_warmup_epochs: int = 5, memory_lr: float = 0.0001, gmm_lr: float = 0.001, use_memory_scheduler: bool = True, memory_decay: float = 0.99, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_memory_regularization: bool = True, memory_reg_weight: float = 0.01)[source]

Bases: TrainingConfig

Tempura-specific training configuration.

Parameters:
  • memory_warmup_epochs (int) – Number of epochs for memory warmup

  • memory_lr (float) – Learning rate for memory parameters

  • gmm_lr (float) – Learning rate for GMM parameters

  • use_memory_scheduler (bool) – Use separate scheduler for memory

  • memory_decay (float) – Memory decay rate

  • use_curriculum_learning (bool) – Use curriculum learning

  • curriculum_epochs (int) – Number of epochs for curriculum

  • use_memory_regularization (bool) – Use memory regularization

  • memory_reg_weight (float) – Memory regularization weight

memory_warmup_epochs: int = 5
memory_lr: float = 0.0001
gmm_lr: float = 0.001
use_memory_scheduler: bool = True
memory_decay: float = 0.99
use_curriculum_learning: bool = False
curriculum_epochs: int = 10
use_memory_regularization: bool = True
memory_reg_weight: float = 0.01
__init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, memory_warmup_epochs: int = 5, memory_lr: float = 0.0001, gmm_lr: float = 0.001, use_memory_scheduler: bool = True, memory_decay: float = 0.99, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_memory_regularization: bool = True, memory_reg_weight: float = 0.01) None
class m3sgg.core.config.structured.tempura.TempuraLossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 1.0, memory_loss_weight: float = 0.1, gmm_loss_weight: float = 0.1, consistency_loss_weight: float = 0.1, uncertainty_loss_weight: float = 0.1, use_memory_loss: bool = True, use_gmm_loss: bool = True, use_consistency_loss: bool = True, use_uncertainty_loss: bool = True, memory_loss_type: str = 'mse', gmm_loss_type: str = 'nll', consistency_loss_type: str = 'mse', uncertainty_loss_type: str = 'kl')[source]

Bases: LossConfig

Tempura-specific loss configuration.

Parameters:
  • obj_loss_weight (float) – Weight for object classification loss

  • rel_loss_weight (float) – Weight for relation classification loss

  • memory_loss_weight (float) – Weight for memory loss

  • gmm_loss_weight (float) – Weight for GMM loss

  • consistency_loss_weight (float) – Weight for consistency loss

  • uncertainty_loss_weight (float) – Weight for uncertainty loss

  • use_memory_loss (bool) – Use memory loss

  • use_gmm_loss (bool) – Use GMM loss

  • use_consistency_loss (bool) – Use consistency loss

  • use_uncertainty_loss (bool) – Use uncertainty loss

  • memory_loss_type (str) – Type of memory loss

  • gmm_loss_type (str) – Type of GMM loss

  • consistency_loss_type (str) – Type of consistency loss

  • uncertainty_loss_type (str) – Type of uncertainty loss

obj_loss_weight: float = 1.0
rel_loss_weight: float = 1.0
memory_loss_weight: float = 0.1
gmm_loss_weight: float = 0.1
consistency_loss_weight: float = 0.1
uncertainty_loss_weight: float = 0.1
use_memory_loss: bool = True
use_gmm_loss: bool = True
use_consistency_loss: bool = True
use_uncertainty_loss: bool = True
memory_loss_type: str = 'mse'
gmm_loss_type: str = 'nll'
consistency_loss_type: str = 'mse'
uncertainty_loss_type: str = 'kl'
__init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 1.0, memory_loss_weight: float = 0.1, gmm_loss_weight: float = 0.1, consistency_loss_weight: float = 0.1, uncertainty_loss_weight: float = 0.1, use_memory_loss: bool = True, use_gmm_loss: bool = True, use_consistency_loss: bool = True, use_uncertainty_loss: bool = True, memory_loss_type: str = 'mse', gmm_loss_type: str = 'nll', consistency_loss_type: str = 'mse', uncertainty_loss_type: str = 'kl') None
class m3sgg.core.config.structured.tempura.TempuraDataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_memory_sampling: bool = True, memory_sampling_ratio: float = 0.1, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_pseudo_labels: bool = False, pseudo_label_threshold: float = 0.7, use_uncertainty_sampling: bool = False, uncertainty_threshold: float = 0.5)[source]

Bases: DataConfig

Tempura-specific data configuration.

Parameters:
  • max_objects (int) – Maximum number of objects per frame

  • max_relations (int) – Maximum number of relations per frame

  • max_frames (int) – Maximum number of frames per video

  • use_temporal_sampling (bool) – Use temporal sampling

  • temporal_stride (int) – Temporal stride for sampling

  • use_memory_sampling (bool) – Use memory sampling

  • memory_sampling_ratio (float) – Memory sampling ratio

  • use_negative_sampling (bool) – Use negative sampling

  • negative_ratio (float) – Ratio of negative samples

  • use_pseudo_labels (bool) – Use pseudo labels

  • pseudo_label_threshold (float) – Pseudo label threshold

  • use_uncertainty_sampling (bool) – Use uncertainty sampling

  • uncertainty_threshold (float) – Uncertainty threshold

max_objects: int = 50
max_relations: int = 100
max_frames: int = 30
use_temporal_sampling: bool = True
temporal_stride: int = 1
use_memory_sampling: bool = True
memory_sampling_ratio: float = 0.1
use_negative_sampling: bool = True
negative_ratio: float = 0.5
use_pseudo_labels: bool = False
pseudo_label_threshold: float = 0.7
use_uncertainty_sampling: bool = False
uncertainty_threshold: float = 0.5
__init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_memory_sampling: bool = True, memory_sampling_ratio: float = 0.1, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_pseudo_labels: bool = False, pseudo_label_threshold: float = 0.7, use_uncertainty_sampling: bool = False, uncertainty_threshold: float = 0.5) None

SceneLLM model configuration classes.

This module provides structured configuration classes specifically for the SceneLLM model, which combines VQ-VAE with language models.

author:

M3SGG Team

version:

0.1.0

class m3sgg.core.config.structured.scenellm.SceneLLMConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'scenellm', use_matcher: bool = False, eval: bool = False, embed_dim: int = 1024, codebook_size: int = 8192, commitment_cost: float = 0.25, vqvae_hidden_dim: int = 512, vqvae_num_layers: int = 4, vqvae_num_resblocks: int = 2, vqvae_dropout: float = 0.1, vqvae_use_attention: bool = True, vqvae_attention_heads: int = 8, llm_name: str = 'google/gemma-2-2b', llm_max_length: int = 512, llm_temperature: float = 0.7, llm_top_p: float = 0.9, llm_top_k: int = 50, llm_repetition_penalty: float = 1.1, llm_do_sample: bool = True, llm_pad_token_id: int = 0, llm_eos_token_id: int = 1, llm_bos_token_id: int = 2, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, ot_step: int = 512, vqvae_epochs: int = 5, stage1_iterations: int = 30000, stage2_iterations: int = 50000, alpha_obj: float = 1.0, alpha_rel: float = 1.0, scenellm_training_stage: str = 'vqvae', disable_checkpoint_saving: bool = False, use_peft: bool = True, peft_config: dict | None = None, use_gradient_checkpointing: bool = True, use_flash_attention: bool = False, use_8bit_optimizer: bool = False, use_4bit_quantization: bool = False)[source]

Bases: BaseConfig

Configuration for SceneLLM model.

SceneLLM combines VQ-VAE (Vector Quantized Variational AutoEncoder) with language models for scene graph generation and text summarization.

Parameters:
  • model_type (str) – Model type identifier

  • embed_dim (int) – Embedding dimension for VQ-VAE

  • codebook_size (int) – Size of VQ-VAE codebook

  • commitment_cost (float) – Commitment cost for VQ-VAE

  • llm_name (str) – LLM model name

  • lora_r (int) – LoRA rank

  • lora_alpha (int) – LoRA alpha

  • lora_dropout (float) – LoRA dropout

  • ot_step (int) – Step size for optimal transport codebook update

  • vqvae_epochs (int) – Epochs for VQ-VAE pretraining

  • stage1_iterations (int) – Iterations for stage 1 training

  • stage2_iterations (int) – Iterations for stage 2 training

  • alpha_obj (float) – Weight for object loss in SceneLLM

  • alpha_rel (float) – Weight for relation loss in SceneLLM

  • scenellm_training_stage (str) – SceneLLM training stage

  • disable_checkpoint_saving (bool) – Disable checkpoint saving

  • vqvae_hidden_dim (int) – VQ-VAE hidden dimension

  • vqvae_num_layers (int) – VQ-VAE number of layers

  • vqvae_num_resblocks (int) – VQ-VAE number of residual blocks

  • vqvae_dropout (float) – VQ-VAE dropout rate

  • vqvae_use_attention (bool) – VQ-VAE use attention

  • vqvae_attention_heads (int) – VQ-VAE attention heads

  • llm_max_length (int) – LLM maximum sequence length

  • llm_temperature (float) – LLM temperature for generation

  • llm_top_p (float) – LLM top-p sampling

  • llm_top_k (int) – LLM top-k sampling

  • llm_repetition_penalty (float) – LLM repetition penalty

  • llm_do_sample (bool) – LLM do sampling

  • llm_pad_token_id (int) – LLM pad token ID

  • llm_eos_token_id (int) – LLM EOS token ID

  • llm_bos_token_id (int) – LLM BOS token ID

  • use_peft (bool) – Use PEFT (Parameter Efficient Fine-Tuning)

  • peft_config (Optional[dict]) – PEFT configuration

  • use_gradient_checkpointing (bool) – Use gradient checkpointing

  • use_flash_attention (bool) – Use flash attention

  • use_8bit_optimizer (bool) – Use 8-bit optimizer

  • use_4bit_quantization (bool) – Use 4-bit quantization

model_type: str = 'scenellm'
embed_dim: int = 1024
codebook_size: int = 8192
commitment_cost: float = 0.25
vqvae_hidden_dim: int = 512
vqvae_num_layers: int = 4
vqvae_num_resblocks: int = 2
vqvae_dropout: float = 0.1
vqvae_use_attention: bool = True
vqvae_attention_heads: int = 8
llm_name: str = 'google/gemma-2-2b'
llm_max_length: int = 512
llm_temperature: float = 0.7
llm_top_p: float = 0.9
llm_top_k: int = 50
llm_repetition_penalty: float = 1.1
llm_do_sample: bool = True
llm_pad_token_id: int = 0
llm_eos_token_id: int = 1
llm_bos_token_id: int = 2
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
ot_step: int = 512
vqvae_epochs: int = 5
stage1_iterations: int = 30000
stage2_iterations: int = 50000
alpha_obj: float = 1.0
alpha_rel: float = 1.0
scenellm_training_stage: str = 'vqvae'
disable_checkpoint_saving: bool = False
use_peft: bool = True
peft_config: dict | None = None
use_gradient_checkpointing: bool = True
use_flash_attention: bool = False
use_8bit_optimizer: bool = False
use_4bit_quantization: bool = False
__init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'scenellm', use_matcher: bool = False, eval: bool = False, embed_dim: int = 1024, codebook_size: int = 8192, commitment_cost: float = 0.25, vqvae_hidden_dim: int = 512, vqvae_num_layers: int = 4, vqvae_num_resblocks: int = 2, vqvae_dropout: float = 0.1, vqvae_use_attention: bool = True, vqvae_attention_heads: int = 8, llm_name: str = 'google/gemma-2-2b', llm_max_length: int = 512, llm_temperature: float = 0.7, llm_top_p: float = 0.9, llm_top_k: int = 50, llm_repetition_penalty: float = 1.1, llm_do_sample: bool = True, llm_pad_token_id: int = 0, llm_eos_token_id: int = 1, llm_bos_token_id: int = 2, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, ot_step: int = 512, vqvae_epochs: int = 5, stage1_iterations: int = 30000, stage2_iterations: int = 50000, alpha_obj: float = 1.0, alpha_rel: float = 1.0, scenellm_training_stage: str = 'vqvae', disable_checkpoint_saving: bool = False, use_peft: bool = True, peft_config: dict | None = None, use_gradient_checkpointing: bool = True, use_flash_attention: bool = False, use_8bit_optimizer: bool = False, use_4bit_quantization: bool = False) None
class m3sgg.core.config.structured.scenellm.SceneLLMTrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 5, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, vqvae_lr: float = 0.0001, llm_lr: float = 2e-05, vqvae_weight_decay: float = 0.0001, llm_weight_decay: float = 0.01, use_warmup: bool = True, warmup_steps: int = 1000, use_cosine_schedule: bool = True, cosine_min_lr: float = 1e-07, use_linear_schedule: bool = False, linear_min_lr: float = 1e-07, use_adafactor: bool = False, adafactor_scale_parameter: bool = True, adafactor_relative_step_size: bool = True, adafactor_warmup_init: bool = False, use_dataloader_pin_memory: bool = True, dataloader_num_workers: int = 4, dataloader_prefetch_factor: int = 2, use_mixed_precision: bool = True, mixed_precision_backend: str = 'apex', mixed_precision_loss_scale: str = 'dynamic', use_gradient_accumulation: bool = True, gradient_accumulation_steps: int = 4, use_gradient_clipping: bool = True, use_ema: bool = False, ema_decay: float = 0.999, use_swa: bool = False, swa_lr: float = 1e-05, swa_epochs: int = 5, use_early_stopping: bool = True, early_stopping_min_delta: float = 0.0001, early_stopping_monitor: str = 'val_loss', early_stopping_mode: str = 'min')[source]

Bases: TrainingConfig

SceneLLM-specific training configuration.

Parameters:
  • vqvae_lr (float) – Learning rate for VQ-VAE

  • llm_lr (float) – Learning rate for LLM

  • vqvae_weight_decay (float) – Weight decay for VQ-VAE

  • llm_weight_decay (float) – Weight decay for LLM

  • use_warmup (bool) – Use learning rate warmup

  • warmup_steps (int) – Number of warmup steps

  • use_cosine_schedule (bool) – Use cosine learning rate schedule

  • cosine_min_lr (float) – Minimum learning rate for cosine schedule

  • use_linear_schedule (bool) – Use linear learning rate schedule

  • linear_min_lr (float) – Minimum learning rate for linear schedule

  • use_adafactor (bool) – Use Adafactor optimizer

  • adafactor_scale_parameter (bool) – Adafactor scale parameter

  • adafactor_relative_step_size (bool) – Adafactor relative step size

  • adafactor_warmup_init (bool) – Adafactor warmup init

  • use_dataloader_pin_memory (bool) – Use DataLoader pin memory

  • dataloader_num_workers (int) – DataLoader number of workers

  • dataloader_prefetch_factor (int) – DataLoader prefetch factor

  • use_mixed_precision (bool) – Use mixed precision training

  • mixed_precision_backend (str) – Mixed precision backend

  • mixed_precision_loss_scale (str) – Mixed precision loss scale

  • use_gradient_accumulation (bool) – Use gradient accumulation

  • gradient_accumulation_steps (int) – Gradient accumulation steps

  • use_gradient_clipping (bool) – Use gradient clipping

  • max_grad_norm (float) – Maximum gradient norm

  • use_ema (bool) – Use exponential moving average

  • ema_decay (float) – EMA decay rate

  • use_swa (bool) – Use stochastic weight averaging

  • swa_lr (float) – SWA learning rate

  • swa_epochs (int) – SWA epochs

  • use_early_stopping (bool) – Use early stopping

  • early_stopping_patience (int) – Early stopping patience

  • early_stopping_min_delta (float) – Early stopping minimum delta

  • early_stopping_monitor (str) – Early stopping monitor metric

  • early_stopping_mode (str) – Early stopping mode

vqvae_lr: float = 0.0001
llm_lr: float = 2e-05
vqvae_weight_decay: float = 0.0001
llm_weight_decay: float = 0.01
use_warmup: bool = True
warmup_steps: int = 1000
use_cosine_schedule: bool = True
cosine_min_lr: float = 1e-07
use_linear_schedule: bool = False
linear_min_lr: float = 1e-07
use_adafactor: bool = False
adafactor_scale_parameter: bool = True
adafactor_relative_step_size: bool = True
adafactor_warmup_init: bool = False
use_dataloader_pin_memory: bool = True
dataloader_num_workers: int = 4
dataloader_prefetch_factor: int = 2
use_mixed_precision: bool = True
mixed_precision_backend: str = 'apex'
mixed_precision_loss_scale: str = 'dynamic'
use_gradient_accumulation: bool = True
gradient_accumulation_steps: int = 4
use_gradient_clipping: bool = True
max_grad_norm: float = 1.0
use_ema: bool = False
ema_decay: float = 0.999
use_swa: bool = False
swa_lr: float = 1e-05
swa_epochs: int = 5
use_early_stopping: bool = True
early_stopping_patience: int = 5
early_stopping_min_delta: float = 0.0001
early_stopping_monitor: str = 'val_loss'
early_stopping_mode: str = 'min'
__init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 0, early_stopping_patience: int = 5, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, vqvae_lr: float = 0.0001, llm_lr: float = 2e-05, vqvae_weight_decay: float = 0.0001, llm_weight_decay: float = 0.01, use_warmup: bool = True, warmup_steps: int = 1000, use_cosine_schedule: bool = True, cosine_min_lr: float = 1e-07, use_linear_schedule: bool = False, linear_min_lr: float = 1e-07, use_adafactor: bool = False, adafactor_scale_parameter: bool = True, adafactor_relative_step_size: bool = True, adafactor_warmup_init: bool = False, use_dataloader_pin_memory: bool = True, dataloader_num_workers: int = 4, dataloader_prefetch_factor: int = 2, use_mixed_precision: bool = True, mixed_precision_backend: str = 'apex', mixed_precision_loss_scale: str = 'dynamic', use_gradient_accumulation: bool = True, gradient_accumulation_steps: int = 4, use_gradient_clipping: bool = True, use_ema: bool = False, ema_decay: float = 0.999, use_swa: bool = False, swa_lr: float = 1e-05, swa_epochs: int = 5, use_early_stopping: bool = True, early_stopping_min_delta: float = 0.0001, early_stopping_monitor: str = 'val_loss', early_stopping_mode: str = 'min') None
class m3sgg.core.config.structured.scenellm.SceneLLMLossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, vqvae_loss_weight: float = 1.0, llm_loss_weight: float = 1.0, obj_loss_weight: float = 1.0, rel_loss_weight: float = 1.0, commitment_loss_weight: float = 0.25, perceptual_loss_weight: float = 0.1, use_commitment_loss: bool = True, use_perceptual_loss: bool = False, use_kl_loss: bool = False, kl_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01, use_auxiliary_loss: bool = False, auxiliary_loss_weight: float = 0.1)[source]

Bases: LossConfig

SceneLLM-specific loss configuration.

Parameters:
  • vqvae_loss_weight (float) – Weight for VQ-VAE loss

  • llm_loss_weight (float) – Weight for LLM loss

  • obj_loss_weight (float) – Weight for object loss

  • rel_loss_weight (float) – Weight for relation loss

  • commitment_loss_weight (float) – Weight for commitment loss

  • perceptual_loss_weight (float) – Weight for perceptual loss

  • use_commitment_loss (bool) – Use commitment loss

  • use_perceptual_loss (bool) – Use perceptual loss

  • use_kl_loss (bool) – Use KL divergence loss

  • kl_loss_weight (float) – Weight for KL divergence loss

  • use_contrastive_loss (bool) – Use contrastive loss

  • contrastive_loss_weight (float) – Weight for contrastive loss

  • contrastive_temperature (float) – Contrastive loss temperature

  • use_consistency_loss (bool) – Use consistency loss

  • consistency_loss_weight (float) – Weight for consistency loss

  • use_regularization_loss (bool) – Use regularization loss

  • regularization_loss_weight (float) – Weight for regularization loss

  • use_auxiliary_loss (bool) – Use auxiliary loss

  • auxiliary_loss_weight (float) – Weight for auxiliary loss

vqvae_loss_weight: float = 1.0
llm_loss_weight: float = 1.0
obj_loss_weight: float = 1.0
rel_loss_weight: float = 1.0
commitment_loss_weight: float = 0.25
perceptual_loss_weight: float = 0.1
use_commitment_loss: bool = True
use_perceptual_loss: bool = False
use_kl_loss: bool = False
kl_loss_weight: float = 0.1
use_contrastive_loss: bool = False
contrastive_loss_weight: float = 0.1
contrastive_temperature: float = 0.07
use_consistency_loss: bool = False
consistency_loss_weight: float = 0.1
use_regularization_loss: bool = False
regularization_loss_weight: float = 0.01
use_auxiliary_loss: bool = False
auxiliary_loss_weight: float = 0.1
__init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, vqvae_loss_weight: float = 1.0, llm_loss_weight: float = 1.0, obj_loss_weight: float = 1.0, rel_loss_weight: float = 1.0, commitment_loss_weight: float = 0.25, perceptual_loss_weight: float = 0.1, use_commitment_loss: bool = True, use_perceptual_loss: bool = False, use_kl_loss: bool = False, kl_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01, use_auxiliary_loss: bool = False, auxiliary_loss_weight: float = 0.1) None
class m3sgg.core.config.structured.scenellm.SceneLLMDataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, max_text_length: int = 512, use_text_augmentation: bool = False, text_augmentation_prob: float = 0.5, use_image_augmentation: bool = False, image_augmentation_prob: float = 0.5, use_temporal_augmentation: bool = False, temporal_augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_hard_negative_mining: bool = False, hard_negative_ratio: float = 0.2, use_contrastive_sampling: bool = False, contrastive_ratio: float = 0.3, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_dynamic_sampling: bool = False, dynamic_sampling_alpha: float = 0.5, use_balanced_sampling: bool = False, balanced_sampling_alpha: float = 0.5)[source]

Bases: DataConfig

SceneLLM-specific data configuration.

Parameters:
  • max_objects (int) – Maximum number of objects per frame

  • max_relations (int) – Maximum number of relations per frame

  • max_frames (int) – Maximum number of frames per video

  • max_text_length (int) – Maximum text length

  • use_text_augmentation (bool) – Use text augmentation

  • text_augmentation_prob (float) – Text augmentation probability

  • use_image_augmentation (bool) – Use image augmentation

  • image_augmentation_prob (float) – Image augmentation probability

  • use_temporal_augmentation (bool) – Use temporal augmentation

  • temporal_augmentation_prob (float) – Temporal augmentation probability

  • use_negative_sampling (bool) – Use negative sampling

  • negative_ratio (float) – Ratio of negative samples

  • use_hard_negative_mining (bool) – Use hard negative mining

  • hard_negative_ratio (float) – Ratio of hard negative samples

  • use_contrastive_sampling (bool) – Use contrastive sampling

  • contrastive_ratio (float) – Ratio of contrastive samples

  • use_curriculum_learning (bool) – Use curriculum learning

  • curriculum_epochs (int) – Number of epochs for curriculum

  • use_dynamic_sampling (bool) – Use dynamic sampling

  • dynamic_sampling_alpha (float) – Dynamic sampling alpha

  • use_balanced_sampling (bool) – Use balanced sampling

  • balanced_sampling_alpha (float) – Balanced sampling alpha

max_objects: int = 50
max_relations: int = 100
max_frames: int = 30
max_text_length: int = 512
use_text_augmentation: bool = False
text_augmentation_prob: float = 0.5
use_image_augmentation: bool = False
image_augmentation_prob: float = 0.5
use_temporal_augmentation: bool = False
temporal_augmentation_prob: float = 0.5
use_negative_sampling: bool = True
negative_ratio: float = 0.5
use_hard_negative_mining: bool = False
hard_negative_ratio: float = 0.2
use_contrastive_sampling: bool = False
contrastive_ratio: float = 0.3
use_curriculum_learning: bool = False
curriculum_epochs: int = 10
use_dynamic_sampling: bool = False
dynamic_sampling_alpha: float = 0.5
use_balanced_sampling: bool = False
balanced_sampling_alpha: float = 0.5
__init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, max_text_length: int = 512, use_text_augmentation: bool = False, text_augmentation_prob: float = 0.5, use_image_augmentation: bool = False, image_augmentation_prob: float = 0.5, use_temporal_augmentation: bool = False, temporal_augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_hard_negative_mining: bool = False, hard_negative_ratio: float = 0.2, use_contrastive_sampling: bool = False, contrastive_ratio: float = 0.3, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_dynamic_sampling: bool = False, dynamic_sampling_alpha: float = 0.5, use_balanced_sampling: bool = False, balanced_sampling_alpha: float = 0.5) None

OED model configuration classes.

This module provides structured configuration classes specifically for the OED (Object-Event Detection) model.

author:

M3SGG Team

version:

0.1.0

class m3sgg.core.config.structured.oed.OEDConfig(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'oed', use_matcher: bool = False, eval: bool = False, num_queries: int = 100, dec_layers_hopd: int = 6, dec_layers_interaction: int = 6, num_attn_classes: int = 3, num_spatial_classes: int = 6, num_contacting_classes: int = 17, alpha: float = 0.5, oed_use_matching: bool = False, bbox_loss_coef: float = 2.5, giou_loss_coef: float = 1.0, obj_loss_coef: float = 1.0, rel_loss_coef: float = 2.0, oed_eos_coef: float = 0.1, interval1: int = 4, interval2: int = 4, num_ref_frames: int = 2, oed_variant: str = 'multi', fuse_semantic_pos: bool = False, query_temporal_interaction: bool = False, hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, use_bbox_encoding: bool = True, bbox_encoding_dim: int = 128, use_positional_encoding: bool = True, positional_encoding_dim: int = 128, use_temporal_encoding: bool = True, temporal_encoding_dim: int = 128, use_attention_weights: bool = False, attention_dropout: float = 0.1, ffn_dim: int = 1024, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True, gradient_checkpointing: bool = False, use_memory_efficient_attention: bool = False, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01)[source]

Bases: BaseConfig

Configuration for OED model.

OED (Object-Event Detection) is a transformer-based model for video scene graph generation that uses object queries and attention mechanisms for detection.

Parameters:
  • model_type (str) – Model type identifier

  • num_queries (int) – Number of query slots for OED

  • dec_layers_hopd (int) – Number of hopd decoding layers in OED transformer

  • dec_layers_interaction (int) – Number of interaction decoding layers in OED transformer

  • num_attn_classes (int) – Number of attention classes

  • num_spatial_classes (int) – Number of spatial classes

  • num_contacting_classes (int) – Number of contacting classes

  • alpha (float) – Focal loss alpha for OED

  • oed_use_matching (bool) – Use obj/sub matching 2class loss in OED decoder

  • bbox_loss_coef (float) – L1 box coefficient

  • giou_loss_coef (float) – GIoU box coefficient

  • obj_loss_coef (float) – Object classification coefficient

  • rel_loss_coef (float) – Relation classification coefficient

  • oed_eos_coef (float) – Relative classification weight of no-object class for OED

  • interval1 (int) – Interval for training frame selection

  • interval2 (int) – Interval for test frame selection

  • num_ref_frames (int) – Number of reference frames

  • oed_variant (str) – OED variant (single/multi)

  • fuse_semantic_pos (bool) – Fuse semantic and positional embeddings

  • query_temporal_interaction (bool) – Enable query temporal interaction

  • hidden_dim (int) – Hidden dimension size

  • num_heads (int) – Number of attention heads

  • num_layers (int) – Number of transformer layers

  • dropout (float) – Dropout rate

  • use_bbox_encoding (bool) – Use bounding box encoding

  • bbox_encoding_dim (int) – Bounding box encoding dimension

  • use_positional_encoding (bool) – Use positional encoding

  • positional_encoding_dim (int) – Positional encoding dimension

  • use_temporal_encoding (bool) – Use temporal encoding

  • temporal_encoding_dim (int) – Temporal encoding dimension

  • use_attention_weights (bool) – Use attention weights for visualization

  • attention_dropout (float) – Attention dropout rate

  • ffn_dim (int) – Feed-forward network dimension

  • activation (str) – Activation function

  • norm_type (str) – Normalization type

  • use_bias (bool) – Use bias in linear layers

  • gradient_checkpointing (bool) – Use gradient checkpointing

  • use_memory_efficient_attention (bool) – Use memory efficient attention

  • use_auxiliary_loss (bool) – Use auxiliary loss

  • aux_loss_weight (float) – Auxiliary loss weight

  • use_contrastive_loss (bool) – Use contrastive loss

  • contrastive_loss_weight (float) – Contrastive loss weight

  • contrastive_temperature (float) – Contrastive loss temperature

  • use_consistency_loss (bool) – Use consistency loss

  • consistency_loss_weight (float) – Consistency loss weight

  • use_regularization_loss (bool) – Use regularization loss

  • regularization_loss_weight (float) – Regularization loss weight

model_type: str = 'oed'
num_queries: int = 100
dec_layers_hopd: int = 6
dec_layers_interaction: int = 6
num_attn_classes: int = 3
num_spatial_classes: int = 6
num_contacting_classes: int = 17
alpha: float = 0.5
oed_use_matching: bool = False
bbox_loss_coef: float = 2.5
giou_loss_coef: float = 1.0
obj_loss_coef: float = 1.0
rel_loss_coef: float = 2.0
oed_eos_coef: float = 0.1
interval1: int = 4
interval2: int = 4
num_ref_frames: int = 2
oed_variant: str = 'multi'
fuse_semantic_pos: bool = False
query_temporal_interaction: bool = False
hidden_dim: int = 256
num_heads: int = 8
num_layers: int = 6
dropout: float = 0.1
use_bbox_encoding: bool = True
bbox_encoding_dim: int = 128
use_positional_encoding: bool = True
positional_encoding_dim: int = 128
use_temporal_encoding: bool = True
temporal_encoding_dim: int = 128
use_attention_weights: bool = False
attention_dropout: float = 0.1
ffn_dim: int = 1024
activation: str = 'relu'
norm_type: str = 'layer_norm'
use_bias: bool = True
gradient_checkpointing: bool = False
use_memory_efficient_attention: bool = False
use_auxiliary_loss: bool = False
aux_loss_weight: float = 0.1
use_contrastive_loss: bool = False
contrastive_loss_weight: float = 0.1
contrastive_temperature: float = 0.07
use_consistency_loss: bool = False
consistency_loss_weight: float = 0.1
use_regularization_loss: bool = False
regularization_loss_weight: float = 0.01
__init__(mode: str = 'predcls', save_path: str = 'output', model_path: str = 'weights/predcls.tar', dataset: str = 'action_genome', data_path: str = 'data/action_genome', datasize: str = 'large', fraction: int = 1, ckpt: str | None = None, optimizer: str = 'adamw', lr: float = 1e-05, nepoch: float = 10, niter: int | None = None, eval_frequency: int = 50, enc_layer: int = 1, dec_layer: int = 3, bce_loss: bool = False, device: str = 'cuda:0', seed: int = 42, num_workers: int = 0, model_type: str = 'oed', use_matcher: bool = False, eval: bool = False, num_queries: int = 100, dec_layers_hopd: int = 6, dec_layers_interaction: int = 6, num_attn_classes: int = 3, num_spatial_classes: int = 6, num_contacting_classes: int = 17, alpha: float = 0.5, oed_use_matching: bool = False, bbox_loss_coef: float = 2.5, giou_loss_coef: float = 1.0, obj_loss_coef: float = 1.0, rel_loss_coef: float = 2.0, oed_eos_coef: float = 0.1, interval1: int = 4, interval2: int = 4, num_ref_frames: int = 2, oed_variant: str = 'multi', fuse_semantic_pos: bool = False, query_temporal_interaction: bool = False, hidden_dim: int = 256, num_heads: int = 8, num_layers: int = 6, dropout: float = 0.1, use_bbox_encoding: bool = True, bbox_encoding_dim: int = 128, use_positional_encoding: bool = True, positional_encoding_dim: int = 128, use_temporal_encoding: bool = True, temporal_encoding_dim: int = 128, use_attention_weights: bool = False, attention_dropout: float = 0.1, ffn_dim: int = 1024, activation: str = 'relu', norm_type: str = 'layer_norm', use_bias: bool = True, gradient_checkpointing: bool = False, use_memory_efficient_attention: bool = False, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01) None
class m3sgg.core.config.structured.oed.OEDTrainingConfig(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 2, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, scheduler_type: str = 'reduce_on_plateau', scheduler_patience: int = 3, scheduler_factor: float = 0.5, weight_decay: float = 0.0001, clip_grad_norm: float = 1.0, use_amp: bool = False, accumulation_steps: int = 1, use_ema: bool = False, ema_decay: float = 0.999, use_swa: bool = False, swa_lr: float = 1e-05, swa_epochs: int = 5, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01)[source]

Bases: TrainingConfig

OED-specific training configuration.

Parameters:
  • warmup_epochs (int) – Number of warmup epochs

  • scheduler_type (str) – Learning rate scheduler type

  • scheduler_patience (int) – Scheduler patience

  • scheduler_factor (float) – Scheduler reduction factor

  • weight_decay (float) – Weight decay for regularization

  • clip_grad_norm (float) – Gradient clipping norm

  • use_amp (bool) – Use automatic mixed precision

  • accumulation_steps (int) – Gradient accumulation steps

  • use_ema (bool) – Use exponential moving average

  • ema_decay (float) – EMA decay rate

  • use_swa (bool) – Use stochastic weight averaging

  • swa_lr (float) – SWA learning rate

  • swa_epochs (int) – SWA epochs

  • use_curriculum_learning (bool) – Use curriculum learning

  • curriculum_epochs (int) – Number of epochs for curriculum

  • use_auxiliary_loss (bool) – Use auxiliary loss

  • aux_loss_weight (float) – Auxiliary loss weight

  • use_contrastive_loss (bool) – Use contrastive loss

  • contrastive_loss_weight (float) – Contrastive loss weight

  • contrastive_temperature (float) – Contrastive loss temperature

  • use_consistency_loss (bool) – Use consistency loss

  • consistency_loss_weight (float) – Consistency loss weight

  • use_regularization_loss (bool) – Use regularization loss

  • regularization_loss_weight (float) – Regularization loss weight

warmup_epochs: int = 2
scheduler_type: str = 'reduce_on_plateau'
scheduler_patience: int = 3
scheduler_factor: float = 0.5
weight_decay: float = 0.0001
clip_grad_norm: float = 1.0
use_amp: bool = False
accumulation_steps: int = 1
use_ema: bool = False
ema_decay: float = 0.999
use_swa: bool = False
swa_lr: float = 1e-05
swa_epochs: int = 5
use_curriculum_learning: bool = False
curriculum_epochs: int = 10
use_auxiliary_loss: bool = False
aux_loss_weight: float = 0.1
use_contrastive_loss: bool = False
contrastive_loss_weight: float = 0.1
contrastive_temperature: float = 0.07
use_consistency_loss: bool = False
consistency_loss_weight: float = 0.1
use_regularization_loss: bool = False
regularization_loss_weight: float = 0.01
__init__(batch_size: int = 1, val_batch_size: int = 1, max_grad_norm: float = 1.0, warmup_epochs: int = 2, early_stopping_patience: int = 10, save_frequency: int = 1, log_frequency: int = 10, val_frequency: int = 1, scheduler_type: str = 'reduce_on_plateau', scheduler_patience: int = 3, scheduler_factor: float = 0.5, weight_decay: float = 0.0001, clip_grad_norm: float = 1.0, use_amp: bool = False, accumulation_steps: int = 1, use_ema: bool = False, ema_decay: float = 0.999, use_swa: bool = False, swa_lr: float = 1e-05, swa_epochs: int = 5, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01) None
class m3sgg.core.config.structured.oed.OEDLossConfig(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 2.0, bbox_loss_weight: float = 2.5, giou_loss_weight: float = 1.0, use_focal_loss: bool = True, use_class_weights: bool = False, obj_class_weights: dict | None = None, rel_class_weights: dict | None = None, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01)[source]

Bases: LossConfig

OED-specific loss configuration.

Parameters:
  • obj_loss_weight (float) – Weight for object classification loss

  • rel_loss_weight (float) – Weight for relation classification loss

  • bbox_loss_weight (float) – Weight for bounding box regression loss

  • giou_loss_weight (float) – Weight for GIoU loss

  • use_focal_loss (bool) – Use focal loss for classification

  • focal_alpha (float) – Focal loss alpha parameter

  • focal_gamma (float) – Focal loss gamma parameter

  • label_smoothing (float) – Label smoothing factor

  • use_class_weights (bool) – Use class weights for imbalanced data

  • obj_class_weights (Optional[dict]) – Object class weights

  • rel_class_weights (Optional[dict]) – Relation class weights

  • use_auxiliary_loss (bool) – Use auxiliary loss

  • aux_loss_weight (float) – Auxiliary loss weight

  • use_contrastive_loss (bool) – Use contrastive loss

  • contrastive_loss_weight (float) – Contrastive loss weight

  • contrastive_temperature (float) – Contrastive loss temperature

  • use_consistency_loss (bool) – Use consistency loss

  • consistency_loss_weight (float) – Consistency loss weight

  • use_regularization_loss (bool) – Use regularization loss

  • regularization_loss_weight (float) – Regularization loss weight

obj_loss_weight: float = 1.0
rel_loss_weight: float = 2.0
bbox_loss_weight: float = 2.5
giou_loss_weight: float = 1.0
use_focal_loss: bool = True
focal_alpha: float = 0.25
focal_gamma: float = 2.0
label_smoothing: float = 0.0
use_class_weights: bool = False
obj_class_weights: dict | None = None
rel_class_weights: dict | None = None
use_auxiliary_loss: bool = False
aux_loss_weight: float = 0.1
use_contrastive_loss: bool = False
contrastive_loss_weight: float = 0.1
contrastive_temperature: float = 0.07
use_consistency_loss: bool = False
consistency_loss_weight: float = 0.1
use_regularization_loss: bool = False
regularization_loss_weight: float = 0.01
__init__(loss_weights: ~typing.Dict[str, float] = <factory>, label_smoothing: float = 0.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, class_weights: ~typing.Dict[str, float] | None = None, obj_loss_weight: float = 1.0, rel_loss_weight: float = 2.0, bbox_loss_weight: float = 2.5, giou_loss_weight: float = 1.0, use_focal_loss: bool = True, use_class_weights: bool = False, obj_class_weights: dict | None = None, rel_class_weights: dict | None = None, use_auxiliary_loss: bool = False, aux_loss_weight: float = 0.1, use_contrastive_loss: bool = False, contrastive_loss_weight: float = 0.1, contrastive_temperature: float = 0.07, use_consistency_loss: bool = False, consistency_loss_weight: float = 0.1, use_regularization_loss: bool = False, regularization_loss_weight: float = 0.01) None
class m3sgg.core.config.structured.oed.OEDDataConfig(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_augmentation: bool = False, augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_hard_negative_mining: bool = False, hard_negative_ratio: float = 0.2, use_contrastive_sampling: bool = False, contrastive_ratio: float = 0.3, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_dynamic_sampling: bool = False, dynamic_sampling_alpha: float = 0.5, use_balanced_sampling: bool = False, balanced_sampling_alpha: float = 0.5)[source]

Bases: DataConfig

OED-specific data configuration.

Parameters:
  • max_objects (int) – Maximum number of objects per frame

  • max_relations (int) – Maximum number of relations per frame

  • max_frames (int) – Maximum number of frames per video

  • use_temporal_sampling (bool) – Use temporal sampling

  • temporal_stride (int) – Temporal stride for sampling

  • use_augmentation (bool) – Use data augmentation

  • augmentation_prob (float) – Probability of applying augmentation

  • use_negative_sampling (bool) – Use negative sampling for relations

  • negative_ratio (float) – Ratio of negative samples

  • use_hard_negative_mining (bool) – Use hard negative mining

  • hard_negative_ratio (float) – Ratio of hard negative samples

  • use_contrastive_sampling (bool) – Use contrastive sampling

  • contrastive_ratio (float) – Ratio of contrastive samples

  • use_curriculum_learning (bool) – Use curriculum learning

  • curriculum_epochs (int) – Number of epochs for curriculum

  • use_dynamic_sampling (bool) – Use dynamic sampling

  • dynamic_sampling_alpha (float) – Dynamic sampling alpha

  • use_balanced_sampling (bool) – Use balanced sampling

  • balanced_sampling_alpha (float) – Balanced sampling alpha

max_objects: int = 50
max_relations: int = 100
max_frames: int = 30
use_temporal_sampling: bool = True
temporal_stride: int = 1
use_augmentation: bool = False
augmentation_prob: float = 0.5
use_negative_sampling: bool = True
negative_ratio: float = 0.5
use_hard_negative_mining: bool = False
hard_negative_ratio: float = 0.2
use_contrastive_sampling: bool = False
contrastive_ratio: float = 0.3
use_curriculum_learning: bool = False
curriculum_epochs: int = 10
use_dynamic_sampling: bool = False
dynamic_sampling_alpha: float = 0.5
use_balanced_sampling: bool = False
balanced_sampling_alpha: float = 0.5
__init__(cache_dir: str = 'data/cache', pin_memory: bool = False, shuffle: bool = True, drop_last: bool = False, prefetch_factor: int = 2, persistent_workers: bool = False, max_objects: int = 50, max_relations: int = 100, max_frames: int = 30, use_temporal_sampling: bool = True, temporal_stride: int = 1, use_augmentation: bool = False, augmentation_prob: float = 0.5, use_negative_sampling: bool = True, negative_ratio: float = 0.5, use_hard_negative_mining: bool = False, hard_negative_ratio: float = 0.2, use_contrastive_sampling: bool = False, contrastive_ratio: float = 0.3, use_curriculum_learning: bool = False, curriculum_epochs: int = 10, use_dynamic_sampling: bool = False, dynamic_sampling_alpha: float = 0.5, use_balanced_sampling: bool = False, balanced_sampling_alpha: float = 0.5) None

Detectors

class m3sgg.core.detectors.faster_rcnn.detector(train, object_classes, use_SUPPLY, mode='predcls')[source]

Bases: Module

Object detector module for scene graph generation.

Implements object detection functionality using Faster R-CNN backbone for scene graph generation tasks including predcls, sgcls, and sgdet modes.

Parameters:

nn.Module (class) – Base PyTorch module class

__init__(train, object_classes, use_SUPPLY, mode='predcls')[source]

Initialize the object detector.

Parameters:
  • train (bool) – Whether in training mode

  • object_classes (list) – List of object class names

  • use_SUPPLY (bool) – Whether to use SUPPLY relations

  • mode (str, optional) – Detection mode (‘predcls’, ‘sgcls’, ‘sgdet’), defaults to “predcls”

Returns:

None

Return type:

None

forward(im_data, im_info, gt_boxes, num_boxes, gt_annotation, im_all)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Detector factory for creating different object detection models.

This module provides a unified interface for creating various object detection models including Faster R-CNN, ViT-based detectors, DETR, YOLO, and ResNet variants.

class m3sgg.core.detectors.factory.DetectorFactory(config, dataset_train, device, logger: Logger | None = None)[source]

Bases: object

Factory class for creating object detection models based on configuration.

This factory provides a unified interface for creating different object detection models including Faster R-CNN, ViT-based detectors, DETR, YOLO, and ResNet variants. It handles detector instantiation logic and provides a consistent interface across different detector types for improved modularity and maintainability.

Parameters:
  • config (Config) – Configuration object containing detector parameters

  • dataset_train (Dataset) – Training dataset for extracting class information

  • device (torch.device) – Device to place the detector on

  • logger (Optional[logging.Logger]) – Optional logger instance

__init__(config, dataset_train, device, logger: Logger | None = None)[source]

Initialize the detector factory.

Parameters:
  • config (Config) – Configuration object containing detector parameters

  • dataset_train (Dataset) – Training dataset for extracting class information

  • device (torch.device) – Device to place the detector on

  • logger (Optional[logging.Logger]) – Optional logger instance

create_detector() Module[source]

Create a detector instance based on the configuration.

Returns:

Instantiated detector

Return type:

torch.nn.Module

Raises:

ValueError – If detector type is not supported

class m3sgg.core.detectors.factory.BaseDetector[source]

Bases: ABC

Abstract base class for all detector implementations.

This abstract base class ensures consistent interface across different detector types. All detector implementations must inherit from this class and implement the required abstract methods to maintain compatibility with the detector factory and training pipeline.

Parameters:

ABC (ABC) – Abstract base class from abc module

abstract forward(*args, **kwargs) Dict[str, Any][source]

Forward pass of the detector.

This method performs the forward pass of the detector and returns a dictionary containing detection results including bounding boxes, class predictions, and confidence scores.

Parameters:
  • args (Any) – Variable length argument list

  • kwargs (Any) – Arbitrary keyword arguments

Returns:

Dictionary containing detection results

Return type:

Dict[str, Any]

abstract get_feature_extractor() Module[source]

Get the feature extraction backbone.

Returns the feature extraction backbone module used by the detector. This is typically the CNN or transformer backbone that extracts features from input images.

Returns:

Feature extraction module

Return type:

torch.nn.Module

abstract get_classifier() Module[source]

Get the classification head.

Returns the classification head module used by the detector. This module is responsible for predicting object classes based on the extracted features.

Returns:

Classification module

Return type:

torch.nn.Module

EASG Detectors

class m3sgg.core.detectors.easg.object_detector_EASG.detector(train, object_classes, use_SUPPLY, mode='edgecls')[source]

Bases: Module

Object detector module for EASG (Efficient and Accurate Scene Graph) generation.

Implements object detection functionality specifically designed for EASG scene graph generation with video-based detection capabilities.

Parameters:

nn.Module (class) – Base PyTorch module class

__init__(train, object_classes, use_SUPPLY, mode='edgecls')[source]

Initialize the EASG object detector.

Parameters:
  • train (bool) – Whether in training mode

  • object_classes (list) – List of object class names

  • use_SUPPLY (bool) – Whether to use SUPPLY relations

  • mode (str, optional) – Detection mode, defaults to “edgecls”

Returns:

None

Return type:

None

forward(im_data, im_info, gt_grounding, im_all)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Let’s get the relationships yo

class m3sgg.core.detectors.easg.sttran_EASG.ObjectClassifier(mode='edgecls', obj_classes=None)[source]

Bases: Module

Module for computing object contexts and edge contexts for EASG.

EASG-specific implementation of object classification and contextual feature extraction for efficient scene graph generation.

Parameters:

nn.Module (class) – Base PyTorch module class

__init__(mode='edgecls', obj_classes=None)[source]

Initialize the EASG object classifier.

Parameters:
  • mode (str, optional) – Classification mode, defaults to “edgecls”

  • obj_classes (list, optional) – List of object class names, defaults to None

Returns:

None

Return type:

None

forward(entry)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class m3sgg.core.detectors.easg.sttran_EASG.ActionClassifier(mode='edgecls', verb_classes=None)[source]

Bases: Module

__init__(mode='edgecls', verb_classes=None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(entry)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class m3sgg.core.detectors.easg.sttran_EASG.STTran(mode='edgecls', obj_classes=None, verb_classes=None, edge_class_num=None, enc_layer_num=None, dec_layer_num=None, use_visual_features=False)[source]

Bases: Module

__init__(mode='edgecls', obj_classes=None, verb_classes=None, edge_class_num=None, enc_layer_num=None, dec_layer_num=None, use_visual_features=False)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(entry)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Evaluation

class m3sgg.core.evaluation.metrics.BasicSceneGraphEvaluator(mode, AG_object_classes, AG_all_predicates, AG_attention_predicates, AG_spatial_predicates, AG_contacting_predicates, iou_threshold=0.5, constraint=False, semithreshold=None, logger=None)[source]

Bases: object

Evaluator for scene graph generation tasks.

Computes recall metrics for scene graph generation across different tasks (predcls, sgcls, sgdet) and handles constraint evaluation modes.

Parameters:

object (class) – Base object class

__init__(mode, AG_object_classes, AG_all_predicates, AG_attention_predicates, AG_spatial_predicates, AG_contacting_predicates, iou_threshold=0.5, constraint=False, semithreshold=None, logger=None)[source]

Initialize the scene graph evaluator.

Parameters:
  • mode (str) – Evaluation mode (‘predcls’, ‘sgcls’, or ‘sgdet’)

  • AG_object_classes (list) – List of object class names

  • AG_all_predicates (list) – List of all predicate names

  • AG_attention_predicates (list) – List of attention predicate names

  • AG_spatial_predicates (list) – List of spatial predicate names

  • AG_contacting_predicates (list) – List of contacting predicate names

  • iou_threshold (float, optional) – IoU threshold for evaluation, defaults to 0.5

  • constraint (bool, optional) – Whether to use constraint evaluation, defaults to False

  • semithreshold (float, optional) – Semi-constraint threshold, defaults to None

  • logger (logging.Logger, optional) – Logger instance, defaults to None

Returns:

None

Return type:

None

reset_result()[source]
calc_mrecall()[source]
print_stats()[source]
evaluate_scene_graph(gt, pred, return_per_sample=False)[source]

collect the groundtruth and prediction

m3sgg.core.evaluation.metrics.evaluate_from_dict(gt_entry, pred_entry, mode, result_dict, method=None, threshold=0.9, **kwargs)[source]

Shortcut to doing evaluate_recall from dict :param gt_entry: Dictionary containing gt_relations, gt_boxes, gt_classes :param pred_entry: Dictionary containing pred_rels, pred_boxes (if detection), pred_classes :param result_dict: :param kwargs: :return:

m3sgg.core.evaluation.metrics.evaluate_recall(gt_rels, gt_boxes, gt_classes, pred_rels, pred_boxes, pred_classes, rel_scores=None, cls_scores=None, iou_thresh=0.5, phrdet=False)[source]

Evaluates recall metrics for scene graph generation.

Computes recall by matching predicted relations to ground truth relations based on object detection IoU and relation class matching.

Parameters:
  • gt_rels (numpy.ndarray) – Ground truth relations array of shape [#gt_rel, 3]

  • gt_boxes (numpy.ndarray) – Ground truth bounding boxes of shape [#gt_box, 4]

  • gt_classes (numpy.ndarray) – Ground truth object classes of shape [#gt_box]

  • pred_rels (numpy.ndarray) – Predicted relations array of shape [#pred_rel, 3] (id0, id1, rel)

  • pred_boxes (numpy.ndarray) – Predicted bounding boxes of shape [#pred_box, 4]

  • pred_classes (numpy.ndarray) – Predicted object classes of shape [#pred_box]

  • rel_scores (numpy.ndarray, optional) – Relation scores, defaults to None

  • cls_scores (numpy.ndarray, optional) – Classification scores, defaults to None

  • iou_thresh (float, optional) – IoU threshold for matching, defaults to 0.5

  • phrdet (bool, optional) – Whether to use phrase detection mode, defaults to False

Returns:

Tuple containing predicate-to-GT matching, predicted 5-tuples, and relation scores

Return type:

tuple

Training

Trainer class for modularized training loop.

class m3sgg.core.training.trainer.Trainer(model: Module, config: Config, dataloader_train: DataLoader, dataloader_test: DataLoader, optimizer: Optimizer, scheduler: ReduceLROnPlateau, logger: Logger, object_detector: Module | None = None, object_detector_EASG: Module | None = None, matcher: Any | None = None, evaluator: BasicSceneGraphEvaluator | None = None, evaluator2: BasicSceneGraphEvaluator | None = None, dataset_train: Any | None = None, dataset_test: Any | None = None)[source]

Bases: object

Main trainer class for scene graph generation models.

This class encapsulates the training loop, epoch management, and step execution for various scene graph generation models.

Parameters:
__init__(model: Module, config: Config, dataloader_train: DataLoader, dataloader_test: DataLoader, optimizer: Optimizer, scheduler: ReduceLROnPlateau, logger: Logger, object_detector: Module | None = None, object_detector_EASG: Module | None = None, matcher: Any | None = None, evaluator: BasicSceneGraphEvaluator | None = None, evaluator2: BasicSceneGraphEvaluator | None = None, dataset_train: Any | None = None, dataset_test: Any | None = None)[source]

Initialize the Trainer with all necessary components.

Parameters:
  • model – The model to train

  • config – Configuration object

  • dataloader_train – Training data loader

  • dataloader_test – Test data loader

  • optimizer – Optimizer

  • scheduler – Learning rate scheduler

  • logger – Logger instance

  • object_detector – Object detector for Action Genome dataset

  • object_detector_EASG – Object detector for EASG dataset

  • matcher – Hungarian matcher for DSG-DETR

  • evaluator – Scene graph evaluator

  • evaluator2 – Secondary evaluator without constraints

  • dataset_train – Training dataset

  • dataset_test – Test dataset

train_loop() None[source]

Main training loop that orchestrates the entire training process.

This method runs the complete training process including all epochs, evaluation, and checkpoint saving.

train_epoch(epoch: int) None[source]

Train the model for one epoch.

Parameters:

epoch (int) – Current epoch number

train_step(batch_idx: int, train_iter: iter, unc_vals: Any | None = None) Dict[str, Tensor][source]

Execute one training step.

Parameters:
  • batch_idx (int) – Current batch index

  • train_iter (iter) – Training data iterator

  • unc_vals (Optional[Any]) – Uncertainty values for Tempura model

Returns:

Dictionary of loss components

Return type:

Dict[str, torch.Tensor]

train_iter(max_iterations: int | None = None) Iterator[Dict[str, Any]][source]

Iterative training function that yields training progress.

This function provides an iterator interface for training, allowing for more granular control over the training process and real-time monitoring.

Parameters:

max_iterations (Optional[int]) – Maximum number of iterations to run, defaults to None (run until epoch complete)

Yield:

Dictionary containing training progress information

Return type:

Iterator[Dict[str, Any]]

evaluate_epoch(epoch: int) Tuple[float, float][source]

Evaluate the model for one epoch.

Parameters:

epoch (int) – Current epoch number

Returns:

Tuple of (score, mrecall)

Return type:

Tuple[float, float]

save_predictions_csv(predictions_data: List[Dict]) None[source]

Save final predictions as CSV file.

Parameters:

predictions_data (List[Dict]) – List of prediction dictionaries

save_checkpoints(epoch: int, score: float, mrecall: float) None[source]

Save model checkpoints if this is the best model.

Parameters:
  • epoch (int) – Current epoch number

  • score (float) – Current score

  • mrecall (float) – Current mrecall

Evaluation class for modularized evaluation loop.

This module provides a clean separation of evaluation logic from the main training script.

class m3sgg.core.training.evaluation.Evaluator(evaluator: BasicSceneGraphEvaluator, evaluator2: BasicSceneGraphEvaluator | None = None, logger: Logger | None = None)[source]

Bases: object

Evaluation class for scene graph generation models.

This class encapsulates the evaluation loop and metrics computation for various scene graph generation models.

Parameters:
__init__(evaluator: BasicSceneGraphEvaluator, evaluator2: BasicSceneGraphEvaluator | None = None, logger: Logger | None = None)[source]

Initialize the Evaluator with necessary components.

Parameters:
  • evaluator – Primary scene graph evaluator

  • evaluator2 – Secondary evaluator without constraints

  • logger – Logger instance

eval_loop(model: Module, dataloader_test: DataLoader, config: Any, object_detector: Module | None = None, object_detector_EASG: Module | None = None, matcher: Any | None = None, dataset_test: Any | None = None) Tuple[float, float][source]

Run the complete evaluation loop.

Parameters:
  • model (torch.nn.Module) – Model to evaluate

  • dataloader_test (torch.utils.data.DataLoader) – Test data loader

  • config (Any) – Configuration object

  • object_detector (Optional[torch.nn.Module]) – Object detector for Action Genome dataset

  • object_detector_EASG (Optional[torch.nn.Module]) – Object detector for EASG dataset

  • matcher (Optional[Any]) – Hungarian matcher for DSG-DETR

  • dataset_test (Optional[Any]) – Test dataset

Returns:

Tuple of (score, mrecall)

Return type:

Tuple[float, float]

Model factory for creating different model types based on configuration.

This module contains the model instantiation logic that was extracted from the monolithic training.py script to improve modularity and maintainability.

class m3sgg.core.training.model_factory.ModelFactory(config, dataset_train, device, logger: Logger | None = None)[source]

Bases: object

Factory class for creating model instances based on configuration.

This factory provides a unified interface for creating different scene graph generation models including STTran, DSG-DETR, STKET, TEMPURA, SceneLLM, OED, and VLM models. It handles model instantiation logic that was extracted from the monolithic training script to improve modularity and maintainability.

Parameters:
  • config (Config) – Configuration object containing model parameters

  • dataset_train (Dataset) – Training dataset for extracting class information

  • device (torch.device) – Device to place the model on

  • logger (Optional[logging.Logger]) – Optional logger instance

__init__(config, dataset_train, device, logger: Logger | None = None)[source]

Initialize the model factory.

Parameters:
  • config (Config) – Configuration object containing model parameters

  • dataset_train (Dataset) – Training dataset for extracting class information

  • device (torch.device) – Device to place the model on

  • logger (Optional[logging.Logger]) – Optional logger instance

create_model() Module[source]

Create a model instance based on the configuration.

Returns:

Instantiated model

Return type:

torch.nn.Module

Raises:

ValueError – If dataset or model type is not supported

Loss factory for creating different loss functions based on model type and configuration.

This module contains the loss function setup logic that was extracted from the monolithic training.py script to improve modularity and maintainability.

class m3sgg.core.training.loss_factory.LossFactory(config, model, device, logger: Logger | None = None)[source]

Bases: object

Factory class for creating loss functions based on configuration.

This factory provides a unified interface for creating different loss functions based on model type and configuration. It handles loss function setup logic that was extracted from the monolithic training script to improve modularity and maintainability. Supports basic losses for all models and model-specific losses for TEMPURA models.

Parameters:
  • config (Config) – Configuration object containing loss parameters

  • model (torch.nn.Module) – Model instance for extracting class information

  • device (torch.device) – Device to place loss functions on

  • logger (Optional[logging.Logger]) – Optional logger instance

__init__(config, model, device, logger: Logger | None = None)[source]

Initialize the loss factory.

Parameters:
  • config (Config) – Configuration object containing loss parameters

  • model (torch.nn.Module) – Model instance for extracting class information

  • device (torch.device) – Device to place loss functions on

  • logger (Optional[logging.Logger]) – Optional logger instance

create_losses() Dict[str, Module | None][source]

Create loss functions based on the configuration.

Returns:

Dictionary containing loss functions

Return type:

Dict[str, Union[nn.Module, None]]