#!/usr/bin/env python3
"""
Email Pipeline - Full Workflow (Parse + Embed)
For LXC with GPU

Usage:
    python3 email_pipeline.py --scan /path/to/eml --output /output
"""

import os
import re
import json
import sys
from pathlib import Path
from typing import List, Dict, Tuple

# ============== CONFIG ==============
OUTPUT_DIR = "/output"
CHROMA_PATH = "/output/chroma"
BATCH_SIZE = 128  # GPU batch size
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"

CATEGORIES = {
    'customer/order': ['order', 'invoice', 'payment', 'quotation', 'quote', '客户', '订单', '发票', '报价'],
    'customer/inquiry': ['inquiry', 'enquiry', 'question', 'request for', 'information', '咨询', '询问'],
    'supplier': ['supplier', 'vendor', 'procurement', 'vendor contact', '供应商', '供货商', '采购'],
    'shipment': ['shipment', 'shipping', 'delivery', 'tracking', 'freight', 'cargo', '发货', '物流', '货运', '运输'],
    'quotation': ['quotation', 'quote', 'pricing', '报价单', '报价'],
    'contract': ['contract', 'agreement', 'terms', '条款', '合同', '协议'],
    'complaint': ['complaint', 'issue', 'problem', 'defect', 'quality', '投诉', '问题', '质量'],
}

def clean_text(text: str) -> str:
    if not text:
        return ""
    text = re.sub(r'\s+', ' ', text)
    return text.strip()[:8000]

def classify_email(subject: str, body: str, sender: str) -> Tuple[str, List[str]]:
    text = f"{subject} {body} {sender}".lower()
    for category, keywords in CATEGORIES.items():
        for kw in keywords:
            if kw.lower() in text:
                return category, [kw]
    return 'uncategorized', []

def parse_eml(filepath: Path) -> Dict:
    """Parse single .eml file"""
    try:
        with open(filepath, 'rb') as f:
            content = f.read()
        
        text = content.decode('utf-8', errors='ignore')
        if 'From:' not in text and 'Subject:' not in text:
            try:
                text = content.decode('gbk', errors='ignore')
            except:
                pass
        
        email_data = {
            'id': filepath.stem,
            'folder': str(filepath.parent.name),
            'sender': '', 'recipient': '', 'subject': '',
            'date': '', 'body': '',
            'category': 'uncategorized', 'keywords': [],
            'source': str(filepath)
        }
        
        lines = text.split('\n')
        body_start = 0
        
        for i, line in enumerate(lines):
            line_stripped = line.rstrip()
            if not line_stripped:
                body_start = i + 1
                break
            
            if line_stripped.startswith('From:'):
                email_data['sender'] = clean_text(line_stripped[5:].strip())
            elif line_stripped.startswith('To:'):
                email_data['recipient'] = clean_text(line_stripped[3:].strip())
            elif line_stripped.startswith('Subject:'):
                email_data['subject'] = clean_text(line_stripped[8:].strip())
            elif line_stripped.startswith('Date:'):
                email_data['date'] = clean_text(line_stripped[5:].strip())
            elif line_stripped.startswith('发件人:'):
                email_data['sender'] = clean_text(line_stripped[4:].strip())
            elif line_stripped.startswith('收件人:'):
                email_data['recipient'] = clean_text(line_stripped[4:].strip())
            elif line_stripped.startswith('主题:'):
                email_data['subject'] = clean_text(line_stripped[3:].strip())
            elif line_stripped.startswith('日期:'):
                email_data['date'] = clean_text(line_stripped[3:].strip())
        
        email_data['body'] = clean_text('\n'.join(lines[body_start:]))
        email_data['category'], email_data['keywords'] = classify_email(
            email_data['subject'], email_data['body'], email_data['sender']
        )
        
        return email_data if (email_data['subject'] or email_data['body']) else None
        
    except Exception as e:
        return None

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Email Pipeline')
    parser.add_argument('--scan', required=True, help='Path to .eml files')
    parser.add_argument('--output', default='/output', help='Output directory')
    parser.add_argument('--batch-size', type=int, default=128, help='Embedding batch size')
    args = parser.parse_args()
    
    global BATCH_SIZE, OUTPUT_DIR, CHROMA_PATH
    BATCH_SIZE = args.batch_size
    OUTPUT_DIR = args.output
    CHROMA_PATH = os.path.join(OUTPUT_DIR, 'chroma')
    
    print(f"Scanning: {args.scan}")
    
    # Find .eml files
    eml_files = list(Path(args.scan).rglob('*.eml'))
    print(f"Found {len(eml_files)} .eml files")
    
    if not eml_files:
        print("No .eml files found!")
        return
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # ========== PHASE 1: Parse ==========
    print("\n=== Phase 1: Parsing ===")
    emails = []
    for i, eml in enumerate(eml_files):
        if i % 1000 == 0:
            print(f"[{i}/{len(eml_files)}]")
        email = parse_eml(eml)
        if email:
            emails.append(email)
    
    print(f"Parsed {len(emails)} emails")
    
    # Save JSONL
    jsonl_path = os.path.join(OUTPUT_DIR, "emails.jsonl")
    with open(jsonl_path, 'w') as f:
        for email in emails:
            f.write(json.dumps(email, ensure_ascii=False) + '\n')
    print(f"Saved to {jsonl_path}")
    
    # ========== PHASE 2: Embed ==========
    print("\n=== Phase 2: Embedding ===")
    
    # Limit threads for CPU parts
    os.environ.setdefault('OMP_NUM_THREADS', '4')
    os.environ.setdefault('MKL_NUM_THREADS', '4')
    
    # Load model
    from sentence_transformers import SentenceTransformer
    print(f"Loading {EMBEDDING_MODEL}...")
    model = SentenceTransformer(EMBEDDING_MODEL, device='cuda')
    
    # Load ChromaDB
    import chromadb
    os.makedirs(CHROMA_PATH, exist_ok=True)
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    
    try:
        collection = client.get_collection("emails")
        existing_ids = set(collection.get(limit=100000)['ids'])
        print(f"Existing collection: {len(existing_ids)} emails")
    except:
        collection = client.create_collection(name="emails")
        existing_ids = set()
    
    # Filter out existing
    to_index = [e for e in emails if e['id'] not in existing_ids]
    print(f"Need to index: {len(to_index)} emails")
    
    if not to_index:
        print("All already indexed!")
        return
    
    # Batch embed
    batch_ids, batch_texts, batch_metas = [], [], []
    
    for i, email in enumerate(to_index):
        text = f"From: {email['sender']}\nSubject: {email['subject']}\nBody: {email['body'][:500]}"
        batch_ids.append(email['id'])
        batch_texts.append(text)
        batch_metas.append({
            'sender': email['sender'],
            'subject': email['subject'],
            'date': email['date'],
            'category': email['category']
        })
        
        if len(batch_ids) >= BATCH_SIZE:
            print(f"Embedding {len(batch_ids)}...")
            emb = model.encode(batch_texts, batch_size=BATCH_SIZE)
            collection.add(documents=batch_texts, metadatas=batch_metas, ids=batch_ids, embeddings=emb.tolist())
            batch_ids, batch_texts, batch_metas = [], [], []
    
    # Final batch
    if batch_ids:
        print(f"Final: {len(batch_ids)}")
        emb = model.encode(batch_texts, batch_size=BATCH_SIZE)
        collection.add(documents=batch_texts, metadatas=batch_metas, ids=batch_ids, embeddings=emb.tolist())
    
    print(f"\n✅ Done! {len(to_index)} emails indexed to {CHROMA_PATH}")

if __name__ == "__main__":
    main()
