| from typing import Dict, List, Union | |
| import numpy as np | |
| import torch | |
| def detach_to_numpy(data: Union[List, Dict, torch.Tensor]) -> Union[List, Dict, torch.Tensor]: | |
| """ | |
| Recursively detach elements in data | |
| """ | |
| if isinstance(data, torch.Tensor): | |
| return data.cpu().detach().numpy() # pytype: disable=attribute-error | |
| elif isinstance(data, np.ndarray): | |
| return data | |
| elif isinstance(data, list): | |
| return [detach_to_numpy(d) for d in data] | |
| elif isinstance(data, dict): | |
| for k in data.keys(): | |
| data[k] = detach_to_numpy(data[k]) | |
| return data | |
| else: | |
| raise ValueError("data should be tensor, numpy array, dict, or list.") | |