#!/usr/bin/env python3
"""
Stage 3: Query
Search indexed emails with natural language

Usage:
    python3 03_query.py --query "supplier delays" --output /output
    python3 03_query.py --list-categories --output /output
"""

import os
import json
import sys

def query_emails(query_text: str, chroma_path: str, limit: int = 20):
    """Search indexed emails"""
    import chromadb
    
    # Load model
    os.environ.setdefault('OMP_NUM_THREADS', '2')
    os.environ.setdefault('MKL_NUM_THREADS', '1')
    
    try:
        from sentence_transformers import SentenceTransformer
    except ImportError:
        print("ERROR: sentence-transformers not installed")
        print("Run: pip install sentence-transformers torch")
        sys.exit(1)
    
    print(f"Query: {query_text}")
    
    model_name = os.environ.get('EMBEDDING_MODEL', 'BAAI/bge-small-en-v1.5')
    device = 'cuda' if os.environ.get('USE_GPU', 'false').lower() == 'true' else 'cpu'
    model = SentenceTransformer(model_name, device=device)
    
    # Load ChromaDB
    client = chromadb.PersistentClient(path=chroma_path)
    collection = client.get_collection("business_emails")
    
    # Embed query
    query_embedding = model.encode([query_text])
    
    # Search
    results = collection.query(
        query_embeddings=query_embedding.tolist(),
        n_results=limit
    )
    
    # Display
    print(f"\n{'='*60}")
    print(f"Found {len(results['documents'][0])} results:")
    print(f"{'='*60}\n")
    
    for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
        print(f"[{i+1}] {meta.get('category', 'N/A').upper()}")
        print(f"    From: {meta.get('sender', 'N/A')}")
        print(f"    Date: {meta.get('date', 'N/A')}")
        print(f"    Subject: {meta.get('subject', 'N/A')}")
        print(f"    {doc[:150]}...")
        print()

def list_categories(chroma_path: str):
    """List emails by category"""
    import chromadb
    
    client = chromadb.PersistentClient(path=chroma_path)
    
    try:
        collection = client.get_collection("business_emails")
    except:
        print("No collection found. Run --scan first.")
        return
    
    all_data = collection.get(limit=100000)
    
    cats = {}
    for meta in all_data['metadatas']:
        cat = meta.get('category', 'uncategorized')
        cats[cat] = cats.get(cat, 0) + 1
    
    print("\n📊 Email Categories:")
    print("="*40)
    for cat, count in sorted(cats.items(), key=lambda x: -x[1]):
        print(f"  {cat}: {count}")
    print(f"\nTotal: {sum(cats.values())}")

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Stage 3: Query')
    parser.add_argument('--query', help='Search query')
    parser.add_argument('--output', default='/output', help='Output directory')
    parser.add_argument('--list-categories', action='store_true')
    parser.add_argument('--limit', type=int, default=20)
    args = parser.parse_args()
    
    chroma_path = os.path.join(args.output, 'chroma')
    
    if args.query:
        query_emails(args.query, chroma_path, args.limit)
    elif args.list_categories:
        list_categories(chroma_path)
    else:
        parser.print_help()

if __name__ == "__main__":
    main()
