| | """ |
| | Helpers for distributed training. |
| | """ |
| |
|
| | import datetime |
| | import io |
| | import os |
| | import socket |
| |
|
| | import blobfile as bf |
| | from pdb import set_trace as st |
| | |
| | import torch as th |
| | import torch.distributed as dist |
| |
|
| | |
| | |
| | GPUS_PER_NODE = 8 |
| | SETUP_RETRY_COUNT = 3 |
| |
|
| |
|
| | def get_rank(): |
| | if not dist.is_available(): |
| | return 0 |
| |
|
| | if not dist.is_initialized(): |
| | return 0 |
| |
|
| | return dist.get_rank() |
| |
|
| |
|
| | def synchronize(): |
| | if not dist.is_available(): |
| | return |
| |
|
| | if not dist.is_initialized(): |
| | return |
| |
|
| | world_size = dist.get_world_size() |
| |
|
| | if world_size == 1: |
| | return |
| |
|
| | dist.barrier() |
| |
|
| |
|
| | def get_world_size(): |
| | if not dist.is_available(): |
| | return 1 |
| |
|
| | if not dist.is_initialized(): |
| | return 1 |
| |
|
| | return dist.get_world_size() |
| |
|
| |
|
| | def setup_dist(args): |
| | """ |
| | Setup a distributed process group. |
| | """ |
| | if dist.is_initialized(): |
| | return |
| |
|
| | |
| |
|
| | |
| | |
| | dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=54000)) |
| | print(f"{args.local_rank=} init complete") |
| |
|
| | |
| |
|
| | th.cuda.empty_cache() |
| |
|
| | def cleanup(): |
| | dist.destroy_process_group() |
| |
|
| | def dev(): |
| | """ |
| | Get the device to use for torch.distributed. |
| | """ |
| | if th.cuda.is_available(): |
| |
|
| | if get_world_size() > 1: |
| | return th.device(f"cuda:{get_rank() % GPUS_PER_NODE}") |
| | return th.device(f"cuda") |
| |
|
| | return th.device("cpu") |
| |
|
| |
|
| | |
| | def load_state_dict(path, **kwargs): |
| | """ |
| | Load a PyTorch file without redundant fetches across MPI ranks. |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | ckpt = th.load(path, **kwargs) |
| | |
| | |
| | |
| | |
| | return ckpt |
| |
|
| |
|
| | def sync_params(params): |
| | """ |
| | Synchronize a sequence of Tensors across ranks from rank 0. |
| | """ |
| | |
| | for p in params: |
| | with th.no_grad(): |
| | try: |
| | dist.broadcast(p, 0) |
| | except Exception as e: |
| | print(k, e) |
| | |
| |
|
| |
|
| | def _find_free_port(): |
| | try: |
| | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| | s.bind(("", 0)) |
| | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| | return s.getsockname()[1] |
| | finally: |
| | s.close() |
| |
|
| |
|
| | _num_moments = 3 |
| | _reduce_dtype = th.float32 |
| | _counter_dtype = th.float64 |
| | _rank = 0 |
| | _sync_device = None |
| | _sync_called = False |
| | _counters = dict() |
| | _cumulative = dict() |
| |
|
| | def init_multiprocessing(rank, sync_device): |
| | r"""Initializes `utils.torch_utils.training_stats` for collecting statistics |
| | across multiple processes. |
| | This function must be called after |
| | `torch.distributed.init_process_group()` and before `Collector.update()`. |
| | The call is not necessary if multi-process collection is not needed. |
| | Args: |
| | rank: Rank of the current process. |
| | sync_device: PyTorch device to use for inter-process |
| | communication, or None to disable multi-process |
| | collection. Typically `torch.device('cuda', rank)`. |
| | """ |
| | global _rank, _sync_device |
| | assert not _sync_called |
| | _rank = rank |
| | _sync_device = sync_device |