Homework 3: BM, CBM, and DSBM

Tasks

  • Bures-Wasserstein Distance (4 pts)
  • Bridge Matching and Conditional Bridge Matching (8 pts)
  • Diffusion Schrödinger Bridge Matching (8 pts)

Training Cells

Bridge Matching Setup

# ===========
# Hyperparameters: batch size, train steps, posterior-sampling NFE, hidden width, LR.
BATCH_SIZE = 1024  # any size that fits on GPU
ITERATIONS = 10000
SDE_STEPS = 50  # >= 50
HIDDEN_DIM = 128
TIME_EMB_DIM = 32
LR = 3e-4
# ===========

datamodule = BenchmarkDataModule(
    batch_size=BATCH_SIZE,
    val_batch_size=256,
    input_sample=input_sampler.sample,
    target_sample=target_sampler.sample,
    testset_x=testset_x,
    testset_y=testset_y,
)

Bridge Matching

# ===========
# Endpoint-predictor net g_theta(x_t, t) -> x_1 for unconditional BM.
# NOTE: input_dim = DIM (no x_0 concat); for CBM use input_dim = 2 * DIM.
map_net = TimeConditionedResidualMLP(input_dim=DIM, output_dim=DIM, hidden_dim=HIDDEN_DIM, time_emb_dim=TIME_EMB_DIM).to(DEVICE)
# ===========

map_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)

module = BridgeMatching(
    eps=EPS,
    dim=DIM,
    sde_steps=SDE_STEPS,
    num_cond_samples=1000,
    conditional_sampler=conditional_sampler,
    map_net=map_net,
    map_opt=map_opt,
    grad_clip_norm=1.0,  # use float("inf") for no clipping
    is_cond=False,
)

plotter_callback = BenchmarkPlotterCallback(
    DIM,
    1536,
    4,
    3,
    benchmark_sample=conditional_sampler.sample,
    notebook_display_id_prefix="cot_bm",
)
bm_loss_cb = LossCurveCallback(
    train_metrics=["map_loss"],
    val_metrics=["cond_bw_uvp"],
    train_log_every_n_steps=50,
    include_step_zero=True,
    notebook_display_id_prefix="cot_bm",
)
trainer = Trainer(
    accelerator="gpu",
    max_epochs=1,
    val_check_interval=1000,
    callbacks=[plotter_callback, bm_loss_cb],
    limit_train_batches=ITERATIONS,
    enable_checkpointing=False,
    log_every_n_steps=50,
)
trainer.fit(module, datamodule=datamodule)

Conditional Bridge Matching

# ===========
# Conditional endpoint-predictor net g_theta(x_t, t, x_0) -> x_1.
# NOTE: input_dim = 2 * DIM because (x_t, x_0) are concatenated at the input;
# this matches the `is_cond=True` branch in BridgeMatching.training_step.
cbm_map_net = TimeConditionedResidualMLP(
    input_dim=2 * DIM,
    output_dim=DIM,
    hidden_dim=HIDDEN_DIM,
    time_emb_dim=TIME_EMB_DIM,
).to(DEVICE)
# ===========

cbm_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)

cbm_module = BridgeMatching(
    eps=EPS,
    dim=DIM,
    sde_steps=SDE_STEPS,
    num_cond_samples=1000,
    conditional_sampler=conditional_sampler,
    map_net=cbm_map_net,
    map_opt=cbm_opt,
    grad_clip_norm=1.0,
    is_cond=True,
)

cbm_plotter_callback = BenchmarkPlotterCallback(
    DIM,
    1536,
    4,
    3,
    benchmark_sample=conditional_sampler.sample,
    notebook_display_id_prefix="cot_cbm",
)
cbm_loss_cb = LossCurveCallback(
    train_metrics=["map_loss"],
    val_metrics=["cond_bw_uvp"],
    train_log_every_n_steps=50,
    include_step_zero=True,
    notebook_display_id_prefix="cot_cbm",
)

cbm_trainer = Trainer(
    accelerator="gpu",
    max_epochs=1,
    val_check_interval=1000,
    callbacks=[cbm_plotter_callback, cbm_loss_cb],
    limit_train_batches=ITERATIONS,
    enable_checkpointing=False,
    log_every_n_steps=50,
)
cbm_trainer.fit(cbm_module, datamodule=datamodule)

Diffusion Schrödinger Bridge Matching

# ===========
# Hyperparameters: posterior-sampling NFE.
DSBM_TRAIN_NFE = 25
DSBM_WARMUP_STEPS = 200
# ===========

# ===========
# Forward Markov projection f_theta(x_t, t) -> x_1.
forward_net = TimeConditionedResidualMLP(
    input_dim=DIM,
    output_dim=DIM,
    hidden_dim=HIDDEN_DIM,
    time_emb_dim=TIME_EMB_DIM,
).to(DEVICE)
# ===========

# ===========
# Backward Markov projection b_phi(x_t, t) -> x_0.
# Used for backward sampling and online IMF coupling regeneration.
backward_net = TimeConditionedResidualMLP(
    input_dim=DIM,
    output_dim=DIM,
    hidden_dim=HIDDEN_DIM,
    time_emb_dim=TIME_EMB_DIM,
).to(DEVICE)
# ===========

dsbm_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)

dsbm_module = DSBM(
    eps=EPS,
    dim=DIM,
    num_cond_samples=1000,
    conditional_sampler=conditional_sampler,
    forward_net=forward_net,
    backward_net=backward_net,
    opt_ctor=dsbm_opt,
    nfe_train=DSBM_TRAIN_NFE,
    warmup_steps=DSBM_WARMUP_STEPS,
    grad_clip_norm=1.0,
)


dsbm_plotter_callback = BenchmarkPlotterCallback(
    DIM,
    1536,
    4,
    3,
    benchmark_sample=conditional_sampler.sample,
    notebook_display_id_prefix="cot_dsbm",
)
dsbm_loss_cb = LossCurveCallback(
    train_metrics=["dsbm_total_loss", "dsbm_fw_loss", "dsbm_bw_loss"],
    val_metrics=["cond_bw_uvp"],
    train_log_every_n_steps=50,
    include_step_zero=True,
    notebook_display_id_prefix="cot_dsbm",
)

dsbm_trainer = Trainer(
    accelerator="gpu",
    max_epochs=1,
    val_check_interval=1000,
    callbacks=[dsbm_plotter_callback, dsbm_loss_cb],
    limit_train_batches=ITERATIONS,
    enable_checkpointing=False,
    log_every_n_steps=50,
)

dsbm_trainer.fit(dsbm_module, datamodule=datamodule)

Back to top