# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import math import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm import traceback import matplotlib.pyplot as plt import matplotlib from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable import imageio import plotly import plotly.express as px from plotly.subplots import make_subplots import plotly.graph_objects as go from IPython.display import HTML, display def plot_spectra_mhd( k, pred_spectra_kin, true_spectra_kin, pred_spectra_mag, true_spectra_mag, index_t=-1, name="Re100", save_path=None, save_suffix=None, font_size=None, sci_limits=None, style_kin_pred="b-", style_kin_true="k-", style_mag_pred="b--", style_mag_true="k--", xmin=0, xmax=200, ymin=1e-10, ymax=None, ): "Plots spectra of predicted and true outputs" if font_size is not None: plt.rcParams.update({"font.size": font_size}) if sci_limits is not None: plt.rcParams.update({"axes.formatter.limits": sci_limits}) E_kin_pred = pred_spectra_kin[index_t] E_mag_pred = pred_spectra_mag[index_t] E_kin_true = true_spectra_kin[index_t] E_mag_true = true_spectra_mag[index_t] fig = plt.figure(figsize=(6, 5)) plt.loglog(k, E_kin_pred, style_kin_pred, label="$E_{kin}$ Pred") plt.loglog(k, E_kin_true, style_kin_true, label="$E_{kin}$ True") plt.loglog(k, E_mag_pred, style_mag_pred, label="$E_{mag}$ Pred") plt.loglog(k, E_mag_true, style_mag_true, label="$E_{mag}$ True") plt.xlabel("k") plt.ylabel("E(k)") plt.axis([xmin, xmax, ymin, ymax]) plt.title(f"Spectra ${name}$") plt.legend(loc="upper right") if save_path is not None: if save_suffix is not None: figure_path = f"{save_path}_spectra_{save_suffix}.png" else: figure_path = f"{save_path}_spectra.png" plt.savefig(figure_path, bbox_inches="tight") return fig def plot_predictions_mhd( pred, true, inputs, index_t=-1, names=[], save_path=None, save_suffix=None, font_size=None, sci_limits=None, shading="auto", cmap="jet", ): "Plots images of predictions and absolute error" if font_size is not None: plt.rcParams.update({"font.size": font_size}) if sci_limits is not None: plt.rcParams.update({"axes.formatter.limits": sci_limits}) # Plot fig = plt.figure(figsize=(24, 5 * len(names))) # Make plots for each field for index, name in enumerate(names): Nt, Nx, Ny, Nfields = pred.shape u_pred = pred[index_t, ..., index] u_true = true[index_t, ..., index] u_err = u_pred - u_true initial_data = inputs[0, ..., 3:] u0 = initial_data[..., index] x = inputs[0, :, 0, 1] y = inputs[0, 0, :, 2] X, Y = torch.meshgrid(x, y, indexing="ij") t = inputs[index_t, 0, 0, 0] plt.subplot(len(names), 4, index * 4 + 1) plt.pcolormesh(X, Y, u0, cmap=cmap, shading=shading) plt.colorbar() plt.title(f"Intial Condition ${name}_0(x,y)$") plt.tight_layout() plt.axis("square") plt.axis("off") plt.subplot(len(names), 4, index * 4 + 2) plt.pcolormesh(X, Y, u_true, cmap=cmap, shading=shading) plt.colorbar() plt.title(f"Exact ${name}(x,y,t={t:.2f})$") plt.tight_layout() plt.axis("square") plt.axis("off") plt.subplot(len(names), 4, index * 4 + 3) plt.pcolormesh(X, Y, u_pred, cmap=cmap, shading=shading) plt.colorbar() plt.title(f"Predict ${name}(x,y,t={t:.2f})$") plt.axis("square") plt.tight_layout() plt.axis("off") plt.subplot(len(names), 4, index * 4 + 4) plt.pcolormesh(X, Y, u_pred - u_true, cmap=cmap, shading=shading) plt.colorbar() plt.title(f"Absolute Error ${name}(x,y,t={t:.2f})$") plt.tight_layout() plt.axis("square") plt.axis("off") if save_path is not None: if save_suffix is not None: figure_path = f"{save_path}_{save_suffix}.png" else: figure_path = f"{save_path}.png" plt.savefig(figure_path, bbox_inches="tight") # plt.show() # return fig plt.close() def generate_movie_2D( preds_y, test_y, test_x, key=0, plot_title="", field=0, val_cbar_index=-1, err_cbar_index=-1, val_clim=None, err_clim=None, font_size=None, movie_dir="", movie_name="movie.gif", frame_basename="movie", frame_ext="jpg", cmap="jet", shading="gouraud", remove_frames=True, ): "Generates a movie of the exact, predicted, and absolute error fields" frame_files = [] if movie_dir: os.makedirs(movie_dir, exist_ok=True) if font_size is not None: plt.rcParams.update({"font.size": font_size}) pred = preds_y[key][..., field] true = test_y[key][..., field] inputs = test_x[key] error = pred - true Nt, Nx, Ny = pred.shape t = inputs[:, 0, 0, 0] x = inputs[0, :, 0, 1] y = inputs[0, 0, :, 2] X, Y = torch.meshgrid(x, y, indexing="ij") fig, axs = plt.subplots(1, 3, figsize=(18, 5)) ax1 = axs[0] ax2 = axs[1] ax3 = axs[2] pcm1 = ax1.pcolormesh( X, Y, true[val_cbar_index], cmap=cmap, label="true", shading=shading ) pcm2 = ax2.pcolormesh( X, Y, pred[val_cbar_index], cmap=cmap, label="pred", shading=shading ) pcm3 = ax3.pcolormesh( X, Y, error[err_cbar_index], cmap=cmap, label="error", shading=shading ) if val_clim is None: val_clim = pcm1.get_clim() if err_clim is None: err_clim = pcm3.get_clim() pcm1.set_clim(val_clim) plt.colorbar(pcm1, ax=ax1) ax1.axis("square") ax1.set_axis_off() pcm2.set_clim(val_clim) plt.colorbar(pcm2, ax=ax2) ax2.axis("square") ax2.set_axis_off() pcm3.set_clim(err_clim) plt.colorbar(pcm3, ax=ax3) ax3.axis("square") ax3.set_axis_off() plt.tight_layout() for i in range(Nt): # Exact ax1.clear() pcm1 = ax1.pcolormesh(X, Y, true[i], cmap=cmap, label="true", shading=shading) pcm1.set_clim(val_clim) ax1.set_title(f"Exact {plot_title}: $t={t[i]:.2f}$") ax1.axis("square") ax1.set_axis_off() # Predictions ax2.clear() pcm2 = ax2.pcolormesh(X, Y, pred[i], cmap=cmap, label="pred", shading=shading) pcm2.set_clim(val_clim) ax2.set_title(f"Predict {plot_title}: $t={t[i]:.2f}$") ax2.axis("square") ax2.set_axis_off() # Error ax3.clear() pcm3 = ax3.pcolormesh(X, Y, error[i], cmap=cmap, label="error", shading=shading) pcm3.set_clim(err_clim) ax3.set_title(f"Error {plot_title}: $t={t[i]:.2f}$") ax3.axis("square") ax3.set_axis_off() # plt.tight_layout() fig.canvas.draw() if movie_dir: frame_path = os.path.join(movie_dir, f"{frame_basename}-{i:03}.{frame_ext}") frame_files.append(frame_path) plt.savefig(frame_path, bbox_inches="tight") if movie_dir: movie_path = os.path.join(movie_dir, movie_name) with imageio.get_writer(movie_path, mode="I") as writer: for frame in frame_files: image = imageio.imread(frame) writer.append_data(image) if movie_dir and remove_frames: for frame in frame_files: try: os.remove(frame) except: pass def plot_predictions_mhd_plotly( pred, true, inputs, index=0, index_t=-1, name="u", save_path=None, font_size=None, shading="auto", cmap="jet", ): "Plots images of predictions and absolute error to be saved to wandb" Nt, Nx, Ny, Nfields = pred.shape u_pred = pred[index_t, ..., index] u_true = true[index_t, ..., index] ic = inputs[0, ..., 3:] u_ic = ic[..., index] u_err = u_pred - u_true x = inputs[0, :, 0, 1] y = inputs[0, 0, :, 2] X, Y = torch.meshgrid(x, y, indexing="ij") t = inputs[index_t, 0, 0, 0] zmin = u_true.min().item() zmax = u_true.max().item() labels = {"color": name} # Initial Conditions title_ic = f"{name}0" fig_ic = px.imshow( u_ic, binary_string=False, color_continuous_scale=cmap, labels=labels, title=title_ic, ) fig_ic.update_xaxes(showticklabels=False) fig_ic.update_yaxes(showticklabels=False) # Predictions title_pred = f"Predict {name}: t={t:.2f}" fig_pred = px.imshow( u_pred, binary_string=False, color_continuous_scale=cmap, labels=labels, title=title_pred, ) fig_pred.update_xaxes(showticklabels=False) fig_pred.update_yaxes(showticklabels=False) # Ground Truth title_true = f"Exact {name}: t={t:.2f}" fig_true = px.imshow( u_true, binary_string=False, color_continuous_scale=cmap, labels=labels, title=title_true, ) fig_true.update_xaxes(showticklabels=False) fig_true.update_yaxes(showticklabels=False) # Ground Truth title_err = f"Error {name}: t={t:.2f}" fig_err = px.imshow( u_err, binary_string=False, color_continuous_scale=cmap, labels=labels, title=title_err, ) fig_err.update_xaxes(showticklabels=False) fig_err.update_yaxes(showticklabels=False) return fig_ic, fig_pred, fig_true, fig_err