import os import hydra from omegaconf import OmegaConf import torch from omegaconf import DictConfig from physicsnemo.distributed import DistributedManager from physicsnemo.launch.logging import LaunchLogger, PythonLogger from physicsnemo.launch.utils import load_checkpoint, save_checkpoint from physicsnemo.sym.hydra import to_absolute_path from torch.nn.parallel import DistributedDataParallel from torch.optim import AdamW from dataloaders import Dedalus2DDataset, MHDDataloaderVecPot from losses import LossMHDVecPot_PhysicsNeMo from tfno import TFNO from utils.plot_utils import plot_predictions_mhd, plot_predictions_mhd_plotly dtype = torch.float torch.set_default_dtype(dtype) @hydra.main( version_base="1.3", config_path="config", config_name="train_mhd_vec_pot_tfno.yaml" ) def main(cfg: DictConfig) -> None: DistributedManager.initialize() # Only call this once in the entire script! dist = DistributedManager() # call if required elsewhere cfg = OmegaConf.to_container(cfg, resolve=True) # initialize monitoring log = PythonLogger(name="mhd_pino") log.file_logging() log_params = cfg["log_params"] # Load config file parameters model_params = cfg["model_params"] dataset_params = cfg["dataset_params"] train_loader_params = cfg["train_loader_params"] val_loader_params = cfg["val_loader_params"] loss_params = cfg["loss_params"] optimizer_params = cfg["optimizer_params"] train_params = cfg["train_params"] load_ckpt = cfg["load_ckpt"] output_dir = cfg["output_dir"] output_dir = to_absolute_path(output_dir) os.makedirs(output_dir, exist_ok=True) data_dir = dataset_params["data_dir"] ckpt_path = train_params["ckpt_path"] # Construct dataloaders dataset_train = Dedalus2DDataset( dataset_params["data_dir"], output_names=dataset_params["output_names"], field_names=dataset_params["field_names"], num_train=dataset_params["num_train"], num_test=dataset_params["num_test"], num=dataset_params["num"], use_train=True, ) dataset_val = Dedalus2DDataset( data_dir, output_names=dataset_params["output_names"], field_names=dataset_params["field_names"], num_train=dataset_params["num_train"], num_test=dataset_params["num_test"], num=dataset_params["num"], use_train=False, ) mhd_dataloader_train = MHDDataloaderVecPot( dataset_train, sub_x=dataset_params["sub_x"], sub_t=dataset_params["sub_t"], ind_x=dataset_params["ind_x"], ind_t=dataset_params["ind_t"], ) mhd_dataloader_val = MHDDataloaderVecPot( dataset_val, sub_x=dataset_params["sub_x"], sub_t=dataset_params["sub_t"], ind_x=dataset_params["ind_x"], ind_t=dataset_params["ind_t"], ) dataloader_train, sampler_train = mhd_dataloader_train.create_dataloader( batch_size=train_loader_params["batch_size"], shuffle=train_loader_params["shuffle"], num_workers=train_loader_params["num_workers"], pin_memory=train_loader_params["pin_memory"], distributed=dist.distributed, ) dataloader_val, sampler_val = mhd_dataloader_val.create_dataloader( batch_size=val_loader_params["batch_size"], shuffle=val_loader_params["shuffle"], num_workers=val_loader_params["num_workers"], pin_memory=val_loader_params["pin_memory"], distributed=dist.distributed, ) # define FNO model model = TFNO( in_channels=model_params["in_dim"], out_channels=model_params["out_dim"], decoder_layers=model_params["decoder_layers"], decoder_layer_size=model_params["fc_dim"], dimension=model_params["dimension"], latent_channels=model_params["layers"], num_fno_layers=model_params["num_fno_layers"], num_fno_modes=model_params["modes"], padding=[model_params["pad_z"], model_params["pad_y"], model_params["pad_x"]], rank=model_params["rank"], factorization=model_params["factorization"], fixed_rank_modes=model_params["fixed_rank_modes"], decomposition_kwargs=model_params["decomposition_kwargs"], ).to(dist.device) # Set up DistributedDataParallel if using more than a single process. # The `distributed` property of DistributedManager can be used to # check this. if dist.distributed: ddps = torch.cuda.Stream() with torch.cuda.stream(ddps): model = DistributedDataParallel( model, device_ids=[dist.local_rank], # Set the device_id to be # the local rank of this process on # this node output_device=dist.device, broadcast_buffers=dist.broadcast_buffers, find_unused_parameters=dist.find_unused_parameters, ) torch.cuda.current_stream().wait_stream(ddps) # Construct optimizer and scheduler optimizer = AdamW( model.parameters(), betas=optimizer_params["betas"], lr=optimizer_params["lr"], weight_decay=0.1, ) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=optimizer_params["milestones"], gamma=optimizer_params["gamma"], ) # Construct Loss class mhd_loss = LossMHDVecPot_PhysicsNeMo(**loss_params) # Load model from checkpoint (if exists) loaded_epoch = 0 if load_ckpt: loaded_epoch = load_checkpoint( ckpt_path, model, optimizer, scheduler, device=dist.device ) # Training Loop epochs = train_params["epochs"] ckpt_freq = train_params["ckpt_freq"] names = dataset_params["fields"] input_norm = torch.tensor(model_params["input_norm"]).to(dist.device) output_norm = torch.tensor(model_params["output_norm"]).to(dist.device) for epoch in range(max(1, loaded_epoch + 1), epochs + 1): with LaunchLogger( "train", epoch=epoch, num_mini_batch=len(dataloader_train), epoch_alert_freq=1, ) as log: if dist.distributed: sampler_train.set_epoch(epoch) # Train Loop model.train() for i, (inputs, outputs) in enumerate(dataloader_train): inputs = inputs.type(torch.FloatTensor).to(dist.device) outputs = outputs.type(torch.FloatTensor).to(dist.device) # Zero Gradients optimizer.zero_grad() # Compute Predictions pred = ( model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute( 0, 2, 3, 4, 1 ) * output_norm ) # Compute Loss loss, loss_dict = mhd_loss(pred, outputs, inputs, return_loss_dict=True) # Compute Gradients for Back Propagation loss.backward() # Update Weights optimizer.step() log.log_minibatch(loss_dict) log.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]}) scheduler.step() with LaunchLogger("valid", epoch=epoch) as log: # Val loop model.eval() plot_count = 0 with torch.no_grad(): for i, (inputs, outputs) in enumerate(dataloader_val): inputs = inputs.type(dtype).to(dist.device) outputs = outputs.type(dtype).to(dist.device) # Compute Predictions pred = ( model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute( 0, 2, 3, 4, 1 ) * output_norm ) # Compute Loss loss, loss_dict = mhd_loss( pred, outputs, inputs, return_loss_dict=True ) log.log_minibatch(loss_dict) # Get prediction plots to log # Do for number of batches specified in the config file if (i < log_params["log_num_plots"]) and ( epoch % log_params["log_plot_freq"] == 0 ): # Add all predictions in batch for j, _ in enumerate(pred): # Make plots for each field for index, name in enumerate(names): # Generate figure _ = plot_predictions_mhd_plotly( pred[j].cpu(), outputs[j].cpu(), inputs[j].cpu(), index=index, name=name, ) plot_count += 1 # Get prediction plots and save images locally if (i < 2) and (epoch % log_params["log_plot_freq"] == 0): # Add all predictions in batch for j, _ in enumerate(pred): # Generate figure plot_predictions_mhd( pred[j].cpu(), outputs[j].cpu(), inputs[j].cpu(), names=names, save_path=os.path.join( output_dir, "MHD_physicsnemo" + "_" + str(dist.rank), ), save_suffix=i, ) if epoch % ckpt_freq == 0 and dist.rank == 0: save_checkpoint(ckpt_path, model, optimizer, scheduler, epoch=epoch) if __name__ == "__main__": main()