#!/usr/bin/env python3
"""
Query indexed emails

Usage:
    python3 query_emails.py "supplier delivery delays"
    python3 query_emails.py --list
    python3 query_emails.py --stats
"""

import os
import json
import sys
import argparse

CHROMA_PATH = "/output/chroma"

def search(query: str, limit: int = 10, category: str = None):
    from sentence_transformers import SentenceTransformer
    import chromadb
    
    os.environ.setdefault('OMP_NUM_THREADS', '4')
    
    print(f"Query: {query}\n")
    
    # Load
    model = SentenceTransformer("BAAI/bge-small-en-v1.5", device='cuda')
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    collection = client.get_collection("emails")
    
    # Search
    q_emb = model.encode([query])
    
    where = {"category": category} if category else None
    
    results = collection.query(
        query_embeddings=q_emb.tolist(),
        n_results=limit,
        where=where
    )
    
    # Display
    docs = results['documents'][0]
    metas = results['metadatas'][0]
    
    print("=" * 70)
    print(f"Found {len(docs)} results")
    print("=" * 70)
    
    for i, (doc, meta) in enumerate(zip(docs, metas)):
        print(f"\n[{i+1}] {meta.get('type', 'email').upper()}")
        print(f"    Category: {meta.get('category', 'N/A')}")
        if meta.get('sender'):
            print(f"    From: {meta.get('sender')}")
        if meta.get('date'):
            print(f"    Date: {meta.get('date')}")
        print(f"    Subject: {meta.get('subject', 'N/A')}")
        print(f"    Preview: {doc[:150]}...")

def list_categories():
    import chromadb
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    collection = client.get_collection("emails")
    
    all_data = collection.get(limit=100000)
    
    cats = {}
    for meta in all_data['metadatas']:
        c = meta.get('category', 'uncategorized')
        cats[c] = cats.get(c, 0) + 1
    
    print("\n📊 CATEGORIES")
    print("=" * 40)
    for c, n in sorted(cats.items(), key=lambda x: -x[1]):
        print(f"  {c}: {n}")

def stats():
    import chromadb
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    collection = client.get_collection("emails")
    
    all_data = collection.get(limit=100000)
    
    types = {}
    cats = {}
    for meta in all_data['metadatas']:
        t = meta.get('type', 'unknown')
        c = meta.get('category', 'uncategorized')
        types[t] = types.get(t, 0) + 1
        cats[c] = cats.get(c, 0) + 1
    
    print("\n📊 COLLECTION STATS")
    print("=" * 40)
    print(f"Total items: {len(all_data['ids'])}")
    print(f"\nBy type:")
    for t, n in types.items():
        print(f"  {t}: {n}")
    print(f"\nBy category:")
    for c, n in cats.items():
        print(f"  {c}: {n}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('query', nargs='?', help='Search query')
    parser.add_argument('--list', action='store_true', help='List categories')
    parser.add_argument('--stats', action='store_true', help='Show stats')
    parser.add_argument('--limit', type=int, default=10, help='Result limit')
    parser.add_argument('--category', help='Filter by category')
    args = parser.parse_args()
    
    if args.stats:
        stats()
    elif args.list:
        list_categories()
    elif args.query:
        search(args.query, limit=args.limit, category=args.category)
    else:
        parser.print_help()
        print("\nExamples:")
        print("  python3 query_emails.py \"supplier delays\"")
        print("  python3 query_emails.py \"invoice issues\" --limit 20")
        print("  python3 query_emails.py --category customer/order")
        print("  python3 query_emails.py --stats")

if __name__ == "__main__":
    main()
