227 lines
9.2 KiB
Python
227 lines
9.2 KiB
Python
"""
|
||
Simulation Pipeline v3.1 - Turbo Production Engine
|
||
Supports concurrent students via ThreadPoolExecutor with Thread-Safe I/O.
|
||
"""
|
||
import time
|
||
import os
|
||
import sys
|
||
import threading
|
||
import pandas as pd
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, cast, Set, Optional, Tuple
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
# Import services
|
||
try:
|
||
from services.data_loader import load_personas, load_questions
|
||
from services.simulator import SimulationEngine
|
||
from services.cognition_simulator import CognitionSimulator
|
||
import config
|
||
except ImportError:
|
||
# Linter path fallback
|
||
sys.path.append(os.path.join(os.getcwd(), "services"))
|
||
from data_loader import load_personas, load_questions
|
||
from simulator import SimulationEngine
|
||
from cognition_simulator import CognitionSimulator
|
||
import config
|
||
|
||
# Initialize Threading Lock for shared resources (saving files, printing)
|
||
save_lock = threading.Lock()
|
||
|
||
def simulate_domain_for_students(
|
||
engine: SimulationEngine,
|
||
students: List[Dict],
|
||
domain: str,
|
||
questions: List[Dict],
|
||
age_group: str,
|
||
output_path: Optional[Path] = None,
|
||
verbose: bool = False
|
||
) -> pd.DataFrame:
|
||
"""
|
||
Simulate one domain for a list of students using multithreading.
|
||
"""
|
||
results: List[Dict] = []
|
||
existing_cpids: Set[str] = set()
|
||
|
||
# Get all Q-codes for this domain (columns)
|
||
all_q_codes = [q['q_code'] for q in questions]
|
||
|
||
if output_path and output_path.exists():
|
||
try:
|
||
df_existing = pd.read_excel(output_path)
|
||
if not df_existing.empty and 'Participant' in df_existing.columns:
|
||
results = df_existing.to_dict('records')
|
||
# Map Student CPID or Participant based on schema
|
||
cpid_col = 'Student CPID' if 'Student CPID' in df_existing.columns else 'Participant'
|
||
# Filter out NaN, empty strings, and 'nan' string values
|
||
existing_cpids = set()
|
||
for cpid in df_existing[cpid_col].dropna().astype(str):
|
||
cpid_str = str(cpid).strip()
|
||
if cpid_str and cpid_str.lower() != 'nan' and cpid_str != '':
|
||
existing_cpids.add(cpid_str)
|
||
print(f" 🔄 Resuming: Found {len(existing_cpids)} students already completed in {output_path.name}")
|
||
except Exception as e:
|
||
print(f" ⚠️ Could not load existing file for resume: {e}")
|
||
|
||
# Process chunks for simulation
|
||
chunk_size = int(getattr(config, 'QUESTIONS_PER_PROMPT', 15))
|
||
questions_list = cast(List[Dict[str, Any]], questions)
|
||
question_chunks: List[List[Dict[str, Any]]] = []
|
||
for i in range(0, len(questions_list), chunk_size):
|
||
question_chunks.append(questions_list[i : i + chunk_size])
|
||
|
||
print(f" [INFO] Splitting {len(questions)} questions into {len(question_chunks)} chunks (size {chunk_size})")
|
||
|
||
# Filter out already processed students
|
||
pending_students = [s for s in students if str(s.get('StudentCPID')) not in existing_cpids]
|
||
|
||
if not pending_students:
|
||
return pd.DataFrame(results, columns=['Participant', 'First Name', 'Last Name', 'Student CPID'] + all_q_codes)
|
||
|
||
def process_student(student: Dict, p_idx: int):
|
||
cpid = student.get('StudentCPID', 'UNKNOWN')
|
||
if verbose or (p_idx % 20 == 0):
|
||
with save_lock:
|
||
print(f" [TURBO] Processing Student {p_idx+1}/{len(pending_students)}: {cpid}")
|
||
|
||
all_answers: Dict[str, Any] = {}
|
||
for c_idx, chunk in enumerate(question_chunks):
|
||
answers = engine.simulate_batch(student, chunk, verbose=verbose)
|
||
|
||
# FAIL-SAFE: Sub-chunking if keys missing
|
||
chunk_codes = [q['q_code'] for q in chunk]
|
||
missing = [code for code in chunk_codes if code not in answers]
|
||
|
||
if missing:
|
||
sub_chunks = [chunk[i : i + 5] for i in range(0, len(chunk), 5)]
|
||
for sc in sub_chunks:
|
||
sc_answers = engine.simulate_batch(student, sc, verbose=verbose)
|
||
if sc_answers:
|
||
answers.update(sc_answers)
|
||
time.sleep(config.LLM_DELAY)
|
||
|
||
all_answers.update(answers)
|
||
time.sleep(config.LLM_DELAY)
|
||
|
||
# Build final row
|
||
row = {
|
||
'Participant': f"{student.get('First Name', '')} {student.get('Last Name', '')}".strip(),
|
||
'First Name': student.get('First Name', ''),
|
||
'Last Name': student.get('Last Name', ''),
|
||
'Student CPID': cpid,
|
||
**{q: all_answers.get(q, '') for q in all_q_codes}
|
||
}
|
||
|
||
# Thread-safe result update and incremental save
|
||
with save_lock:
|
||
results.append(row)
|
||
if output_path:
|
||
columns = ['Participant', 'First Name', 'Last Name', 'Student CPID'] + all_q_codes
|
||
pd.DataFrame(results, columns=columns).to_excel(output_path, index=False)
|
||
|
||
# Execute multithreaded simulation
|
||
max_workers = getattr(config, 'MAX_WORKERS', 5)
|
||
print(f" 🚀 Launching Turbo Simulation with {max_workers} workers...")
|
||
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
for i, student in enumerate(pending_students):
|
||
executor.submit(process_student, student, i)
|
||
|
||
columns = ['Participant', 'First Name', 'Last Name', 'Student CPID'] + all_q_codes
|
||
return pd.DataFrame(results, columns=columns)
|
||
|
||
|
||
def run_full(verbose: bool = False, limit_students: Optional[int] = None) -> None:
|
||
"""
|
||
Executes the full 3000 student simulation across all domains and cognition.
|
||
"""
|
||
adolescents, adults = load_personas()
|
||
|
||
if limit_students:
|
||
adolescents = adolescents[:limit_students]
|
||
adults = adults[:limit_students]
|
||
|
||
print("="*80)
|
||
print(f"🚀 TURBO FULL RUN: {len(adolescents)} Adolescents + {len(adults)} Adults × ALL Domains")
|
||
print("="*80)
|
||
|
||
questions_map = load_questions()
|
||
|
||
all_students = {'adolescent': adolescents, 'adult': adults}
|
||
engine = SimulationEngine(config.ANTHROPIC_API_KEY)
|
||
output_base = config.OUTPUT_DIR / "full_run"
|
||
|
||
for age_key, age_label in [('adolescent', 'adolescense'), ('adult', 'adults')]:
|
||
students = all_students[age_key]
|
||
age_suffix = config.AGE_GROUPS[age_key]
|
||
|
||
print(f"\n📂 Processing {age_label.upper()} ({len(students)} students)")
|
||
|
||
# 1. Survey Domains
|
||
(output_base / age_label / "5_domain").mkdir(parents=True, exist_ok=True)
|
||
for domain in config.DOMAINS:
|
||
file_name = config.OUTPUT_FILE_NAMES.get(domain, f'{domain}_{age_suffix}.xlsx').replace('{age}', age_suffix)
|
||
output_path = output_base / age_label / "5_domain" / file_name
|
||
|
||
print(f"\n 📝 Domain: {domain}")
|
||
questions = questions_map.get(domain, [])
|
||
age_questions = [q for q in questions if age_suffix in q.get('age_group', '')]
|
||
if not age_questions:
|
||
age_questions = questions
|
||
|
||
simulate_domain_for_students(
|
||
engine, students, domain, age_questions, age_suffix,
|
||
output_path=output_path, verbose=verbose
|
||
)
|
||
|
||
# 2. Cognition Tests
|
||
cog_sim = CognitionSimulator()
|
||
(output_base / age_label / "cognition").mkdir(parents=True, exist_ok=True)
|
||
|
||
for test in config.COGNITION_TESTS:
|
||
file_name = config.COGNITION_FILE_NAMES.get(test, f'{test}_{age_suffix}.xlsx').replace('{age}', age_suffix)
|
||
output_path = output_base / age_label / "cognition" / file_name
|
||
|
||
# Check if file exists and is complete
|
||
if output_path.exists():
|
||
try:
|
||
df_existing = pd.read_excel(output_path)
|
||
expected_rows = len(students)
|
||
actual_rows = len(df_existing)
|
||
|
||
if actual_rows == expected_rows:
|
||
print(f" ⏭️ Skipping Cognition: {output_path.name} (already complete: {actual_rows} rows)")
|
||
continue
|
||
else:
|
||
print(f" 🔄 Regenerating Cognition: {output_path.name} (incomplete: {actual_rows}/{expected_rows} rows)")
|
||
except Exception as e:
|
||
print(f" 🔄 Regenerating Cognition: {output_path.name} (file error: {e})")
|
||
|
||
print(f" 🔹 Cognition: {test}")
|
||
results = []
|
||
for student in students:
|
||
results.append(cog_sim.simulate_student_test(student, test, age_suffix))
|
||
|
||
pd.DataFrame(results).to_excel(output_path, index=False)
|
||
print(f" 💾 Saved: {output_path.name}")
|
||
|
||
print("\n" + "="*80)
|
||
print("✅ TURBO FULL RUN COMPLETE")
|
||
print("="*80)
|
||
|
||
|
||
def run_dry_run() -> None:
|
||
"""Dry run for basic verification (5 students)."""
|
||
config.LLM_DELAY = 1.0
|
||
run_full(verbose=True, limit_students=5)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
if "--full" in sys.argv:
|
||
run_full()
|
||
elif "--dry" in sys.argv:
|
||
run_dry_run()
|
||
else:
|
||
print("💡 Usage: python main.py --full (Production)")
|
||
print("💡 Usage: python main.py --dry (5 Student Test)")
|