Spaces:
Paused
Paused
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import glob | |
| import os | |
| import h5py | |
| from torch.utils import data | |
| class Dedalus2DDataset(data.Dataset): | |
| "Dataset for MHD 2D Dataset" | |
| def __init__( | |
| self, | |
| data_path, | |
| output_names="output-", | |
| field_names=["magnetic field", "velocity"], | |
| num_train=None, | |
| num_test=None, | |
| num=None, | |
| use_train=True, | |
| ): | |
| self.data_path = data_path | |
| output_names = "output-" + "?"*len(str(len(os.listdir(data_path)))) | |
| self.output_names = output_names | |
| raw_path = os.path.join(data_path, output_names, "*.h5") | |
| files_raw = sorted(glob.glob(raw_path)) | |
| self.files_raw = files_raw | |
| self.num_files_raw = num_files_raw = len(files_raw) | |
| self.field_names = field_names | |
| self.use_train = use_train | |
| # Handle num parameter: -1 means use full dataset, otherwise limit to specified number | |
| if num is not None and num > 0: | |
| num_files_raw = min(num, num_files_raw) | |
| files_raw = files_raw[:num_files_raw] | |
| self.files_raw = files_raw | |
| self.num_files_raw = num_files_raw | |
| # Handle percentage-based splits | |
| if num_train is not None and num_train <= 1.0: | |
| # num_train is a percentage | |
| num_train = int(num_train * num_files_raw) | |
| elif num_train is None or num_train > num_files_raw: | |
| num_train = num_files_raw | |
| if num_test is not None and num_test <= 1.0: | |
| # num_test is a percentage | |
| num_test = int(num_test * num_files_raw) | |
| elif num_test is None or num_test > (num_files_raw - num_train): | |
| num_test = num_files_raw - num_train | |
| self.num_train = num_train | |
| self.train_files = self.files_raw[:num_train] | |
| self.num_test = num_test | |
| self.test_end = test_end = num_train + num_test | |
| self.test_files = self.files_raw[num_train:test_end] | |
| if (self.use_train) or (self.test_files is None): | |
| files = self.train_files | |
| else: | |
| files = self.test_files | |
| self.files = files | |
| self.num_files = num_files = len(files) | |
| def __len__(self): | |
| length = len(self.files) | |
| return length | |
| def __getitem__(self, index): | |
| "Gets item for dataloader" | |
| file = self.files[index] | |
| field_names = self.field_names | |
| fields = {} | |
| coords = [] | |
| with h5py.File(file, mode="r") as h5file: | |
| data_file = h5file["tasks"] | |
| keys = list(data_file.keys()) | |
| if field_names is None: | |
| field_names = keys | |
| for field_name in field_names: | |
| if field_name in data_file: | |
| field = data_file[field_name][:] | |
| fields[field_name] = field | |
| else: | |
| print(f"field name {field_name} not found") | |
| dataset = fields | |
| return dataset | |
| def get_coords(self, index): | |
| "Gets coordinates of t, x, y for dataloader" | |
| file = self.files[index] | |
| with h5py.File(file, mode="r") as h5file: | |
| data_file = h5file["tasks"] | |
| keys = list(data_file.keys()) | |
| dims = data_file[keys[0]].dims | |
| ndims = len(dims) | |
| t = dims[0]["sim_time"][:] | |
| x = dims[ndims - 2][0][:] | |
| y = dims[ndims - 1][0][:] | |
| return t, x, y | |