Homework 2: ENOT vs. SCONES for Entropic OT

Tasks

  • Bures-Wasserstein Distance (4 pts)
  • Entropic Neural Optimal Transport (8 pts)
  • SCONES (8 pts)

Training Cells

Entropic Neural Optimal Transport

# ===========
# Hyperparameters: batch size, train steps, EM resolution, map/potential step ratio, LR.
BATCH_SIZE =   # any size that fits on GPU
ITERATIONS = 
SDE_STEPS =   # >= 50
NUM_MAP_STEPS =   # Ratio of mapping updates to potential updates
INTEGRAL_SCALE = 
HIDDEN_DIM = 
LR = 
# ===========

datamodule = BenchmarkDataModule(
    batch_size=BATCH_SIZE,
    val_batch_size=256,
    input_sample=input_sampler.sample,
    target_sample=target_sampler.sample,
    testset_x=testset_x,
    testset_y=testset_y,
)

# ===========
# Drift network v_theta(x, t); input concat [x, t] so first layer is DIM + 1.
map_net = nn.Sequential(
    nn.Linear(DIM + 1, HIDDEN_DIM),
    nn.ReLU(),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
    nn.ReLU(),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
    nn.ReLU(),
    nn.Linear(HIDDEN_DIM, DIM),
)
# ===========

# ===========
# Terminal potential beta_phi(y); scalar-output MLP.
potential_net = nn.Sequential(
    nn.Linear(DIM, HIDDEN_DIM),
    nn.ReLU(),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
    nn.ReLU(),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
    nn.ReLU(),
    nn.Linear(HIDDEN_DIM, 1),
)
# ===========

map_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)
potential_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)

module = ENOT(
    eps=EPS,
    dim=DIM,
    sde_steps=SDE_STEPS,
    integral_scale=INTEGRAL_SCALE,
    num_map_steps=NUM_MAP_STEPS,
    num_cond_samples=1000,
    conditional_sampler=conditional_sampler,
    map_net=map_net,
    potential_net=potential_net,
    map_opt=map_opt,
    potential_opt=potential_opt,
    grad_clip_norm=1.0,  # use float("inf") for no clipping
    n_last_steps_without_noise=1,  # omit diffusion on last EM steps (noise-free tail; common ENOT setup)
)

# Adjust plotter / logger / trainer options below if needed.
plotter_callback = BenchmarkPlotterCallback(
    DIM, 1536, 4, 3, benchmark_sample=conditional_sampler.sample, notebook_display_id_prefix="cot_enot"
)
enot_loss_cb = LossCurveCallback(
    train_metrics=["map_loss", "potential_loss"],
    val_metrics=["cond_bw_uvp"],
    train_log_every_n_steps=50,
    include_step_zero=True,
    notebook_display_id_prefix="cot_enot",
)
comet_logger = CometLogger(
    project="cot",
    name=f"enot_dim_{DIM}_eps_{EPS}",
    mode="create",  # new experiment each run (default get_or_create reuses same experiment)
)
trainer = Trainer(
    accelerator="gpu",
    max_epochs=1,
    val_check_interval=1000,
    callbacks=[plotter_callback, enot_loss_cb],
    limit_train_batches=ITERATIONS,
    enable_checkpointing=False,
    logger=comet_logger,
    log_every_n_steps=50,
)
trainer.fit(module, datamodule=datamodule)

Score Pretraining

# ===========
# NCSN pretrain: hyperparameters, score MLP (DIM+1 input), Trainer.fit.
NCSN_BATCH_SIZE =   # more samples/step => lower-variance score-matching gradients (raise if GPU memory allows)
NCSN_STEPS =   
NCSN_LR = 
NCSN_NOISE_LEVELS = 
HIDDEN_DIM_SCORE = 

score_opt_ncsn = partial(torch.optim.Adam, lr=NCSN_LR, weight_decay=1e-10)
# Target GM as torch.distributions.MixtureSameFamily; NCSN validation calls .log_prob(y) on this object.
reference_data = benchmark_gaussian_mixture_as_mixture_same_family(...)
# ===========

