""" Data loading utilities for the Chess Challenge. This module provides functions to load and process chess game data from the Lichess dataset on Hugging Face. """ from __future__ import annotations from typing import Dict, Iterator, List, Optional import torch from torch.utils.data import Dataset class ChessDataset(Dataset): """ PyTorch Dataset for chess games. This dataset loads games from a Hugging Face dataset and prepares them for language modeling training. Each game is tokenized and truncated/padded to max_length. The labels are shifted by one position for next-token prediction. Example: >>> from tokenizer import ChessTokenizer >>> tokenizer = ChessTokenizer.build_vocab_from_dataset() >>> dataset = ChessDataset(tokenizer, max_length=256) >>> sample = dataset[0] >>> print(sample["input_ids"].shape) # (256,) """ def __init__( self, tokenizer, dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", max_length: int = 256, max_samples: Optional[int] = None, ): """ Initialize the chess dataset. Args: tokenizer: The chess tokenizer to use. dataset_name: Name of the dataset on Hugging Face Hub. split: Dataset split to use. column: Column containing the game strings. max_length: Maximum sequence length. max_samples: Maximum number of samples to load. """ from datasets import load_dataset self.tokenizer = tokenizer self.max_length = max_length self.column = column # Load dataset dataset = load_dataset(dataset_name, split=split) if max_samples is not None: dataset = dataset.select(range(min(max_samples, len(dataset)))) self.data = dataset def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: game = self.data[idx][self.column] # Prepend BOS token for proper language modeling game_with_bos = self.tokenizer.bos_token + " " + game # Tokenize encoding = self.tokenizer( game_with_bos, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt", ) # Squeeze batch dimension input_ids = encoding["input_ids"].squeeze(0) attention_mask = encoding["attention_mask"].squeeze(0) # Labels are the same as input_ids (model will shift internally) labels = input_ids.clone() # Set padding tokens to -100 to ignore in loss labels[attention_mask == 0] = -100 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } class ChessDataCollator: """ Data collator for chess games. This collator pads sequences to the same length within a batch and creates the appropriate attention masks. """ def __init__(self, tokenizer, max_length: int = 256): self.tokenizer = tokenizer self.max_length = max_length def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]: # Stack tensors input_ids = torch.stack([f["input_ids"] for f in features]) attention_mask = torch.stack([f["attention_mask"] for f in features]) labels = torch.stack([f["labels"] for f in features]) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def create_train_val_datasets( tokenizer, dataset_name: str = "dlouapre/lichess_2025-01_1M", max_length: int = 256, train_samples: Optional[int] = None, val_samples: int = 5000, val_ratio: float = 0.05, ): """ Create training and validation datasets. Args: tokenizer: The chess tokenizer. dataset_name: Name of the dataset. max_length: Maximum sequence length. train_samples: Maximum training samples (None for all). val_samples: Number of validation samples. val_ratio: Ratio of validation samples (used if train_samples is None). Returns: Tuple of (train_dataset, val_dataset). """ from datasets import load_dataset # Load full dataset full_dataset = load_dataset(dataset_name, split="train") # Determine split sizes total = len(full_dataset) if train_samples is not None: n_train = min(train_samples, total - val_samples) else: n_train = int(total * (1 - val_ratio)) n_val = min(val_samples, total - n_train) # Split dataset train_data = full_dataset.select(range(n_train)) val_data = full_dataset.select(range(n_train, n_train + n_val)) # Create dataset objects train_dataset = ChessDataset( tokenizer=tokenizer, dataset_name=dataset_name, max_length=max_length, ) train_dataset.data = train_data val_dataset = ChessDataset( tokenizer=tokenizer, dataset_name=dataset_name, max_length=max_length, ) val_dataset.data = val_data return train_dataset, val_dataset def stream_games( dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", ) -> Iterator[str]: """ Stream games from the dataset for memory-efficient processing. Args: dataset_name: Name of the dataset on Hugging Face Hub. split: Dataset split to use. column: Column containing the game strings. Yields: Game strings one at a time. """ from datasets import load_dataset dataset = load_dataset(dataset_name, split=split, streaming=True) for example in dataset: yield example[column] def analyze_dataset_statistics( dataset_name: str = "dlouapre/lichess_2025-01_1M", max_samples: int = 10000, ) -> Dict: """ Analyze statistics of the chess dataset. Args: dataset_name: Name of the dataset. max_samples: Maximum number of samples to analyze. Returns: Dictionary containing dataset statistics. """ from collections import Counter from datasets import load_dataset dataset = load_dataset(dataset_name, split="train") dataset = dataset.select(range(min(max_samples, len(dataset)))) game_lengths = [] move_counts = Counter() opening_moves = Counter() for example in dataset: moves = example["text"].strip().split() game_lengths.append(len(moves)) move_counts.update(moves) # Track common openings (first 4 moves) if len(moves) >= 4: opening = " ".join(moves[:4]) opening_moves[opening] += 1 return { "total_games": len(dataset), "avg_game_length": sum(game_lengths) / len(game_lengths), "min_game_length": min(game_lengths), "max_game_length": max(game_lengths), "unique_moves": len(move_counts), "most_common_moves": move_counts.most_common(20), "most_common_openings": opening_moves.most_common(10), }