#!/usr/bin/env python3
"""
Embed + Index emails.jsonl to ChromaDB
Takes pre-parsed JSONL, outputs vector index

Usage:
    python3 embed_emails.py emails.jsonl /output
"""

import os
import json
import sys

# Config
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
BATCH_SIZE = 256
CHROMA_PATH = "/output/chroma"

def main():
    if len(sys.argv) < 2:
        print("Usage: python3 embed_emails.py <emails.jsonl> [output_dir]")
        sys.exit(1)
    
    jsonl_path = sys.argv[1]
    output_dir = sys.argv[2] if len(sys.argv) > 2 else "/output"
    chroma_path = os.path.join(output_dir, "chroma")
    
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(chroma_path, exist_ok=True)
    
    print("=" * 60)
    print("EMAIL EMBEDDER")
    print("=" * 60)
    
    # Load JSONL
    print(f"\n[1/3] Loading {jsonl_path}...")
    items = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line:
                items.append(json.loads(line))
    print(f"  Loaded {len(items)} items")
    
    # Stats
    types = {}
    for item in items:
        t = item.get('type', 'unknown')
        types[t] = types.get(t, 0) + 1
    print(f"  Types: {types}")
    
    # Load model
    print(f"\n[2/3] Loading {EMBEDDING_MODEL}...")
    os.environ.setdefault('OMP_NUM_THREADS', '4')
    from sentence_transformers import SentenceTransformer
    import chromadb
    
    model = SentenceTransformer(EMBEDDING_MODEL, device='cuda')
    print(f"  Device: CUDA")
    
    # ChromaDB
    client = chromadb.PersistentClient(path=chroma_path)
    try:
        collection = client.get_collection("emails")
        existing_ids = set(collection.get(limit=100000)['ids'])
        print(f"  Existing collection: {len(existing_ids)} items")
    except:
        collection = client.create_collection(name="emails", metadata={"description": "Business emails"})
        existing_ids = set()
        print("  Created new collection")
    
    # Filter
    to_index = [x for x in items if x.get('id') not in existing_ids]
    print(f"  Indexing {len(to_index)} new items...")
    
    if not to_index:
        print("  Nothing to index - all done!")
        return
    
    # Batch embed
    print(f"\n[3/3] Embedding (batch size: {BATCH_SIZE})...")
    
    batch_ids, batch_texts, batch_metas = [], [], []
    
    for i, item in enumerate(to_index):
        # Build searchable text
        if item.get('body'):
            text = f"Subject: {item.get('subject', '')}\nFrom: {item.get('sender', '')}\n{item['body'][:700]}"
        else:
            text = f"Subject: {item.get('subject', '')}\nFrom: {item.get('sender', '')}"
        
        batch_ids.append(item.get('id', f"item_{i}"))
        batch_texts.append(text)
        batch_metas.append({
            'type': item.get('type', 'email'),
            'category': item.get('category', 'uncategorized'),
            'sender': item.get('sender', ''),
            'subject': item.get('subject', ''),
            'date': item.get('date', ''),
            'source': item.get('source', ''),
        })
        
        if len(batch_ids) >= BATCH_SIZE:
            print(f"  Batch {i+1}/{len(to_index)}...")
            emb = model.encode(batch_texts, batch_size=BATCH_SIZE)
            collection.add(documents=batch_texts, metadatas=batch_metas, ids=batch_ids, embeddings=emb.tolist())
            batch_ids, batch_texts, batch_metas = [], [], []
    
    # Final batch
    if batch_ids:
        print(f"  Final batch: {len(batch_ids)}")
        emb = model.encode(batch_texts, batch_size=BATCH_SIZE)
        collection.add(documents=batch_texts, metadatas=batch_metas, ids=batch_ids, embeddings=emb.tolist())
    
    # Summary
    print("\n" + "=" * 60)
    print("DONE")
    print("=" * 60)
    total = len(collection.get(limit=100000)['ids'])
    print(f"Total indexed: {total}")
    print(f"Index location: {chroma_path}")

if __name__ == "__main__":
    main()
