Homework 1: Neural OT vs. Large-Scale OT for Entropic OT

Tasks

  • Bures-Wasserstein Distance (4 pts)
  • Neural Optimal Transport (8 pts)
  • Large-Scale Optimal Transport (8 pts)

Training Cells

Neural Optimal Transport

# ===========
# your code
BATCH_SIZE =  # any size that fits to gpu
ITERATIONS = 
NUM_LATENTS = 
NUM_MAP_STEPS =  # > 50
HIDDEN_DIM = 
LR = 
# ===========

datamodule = BenchmarkDataModule(
    batch_size=BATCH_SIZE,
    val_batch_size=256,
    input_sample=input_sampler.sample,
    target_sample=target_sampler.sample,
    testset_x=testset_x,
    testset_y=testset_y
)

flows = []
# ===========
# your code
# init here list flows using normflows library
# they have great examples of conditional models in repo
# you can use any normalizing flows, e.g. RealNVP

# ===========
base = nf.distributions.DiagGaussian(DIM, trainable=False)
map_net = nf.ConditionalNormalizingFlow(base, flows)

# ===========
# your code
# init here `potential_net`, a simple MLP with HIDDEN_DIM
# NOTE: output is scalar

# ===========

map_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10) 
potential_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10) 

module = NOT(
    EPS, DIM, NUM_LATENTS, NUM_MAP_STEPS, 1000,
    conditional_sampler,
    map_net, potential_net,
    map_opt, potential_opt
)

# you can change here any variables
plotter_callback = BenchmarkPlotterCallback(
    DIM, 1536, 4, 3,
    benchmark_sample=conditional_sampler.sample
)
comet_logger = CometLogger(
    project="cot", name=f"not_dim_{DIM}_eps_{EPS}"
)
trainer = Trainer(
    accelerator="gpu", max_epochs=1, 
    val_check_interval=1000,
    callbacks=[plotter_callback],
    limit_train_batches=ITERATIONS,
    enable_checkpointing=False, logger=comet_logger
)
trainer.fit(module, datamodule=datamodule)
┏━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name           Type                        Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ cond_bw_uvp   │ CondBWUVP                  │      0 │ train │     0 │
│ 1 │ map_net       │ ConditionalNormalizingFlow │  534 K │ train │     0 │
│ 2 │ potential_net │ Sequential                 │  132 K │ train │     0 │
└───┴───────────────┴────────────────────────────┴────────┴───────┴───────┘
Trainable params: 667 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 667 K                                                                                                
Total estimated model params size (MB): 2                                                                          
Modules in train mode: 56                                                                                          
Modules in eval mode: 0                                                                                            
Total FLOPs: 0                                                                                                     

Large-Scale Optimal Transport

# ===========
# your code
BATCH_SIZE =  # any size that fits to gpu
ITERATIONS =  #
HIDDEN_DIM = 
LR = 
# ===========

datamodule = BenchmarkDataModule(
    batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    input_sample=input_sampler.sample,
    target_sample=target_sampler.sample,
    testset_x=testset_x,
    testset_y=testset_y
)

# ===========
# your code
# init here `potential_start_net` and `potential_end_net`, simple MLPs with HIDDEN_DIM
# NOTE: output is scalar

# ===========

# ===========
# your code
# init here `map_net`, a simple MLP with HIDDEN_DIM
# NOTE: output is DIM

# ===========

potentials_opt = partial(torch.optim.Adam, lr=LR) 
map_opt = partial(torch.optim.Adam, lr=LR) 

module = LSOT(
    EPS, DIM, ITERATIONS // 2, 1000, conditional_sampler,
    potential_start_net, potential_end_net, map_net,
    potentials_opt, map_opt
)

# you can change here any variables
plotter_callback = BenchmarkPlotterCallback(
    DIM, 1536, 4, 3,
    benchmark_sample=conditional_sampler.sample
)
comet_logger = CometLogger(
    project="cot", name=f"lsot_dim_{DIM}_eps_{EPS}"
)
trainer = Trainer(
    accelerator="gpu", max_epochs=1, 
    val_check_interval=1000,
    callbacks=[plotter_callback],
    limit_train_batches=ITERATIONS,
    enable_checkpointing=False, logger=comet_logger
)
trainer.fit(module, datamodule=datamodule)
┏━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name                 Type        Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ cond_bw_uvp         │ CondBWUVP  │      0 │ train │     0 │
│ 1 │ potential_start_net │ Sequential │  132 K │ train │     0 │
│ 2 │ potential_end_net   │ Sequential │  132 K │ train │     0 │
│ 3 │ map_net             │ Sequential │  132 K │ train │     0 │
└───┴─────────────────────┴────────────┴────────┴───────┴───────┘
Trainable params: 398 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 398 K                                                                                                
Total estimated model params size (MB): 1                                                                          
Modules in train mode: 25                                                                                          
Modules in eval mode: 0                                                                                            
Total FLOPs: 0                                                                                                     

Back to top