DragMesh-2 Evaluation Checkpoints

This repository contains the PyTorch policy checkpoints used for the DragMesh-2 main-table evaluation. The release includes seven policy variants evaluated on seven object-part tasks, for a total of 49 checkpoints.

The corresponding hand-object interaction trajectories are available in the DragMesh-2 dataset. The referenced objects and actionable parts originate from GAPartNet; GAPartNet assets are not included in this model repository.

Repository layout

checkpoints/
  <variant>/
    object_<object_id>/
      handle_<part_id>/
        policy.pth
model_manifest.jsonl

Each path explicitly identifies the evaluated experiment, source object, and manipulated part. model_manifest.jsonl records the normalized checkpoint path, original relative path, object category, file size, and SHA-256 digest.

Policy variants

Name Variant
state State-only PPO baseline
history Flat-history PPO baseline
gla GLA without the auxiliary objective
pica PICA without the GLA auxiliary objective (v2c)
dragmesh2 DragMesh-2 PICA policy
gru GRU PPO baseline
transformer Transformer PPO baseline

Evaluation tasks

Category Object ID Part ID
Dishwasher 12583 handle_1
Microwave 7310 handle_1
StorageFurniture 45261 handle_7
StorageFurniture 45661 handle_3
StorageFurniture 45936 handle_1
StorageFurniture 46440 handle_5
StorageFurniture 48513 handle_2

Checkpoint format

Each policy.pth is a PyTorch training checkpoint with these top-level fields:

  • model: policy state dictionary;
  • optimizer: optimizer state;
  • running_mean_std: observation normalization state;
  • reward_mean_std: reward normalization state;
  • epoch, frame, and last_mean_rewards: training metadata;
  • env_state: serialized environment state when available.

Load checkpoints only in a trusted environment. With PyTorch 2.6 or later, the checkpoint can be read in weights-only mode by allowlisting the NumPy scalar types stored in its training metadata:

from pathlib import Path

import numpy as np
import torch

checkpoint_path = Path(
    "checkpoints/dragmesh2/"
    "object_45661/handle_3/policy.pth"
)

safe_globals = [
    np.core.multiarray.scalar,
    np.dtype,
    np.dtypes.Float32DType,
]

with torch.serialization.safe_globals(safe_globals):
    checkpoint = torch.load(
        checkpoint_path,
        map_location="cpu",
        weights_only=True,
    )

policy_state_dict = checkpoint["model"]

The policy architecture and observation configuration must match the corresponding DragMesh-2 experiment when restoring a checkpoint.

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading