0xZohar's picture
Add code/cube3d/training/utils.py
d0d37cd verified
import os
import logging
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model, save_model
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.ticker import MaxNLocator
BOUNDING_BOX_MAX_SIZE = 1.925
def normalize_bbox(bounding_box_xyz: Tuple[float]):
#import ipdb; ipdb.set_trace()
max_l = max(bounding_box_xyz)
return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
def normalize_bboxs(bounding_box_xyz, max_xyz):
#max_l = max(bounding_box_xyz)
normalized = BOUNDING_BOX_MAX_SIZE * bounding_box_xyz / torch.tensor(max_xyz, device=bounding_box_xyz.device)
return normalized
def load_config(cfg_path: str) -> Any:
"""
Load and resolve a configuration file.
Args:
cfg_path (str): The path to the configuration file.
Returns:
Any: The loaded and resolved configuration object.
Raises:
AssertionError: If the loaded configuration is not an instance of DictConfig.
"""
cfg = OmegaConf.load(cfg_path)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return cfg
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
"""
Parses a configuration dictionary into a structured configuration object.
Args:
cfg_type (Any): The type of the structured configuration object.
cfg (DictConfig): The configuration dictionary to be parsed.
Returns:
Any: The structured configuration object created from the dictionary.
"""
scfg = OmegaConf.structured(cfg_type(**cfg))
return scfg
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
load_model(model, ckpt_path)
def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
"""
Save a PyTorch model as safetensors format.
Args:
model: PyTorch model to save
save_path: Output path (must end with .safetensors)
"""
assert save_path.endswith(".safetensors"), "Path must end with .safetensors"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
save_model(model, save_path)
assert os.path.exists(save_path), f"Failed to save to {save_path}"
def select_device() -> Any:
"""
Selects the appropriate PyTorch device for tensor allocation.
Returns:
Any: The `torch.device` object.
"""
return torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
def mask_cross_entropy(p_st, p_ed, p_max, logits, target, shift):
p_range = torch.arange(p_st, p_ed, device=logits.device)
p_range_expanded = p_range.unsqueeze(0).repeat(p_max.shape[0], 1)
valid_p_mask = p_range_expanded <= p_max.unsqueeze(1)+p_st
valid_p_mask = valid_p_mask.unsqueeze(1).expand(-1, logits.shape[1], -1)
logits_masked = logits.clone()
logits_masked[:,:,p_st:p_ed][~valid_p_mask] = float('-inf')
p_loss = F.cross_entropy(
logits_masked[:, :-1, p_st:p_ed].permute(0, 2, 1),
target[:, shift:, p_st:p_ed].argmax(-1),
)
return p_loss
def positional_encoding(x, num_freqs):
freqs = 2.0 ** torch.arange(num_freqs, device=x.device) # [num_freqs]
angles = x.unsqueeze(-1) * freqs # [..., num_freqs]
sin_cos = torch.cat([angles.sin(), angles.cos()], dim=-1) # [..., 2*num_freqs]
return sin_cos.flatten(-2)
def visualize_token_probabilities(
probs,
cut_idx,
sample_idx=0,
tokens_per_page=10, # 每页显示的token数量
figsize=(12, 20), # 单页图表大小
save_dir=None # 保存图片的目录(None则直接显示)
):
"""
分页展示所有有效token的概率分布(每页10个,一行一个token)
参数:
- probs: 概率张量,形状为 (batch_size, seq_len, num_classes)
- cut_idx: 有效区域的截止索引
- sample_idx: 要可视化的batch样本索引
- tokens_per_page: 每页显示的token数量
- figsize: 单页图表大小
- save_dir: 保存图片的目录(若为None则直接显示)
"""
# 转换为numpy数组
if isinstance(probs, torch.Tensor):
probs = probs.cpu().detach().numpy()
# 获取单个样本的概率分布
sample_probs = probs[sample_idx] # (seq_len, num_classes)
seq_len, num_classes = sample_probs.shape
# 处理cut_idx,确定有效区域并提取有效token
if isinstance(cut_idx, torch.Tensor):
cut_idx = cut_idx.cpu().detach().numpy()
valid_length = min(int(cut_idx[sample_idx] if not np.isscalar(cut_idx) else cut_idx), seq_len)
valid_probs = sample_probs[:valid_length, :] # 只取有效区域内的token
num_valid_tokens = valid_probs.shape[0]
if num_valid_tokens == 0:
print(f"警告:没有有效token可显示(有效区域长度:{valid_length})")
return None
# 创建保存目录(如果需要)
if save_dir is not None and not os.path.exists(save_dir):
os.makedirs(save_dir)
# 计算总页数
total_pages = (num_valid_tokens + tokens_per_page - 1) // tokens_per_page
print(f"共{num_valid_tokens}个有效token,分为{total_pages}页展示")
# 分页生成图表
figures = []
for page in range(total_pages):
# 计算当前页的token范围
start = page * tokens_per_page
end = min(start + tokens_per_page, num_valid_tokens)
page_tokens = end - start
# 创建当前页的画布
fig, axes = plt.subplots(page_tokens, 1, figsize=(figsize[0], 2*page_tokens))
fig.suptitle(
f'Token Probability Distributions (Sample {sample_idx}) - Page {page+1}/{total_pages}',
fontsize=16,
y=1.02
)
# 为当前页的每个token绘制分布
for i in range(page_tokens):
token_idx = start + i
token_probs = valid_probs[i] # 当前页内的相对索引
ax = axes[i] if page_tokens > 1 else axes # 处理单token情况
# 绘制条形图
class_indices = np.arange(num_classes)
bars = ax.bar(class_indices, token_probs, width=0.8, color='skyblue', edgecolor='black')
# 突出显示最高概率的类别
max_prob_idx = np.argmax(token_probs)
max_prob_value = token_probs[max_prob_idx]
bars[max_prob_idx].set_color('orange')
# 标注概率>5%的类别
for j, (bar, prob) in enumerate(zip(bars, token_probs)):
height = bar.get_height()
if prob > 0.05:
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
f'{prob:.2f}', ha='center', va='bottom', fontsize=9)
# 设置子图标题和坐标轴
ax.set_title(
f'Token {token_idx} (Max: Class {max_prob_idx} = {max_prob_value:.2f})',
fontsize=11
)
ax.set_xlabel('Class Index')
ax.set_ylabel('Probability')
ax.set_ylim(0, 1.1)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.grid(True, alpha=0.3, axis='y')
# 除最后一个子图外隐藏x轴标签
if i != page_tokens - 1:
ax.set_xlabel('')
plt.tight_layout()
figures.append(fig)
# 保存或显示图表
if save_dir is not None:
save_path = os.path.join(save_dir, f'token_probs_page_{page+1}.png')
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"已保存第{page+1}页至: {save_path}")
else:
plt.show()
plt.close(fig) # 关闭当前页图表,释放内存
return figures
def visualize_max_prob_distribution(
probs,
cut_idx=None, # 不再需要,因为已提前过滤
sample_idx=0,
bins=20,
figsize=(12, 6)
):
# 转换为numpy数组
if isinstance(probs, torch.Tensor):
probs = probs.cpu().detach().numpy()
# 获取单个样本的概率分布并计算最大概率
sample_probs = probs[sample_idx]
max_probs_per_token = np.max(sample_probs, axis=1) # 所有token都是已过滤的有效token
# 创建画布
fig, ax = plt.subplots(figsize=figsize)
# 绘制直方图
n, bins, patches = ax.hist(
max_probs_per_token,
bins=bins,
range=(0, 1),
edgecolor='black',
alpha=0.7,
color='skyblue'
)
# 标注数量
for count, patch in zip(n, patches):
height = patch.get_height()
if height > 0:
ax.text(
patch.get_x() + patch.get_width()/2.,
height + 0.5,
f'{int(count)}',
ha='center',
va='bottom',
fontsize=9
)
# 统计指标
mean_prob = np.mean(max_probs_per_token)
median_prob = np.median(max_probs_per_token)
max_count = int(np.max(n)) if len(n) > 0 else 0
# 设置标题和坐标轴
ax.set_title(
f'Distribution of Maximum Probabilities (All Valid Tokens from 5 Iterations)\n'
f'Total tokens: {len(max_probs_per_token)} | Mean: {mean_prob:.2f} | Median: {median_prob:.2f}',
fontsize=14
)
ax.set_xlabel('Maximum Probability Value (0-1)')
ax.set_ylabel('Number of Tokens (Frequency)')
ax.set_xlim(0, 1)
ax.set_ylim(0, max_count + 2)
ax.xaxis.set_major_locator(MaxNLocator(nbins=11))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
return fig
def top_k_prob_mask(probs, cut_idx, top_percent=0.15, visualize=False):
max_probs = probs.permute(0, 2, 1).max(dim=1).values # (batch_size, seq_len)
batch_size, seq_len = max_probs.shape
# 1. 生成基础mask:cut_idx前面为True,后面为False
if isinstance(cut_idx, (int, float)):
cut_idx = torch.tensor([cut_idx] * batch_size, device=max_probs.device)
base_mask = (torch.arange(seq_len, device=max_probs.device)[None, :] < cut_idx[:, None])
valid_count = base_mask.sum().item()
# 处理无有效位置的情况
if valid_count == 0:
empty_mask = torch.zeros_like(max_probs, dtype=torch.bool)
return empty_mask, empty_mask
# 2. 计算原始目标mask(cut内前N%高概率True)
valid_probs = max_probs[base_mask]
total_valid = valid_probs.numel()
k = max(min(int(total_valid * top_percent), total_valid), 1)
_, top_valid_indices = torch.topk(valid_probs, k)
# 原始mask:cut内top k为True,其余全False
valid_area_original = torch.zeros(total_valid, dtype=torch.bool, device=max_probs.device)
valid_area_original[top_valid_indices] = True
original_mask = torch.zeros_like(max_probs, dtype=torch.bool)
original_mask[base_mask] = valid_area_original
# 3. 计算反向mask(cut内非top k为True,cut外全False)
valid_area_reverse = ~valid_area_original # 与原始有效区域完全相反
reverse_mask = torch.zeros_like(max_probs, dtype=torch.bool)
reverse_mask[base_mask] = valid_area_reverse # cut外保持False
return original_mask, reverse_mask # 返回两个mask