score_mlp = nn.Sequential(
    nn.Linear(DIM + 1, HIDDEN_DIM_SCORE),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM_SCORE, HIDDEN_DIM_SCORE),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM_SCORE, HIDDEN_DIM_SCORE),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM_SCORE, DIM),
).to(DEVICE)

ncsn_module = NCSN(
    score_mlp,
    data_sampler=target_sampler,
    score_opt=score_opt_ncsn,
    noise_levels=NCSN_NOISE_LEVELS,
    reference_distribution=reference_data,
).to(DEVICE)

datamodule_ncsn = BenchmarkDataModule(
    batch_size=NCSN_BATCH_SIZE,
    val_batch_size=min(NCSN_BATCH_SIZE, 512),
    input_sample=input_sampler.sample,
    target_sample=target_sampler.sample,
    testset_x=testset_x,
    testset_y=testset_y,
)

ncsn_loss_cb = LossCurveCallback(
    train_metrics=["train_loss"],
    val_metrics=["val_score_ref_mse"],
    train_log_every_n_steps=min(50, NCSN_STEPS),
    include_step_zero=True,
    notebook_display_id_prefix="cot_ncsn",
)

trainer_ncsn = Trainer(
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    max_epochs=1,
    val_check_interval=min(500, NCSN_STEPS),
    limit_train_batches=NCSN_STEPS,
    enable_checkpointing=False,
    logger=False,
    log_every_n_steps=min(50, NCSN_STEPS),
    callbacks=[ncsn_loss_cb],
)
trainer_ncsn.fit(ncsn_module, datamodule=datamodule_ncsn)

ncsn_module.eval()
NOISE_LEVELS = tuple(float(x) for x in ncsn_module.sigmas.cpu().tolist())

SCONES

# ===========
# SCONES train: batch size, steps, nets, Langevin schedule (reuse NCSN noise levels).
BATCH_SIZE = 
ITERATIONS = 
HIDDEN_DIM = 
LR = 
SAMPLING_LR =   # Langevin step scaling in SCONES.sample (annealed schedule)
LANGEVIN_INNER = 
# Use the same noise levels as the pretrained NCSN (``NOISE_LEVELS`` from the pretrain cell).
NOISE_LEVELS = tuple(float(x) for x in ncsn_module.sigmas.cpu().tolist())
# ===========

datamodule = BenchmarkDataModule(
    batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    input_sample=input_sampler.sample,
    target_sample=target_sampler.sample,
    testset_x=testset_x,
    testset_y=testset_y,
)

phi_net = nn.Sequential(
    nn.Linear(DIM, HIDDEN_DIM),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM, 1),
)
psi_net = nn.Sequential(
    nn.Linear(DIM, HIDDEN_DIM),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
    nn.ReLU(True),
    nn.Linear(HIDDEN_DIM, 1),
)

dual_opt = partial(torch.optim.Adam, lr=LR)

module = SCONES(
    EPS,
    DIM,
    1000,
    conditional_sampler,
    phi_net,
    psi_net,
    ncsn_module,
    dual_opt,
    dual_sched=None,
    sampling_lr=SAMPLING_LR,
    noise_levels=NOISE_LEVELS,
    langevin_steps_per_level=LANGEVIN_INNER,
    tweedie_denoise=True,
)

plotter_callback = BenchmarkPlotterCallback(
    DIM, 1536, 4, 3, benchmark_sample=conditional_sampler.sample, notebook_display_id_prefix="cot_scones"
)
loss_curve_callback = LossCurveCallback(
    train_metrics=["dual", "h_star_mean"],
    val_metrics=["cond_bw_uvp"],
    train_log_every_n_steps=50,
    include_step_zero=True,
    notebook_display_id_prefix="cot_scones",
)
comet_logger = CometLogger(
    project="cot",
    name=f"scones_dim_{DIM}_eps_{EPS}",
    mode="create",  # new experiment each run (default get_or_create reuses same experiment)
)
trainer = Trainer(
    accelerator="gpu",
    max_epochs=1,
    val_check_interval=1000,
    callbacks=[plotter_callback, loss_curve_callback],
    limit_train_batches=ITERATIONS,
    enable_checkpointing=False,
    logger=comet_logger,
    log_every_n_steps=50,
)
trainer.fit(module, datamodule=datamodule)

Back to top