#!/usr/bin/env python3
"""
Stage 2: Embed & Index
Reads emails.jsonl → generates embeddings → ChromaDB
Uses GPU if available, falls back to CPU

Usage:
    python3 02_embed.py --input /output/emails.jsonl --output /output
"""

import os
import json
import sys

def embed_and_index(input_file: str, chroma_path: str, batch_size: int = 32):
    """Load JSONL and embed to ChromaDB"""
    import chromadb
    
    # Load model
    os.environ.setdefault('OMP_NUM_THREADS', '2')
    os.environ.setdefault('MKL_NUM_THREADS', '1')
    
    try:
        from sentence_transformers import SentenceTransformer
    except ImportError:
        print("ERROR: sentence-transformers not installed")
        print("Run: pip install sentence-transformers torch")
        sys.exit(1)
    
    print("Loading embedding model...")
    model_name = os.environ.get('EMBEDDING_MODEL', 'BAAI/bge-small-en-v1.5')
    device = 'cuda' if os.environ.get('USE_GPU', 'false').lower() == 'true' else 'cpu'
    print(f"Using device: {device}")
    
    model = SentenceTransformer(model_name, device=device)
    
    # Load ChromaDB
    os.makedirs(chroma_path, exist_ok=True)
    client = chromadb.PersistentClient(path=chroma_path)
    
    try:
        collection = client.get_collection("business_emails")
        existing_ids = set(collection.get(limit=100000)['ids'])
        print(f"Existing collection: {len(existing_ids)} emails")
    except:
        collection = client.create_collection(name="business_emails")
        existing_ids = set()
        print("Created new collection")
    
    # Read JSONL
    emails = []
    with open(input_file, 'r') as f:
        for line in f:
            emails.append(json.loads(line))
    
    print(f"Loaded {len(emails)} emails")
    
    # Filter already indexed
    emails_to_index = [e for e in emails if e['id'] not in existing_ids]
    print(f"Need to index: {len(emails_to_index)} emails")
    
    if not emails_to_index:
        print("All emails already indexed!")
        return
    
    # Batch process
    batch_ids = []
    batch_texts = []
    batch_metas = []
    
    for i, email in enumerate(emails_to_index):
        text = f"From: {email['sender']}\nSubject: {email['subject']}\nBody: {email['body'][:500]}"
        batch_ids.append(email['id'])
        batch_texts.append(text)
        batch_metas.append({
            'sender': email['sender'],
            'subject': email['subject'],
            'date': email['date'],
            'category': email['category'],
            'folder': email.get('folder', '')
        })
        
        if len(batch_ids) >= batch_size:
            print(f"Embedding {len(batch_ids)}...")
            embeddings = model.encode(batch_texts, batch_size=batch_size, show_progress_bar=True)
            collection.add(
                documents=batch_texts,
                metadatas=batch_metas,
                ids=batch_ids,
                embeddings=embeddings.tolist()
            )
            batch_ids = []
            batch_texts = []
            batch_metas = []
    
    # Final batch
    if batch_ids:
        print(f"Final batch: {len(batch_ids)}")
        embeddings = model.encode(batch_texts, batch_size=batch_size, show_progress_bar=True)
        collection.add(
            documents=batch_texts,
            metadatas=batch_metas,
            ids=batch_ids,
            embeddings=embeddings.tolist()
        )
    
    print(f"✅ Indexed {len(emails_to_index)} emails to {chroma_path}")

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Stage 2: Embed & Index')
    parser.add_argument('--input', required=True, help='Input JSONL file')
    parser.add_argument('--output', default='/output', help='Output directory')
    parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
    args = parser.parse_args()
    
    chroma_path = os.path.join(args.output, 'chroma')
    embed_and_index(args.input, chroma_path, args.batch_size)

if __name__ == "__main__":
    main()
