#!/usr/bin/env python3
"""
Business RAG Query Interface
Search emails & documents with natural language

Usage:
    python3 query.py "supplier delivery delays last year"
    python3 query.py --list-categories
    python3 query.py --stats
"""

import os
import json
import argparse

def query(query_text: str, limit: int = 20, min_confidence: float = 0.0):
    import chromadb
    from sentence_transformers import SentenceTransformer
    
    CHROMA_PATH = "/output/chroma"
    
    os.environ.setdefault('OMP_NUM_THREADS', '4')
    
    print(f"Query: {query_text}\n")
    
    # Load model & ChromaDB
    model = SentenceTransformer("BAAI/bge-small-en-v1.5", device='cuda')
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    collection = client.get_collection("business")
    
    # Embed query
    q_emb = model.encode([query_text])
    
    # Search
    results = collection.query(
        query_embeddings=q_emb.tolist(),
        n_results=limit,
        where={"confidence": {"$gte": min_confidence}} if min_confidence > 0 else None
    )
    
    # Display
    print("=" * 70)
    print(f"Found {len(results['documents'][0])} results")
    print("=" * 70)
    
    for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
        print(f"\n[{i+1}] {meta.get('type', 'unknown').upper()}")
        print(f"    Category: {meta.get('category', 'N/A')}")
        print(f"    Confidence: {meta.get('confidence', 0):.0%}")
        if meta.get('sender'):
            print(f"    From: {meta.get('sender', '')}")
        if meta.get('date'):
            print(f"    Date: {meta.get('date', '')}")
        print(f"    Subject: {meta.get('subject', 'N/A')}")
        print(f"    Preview: {doc[:200]}...")
    
    return results

def list_categories():
    import chromadb
    
    CHROMA_PATH = "/output/chroma"
    
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    collection = client.get_collection("business")
    
    all_data = collection.get(limit=100000)
    
    cats = {}
    types = {}
    for meta in all_data['metadatas']:
        cats[meta.get('category', 'uncategorized')] = cats.get(meta.get('category', 'uncategorized'), 0) + 1
        t = meta.get('type', 'unknown')
        types[t] = types.get(t, 0) + 1
    
    print("\n📊 COLLECTION STATS")
    print("=" * 40)
    print(f"\nTotal items: {len(all_data['ids'])}")
    print(f"\nBy type:")
    for t, c in types.items():
        print(f"  {t}: {c}")
    print(f"\nBy category:")
    for c, n in sorted(cats.items(), key=lambda x: -x[1]):
        print(f"  {c}: {n}")

def main():
    parser = argparse.ArgumentParser(description='Business RAG Query')
    parser.add_argument('query', nargs='?', help='Search query')
    parser.add_argument('--list-categories', action='store_true', help='Show category stats')
    parser.add_argument('--stats', action='store_true', help='Show collection stats')
    parser.add_argument('--limit', type=int, default=20, help='Result limit')
    parser.add_argument('--min-confidence', type=float, default=0.0, help='Min confidence filter')
    args = parser.parse_args()
    
    if args.stats or args.list_categories:
        list_categories()
    elif args.query:
        query(args.query, limit=args.limit, min_confidence=args.min_confidence)
    else:
        parser.print_help()
        print("\nExamples:")
        print("  python3 query.py \"supplier delays\"")
        print("  python3 query.py --list-categories")
        print("  python3 query.py \"invoice issues\" --min-confidence 0.3")

if __name__ == "__main__":
    main()
