Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Life Coach v1 - Phi-4 Fine-tuned Life Coaching Assistant | |
| A simple command-line life coaching assistant using Microsoft's Phi-4 model. | |
| Fine-tunes on life coaching conversations and provides interactive chat sessions. | |
| """ | |
| import torch | |
| import json | |
| import os | |
| import gc | |
| import argparse | |
| from pathlib import Path | |
| from typing import Optional | |
| from tqdm import tqdm | |
| # Set PyTorch CUDA memory allocation config to reduce fragmentation | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForSeq2Seq | |
| ) | |
| from datasets import Dataset, load_dataset, concatenate_datasets | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType | |
| import logging | |
| import random | |
| import shutil | |
| import gzip | |
| from typing import List, Dict | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def cleanup_gpu_memory(): | |
| """ | |
| Clean up GPU memory before starting the program. | |
| Clears PyTorch cache and runs garbage collection. | |
| """ | |
| logger.info("=" * 80) | |
| logger.info("GPU MEMORY CLEANUP") | |
| logger.info("=" * 80) | |
| if torch.cuda.is_available(): | |
| # Clear PyTorch CUDA cache | |
| torch.cuda.empty_cache() | |
| # Run garbage collection | |
| gc.collect() | |
| # Get GPU memory stats | |
| for i in range(torch.cuda.device_count()): | |
| total = torch.cuda.get_device_properties(i).total_memory / 1024**3 | |
| reserved = torch.cuda.memory_reserved(i) / 1024**3 | |
| allocated = torch.cuda.memory_allocated(i) / 1024**3 | |
| free = total - reserved | |
| logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}") | |
| logger.info(f" Total memory: {total:.2f} GB") | |
| logger.info(f" Reserved: {reserved:.2f} GB") | |
| logger.info(f" Allocated: {allocated:.2f} GB") | |
| logger.info(f" Free: {free:.2f} GB") | |
| if reserved > 1.0: # More than 1GB reserved | |
| logger.warning(f" ⚠️ GPU {i} has {reserved:.2f} GB reserved!") | |
| logger.warning(f" ⚠️ This might be from a previous run.") | |
| logger.warning(f" ⚠️ If you encounter OOM errors, kill other processes using:") | |
| logger.warning(f" ⚠️ nvidia-smi | grep python") | |
| else: | |
| logger.warning("No CUDA GPUs available! Running on CPU (very slow).") | |
| logger.info("=" * 80) | |
| def clear_hf_cache(): | |
| """Clear Hugging Face datasets cache to save disk space.""" | |
| try: | |
| from datasets import config | |
| cache_dir = config.HF_DATASETS_CACHE | |
| if os.path.exists(cache_dir): | |
| # Get size before clearing | |
| size_mb = sum(os.path.getsize(os.path.join(dirpath,filename)) | |
| for dirpath, _, filenames in os.walk(cache_dir) | |
| for filename in filenames) / (1024 * 1024) | |
| logger.info(f"Clearing HF cache ({size_mb:.1f} MB)...") | |
| shutil.rmtree(cache_dir, ignore_errors=True) | |
| os.makedirs(cache_dir, exist_ok=True) | |
| logger.info("✓ Cache cleared") | |
| except Exception as e: | |
| logger.warning(f"Failed to clear cache: {e}") | |
| def load_mental_health_counseling() -> List[Dict]: | |
| """Load Amod/mental_health_counseling_conversations dataset - ALL samples.""" | |
| logger.info(f"Loading mental health counseling dataset...") | |
| try: | |
| dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train") | |
| logger.info(f" Dataset has {len(dataset)} samples available") | |
| conversations = [] | |
| for item in dataset: | |
| # Format: Context (user) -> Response (assistant) | |
| conversations.append({ | |
| "messages": [ | |
| {"role": "user", "content": item.get("Context", "").strip()}, | |
| {"role": "assistant", "content": item.get("Response", "").strip()} | |
| ] | |
| }) | |
| logger.info(f"✓ Loaded {len(conversations)} mental health counseling conversations") | |
| return conversations | |
| except Exception as e: | |
| logger.warning(f"Failed to load mental health counseling dataset: {e}") | |
| return [] | |
| def load_counsel_chat() -> List[Dict]: | |
| """Load nbertagnolli/counsel-chat dataset - ALL samples.""" | |
| logger.info(f"Loading CounselChat (nbertagnolli) dataset...") | |
| try: | |
| dataset = load_dataset("nbertagnolli/counsel-chat", split="train") | |
| logger.info(f" Dataset has {len(dataset)} samples available") | |
| conversations = [] | |
| for item in dataset: | |
| # Try different possible field names | |
| question = None | |
| answer = None | |
| # Common field patterns | |
| for q_field in ["questionText", "question", "query", "input", "user_message"]: | |
| if q_field in item and item.get(q_field): | |
| question = item[q_field].strip() | |
| break | |
| for a_field in ["answerText", "answer", "response", "output", "counselor_message"]: | |
| if a_field in item and item.get(a_field): | |
| answer = item[a_field].strip() | |
| break | |
| if question and answer: | |
| conversations.append({ | |
| "messages": [ | |
| {"role": "user", "content": question}, | |
| {"role": "assistant", "content": answer} | |
| ] | |
| }) | |
| logger.info(f"✓ Loaded {len(conversations)} CounselChat conversations") | |
| return conversations | |
| except Exception as e: | |
| logger.warning(f"Failed to load CounselChat dataset: {e}") | |
| return [] | |
| def load_cbt_cognitive_distortions() -> List[Dict]: | |
| """Load epsilon3/cbt-cognitive-distortions-analysis dataset - ALL samples.""" | |
| logger.info(f"Loading CBT Cognitive Distortions dataset...") | |
| try: | |
| dataset = load_dataset("epsilon3/cbt-cognitive-distortions-analysis", split="train") | |
| logger.info(f" Dataset has {len(dataset)} samples available") | |
| conversations = [] | |
| for item in dataset: | |
| # Try different field patterns | |
| user_msg = None | |
| assistant_msg = None | |
| for u_field in ["input", "text", "thought", "statement", "user_input"]: | |
| if u_field in item and item.get(u_field): | |
| user_msg = item[u_field].strip() | |
| break | |
| for a_field in ["output", "analysis", "reframe", "response", "cbt_response"]: | |
| if a_field in item and item.get(a_field): | |
| assistant_msg = item[a_field].strip() | |
| break | |
| if user_msg and assistant_msg: | |
| conversations.append({ | |
| "messages": [ | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": assistant_msg} | |
| ] | |
| }) | |
| logger.info(f"✓ Loaded {len(conversations)} CBT Cognitive Distortions conversations") | |
| return conversations | |
| except Exception as e: | |
| logger.warning(f"Failed to load CBT Cognitive Distortions dataset: {e}") | |
| return [] | |
| def load_peer_counseling_reflections() -> List[Dict]: | |
| """Load emoneil/reflections-in-peer-counseling dataset - ALL samples.""" | |
| logger.info(f"Loading Peer Counseling Reflections dataset...") | |
| try: | |
| dataset = load_dataset("emoneil/reflections-in-peer-counseling", split="train") | |
| logger.info(f" Dataset has {len(dataset)} samples available") | |
| conversations = [] | |
| for item in dataset: | |
| # Try different field patterns | |
| user_msg = None | |
| assistant_msg = None | |
| for u_field in ["question", "statement", "input", "user_message", "counselee"]: | |
| if u_field in item and item.get(u_field): | |
| user_msg = item[u_field].strip() | |
| break | |
| for a_field in ["reflection", "response", "output", "counselor_response", "counselor"]: | |
| if a_field in item and item.get(a_field): | |
| assistant_msg = item[a_field].strip() | |
| break | |
| if user_msg and assistant_msg: | |
| conversations.append({ | |
| "messages": [ | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": assistant_msg} | |
| ] | |
| }) | |
| logger.info(f"✓ Loaded {len(conversations)} Peer Counseling Reflections conversations") | |
| return conversations | |
| except Exception as e: | |
| logger.warning(f"Failed to load Peer Counseling Reflections dataset: {e}") | |
| return [] | |
| def load_dolly_dataset() -> List[Dict]: | |
| """Load databricks-dolly-15k dataset (instruction-following) - ALL relevant samples.""" | |
| logger.info(f"Loading Dolly instruction dataset...") | |
| try: | |
| dataset = load_dataset("databricks/databricks-dolly-15k", split="train") | |
| logger.info(f" Dataset has {len(dataset)} samples available") | |
| # Filter for relevant categories (brainstorming, open_qa, creative_writing) | |
| relevant_categories = {"brainstorming", "open_qa", "creative_writing", "general_qa"} | |
| conversations = [] | |
| for item in dataset: | |
| if item.get("category", "") in relevant_categories: | |
| instruction = item.get("instruction", "").strip() | |
| context = item.get("context", "").strip() | |
| response = item.get("response", "").strip() | |
| # Combine instruction and context if both exist | |
| user_message = f"{instruction}\n\n{context}" if context else instruction | |
| if user_message and response: | |
| conversations.append({ | |
| "messages": [ | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": response} | |
| ] | |
| }) | |
| logger.info(f"✓ Loaded {len(conversations)} Dolly instruction conversations (filtered from {len(dataset)} total)") | |
| return conversations | |
| except Exception as e: | |
| logger.warning(f"Failed to load Dolly dataset: {e}") | |
| return [] | |
| def load_mentalchat16k() -> List[Dict]: | |
| """Load ShenLab/MentalChat16K dataset - ALL samples.""" | |
| logger.info(f"Loading MentalChat16K dataset...") | |
| try: | |
| dataset = load_dataset("ShenLab/MentalChat16K", split="train") | |
| logger.info(f" Dataset has {len(dataset)} samples available") | |
| conversations = [] | |
| for item in dataset: | |
| # Try different possible field names | |
| user_msg = None | |
| assistant_msg = None | |
| # Common field name patterns | |
| for user_field in ["query", "question", "input", "user", "prompt", "instruction"]: | |
| if user_field in item and item.get(user_field): | |
| user_msg = item[user_field].strip() | |
| break | |
| for assistant_field in ["response", "answer", "output", "assistant", "reply"]: | |
| if assistant_field in item and item.get(assistant_field): | |
| assistant_msg = item[assistant_field].strip() | |
| break | |
| if user_msg and assistant_msg: | |
| conversations.append({ | |
| "messages": [ | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": assistant_msg} | |
| ] | |
| }) | |
| logger.info(f"✓ Loaded {len(conversations)} MentalChat16K conversations") | |
| return conversations | |
| except Exception as e: | |
| logger.warning(f"Failed to load MentalChat16K dataset: {e}") | |
| return [] | |
| def load_additional_mental_health_datasets() -> List[Dict]: | |
| """Load additional mental health datasets - ALL samples.""" | |
| logger.info(f"Loading additional mental health datasets...") | |
| all_conversations = [] | |
| # List of additional datasets to try | |
| additional_datasets = [ | |
| ("heliosbrahma/mental_health_chatbot_dataset", ["prompt", "question"], ["response", "answer"]), | |
| ("mpingale/mental-health-chat-dataset", ["question", "query"], ["answer", "response"]), | |
| ("sauravjoshi23/psychology-dataset", ["input", "question"], ["output", "answer"]), | |
| ] | |
| for dataset_name, user_fields, assistant_fields in additional_datasets: | |
| try: | |
| logger.info(f" Loading {dataset_name}...") | |
| dataset = load_dataset(dataset_name, split="train") | |
| logger.info(f" Has {len(dataset)} samples available") | |
| for item in dataset: | |
| # Try different field names | |
| user_msg = None | |
| assistant_msg = None | |
| for field in user_fields: | |
| if field in item and item.get(field): | |
| user_msg = item[field].strip() | |
| break | |
| for field in assistant_fields: | |
| if field in item and item.get(field): | |
| assistant_msg = item[field].strip() | |
| break | |
| if user_msg and assistant_msg: | |
| all_conversations.append({ | |
| "messages": [ | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": assistant_msg} | |
| ] | |
| }) | |
| logger.info(f" ✓ Loaded {len([c for c in all_conversations if c])} from this dataset") | |
| except Exception as e: | |
| logger.warning(f" Failed: {e}") | |
| continue | |
| logger.info(f"✓ Loaded {len(all_conversations)} additional mental health conversations total") | |
| return all_conversations | |
| def quality_filter_conversation(conv: Dict, min_response_length: int = 50, max_total_length: int = 2048) -> bool: | |
| """Filter conversation based on quality criteria.""" | |
| try: | |
| messages = conv.get("messages", []) | |
| if len(messages) < 2: | |
| return False | |
| # Check response length | |
| assistant_msg = [m for m in messages if m.get("role") == "assistant"] | |
| if not assistant_msg: | |
| return False | |
| response = assistant_msg[0].get("content", "") | |
| if len(response) < min_response_length: | |
| return False | |
| # Check total length | |
| total_length = sum(len(m.get("content", "")) for m in messages) | |
| if total_length > max_total_length: | |
| return False | |
| # Check for empty messages | |
| if any(not m.get("content", "").strip() for m in messages): | |
| return False | |
| return True | |
| except: | |
| return False | |
| def load_mixed_dataset( | |
| total_samples: int = 100000, | |
| cache_file: str = "mixed_lifecoach_dataset_100k.jsonl.gz", # Now compressed by default | |
| use_cache: bool = True | |
| ) -> List[Dict]: | |
| """ | |
| Load and mix multiple datasets for comprehensive life coaching training. | |
| Saves compressed cache to save disk space. | |
| Datasets loaded (ALL available samples): | |
| 1. Mental Health Counseling (Amod/mental_health_counseling_conversations) | |
| 2. CounselChat (nbertagnolli/counsel-chat) | |
| 3. CBT Cognitive Distortions (epsilon3/cbt-cognitive-distortions-analysis) | |
| 4. Peer Counseling Reflections (emoneil/reflections-in-peer-counseling) | |
| 5. MentalChat16K (ShenLab/MentalChat16K) | |
| 6. Dolly Instructions (databricks/databricks-dolly-15k - filtered categories) | |
| 7-8. Additional mental health datasets (heliosbrahma, mpingale, sauravjoshi23) | |
| """ | |
| cache_path = Path(cache_file) | |
| cache_path_uncompressed = Path(cache_file.replace('.gz', '')) | |
| # Try to load from compressed cache first | |
| if use_cache and cache_path.exists(): | |
| logger.info(f"Loading cached dataset from {cache_file} (compressed)...") | |
| try: | |
| conversations = [] | |
| with gzip.open(cache_path, 'rt', encoding='utf-8') as f: | |
| for line in f: | |
| conversations.append(json.loads(line.strip())) | |
| logger.info(f"✓ Loaded {len(conversations)} conversations from compressed cache") | |
| return conversations | |
| except Exception as e: | |
| logger.warning(f"Failed to load compressed cache: {e}. Trying uncompressed...") | |
| # Try uncompressed cache (backward compatibility) | |
| if use_cache and cache_path_uncompressed.exists(): | |
| logger.info(f"Loading cached dataset from {cache_path_uncompressed} (uncompressed)...") | |
| try: | |
| conversations = [] | |
| with open(cache_path_uncompressed, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| conversations.append(json.loads(line.strip())) | |
| logger.info(f"✓ Loaded {len(conversations)} conversations from uncompressed cache") | |
| return conversations | |
| except Exception as e: | |
| logger.warning(f"Failed to load cache: {e}. Rebuilding dataset...") | |
| # Load ALL available samples from each dataset | |
| logger.info("=" * 80) | |
| logger.info(f"LOADING MIXED DATASET (Target: ~{total_samples} samples)") | |
| logger.info("Loading ALL available samples from each dataset") | |
| logger.info("=" * 80) | |
| all_conversations = [] | |
| # Load each dataset ONE AT A TIME and clear cache after each | |
| # This saves disk space by not keeping all downloads simultaneously | |
| logger.info("Dataset 1/8: Mental Health Counseling (Amod)") | |
| all_conversations.extend(load_mental_health_counseling()) | |
| logger.info(f" Running total: {len(all_conversations)} conversations") | |
| clear_hf_cache() | |
| gc.collect() | |
| # Stop early if we've reached target | |
| if len(all_conversations) >= total_samples: | |
| logger.info(f"✓ Reached target of {total_samples} samples, stopping dataset loading") | |
| else: | |
| logger.info("Dataset 2/8: CounselChat (nbertagnolli)") | |
| all_conversations.extend(load_counsel_chat()) | |
| logger.info(f" Running total: {len(all_conversations)} conversations") | |
| clear_hf_cache() | |
| gc.collect() | |
| if len(all_conversations) >= total_samples: | |
| logger.info(f"✓ Reached target of {total_samples} samples, stopping dataset loading") | |
| else: | |
| logger.info("Dataset 3/8: CBT Cognitive Distortions (epsilon3)") | |
| all_conversations.extend(load_cbt_cognitive_distortions()) | |
| logger.info(f" Running total: {len(all_conversations)} conversations") | |
| clear_hf_cache() | |
| gc.collect() | |
| if len(all_conversations) >= total_samples: | |
| logger.info(f"✓ Reached target of {total_samples} samples, stopping dataset loading") | |
| else: | |
| logger.info("Dataset 4/8: Peer Counseling Reflections (emoneil)") | |
| all_conversations.extend(load_peer_counseling_reflections()) | |
| logger.info(f" Running total: {len(all_conversations)} conversations") | |
| clear_hf_cache() | |
| gc.collect() | |
| if len(all_conversations) >= total_samples: | |
| logger.info(f"✓ Reached target of {total_samples} samples, stopping dataset loading") | |
| else: | |
| logger.info("Dataset 5/8: MentalChat16K (ShenLab)") | |
| all_conversations.extend(load_mentalchat16k()) | |
| logger.info(f" Running total: {len(all_conversations)} conversations") | |
| clear_hf_cache() | |
| gc.collect() | |
| if len(all_conversations) >= total_samples: | |
| logger.info(f"✓ Reached target of {total_samples} samples, stopping dataset loading") | |
| else: | |
| logger.info("Dataset 6/8: Dolly Instructions (databricks)") | |
| all_conversations.extend(load_dolly_dataset()) | |
| logger.info(f" Running total: {len(all_conversations)} conversations") | |
| clear_hf_cache() | |
| gc.collect() | |
| if len(all_conversations) >= total_samples: | |
| logger.info(f"✓ Reached target of {total_samples} samples, stopping dataset loading") | |
| else: | |
| logger.info("Datasets 7-8: Additional Mental Health Datasets") | |
| all_conversations.extend(load_additional_mental_health_datasets()) | |
| logger.info(f" Running total: {len(all_conversations)} conversations") | |
| clear_hf_cache() | |
| gc.collect() | |
| logger.info("=" * 80) | |
| logger.info(f"Total conversations loaded: {len(all_conversations)}") | |
| # Apply quality filtering | |
| logger.info("Applying quality filters...") | |
| filtered_conversations = [conv for conv in all_conversations if quality_filter_conversation(conv)] | |
| logger.info(f"✓ After filtering: {len(filtered_conversations)} conversations") | |
| # Shuffle to mix datasets | |
| random.shuffle(filtered_conversations) | |
| # Trim to target size | |
| if len(filtered_conversations) > total_samples: | |
| filtered_conversations = filtered_conversations[:total_samples] | |
| logger.info(f"Final dataset size: {len(filtered_conversations)} conversations") | |
| # Save compressed cache to save disk space | |
| if use_cache: | |
| logger.info(f"Saving compressed cache to {cache_file}...") | |
| try: | |
| with gzip.open(cache_path, 'wt', encoding='utf-8') as f: | |
| for conv in filtered_conversations: | |
| f.write(json.dumps(conv, ensure_ascii=False) + '\n') | |
| # Get file sizes for comparison | |
| compressed_size_mb = cache_path.stat().st_size / (1024 * 1024) | |
| logger.info(f"✓ Compressed cache saved successfully ({compressed_size_mb:.1f} MB)") | |
| except Exception as e: | |
| logger.warning(f"Failed to save compressed cache: {e}") | |
| logger.info("=" * 80) | |
| return filtered_conversations | |
| class LifeCoachModel: | |
| """Life coaching assistant using Phi-4 model.""" | |
| def __init__( | |
| self, | |
| model_name: str = "microsoft/Phi-4", | |
| model_save_path: str = "/data/life_coach_model", | |
| train_file: str = "cbt_life_coach_improved_50000.jsonl", | |
| max_length: int = 2048 | |
| ): | |
| """ | |
| Initialize the Life Coach model. | |
| Args: | |
| model_name: Hugging Face model identifier | |
| model_save_path: Path to save/load fine-tuned model | |
| train_file: Path to training data file (JSONL format) | |
| max_length: Maximum sequence length for training | |
| """ | |
| self.model_name = model_name | |
| # Check if /data is writable, otherwise use local directory | |
| save_path = Path(model_save_path) | |
| if str(save_path).startswith("/data"): | |
| try: | |
| Path("/data").mkdir(parents=True, exist_ok=True) | |
| # Test write permissions | |
| test_file = Path("/data/.test_write") | |
| test_file.touch() | |
| test_file.unlink() | |
| self.model_save_path = save_path | |
| logger.info(f"Using /data directory for model storage: {save_path}") | |
| except (PermissionError, OSError) as e: | |
| # Fall back to local directory | |
| local_path = Path("./data/life_coach_model") | |
| logger.warning(f"/data directory not writable ({e}), using local directory: {local_path}") | |
| self.model_save_path = local_path | |
| else: | |
| self.model_save_path = save_path | |
| self.train_file = Path(train_file) | |
| self.max_length = max_length | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Device: {self.device}") | |
| logger.info(f"Model: {model_name}") | |
| logger.info(f"Save path: {self.model_save_path}") | |
| logger.info(f"Training file: {self.train_file}") | |
| self.tokenizer = None | |
| self.model = None | |
| def load_tokenizer(self): | |
| """Carica il tokenizer da /data/hf_cache (persistente) o scaricalo una volta.""" | |
| logger.info("Loading tokenizer...") | |
| cache_dir = "/data/hf_cache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| cache_dir=cache_dir, | |
| local_files_only=False, # Permette download solo se non esiste | |
| trust_remote_code=True, | |
| use_fast=True | |
| ) | |
| logger.info(f"Tokenizer caricato (cache: {cache_dir})") | |
| except Exception as e: | |
| logger.error(f"Errore critico nel caricamento tokenizer: {e}") | |
| raise | |
| def load_model(self, fine_tuned=True): | |
| """Load the fine-tuned model with safe settings for HF Spaces.""" | |
| logger.info(f"Loading {'fine-tuned' if fine_tuned else 'base'} model from {self.model_save_path}") | |
| # Forza impostazioni sicure | |
| import torch | |
| from transformers import AutoModelForCausalLM | |
| from peft import PeftModel | |
| base_model_name = self.model_name | |
| # Carica modello base con device_map e offload | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| offload_folder="/tmp/offload", # Usa /tmp per offload | |
| cache_dir="/data/hf_cache" | |
| ) | |
| if fine_tuned: | |
| logger.info(f"Loading adapter from {self.model_save_path}") | |
| self.model = PeftModel.from_pretrained( | |
| base_model, | |
| self.model_save_path, | |
| device_map="auto", | |
| offload_folder="/tmp/offload", | |
| torch_dtype=torch.float16 | |
| ) | |
| else: | |
| self.model = base_model | |
| self.model.eval() | |
| logger.info("Model loaded successfully!") | |
| def load_training_data(self, num_samples: Optional[int] = None) -> Dataset: | |
| """ | |
| Load training data from mixed datasets or JSONL file. | |
| Args: | |
| num_samples: Number of samples to load (None for 100,000 default) | |
| Returns: | |
| Dataset object | |
| """ | |
| # Try to load from mixed datasets first (new method) | |
| # If train_file doesn't exist or is the old one, use mixed datasets | |
| use_mixed_datasets = True | |
| if self.train_file.exists(): | |
| # Check if it's the old single dataset file | |
| if "cbt_life_coach" in str(self.train_file): | |
| logger.info("Found old training file. Using new mixed datasets instead...") | |
| use_mixed_datasets = True | |
| else: | |
| # It might be a cached mixed dataset | |
| logger.info(f"Found training file at {self.train_file}") | |
| use_mixed_datasets = False | |
| if use_mixed_datasets: | |
| # Load mixed datasets from Hugging Face | |
| logger.info("Loading mixed datasets from Hugging Face...") | |
| if num_samples is None: | |
| num_samples = 100000 # Default to 100k samples | |
| # Load mixed dataset (will use cache if available) | |
| cache_file = f"mixed_lifecoach_dataset_{num_samples}.jsonl.gz" # Compressed format | |
| data = load_mixed_dataset( | |
| total_samples=num_samples, | |
| cache_file=cache_file, | |
| use_cache=True | |
| ) | |
| else: | |
| # Fall back to loading from JSONL file | |
| logger.info(f"Loading training data from {self.train_file}") | |
| data = [] | |
| with open(self.train_file, 'r', encoding='utf-8') as f: | |
| for i, line in enumerate(f): | |
| if num_samples and i >= num_samples: | |
| break | |
| try: | |
| data.append(json.loads(line.strip())) | |
| except json.JSONDecodeError: | |
| logger.warning(f"Skipping invalid JSON at line {i+1}") | |
| logger.info(f"Loaded {len(data)} training examples") | |
| # Convert to Hugging Face Dataset | |
| dataset = Dataset.from_list(data) | |
| # Preprocess for Phi-4 format | |
| logger.info("Preprocessing data for Phi-4 format...") | |
| dataset = dataset.map( | |
| self._preprocess_function, | |
| batched=True, | |
| remove_columns=dataset.column_names, | |
| desc="Tokenizing" | |
| ) | |
| return dataset | |
| def _preprocess_function(self, examples): | |
| """ | |
| Preprocess data into Phi-4 chat format. | |
| Phi-4 uses: | |
| <|system|> | |
| {system message}<|end|> | |
| <|user|> | |
| {user message}<|end|> | |
| <|assistant|> | |
| {assistant response}<|end|> | |
| """ | |
| texts = [] | |
| # Handle both 'conversations' (our format) and 'messages' (standard format) | |
| conversations_key = 'conversations' if 'conversations' in examples else 'messages' | |
| for conversation in examples[conversations_key]: | |
| text = "" | |
| for message in conversation: | |
| # Handle both 'from'/'value' and 'role'/'content' formats | |
| if 'from' in message: | |
| role = message['from'] | |
| content = message['value'] | |
| else: | |
| role = message['role'] | |
| content = message['content'] | |
| # Convert to Phi-4 format | |
| if role == 'system': | |
| text += f"<|system|>\n{content}<|end|>\n" | |
| elif role == 'user': | |
| text += f"<|user|>\n{content}<|end|>\n" | |
| elif role == 'assistant': | |
| text += f"<|assistant|>\n{content}<|end|>\n" | |
| texts.append(text) | |
| # Tokenize with dynamic padding (like quantum server) | |
| # Don't pad here - let DataCollatorForSeq2Seq handle it dynamically per batch | |
| model_inputs = self.tokenizer( | |
| texts, | |
| max_length=self.max_length, | |
| truncation=True, | |
| padding=False, # Dynamic padding - saves massive memory! | |
| return_tensors=None # Don't convert to tensors yet | |
| ) | |
| # Set labels (for causal language modeling, labels = input_ids) | |
| # Note: .copy() instead of .clone() since we're not using tensors yet | |
| model_inputs["labels"] = model_inputs["input_ids"].copy() | |
| return model_inputs | |
| def setup_lora(self): | |
| """Setup LoRA (Low-Rank Adaptation) for efficient fine-tuning.""" | |
| logger.info("Setting up LoRA adapters...") | |
| # Prepare model for k-bit training (critical for load_in_8bit=True) | |
| logger.info("Preparing model for 8-bit training...") | |
| self.model = prepare_model_for_kbit_training(self.model) | |
| # Enable gradient checkpointing to save GPU memory | |
| # This reduces memory usage by 20-30 GB with minimal performance impact | |
| if hasattr(self.model, 'gradient_checkpointing_enable'): | |
| self.model.gradient_checkpointing_enable() | |
| logger.info("✓ Gradient checkpointing enabled (saves 20-30 GB GPU memory)") | |
| # LoRA configuration | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| r=16, # Rank | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| bias="none", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] # Attention layers | |
| ) | |
| # Apply LoRA | |
| self.model = get_peft_model(self.model, lora_config) | |
| # Print trainable parameters | |
| trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} " | |
| f"({100 * trainable_params / total_params:.2f}%)") | |
| def fine_tune( | |
| self, | |
| num_samples: Optional[int] = 5000, | |
| epochs: int = 3, | |
| batch_size: int = 8, | |
| learning_rate: float = 5e-5, | |
| gradient_accumulation_steps: int = 2 | |
| ): | |
| """ | |
| Fine-tune the model on life coaching data. | |
| Args: | |
| num_samples: Number of training samples (None for all) | |
| epochs: Number of training epochs | |
| batch_size: Training batch size | |
| learning_rate: Learning rate | |
| gradient_accumulation_steps: Gradient accumulation steps (for memory efficiency) | |
| """ | |
| logger.info("=" * 80) | |
| logger.info("STARTING FINE-TUNING") | |
| logger.info("=" * 80) | |
| # Load data | |
| dataset = self.load_training_data(num_samples) | |
| # Setup LoRA | |
| self.setup_lora() | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir="./training_output", | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| learning_rate=learning_rate, | |
| fp16=True, # Mixed precision training | |
| logging_steps=10, | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| warmup_steps=100, | |
| weight_decay=0.01, | |
| report_to="none", # Disable wandb/tensorboard | |
| ) | |
| # Data collator | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer=self.tokenizer, | |
| model=self.model, | |
| padding=True | |
| ) | |
| # Trainer | |
| trainer = Trainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| data_collator=data_collator, | |
| ) | |
| # Train | |
| logger.info("Training started...") | |
| trainer.train() | |
| logger.info("=" * 80) | |
| logger.info("TRAINING COMPLETED") | |
| logger.info("=" * 80) | |
| # Save model | |
| self.save_model() | |
| def save_model(self): | |
| """Save the fine-tuned model to disk.""" | |
| logger.info(f"Saving model to {self.model_save_path}") | |
| self.model_save_path.mkdir(parents=True, exist_ok=True) | |
| # Save model and tokenizer | |
| self.model.save_pretrained(str(self.model_save_path)) | |
| self.tokenizer.save_pretrained(str(self.model_save_path)) | |
| logger.info("Model saved successfully") | |
| def generate_response(self, prompt: str, max_new_tokens: int = 128, conversation_history: list = None) -> str: | |
| """ | |
| Generate a response to a user prompt. | |
| Args: | |
| prompt: User's input message | |
| max_new_tokens: Maximum tokens to generate | |
| conversation_history: List of previous messages for context | |
| Returns: | |
| Generated response | |
| """ | |
| # Build full conversation context with system prompt | |
| formatted_prompt = "" | |
| # Add system prompt to guide the model's behavior | |
| system_prompt = """You are Robert, a friendly and experienced life coach. Here's your background: | |
| About You: | |
| - Name: Robert (Bob to friends) | |
| - Age: 42 years old | |
| - Experience: 15 years as a certified life coach and motivational speaker | |
| - Education: Master's degree in Psychology from UC Berkeley | |
| - Specialties: Personal growth, career transitions, work-life balance, goal setting, stress management | |
| - Personal: Married with two kids, enjoy hiking and meditation in your free time | |
| - Approach: Warm, empathetic, practical, and solution-focused | |
| Your Coaching Style: | |
| - Respond ONLY to what the user actually tells you - never make assumptions about their problems | |
| - Start conversations in a welcoming, open manner | |
| - Ask clarifying questions to understand their situation better | |
| - Provide practical, actionable advice based on what they share | |
| - Be encouraging and positive, but also honest and realistic | |
| - Keep responses concise and focused (2-4 sentences usually) | |
| - Share brief personal insights when relevant, but keep the focus on the client | |
| Important: Never assume clients have problems they haven't mentioned. Let them guide the conversation and share what's on their mind.""" | |
| formatted_prompt += f"<|system|>\n{system_prompt}<|end|>\n" | |
| # Add conversation history if provided | |
| if conversation_history: | |
| for msg in conversation_history: | |
| if msg["role"] == "user": | |
| formatted_prompt += f"<|user|>\n{msg['content']}<|end|>\n" | |
| elif msg["role"] == "assistant": | |
| formatted_prompt += f"<|assistant|>\n{msg['content']}<|end|>\n" | |
| # Add current prompt | |
| formatted_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n" | |
| # DEBUG: Print the full prompt being sent to the model | |
| logger.info("=" * 80) | |
| logger.info("FULL PROMPT SENT TO MODEL:") | |
| logger.info(formatted_prompt) | |
| logger.info("=" * 80) | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=self.max_length | |
| ).to(self.device) | |
| # Get input length to extract only new tokens | |
| input_length = inputs['input_ids'].shape[1] | |
| # Get the token ID for <|end|> to use as a stopping token | |
| end_token_id = self.tokenizer.convert_tokens_to_ids("<|end|>") | |
| # Build list of EOS token IDs (stop generation at <|end|> or EOS) | |
| eos_token_ids = [self.tokenizer.eos_token_id] | |
| if end_token_id is not None and end_token_id != self.tokenizer.unk_token_id: | |
| eos_token_ids.append(end_token_id) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.7, # Balanced - coherent but still creative | |
| top_p=0.9, # Standard setting for focused responses | |
| top_k=50, # Add top-k sampling | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=eos_token_ids, # Stop at <|end|> or EOS | |
| repetition_penalty=1.15 # Stronger penalty to prevent repetition | |
| ) | |
| # Decode ONLY the newly generated tokens (not the input) | |
| generated_tokens = outputs[0][input_length:] | |
| # Decode without skipping special tokens first to find the end marker | |
| response_with_tokens = self.tokenizer.decode(generated_tokens, skip_special_tokens=False) | |
| # Extract only up to the first <|end|> token (model may generate multi-turn conversations) | |
| if "<|end|>" in response_with_tokens: | |
| response_text = response_with_tokens.split("<|end|>")[0] | |
| else: | |
| response_text = response_with_tokens | |
| # Clean up any remaining special tokens | |
| response_text = response_text.replace("<|assistant|>", "").replace("<|user|>", "").replace("<|system|>", "") | |
| # Remove any remaining special tokens using the tokenizer | |
| response_text = response_text.strip() | |
| return response_text | |
| def interactive_chat(self): | |
| """Start an interactive chat session.""" | |
| logger.info("=" * 80) | |
| logger.info("LIFE COACH V1 - Interactive Chat Session") | |
| logger.info("=" * 80) | |
| print("\nWelcome to Life Coach v1!") | |
| print("I'm here to help you with life coaching, goal setting, motivation, and personal growth.") | |
| print("\nCommands:") | |
| print(" - Type your question or concern to get coaching advice") | |
| print(" - Type 'quit' or 'exit' to end the session") | |
| print(" - Type 'clear' to clear conversation history") | |
| print("=" * 80) | |
| print() | |
| conversation_history = [] | |
| while True: | |
| try: | |
| # Get user input | |
| user_input = input("\n🧑 You: ").strip() | |
| if not user_input: | |
| continue | |
| # Check for exit commands | |
| if user_input.lower() in ['quit', 'exit', 'q']: | |
| print("\n👋 Thank you for using Life Coach v1. Take care!") | |
| break | |
| # Check for clear command | |
| if user_input.lower() == 'clear': | |
| conversation_history = [] | |
| print("✅ Conversation history cleared.") | |
| continue | |
| # Generate response with conversation context | |
| print("\n🤖 Life Coach: ", end="", flush=True) | |
| response = self.generate_response(user_input, conversation_history=conversation_history) | |
| print(response) | |
| # Update conversation history | |
| conversation_history.append({ | |
| "role": "user", | |
| "content": user_input | |
| }) | |
| conversation_history.append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| except KeyboardInterrupt: | |
| print("\n\n👋 Session interrupted. Goodbye!") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error during chat: {e}") | |
| print(f"\n❌ Error: {e}") | |
| def main(): | |
| """Main entry point.""" | |
| parser = argparse.ArgumentParser( | |
| description="Life Coach v1 - Phi-4 based life coaching assistant" | |
| ) | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| choices=["train", "chat", "both"], | |
| default="both", | |
| help="Mode: train (fine-tune only), chat (chat only), both (train then chat)" | |
| ) | |
| parser.add_argument( | |
| "--model-name", | |
| type=str, | |
| default="microsoft/Phi-4", | |
| help="Hugging Face model name" | |
| ) | |
| parser.add_argument( | |
| "--model-path", | |
| type=str, | |
| default="/data/life_coach_model", | |
| help="Path to save/load fine-tuned model" | |
| ) | |
| parser.add_argument( | |
| "--train-file", | |
| type=str, | |
| default="cbt_life_coach_improved_50000.jsonl", | |
| help="Path to training data file (JSONL format)" | |
| ) | |
| parser.add_argument( | |
| "--num-samples", | |
| type=int, | |
| default=-1, | |
| help="Number of training samples (default: -1 for all 100,000 from mixed datasets)" | |
| ) | |
| parser.add_argument( | |
| "--epochs", | |
| type=int, | |
| default=3, | |
| help="Number of training epochs" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=4, | |
| help="Training batch size (default: 4 for memory safety)" | |
| ) | |
| parser.add_argument( | |
| "--learning-rate", | |
| type=float, | |
| default=5e-5, | |
| help="Learning rate (default: 5e-5, matching quantum server)" | |
| ) | |
| parser.add_argument( | |
| "--gradient-accumulation", | |
| type=int, | |
| default=4, | |
| help="Gradient accumulation steps (default: 4, effective batch=16)" | |
| ) | |
| parser.add_argument( | |
| "--force-retrain", | |
| action="store_true", | |
| help="Force retraining even if fine-tuned model exists" | |
| ) | |
| args = parser.parse_args() | |
| # Clean up GPU memory before starting | |
| cleanup_gpu_memory() | |
| # Initialize model | |
| coach = LifeCoachModel( | |
| model_name=args.model_name, | |
| model_save_path=args.model_path, | |
| train_file=args.train_file | |
| ) | |
| # Load tokenizer | |
| coach.load_tokenizer() | |
| # Check if fine-tuned model already exists | |
| model_exists = coach.model_save_path.exists() and (coach.model_save_path / "adapter_model.safetensors").exists() | |
| # Training mode | |
| if args.mode in ["train", "both"]: | |
| # Check if we should skip training | |
| if model_exists and not args.force_retrain: | |
| logger.info("=" * 80) | |
| logger.info("FINE-TUNED MODEL ALREADY EXISTS") | |
| logger.info("=" * 80) | |
| logger.info(f"Found existing model at: {coach.model_save_path}") | |
| logger.info("Skipping training. Loading existing model...") | |
| logger.info("(Use --force-retrain to retrain from scratch)") | |
| logger.info("=" * 80) | |
| # Load the existing fine-tuned model | |
| coach.load_model(fine_tuned=True) | |
| else: | |
| if args.force_retrain and model_exists: | |
| logger.info("=" * 80) | |
| logger.info("FORCING RETRAINING (--force-retrain flag set)") | |
| logger.info("=" * 80) | |
| # Load base model for training | |
| coach.load_model(fine_tuned=False) | |
| # Fine-tune | |
| num_samples = None if args.num_samples == -1 else args.num_samples | |
| coach.fine_tune( | |
| num_samples=num_samples, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| learning_rate=args.learning_rate, | |
| gradient_accumulation_steps=args.gradient_accumulation | |
| ) | |
| # For "both" mode, reload the fine-tuned model for chat | |
| if args.mode == "both": | |
| logger.info("Reloading fine-tuned model for chat...") | |
| coach.load_model(fine_tuned=True) | |
| # If only training mode, exit | |
| if args.mode == "train": | |
| logger.info("Training complete. Use --mode chat to start chatting.") | |
| return | |
| # Chat mode | |
| elif args.mode == "chat": | |
| if not model_exists: | |
| logger.error("=" * 80) | |
| logger.error("ERROR: No fine-tuned model found!") | |
| logger.error("=" * 80) | |
| logger.error(f"Expected location: {coach.model_save_path}") | |
| logger.error("Please train the model first using:") | |
| logger.error(" python3 life_coach_v1.py --mode train") | |
| logger.error("=" * 80) | |
| return | |
| # Load fine-tuned model | |
| logger.info(f"Loading fine-tuned model from {coach.model_save_path}") | |
| coach.load_model(fine_tuned=True) | |
| # Start interactive chat | |
| coach.interactive_chat() | |
| if __name__ == "__main__": | |
| main() | |