# ===========
# Hyperparameters: batch size, train steps, posterior-sampling NFE, hidden width, LR.
BATCH_SIZE = 1024 # any size that fits on GPU
ITERATIONS = 10000
SDE_STEPS = 50 # >= 50
HIDDEN_DIM = 128
TIME_EMB_DIM = 32
LR = 3e-4
# ===========
datamodule = BenchmarkDataModule(
batch_size=BATCH_SIZE,
val_batch_size=256,
input_sample=input_sampler.sample,
target_sample=target_sampler.sample,
testset_x=testset_x,
testset_y=testset_y,
)Homework 3: BM, CBM, and DSBM
Tasks
- Bures-Wasserstein Distance (4 pts)
- Bridge Matching and Conditional Bridge Matching (8 pts)
- Diffusion Schrödinger Bridge Matching (8 pts)
Training Cells
Bridge Matching Setup
Bridge Matching
# ===========
# Endpoint-predictor net g_theta(x_t, t) -> x_1 for unconditional BM.
# NOTE: input_dim = DIM (no x_0 concat); for CBM use input_dim = 2 * DIM.
map_net = TimeConditionedResidualMLP(input_dim=DIM, output_dim=DIM, hidden_dim=HIDDEN_DIM, time_emb_dim=TIME_EMB_DIM).to(DEVICE)
# ===========
map_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)
module = BridgeMatching(
eps=EPS,
dim=DIM,
sde_steps=SDE_STEPS,
num_cond_samples=1000,
conditional_sampler=conditional_sampler,
map_net=map_net,
map_opt=map_opt,
grad_clip_norm=1.0, # use float("inf") for no clipping
is_cond=False,
)
plotter_callback = BenchmarkPlotterCallback(
DIM,
1536,
4,
3,
benchmark_sample=conditional_sampler.sample,
notebook_display_id_prefix="cot_bm",
)
bm_loss_cb = LossCurveCallback(
train_metrics=["map_loss"],
val_metrics=["cond_bw_uvp"],
train_log_every_n_steps=50,
include_step_zero=True,
notebook_display_id_prefix="cot_bm",
)
trainer = Trainer(
accelerator="gpu",
max_epochs=1,
val_check_interval=1000,
callbacks=[plotter_callback, bm_loss_cb],
limit_train_batches=ITERATIONS,
enable_checkpointing=False,
log_every_n_steps=50,
)
trainer.fit(module, datamodule=datamodule)


Conditional Bridge Matching
# ===========
# Conditional endpoint-predictor net g_theta(x_t, t, x_0) -> x_1.
# NOTE: input_dim = 2 * DIM because (x_t, x_0) are concatenated at the input;
# this matches the `is_cond=True` branch in BridgeMatching.training_step.
cbm_map_net = TimeConditionedResidualMLP(
input_dim=2 * DIM,
output_dim=DIM,
hidden_dim=HIDDEN_DIM,
time_emb_dim=TIME_EMB_DIM,
).to(DEVICE)
# ===========
cbm_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)
cbm_module = BridgeMatching(
eps=EPS,
dim=DIM,
sde_steps=SDE_STEPS,
num_cond_samples=1000,
conditional_sampler=conditional_sampler,
map_net=cbm_map_net,
map_opt=cbm_opt,
grad_clip_norm=1.0,
is_cond=True,
)
cbm_plotter_callback = BenchmarkPlotterCallback(
DIM,
1536,
4,
3,
benchmark_sample=conditional_sampler.sample,
notebook_display_id_prefix="cot_cbm",
)
cbm_loss_cb = LossCurveCallback(
train_metrics=["map_loss"],
val_metrics=["cond_bw_uvp"],
train_log_every_n_steps=50,
include_step_zero=True,
notebook_display_id_prefix="cot_cbm",
)
cbm_trainer = Trainer(
accelerator="gpu",
max_epochs=1,
val_check_interval=1000,
callbacks=[cbm_plotter_callback, cbm_loss_cb],
limit_train_batches=ITERATIONS,
enable_checkpointing=False,
log_every_n_steps=50,
)
cbm_trainer.fit(cbm_module, datamodule=datamodule)


Diffusion Schrödinger Bridge Matching
# ===========
# Hyperparameters: posterior-sampling NFE.
DSBM_TRAIN_NFE = 25
DSBM_WARMUP_STEPS = 200
# ===========
# ===========
# Forward Markov projection f_theta(x_t, t) -> x_1.
forward_net = TimeConditionedResidualMLP(
input_dim=DIM,
output_dim=DIM,
hidden_dim=HIDDEN_DIM,
time_emb_dim=TIME_EMB_DIM,
).to(DEVICE)
# ===========
# ===========
# Backward Markov projection b_phi(x_t, t) -> x_0.
# Used for backward sampling and online IMF coupling regeneration.
backward_net = TimeConditionedResidualMLP(
input_dim=DIM,
output_dim=DIM,
hidden_dim=HIDDEN_DIM,
time_emb_dim=TIME_EMB_DIM,
).to(DEVICE)
# ===========
dsbm_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)
dsbm_module = DSBM(
eps=EPS,
dim=DIM,
num_cond_samples=1000,
conditional_sampler=conditional_sampler,
forward_net=forward_net,
backward_net=backward_net,
opt_ctor=dsbm_opt,
nfe_train=DSBM_TRAIN_NFE,
warmup_steps=DSBM_WARMUP_STEPS,
grad_clip_norm=1.0,
)
dsbm_plotter_callback = BenchmarkPlotterCallback(
DIM,
1536,
4,
3,
benchmark_sample=conditional_sampler.sample,
notebook_display_id_prefix="cot_dsbm",
)
dsbm_loss_cb = LossCurveCallback(
train_metrics=["dsbm_total_loss", "dsbm_fw_loss", "dsbm_bw_loss"],
val_metrics=["cond_bw_uvp"],
train_log_every_n_steps=50,
include_step_zero=True,
notebook_display_id_prefix="cot_dsbm",
)
dsbm_trainer = Trainer(
accelerator="gpu",
max_epochs=1,
val_check_interval=1000,
callbacks=[dsbm_plotter_callback, dsbm_loss_cb],
limit_train_batches=ITERATIONS,
enable_checkpointing=False,
log_every_n_steps=50,
)
dsbm_trainer.fit(dsbm_module, datamodule=datamodule)

