# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import os import h5py from torch.utils import data class Dedalus2DDataset(data.Dataset): "Dataset for MHD 2D Dataset" def __init__( self, data_path, output_names="output-", field_names=["magnetic field", "velocity"], num_train=None, num_test=None, num=None, use_train=True, ): self.data_path = data_path output_names = "output-" + "?"*len(str(len(os.listdir(data_path)))) self.output_names = output_names raw_path = os.path.join(data_path, output_names, "*.h5") files_raw = sorted(glob.glob(raw_path)) self.files_raw = files_raw self.num_files_raw = num_files_raw = len(files_raw) self.field_names = field_names self.use_train = use_train # Handle num parameter: -1 means use full dataset, otherwise limit to specified number if num is not None and num > 0: num_files_raw = min(num, num_files_raw) files_raw = files_raw[:num_files_raw] self.files_raw = files_raw self.num_files_raw = num_files_raw # Handle percentage-based splits if num_train is not None and num_train <= 1.0: # num_train is a percentage num_train = int(num_train * num_files_raw) elif num_train is None or num_train > num_files_raw: num_train = num_files_raw if num_test is not None and num_test <= 1.0: # num_test is a percentage num_test = int(num_test * num_files_raw) elif num_test is None or num_test > (num_files_raw - num_train): num_test = num_files_raw - num_train self.num_train = num_train self.train_files = self.files_raw[:num_train] self.num_test = num_test self.test_end = test_end = num_train + num_test self.test_files = self.files_raw[num_train:test_end] if (self.use_train) or (self.test_files is None): files = self.train_files else: files = self.test_files self.files = files self.num_files = num_files = len(files) def __len__(self): length = len(self.files) return length def __getitem__(self, index): "Gets item for dataloader" file = self.files[index] field_names = self.field_names fields = {} coords = [] with h5py.File(file, mode="r") as h5file: data_file = h5file["tasks"] keys = list(data_file.keys()) if field_names is None: field_names = keys for field_name in field_names: if field_name in data_file: field = data_file[field_name][:] fields[field_name] = field else: print(f"field name {field_name} not found") dataset = fields return dataset def get_coords(self, index): "Gets coordinates of t, x, y for dataloader" file = self.files[index] with h5py.File(file, mode="r") as h5file: data_file = h5file["tasks"] keys = list(data_file.keys()) dims = data_file[keys[0]].dims ndims = len(dims) t = dims[0]["sim_time"][:] x = dims[ndims - 2][0][:] y = dims[ndims - 1][0][:] return t, x, y