Spaces:
Paused
Paused
| import os | |
| import logging | |
| from typing import Any, Optional, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from omegaconf import DictConfig, OmegaConf | |
| from safetensors.torch import load_model, save_model | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| from matplotlib.ticker import MaxNLocator | |
| BOUNDING_BOX_MAX_SIZE = 1.925 | |
| def normalize_bbox(bounding_box_xyz: Tuple[float]): | |
| #import ipdb; ipdb.set_trace() | |
| max_l = max(bounding_box_xyz) | |
| return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz] | |
| def normalize_bboxs(bounding_box_xyz, max_xyz): | |
| #max_l = max(bounding_box_xyz) | |
| normalized = BOUNDING_BOX_MAX_SIZE * bounding_box_xyz / torch.tensor(max_xyz, device=bounding_box_xyz.device) | |
| return normalized | |
| def load_config(cfg_path: str) -> Any: | |
| """ | |
| Load and resolve a configuration file. | |
| Args: | |
| cfg_path (str): The path to the configuration file. | |
| Returns: | |
| Any: The loaded and resolved configuration object. | |
| Raises: | |
| AssertionError: If the loaded configuration is not an instance of DictConfig. | |
| """ | |
| cfg = OmegaConf.load(cfg_path) | |
| OmegaConf.resolve(cfg) | |
| assert isinstance(cfg, DictConfig) | |
| return cfg | |
| def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any: | |
| """ | |
| Parses a configuration dictionary into a structured configuration object. | |
| Args: | |
| cfg_type (Any): The type of the structured configuration object. | |
| cfg (DictConfig): The configuration dictionary to be parsed. | |
| Returns: | |
| Any: The structured configuration object created from the dictionary. | |
| """ | |
| scfg = OmegaConf.structured(cfg_type(**cfg)) | |
| return scfg | |
| def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None: | |
| """ | |
| Load a safetensors checkpoint into a PyTorch model. | |
| The model is updated in place. | |
| Args: | |
| model: PyTorch model to load weights into | |
| ckpt_path: Path to the safetensors checkpoint file | |
| Returns: | |
| None | |
| """ | |
| assert ckpt_path.endswith( | |
| ".safetensors" | |
| ), f"Checkpoint path '{ckpt_path}' is not a safetensors file" | |
| load_model(model, ckpt_path) | |
| def save_model_weights(model: torch.nn.Module, save_path: str) -> None: | |
| """ | |
| Save a PyTorch model as safetensors format. | |
| Args: | |
| model: PyTorch model to save | |
| save_path: Output path (must end with .safetensors) | |
| """ | |
| assert save_path.endswith(".safetensors"), "Path must end with .safetensors" | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| save_model(model, save_path) | |
| assert os.path.exists(save_path), f"Failed to save to {save_path}" | |
| def select_device() -> Any: | |
| """ | |
| Selects the appropriate PyTorch device for tensor allocation. | |
| Returns: | |
| Any: The `torch.device` object. | |
| """ | |
| return torch.device( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" | |
| if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| def mask_cross_entropy(p_st, p_ed, p_max, logits, target, shift): | |
| p_range = torch.arange(p_st, p_ed, device=logits.device) | |
| p_range_expanded = p_range.unsqueeze(0).repeat(p_max.shape[0], 1) | |
| valid_p_mask = p_range_expanded <= p_max.unsqueeze(1)+p_st | |
| valid_p_mask = valid_p_mask.unsqueeze(1).expand(-1, logits.shape[1], -1) | |
| logits_masked = logits.clone() | |
| logits_masked[:,:,p_st:p_ed][~valid_p_mask] = float('-inf') | |
| p_loss = F.cross_entropy( | |
| logits_masked[:, :-1, p_st:p_ed].permute(0, 2, 1), | |
| target[:, shift:, p_st:p_ed].argmax(-1), | |
| ) | |
| return p_loss | |
| def positional_encoding(x, num_freqs): | |
| freqs = 2.0 ** torch.arange(num_freqs, device=x.device) # [num_freqs] | |
| angles = x.unsqueeze(-1) * freqs # [..., num_freqs] | |
| sin_cos = torch.cat([angles.sin(), angles.cos()], dim=-1) # [..., 2*num_freqs] | |
| return sin_cos.flatten(-2) | |
| def visualize_token_probabilities( | |
| probs, | |
| cut_idx, | |
| sample_idx=0, | |
| tokens_per_page=10, # 每页显示的token数量 | |
| figsize=(12, 20), # 单页图表大小 | |
| save_dir=None # 保存图片的目录(None则直接显示) | |
| ): | |
| """ | |
| 分页展示所有有效token的概率分布(每页10个,一行一个token) | |
| 参数: | |
| - probs: 概率张量,形状为 (batch_size, seq_len, num_classes) | |
| - cut_idx: 有效区域的截止索引 | |
| - sample_idx: 要可视化的batch样本索引 | |
| - tokens_per_page: 每页显示的token数量 | |
| - figsize: 单页图表大小 | |
| - save_dir: 保存图片的目录(若为None则直接显示) | |
| """ | |
| # 转换为numpy数组 | |
| if isinstance(probs, torch.Tensor): | |
| probs = probs.cpu().detach().numpy() | |
| # 获取单个样本的概率分布 | |
| sample_probs = probs[sample_idx] # (seq_len, num_classes) | |
| seq_len, num_classes = sample_probs.shape | |
| # 处理cut_idx,确定有效区域并提取有效token | |
| if isinstance(cut_idx, torch.Tensor): | |
| cut_idx = cut_idx.cpu().detach().numpy() | |
| valid_length = min(int(cut_idx[sample_idx] if not np.isscalar(cut_idx) else cut_idx), seq_len) | |
| valid_probs = sample_probs[:valid_length, :] # 只取有效区域内的token | |
| num_valid_tokens = valid_probs.shape[0] | |
| if num_valid_tokens == 0: | |
| print(f"警告:没有有效token可显示(有效区域长度:{valid_length})") | |
| return None | |
| # 创建保存目录(如果需要) | |
| if save_dir is not None and not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| # 计算总页数 | |
| total_pages = (num_valid_tokens + tokens_per_page - 1) // tokens_per_page | |
| print(f"共{num_valid_tokens}个有效token,分为{total_pages}页展示") | |
| # 分页生成图表 | |
| figures = [] | |
| for page in range(total_pages): | |
| # 计算当前页的token范围 | |
| start = page * tokens_per_page | |
| end = min(start + tokens_per_page, num_valid_tokens) | |
| page_tokens = end - start | |
| # 创建当前页的画布 | |
| fig, axes = plt.subplots(page_tokens, 1, figsize=(figsize[0], 2*page_tokens)) | |
| fig.suptitle( | |
| f'Token Probability Distributions (Sample {sample_idx}) - Page {page+1}/{total_pages}', | |
| fontsize=16, | |
| y=1.02 | |
| ) | |
| # 为当前页的每个token绘制分布 | |
| for i in range(page_tokens): | |
| token_idx = start + i | |
| token_probs = valid_probs[i] # 当前页内的相对索引 | |
| ax = axes[i] if page_tokens > 1 else axes # 处理单token情况 | |
| # 绘制条形图 | |
| class_indices = np.arange(num_classes) | |
| bars = ax.bar(class_indices, token_probs, width=0.8, color='skyblue', edgecolor='black') | |
| # 突出显示最高概率的类别 | |
| max_prob_idx = np.argmax(token_probs) | |
| max_prob_value = token_probs[max_prob_idx] | |
| bars[max_prob_idx].set_color('orange') | |
| # 标注概率>5%的类别 | |
| for j, (bar, prob) in enumerate(zip(bars, token_probs)): | |
| height = bar.get_height() | |
| if prob > 0.05: | |
| ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
| f'{prob:.2f}', ha='center', va='bottom', fontsize=9) | |
| # 设置子图标题和坐标轴 | |
| ax.set_title( | |
| f'Token {token_idx} (Max: Class {max_prob_idx} = {max_prob_value:.2f})', | |
| fontsize=11 | |
| ) | |
| ax.set_xlabel('Class Index') | |
| ax.set_ylabel('Probability') | |
| ax.set_ylim(0, 1.1) | |
| ax.xaxis.set_major_locator(MaxNLocator(integer=True)) | |
| ax.grid(True, alpha=0.3, axis='y') | |
| # 除最后一个子图外隐藏x轴标签 | |
| if i != page_tokens - 1: | |
| ax.set_xlabel('') | |
| plt.tight_layout() | |
| figures.append(fig) | |
| # 保存或显示图表 | |
| if save_dir is not None: | |
| save_path = os.path.join(save_dir, f'token_probs_page_{page+1}.png') | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f"已保存第{page+1}页至: {save_path}") | |
| else: | |
| plt.show() | |
| plt.close(fig) # 关闭当前页图表,释放内存 | |
| return figures | |
| def visualize_max_prob_distribution( | |
| probs, | |
| cut_idx=None, # 不再需要,因为已提前过滤 | |
| sample_idx=0, | |
| bins=20, | |
| figsize=(12, 6) | |
| ): | |
| # 转换为numpy数组 | |
| if isinstance(probs, torch.Tensor): | |
| probs = probs.cpu().detach().numpy() | |
| # 获取单个样本的概率分布并计算最大概率 | |
| sample_probs = probs[sample_idx] | |
| max_probs_per_token = np.max(sample_probs, axis=1) # 所有token都是已过滤的有效token | |
| # 创建画布 | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # 绘制直方图 | |
| n, bins, patches = ax.hist( | |
| max_probs_per_token, | |
| bins=bins, | |
| range=(0, 1), | |
| edgecolor='black', | |
| alpha=0.7, | |
| color='skyblue' | |
| ) | |
| # 标注数量 | |
| for count, patch in zip(n, patches): | |
| height = patch.get_height() | |
| if height > 0: | |
| ax.text( | |
| patch.get_x() + patch.get_width()/2., | |
| height + 0.5, | |
| f'{int(count)}', | |
| ha='center', | |
| va='bottom', | |
| fontsize=9 | |
| ) | |
| # 统计指标 | |
| mean_prob = np.mean(max_probs_per_token) | |
| median_prob = np.median(max_probs_per_token) | |
| max_count = int(np.max(n)) if len(n) > 0 else 0 | |
| # 设置标题和坐标轴 | |
| ax.set_title( | |
| f'Distribution of Maximum Probabilities (All Valid Tokens from 5 Iterations)\n' | |
| f'Total tokens: {len(max_probs_per_token)} | Mean: {mean_prob:.2f} | Median: {median_prob:.2f}', | |
| fontsize=14 | |
| ) | |
| ax.set_xlabel('Maximum Probability Value (0-1)') | |
| ax.set_ylabel('Number of Tokens (Frequency)') | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, max_count + 2) | |
| ax.xaxis.set_major_locator(MaxNLocator(nbins=11)) | |
| ax.yaxis.set_major_locator(MaxNLocator(integer=True)) | |
| ax.grid(True, alpha=0.3, axis='y') | |
| plt.tight_layout() | |
| return fig | |
| def top_k_prob_mask(probs, cut_idx, top_percent=0.15, visualize=False): | |
| max_probs = probs.permute(0, 2, 1).max(dim=1).values # (batch_size, seq_len) | |
| batch_size, seq_len = max_probs.shape | |
| # 1. 生成基础mask:cut_idx前面为True,后面为False | |
| if isinstance(cut_idx, (int, float)): | |
| cut_idx = torch.tensor([cut_idx] * batch_size, device=max_probs.device) | |
| base_mask = (torch.arange(seq_len, device=max_probs.device)[None, :] < cut_idx[:, None]) | |
| valid_count = base_mask.sum().item() | |
| # 处理无有效位置的情况 | |
| if valid_count == 0: | |
| empty_mask = torch.zeros_like(max_probs, dtype=torch.bool) | |
| return empty_mask, empty_mask | |
| # 2. 计算原始目标mask(cut内前N%高概率True) | |
| valid_probs = max_probs[base_mask] | |
| total_valid = valid_probs.numel() | |
| k = max(min(int(total_valid * top_percent), total_valid), 1) | |
| _, top_valid_indices = torch.topk(valid_probs, k) | |
| # 原始mask:cut内top k为True,其余全False | |
| valid_area_original = torch.zeros(total_valid, dtype=torch.bool, device=max_probs.device) | |
| valid_area_original[top_valid_indices] = True | |
| original_mask = torch.zeros_like(max_probs, dtype=torch.bool) | |
| original_mask[base_mask] = valid_area_original | |
| # 3. 计算反向mask(cut内非top k为True,cut外全False) | |
| valid_area_reverse = ~valid_area_original # 与原始有效区域完全相反 | |
| reverse_mask = torch.zeros_like(max_probs, dtype=torch.bool) | |
| reverse_mask[base_mask] = valid_area_reverse # cut外保持False | |
| return original_mask, reverse_mask # 返回两个mask |