Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Utility functions for pseudolabeling workflow | |
| """ | |
| import json | |
| import argparse | |
| from pathlib import Path | |
| import numpy as np | |
| from typing import Dict, List | |
| import matplotlib.pyplot as plt | |
| from collections import defaultdict | |
| def load_pseudolabels(json_path: str) -> Dict: | |
| """Load pseudolabeled annotations""" | |
| with open(json_path, 'r') as f: | |
| return json.load(f) | |
| def calculate_statistics(data: Dict) -> Dict: | |
| """Calculate statistics from pseudolabeled data""" | |
| stats = { | |
| 'total_images': len(data['images']), | |
| 'total_annotations': len(data['annotations']), | |
| 'original_annotations': 0, | |
| 'pseudolabeled_annotations': 0, | |
| 'verified_annotations': 0, | |
| 'confidence_scores': [], | |
| 'images_with_annotations': 0, | |
| 'images_with_pseudolabels': 0, | |
| 'avg_annotations_per_image': 0 | |
| } | |
| # Count annotations per image | |
| img_ann_count = defaultdict(int) | |
| img_pseudo_count = defaultdict(int) | |
| for ann in data['annotations']: | |
| # Check if pseudolabel | |
| is_pseudo = ann.get('is_pseudolabel', False) | |
| if is_pseudo: | |
| stats['pseudolabeled_annotations'] += 1 | |
| img_pseudo_count[ann['image_id']] += 1 | |
| if 'confidence' in ann: | |
| stats['confidence_scores'].append(ann['confidence']) | |
| else: | |
| stats['original_annotations'] += 1 | |
| if ann.get('verified', False): | |
| stats['verified_annotations'] += 1 | |
| img_ann_count[ann['image_id']] += 1 | |
| stats['images_with_annotations'] = len(img_ann_count) | |
| stats['images_with_pseudolabels'] = len(img_pseudo_count) | |
| if stats['images_with_annotations'] > 0: | |
| stats['avg_annotations_per_image'] = sum(img_ann_count.values()) / stats['images_with_annotations'] | |
| return stats | |
| def sort_images_by_similarity(progress_file: str, annotations_file: str) -> List: | |
| """Sort images by their similarity scores""" | |
| # Load progress file | |
| with open(progress_file, 'r') as f: | |
| progress = json.load(f) | |
| # Load annotations | |
| with open(annotations_file, 'r') as f: | |
| data = json.load(f) | |
| # Calculate similarity scores for each image | |
| image_scores = [] | |
| for img in data['images']: | |
| img_id = img['id'] | |
| # Get annotations for this image | |
| img_anns = [ann for ann in data['annotations'] if ann['image_id'] == img_id] | |
| # Separate by source | |
| original = [ann for ann in img_anns if ann.get('source') == 'original'] | |
| predicted = [ann for ann in img_anns if ann.get('source') == 'predicted'] | |
| # Calculate a simple similarity metric | |
| if original and predicted: | |
| # Ratio of predicted to original | |
| ratio = len(predicted) / len(original) | |
| # Average score of predictions | |
| avg_score = np.mean([ann.get('score', 0) for ann in predicted]) | |
| # Combined metric | |
| similarity = avg_score * min(ratio, 2.0) / 2.0 | |
| else: | |
| similarity = 0.0 | |
| image_scores.append({ | |
| 'image_id': img_id, | |
| 'file_name': img['file_name'], | |
| 'similarity': similarity, | |
| 'n_original': len(original), | |
| 'n_predicted': len(predicted), | |
| 'processed': img_id in progress.get('processed_images', []) | |
| }) | |
| # Sort by similarity | |
| image_scores.sort(key=lambda x: x['similarity'], reverse=True) | |
| return image_scores | |
| def merge_annotations(original_file: str, pseudolabel_file: str, output_file: str, | |
| keep_original: bool = True, min_score: float = 0.3): | |
| """Merge original and pseudolabeled annotations""" | |
| # Load files | |
| with open(original_file, 'r') as f: | |
| original = json.load(f) | |
| with open(pseudolabel_file, 'r') as f: | |
| pseudo = json.load(f) | |
| # Create merged data | |
| merged = { | |
| 'info': original.get('info', pseudo.get('info', {})), | |
| 'licenses': original.get('licenses', pseudo.get('licenses', [])), | |
| 'categories': original.get('categories', pseudo.get('categories', [])), | |
| 'images': [], | |
| 'annotations': [] | |
| } | |
| # Get all unique images | |
| image_ids = set() | |
| image_map = {} | |
| for img in original['images'] + pseudo['images']: | |
| if img['id'] not in image_ids: | |
| image_ids.add(img['id']) | |
| image_map[img['id']] = img | |
| merged['images'].append(img) | |
| # Merge annotations | |
| if keep_original: | |
| # Keep all original annotations | |
| for ann in original['annotations']: | |
| ann['source'] = 'original' | |
| merged['annotations'].append(ann) | |
| # Add pseudolabeled annotations | |
| for ann in pseudo['annotations']: | |
| # Skip if it's an original annotation and we're keeping originals | |
| if keep_original and ann.get('source') == 'original': | |
| continue | |
| # Filter by score | |
| if ann.get('score', 1.0) >= min_score: | |
| merged['annotations'].append(ann) | |
| # Save merged file | |
| with open(output_file, 'w') as f: | |
| json.dump(merged, f, indent=2) | |
| print(f"Merged annotations saved to {output_file}") | |
| print(f"Total images: {len(merged['images'])}") | |
| print(f"Total annotations: {len(merged['annotations'])}") | |
| def visualize_statistics(stats: Dict, output_path: str = None): | |
| """Create visualization of pseudolabeling statistics""" | |
| fig, axes = plt.subplots(2, 2, figsize=(12, 8)) | |
| # Annotations by source | |
| ax = axes[0, 0] | |
| sources = list(stats['annotations_by_source'].keys()) | |
| counts = list(stats['annotations_by_source'].values()) | |
| ax.bar(sources, counts) | |
| ax.set_title('Annotations by Source') | |
| ax.set_xlabel('Source') | |
| ax.set_ylabel('Count') | |
| # Score distribution | |
| ax = axes[0, 1] | |
| for source, scores in stats['scores_by_source'].items(): | |
| if scores and source == 'predicted': | |
| ax.hist(scores, bins=20, alpha=0.7, label=source) | |
| ax.set_title('Score Distribution (Predicted)') | |
| ax.set_xlabel('Score') | |
| ax.set_ylabel('Count') | |
| ax.legend() | |
| # Summary stats | |
| ax = axes[1, 0] | |
| ax.axis('off') | |
| summary_text = f""" | |
| Summary Statistics: | |
| Total Images: {stats['total_images']} | |
| Images with Annotations: {stats['images_with_annotations']} | |
| Total Annotations: {stats['total_annotations']} | |
| Avg Annotations/Image: {stats['avg_annotations_per_image']:.2f} | |
| Original Annotations: {stats['annotations_by_source'].get('original', 0)} | |
| Predicted Annotations: {stats['annotations_by_source'].get('predicted', 0)} | |
| """ | |
| ax.text(0.1, 0.5, summary_text, fontsize=12, verticalalignment='center') | |
| # Pie chart of sources | |
| ax = axes[1, 1] | |
| if counts: | |
| ax.pie(counts, labels=sources, autopct='%1.1f%%') | |
| ax.set_title('Annotation Sources') | |
| plt.tight_layout() | |
| if output_path: | |
| plt.savefig(output_path) | |
| print(f"Statistics plot saved to {output_path}") | |
| else: | |
| plt.show() | |
| def export_for_training(pseudolabel_file: str, output_dir: str, | |
| train_ratio: float = 0.8, min_annotations: int = 1): | |
| """Export pseudolabeled data in training format""" | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Load data | |
| with open(pseudolabel_file, 'r') as f: | |
| data = json.load(f) | |
| # Filter images with minimum annotations | |
| img_ann_count = defaultdict(int) | |
| for ann in data['annotations']: | |
| img_ann_count[ann['image_id']] += 1 | |
| valid_images = [img for img in data['images'] | |
| if img_ann_count[img['id']] >= min_annotations] | |
| # Split into train/val | |
| n_train = int(len(valid_images) * train_ratio) | |
| np.random.shuffle(valid_images) | |
| train_images = valid_images[:n_train] | |
| val_images = valid_images[n_train:] | |
| train_img_ids = {img['id'] for img in train_images} | |
| val_img_ids = {img['id'] for img in val_images} | |
| # Create train and val datasets | |
| train_data = { | |
| 'info': data.get('info', {}), | |
| 'licenses': data.get('licenses', []), | |
| 'categories': data.get('categories', []), | |
| 'images': train_images, | |
| 'annotations': [ann for ann in data['annotations'] | |
| if ann['image_id'] in train_img_ids] | |
| } | |
| val_data = { | |
| 'info': data.get('info', {}), | |
| 'licenses': data.get('licenses', []), | |
| 'categories': data.get('categories', []), | |
| 'images': val_images, | |
| 'annotations': [ann for ann in data['annotations'] | |
| if ann['image_id'] in val_img_ids] | |
| } | |
| # Save files | |
| with open(output_dir / 'train_pseudo.json', 'w') as f: | |
| json.dump(train_data, f, indent=2) | |
| with open(output_dir / 'val_pseudo.json', 'w') as f: | |
| json.dump(val_data, f, indent=2) | |
| print(f"Training data exported to {output_dir}") | |
| print(f"Train: {len(train_images)} images, {len(train_data['annotations'])} annotations") | |
| print(f"Val: {len(val_images)} images, {len(val_data['annotations'])} annotations") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Pseudolabeling utilities") | |
| subparsers = parser.add_subparsers(dest='command', help='Command to run') | |
| # Stats command | |
| stats_parser = subparsers.add_parser('stats', help='Calculate statistics') | |
| stats_parser.add_argument('--input', required=True, help='Pseudolabeled JSON file') | |
| stats_parser.add_argument('--plot', help='Output path for statistics plot') | |
| # Sort command | |
| sort_parser = subparsers.add_parser('sort', help='Sort images by similarity') | |
| sort_parser.add_argument('--progress', required=True, help='Progress JSON file') | |
| sort_parser.add_argument('--annotations', required=True, help='Annotations JSON file') | |
| sort_parser.add_argument('--output', help='Output file for sorted list') | |
| # Merge command | |
| merge_parser = subparsers.add_parser('merge', help='Merge annotations') | |
| merge_parser.add_argument('--original', required=True, help='Original annotations') | |
| merge_parser.add_argument('--pseudo', required=True, help='Pseudolabeled annotations') | |
| merge_parser.add_argument('--output', required=True, help='Output file') | |
| merge_parser.add_argument('--min-score', type=float, default=0.3, help='Minimum score') | |
| merge_parser.add_argument('--no-original', action='store_true', help='Don\'t keep original') | |
| # Export command | |
| export_parser = subparsers.add_parser('export', help='Export for training') | |
| export_parser.add_argument('--input', required=True, help='Pseudolabeled JSON file') | |
| export_parser.add_argument('--output', required=True, help='Output directory') | |
| export_parser.add_argument('--train-ratio', type=float, default=0.8, help='Train split ratio') | |
| export_parser.add_argument('--min-anns', type=int, default=1, help='Min annotations per image') | |
| args = parser.parse_args() | |
| if args.command == 'stats': | |
| data = load_pseudolabels(args.input) | |
| stats = calculate_statistics(data) | |
| print("\nPseudolabeling Statistics:") | |
| print("-" * 40) | |
| for key, value in stats.items(): | |
| if isinstance(value, dict): | |
| print(f"{key}:") | |
| for k, v in value.items(): | |
| if isinstance(v, list): | |
| print(f" {k}: {len(v)} items") | |
| else: | |
| print(f" {k}: {v}") | |
| else: | |
| print(f"{key}: {value}") | |
| if args.plot: | |
| visualize_statistics(stats, args.plot) | |
| elif args.command == 'sort': | |
| sorted_images = sort_images_by_similarity(args.progress, args.annotations) | |
| print("\nTop 10 images by similarity:") | |
| print("-" * 60) | |
| for i, img in enumerate(sorted_images[:10]): | |
| print(f"{i+1}. {img['file_name']}: " | |
| f"similarity={img['similarity']:.3f}, " | |
| f"original={img['n_original']}, " | |
| f"predicted={img['n_predicted']}, " | |
| f"processed={img['processed']}") | |
| if args.output: | |
| with open(args.output, 'w') as f: | |
| json.dump(sorted_images, f, indent=2) | |
| print(f"\nSorted list saved to {args.output}") | |
| elif args.command == 'merge': | |
| merge_annotations( | |
| args.original, | |
| args.pseudo, | |
| args.output, | |
| keep_original=not args.no_original, | |
| min_score=args.min_score | |
| ) | |
| elif args.command == 'export': | |
| export_for_training( | |
| args.input, | |
| args.output, | |
| train_ratio=args.train_ratio, | |
| min_annotations=args.min_anns | |
| ) | |
| else: | |
| parser.print_help() | |
| if __name__ == "__main__": | |
| main() | |