Source code for m3sgg.core.models.scenellm.ot

"""
Optimal Transport Codebook Updater implementation for SceneLLM.
Credit to the authors of the original code: https://doi.org/10.1016/j.patcog.2025.111992.
"""

import torch
import torch.nn.functional as F

try:
    import ot

    OT_AVAILABLE = True
except ImportError as e:
    OT_AVAILABLE = False
    print(
        f"Warning: POT not available ({e}). Install with 'pip install pot'. OT codebook update will use simple merging."
    )


[docs] class OTCodebookUpdater:
[docs] def __init__(self, base_codebook, step=512, max_iterations=10): self.step = step self.C = base_codebook # nn.Embedding self.max_iterations = max_iterations self.reg = 0.01 # Regularization parameter for Sinkhorn
[docs] def update(self, usage_hist): """ Update codebook using Optimal Transport scheme. Args: usage_hist: tensor of shape [codebook_size] with usage frequencies Returns: updated_codebook: new embedding weight matrix """ codebook_size = self.C.weight.size(0) embed_dim = self.C.weight.size(1) usage_prob = F.softmax( usage_hist + 1e-8, dim=0 ) # Add small epsilon to avoid zeros current_entropy = -torch.sum(usage_prob * torch.log(usage_prob + 1e-8)) best_codebook = self.C.weight.clone() for k in range(1, self.max_iterations + 1): new_size = max( codebook_size - k * self.step, self.step ) # Reduce size gradually if new_size >= codebook_size: continue c = self._compute_cost_matrix(usage_prob, new_size) transport_plan = self._sinkhorn_solver(usage_prob, new_size, c) new_codebook = self._generate_new_codebook( transport_plan, new_size, embed_dim ) # Calculate new entropy (simplified) new_usage = torch.sum(transport_plan, dim=0) new_usage_prob = F.softmax(new_usage + 1e-8, dim=0) new_entropy = -torch.sum(new_usage_prob * torch.log(new_usage_prob + 1e-8)) # Check if entropy decreased significantly entropy_decrease = current_entropy - new_entropy if entropy_decrease > 0.01: # Threshold for significant decrease best_codebook = new_codebook break return best_codebook
def _compute_cost_matrix(self, usage_prob, new_size): """Compute cost matrix for transport. TODO: Improve cost matrix with semantic similarity. """ device = usage_prob.device old_size = usage_prob.size(0) # Simple cost based on indices cost_matrix = torch.zeros(old_size, new_size, device=device) for i in range(old_size): for j in range(new_size): # Cost based on distance between indices (normalized) cost_matrix[i, j] = abs(i / old_size - j / new_size) return cost_matrix def _sinkhorn_solver(self, source_prob, target_size, cost_matrix): """Solve optimal transport using Sinkhorn algorithm.""" device = source_prob.device source_size = source_prob.size(0) target_prob = torch.ones(target_size, device=device) / target_size # Convert to numpy for POT library source_np = source_prob.detach().cpu().numpy() target_np = target_prob.detach().cpu().numpy() cost_np = cost_matrix.detach().cpu().numpy() # Solve OT problem if OT_AVAILABLE: try: transport_plan_np = ot.sinkhorn( source_np, target_np, cost_np, self.reg, numItermax=100 ) transport_plan = torch.from_numpy(transport_plan_np).to(device) except Exception: # Fallback: uniform transport plan transport_plan = torch.ones(source_size, target_size, device=device) transport_plan = transport_plan / transport_plan.sum( dim=1, keepdim=True ) else: # Simple fallback when OT library not available transport_plan = torch.ones(source_size, target_size, device=device) transport_plan = transport_plan / transport_plan.sum(dim=1, keepdim=True) return transport_plan def _generate_new_codebook(self, transport_plan, new_size, embed_dim): """Generate new codebook based on transport plan.""" device = self.C.weight.device # Weighted combination of old codewords new_codebook = torch.zeros(new_size, embed_dim, device=device) for j in range(new_size): # Normalize weights for this new codeword weights = transport_plan[:, j] weights = weights / (weights.sum() + 1e-8) # Weighted average of old codewords new_codebook[j] = torch.sum(weights.unsqueeze(1) * self.C.weight, dim=0) return new_codebook