# ===========
# your code
BATCH_SIZE = # any size that fits to gpu
ITERATIONS =
NUM_LATENTS =
NUM_MAP_STEPS = # > 50
HIDDEN_DIM =
LR =
# ===========
datamodule = BenchmarkDataModule(
batch_size=BATCH_SIZE,
val_batch_size=256,
input_sample=input_sampler.sample,
target_sample=target_sampler.sample,
testset_x=testset_x,
testset_y=testset_y
)
flows = []
# ===========
# your code
# init here list flows using normflows library
# they have great examples of conditional models in repo
# you can use any normalizing flows, e.g. RealNVP
# ===========
base = nf.distributions.DiagGaussian(DIM, trainable=False)
map_net = nf.ConditionalNormalizingFlow(base, flows)
# ===========
# your code
# init here `potential_net`, a simple MLP with HIDDEN_DIM
# NOTE: output is scalar
# ===========
map_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)
potential_opt = partial(torch.optim.Adam, lr=LR, weight_decay=1e-10)
module = NOT(
EPS, DIM, NUM_LATENTS, NUM_MAP_STEPS, 1000,
conditional_sampler,
map_net, potential_net,
map_opt, potential_opt
)
# you can change here any variables
plotter_callback = BenchmarkPlotterCallback(
DIM, 1536, 4, 3,
benchmark_sample=conditional_sampler.sample
)
comet_logger = CometLogger(
project="cot", name=f"not_dim_{DIM}_eps_{EPS}"
)
trainer = Trainer(
accelerator="gpu", max_epochs=1,
val_check_interval=1000,
callbacks=[plotter_callback],
limit_train_batches=ITERATIONS,
enable_checkpointing=False, logger=comet_logger
)
trainer.fit(module, datamodule=datamodule)