Spaces:
Paused
Paused
| from utils.data_utils.tsptw_dataset import TSPTWDataset | |
| from utils.data_utils.pctsp_dataset import PCTSPDataset | |
| from utils.data_utils.pctsptw_dataset import PCTSPTWDataset | |
| from utils.data_utils.cvrp_dataset import CVRPDataset | |
| from utils.data_utils.cvrptw_dataset import CVRPTWDataset | |
| from utils.utils import save_dataset | |
| def generate_dataset(num_samples, args): | |
| if args.problem == "tsptw": | |
| data_generator = TSPTWDataset(coord_dim=args.coord_dim, | |
| num_samples=num_samples, | |
| num_nodes=args.num_nodes, | |
| random_seed=args.random_seed, | |
| solver=args.solver, | |
| classifier=args.classifier, | |
| annotation=args.annotation, | |
| parallel=args.parallel, | |
| num_cpus=args.num_cpus, | |
| distribution=args.distribution) | |
| elif args.problem == "pctsp": | |
| data_generator = PCTSPDataset(coord_dim=args.coord_dim, | |
| num_samples=num_samples, | |
| num_nodes=args.num_nodes, | |
| random_seed=args.random_seed, | |
| solver=args.solver, | |
| classifier=args.classifier, | |
| annotation=args.annotation, | |
| parallel=args.parallel, | |
| num_cpus=args.num_cpus, | |
| penalty_factor=args.penalty_factor) | |
| elif args.problem == "pctsptw": | |
| data_generator = PCTSPTWDataset(coord_dim=args.coord_dim, | |
| num_samples=num_samples, | |
| num_nodes=args.num_nodes, | |
| random_seed=args.random_seed, | |
| solver=args.solver, | |
| classifier=args.classifier, | |
| annotation=args.annotation, | |
| parallel=args.parallel, | |
| num_cpus=args.num_cpus, | |
| penalty_factor=args.penalty_factor) | |
| elif args.problem == "cvrp": | |
| data_generator = CVRPDataset(coord_dim=args.coord_dim, | |
| num_samples=num_samples, | |
| num_nodes=args.num_nodes, | |
| random_seed=args.random_seed, | |
| solver=args.solver, | |
| classifier=args.classifier, | |
| annotation=args.annotation, | |
| parallel=args.parallel, | |
| num_cpus=args.num_cpus) | |
| elif args.problem == "cvrptw": | |
| data_generator = CVRPTWDataset(coord_dim=args.coord_dim, | |
| num_samples=num_samples, | |
| num_nodes=args.num_nodes, | |
| random_seed=args.random_seed, | |
| solver=args.solver, | |
| classifier=args.classifier, | |
| annotation=args.annotation, | |
| parallel=args.parallel, | |
| num_cpus=args.num_cpus) | |
| else: | |
| raise NotImplementedError | |
| return data_generator.generate_dataset() | |
| if __name__ == "__main__": | |
| import argparse | |
| import os | |
| import numpy as np | |
| parser = argparse.ArgumentParser(description='') | |
| # common settings | |
| parser.add_argument("--problem", type=str, default="tsptw") | |
| parser.add_argument("--random_seed", type=int, default=1234) | |
| parser.add_argument("--data_type", type=str, nargs="*", default=["all"], help="data type: 'all' or combo. of ['train', 'valid', 'test'].") | |
| parser.add_argument("--num_samples", type=int, nargs="*", default=[1000, 100, 100]) | |
| parser.add_argument("--num_nodes", type=int, default=20) | |
| parser.add_argument("--coord_dim", type=int, default=2, help="only coord_dim=2 is supported for now.") | |
| parser.add_argument("--solver", type=str, default="ortools", help="solver that outputs a tour") | |
| parser.add_argument("--classifier", type=str, default="ortools", help="classifier for annotation") | |
| parser.add_argument("--annotation", action="store_true") | |
| parser.add_argument("--parallel", action="store_true") | |
| parser.add_argument("--num_cpus", type=int, default=os.cpu_count()) | |
| parser.add_argument("--output_dir", type=str, default="data") | |
| # for TSPTW | |
| parser.add_argument("--distribution", type=str, default="da_silva") | |
| # for PCTSP | |
| parser.add_argument("--penalty_factor", type=float, default=3.) | |
| args = parser.parse_args() | |
| # 3d problems are not supported | |
| assert args.coord_dim == 2, "only coord_dim=2 is supported for now." | |
| # calc num. of total samples (train + valid + test samples) | |
| if args.data_type[0] == "all": | |
| assert len(args.num_samples) == 3, "please specify # samples for each of the three types (train/valid/test) when you set data_type 'all'. (e.g., --num_samples 1280000 1000 1000)" | |
| else: | |
| assert len(args.data_type) == len(args.num_samples), "please match # data_types and # elements in num_samples-arg" | |
| num_samples = np.sum(args.num_samples) | |
| # generate a dataset | |
| dataset = generate_dataset(num_samples, args) | |
| # split the dataset | |
| if args.data_type[0] == "all": | |
| types = ["train", "valid", "eval"] | |
| else: | |
| types = args.data_type | |
| num_sample_list = args.num_samples | |
| num_sample_list.insert(0, 0) | |
| start = 0 | |
| for i, type_name in enumerate(types): | |
| start += num_sample_list[i] | |
| end = start + num_sample_list[i+1] | |
| divided_datset = dataset[start:end] | |
| output_fname = f"{args.output_dir}/{args.problem}/{type_name}_{args.problem}_{args.num_nodes}nodes_{num_sample_list[i+1]}samples_seed{args.random_seed}.pkl" | |
| save_dataset(divided_datset, output_fname) |