#!/usr/bin/env python3
"""
Business RAG Pipeline v2 - With Attachment Support
Parse + Embed + Classify + Index  (emails + all attachments)

Usage:
    python3 business_rag.py --emails /path/to/eml --docs /path/to/docs --output /output
"""

import os
import re
import json
import sys
import email
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from email.policy import default
import base64

# ============== CONFIG ==============
OUTPUT_DIR = "/output"
CHROMA_PATH = "/output/chroma"
BATCH_SIZE = 128  # Smaller batch for attachment processing
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"

# Categories
CATEGORIES = {
    'customer/order': ['order', 'invoice', 'payment', 'quotation', 'quote', '客户', '订单', '发票', '报价', 'PO', 'purchase order', '账单', '发票'],
    'customer/inquiry': ['inquiry', 'enquiry', 'question', 'request for', 'information', '咨询', '询问'],
    'supplier': ['supplier', 'vendor', 'procurement', 'vendor contact', '供应商', '供货商', '采购'],
    'shipment': ['shipment', 'shipping', 'delivery', 'tracking', 'freight', 'cargo', '发货', '物流', '货运', '运输', 'bol', 'bill of lading', '提单'],
    'quotation': ['quotation', 'quote', 'pricing', '报价单', '报价'],
    'contract': ['contract', 'agreement', 'terms', '条款', '合同', '协议'],
    'complaint': ['complaint', 'issue', 'problem', 'defect', 'quality', '投诉', '问题', '质量'],
    'financial': ['bank', 'transfer', 'wire', '汇率', 'banking', '财务', '银行', '转账'],
    'shipping_doc': ['packing list', 'packinglist', '装箱单', '提单', 'lading'],
}

# ============== ATTACHMENT PARSING ==============

def parse_pdf(filepath: str) -> str:
    """Extract text from PDF"""
    try:
        from pdfminer.high_level import extract_text
        text = extract_text(filepath)
        return text.strip()
    except Exception as e:
        return f"[PDF parse error: {e}]"

def parse_docx(filepath: str) -> str:
    """Extract text from Word"""
    try:
        from docx import Document
        doc = Document(filepath)
        return "\n".join([p.text for p in doc.paragraphs])
    except Exception as e:
        return f"[DOCX parse error: {e}]"

def parse_xlsx(filepath: str) -> str:
    """Extract text from Excel"""
    try:
        import openpyxl
        text_parts = []
        wb = openpyxl.load_workbook(filepath, read_only=True, data_only=True)
        for sheet in wb.sheetnames:
            ws = wb[sheet]
            for row in ws.iter_rows(max_row=100, values_only=True):  # Limit rows
                row_text = " ".join([str(c) if c else "" for c in row])
                if row_text.strip():
                    text_parts.append(row_text)
        return "\n".join(text_parts[:500])  # Limit output
    except Exception as e:
        return f"[XLSX parse error: {e}]"

def parse_pptx(filepath: str) -> str:
    """Extract text from PowerPoint"""
    try:
        from pptx import Presentation
        prs = Presentation(filepath)
        text_parts = []
        for slide in prs.slides:
            for shape in slide.shapes:
                if hasattr(shape, "text"):
                    text_parts.append(shape.text)
        return "\n".join(text_parts[:500])
    except Exception as e:
        return f"[PPTX parse error: {e}]"

def parse_attachment(filepath: str, mime_type: str = "") -> str:
    """Parse any supported attachment"""
    ext = Path(filepath).suffix.lower()
    
    parsers = {
        '.pdf': parse_pdf,
        '.docx': parse_docx,
        '.doc': parse_docx,
        '.xlsx': parse_xlsx,
        '.xls': parse_xlsx,
        '.pptx': parse_pptx,
        '.ppt': parse_pptx,
    }
    
    for ext_pattern, parser in parsers.items():
        if ext == ext_pattern:
            return parser(filepath)
    
    # Try text files
    if ext in ['.txt', '.csv', '.json', '.xml', '.html', '.htm']:
        try:
            with open(filepath, 'r', errors='ignore') as f:
                return f.read()[:5000]
        except:
            pass
    
    return f"[Unsupported attachment: {ext}]"

# ============== EMAIL PARSING ==============

def clean_text(text: str) -> str:
    if not text:
        return ""
    text = re.sub(r'\s+', ' ', text)
    return text.strip()[:8000]

def classify_text(subject: str, body: str, sender: str) -> Tuple[str, List[str], float]:
    """Classify with confidence score based on keyword matches"""
    text = f"{subject} {body} {sender}".lower()
    scores = {}
    
    for category, keywords in CATEGORIES.items():
        score = 0
        matched = []
        for kw in keywords:
            if kw.lower() in text:
                score += 1
                matched.append(kw)
        if score > 0:
            scores[category] = (score, matched)
    
    if scores:
        best_cat = max(scores.keys(), key=lambda k: scores[k][0])
        total = sum(scores[k][0] for k in scores)
        confidence = scores[best_cat][0] / total if total > 0 else 0
        return best_cat, scores[best_cat][1], confidence
    
    return 'uncategorized', [], 0.0

def parse_eml_with_attachments(filepath: Path, temp_dir: str) -> List[Dict]:
    """Parse .eml file and extract attachments"""
    items = []
    
    try:
        with open(filepath, 'rb') as f:
            msg = email.message_from_binary_file(f, policy=default)
        
        # Parse main email
        email_data = {
            'id': filepath.stem,
            'type': 'email',
            'folder': str(filepath.parent.name),
            'sender': clean_text(msg.get('From', '')),
            'recipient': clean_text(msg.get('To', '')),
            'subject': clean_text(msg.get('Subject', '')),
            'date': clean_text(msg.get('Date', '')),
            'body': '',
            'category': 'uncategorized',
            'keywords': [],
            'confidence': 0.0,
            'source': str(filepath),
            'attachments': []
        }
        
        # Get body
        if msg.is_multipart():
            for part in msg.walk():
                content_type = part.get_content_type()
                if content_type == 'text/plain':
                    try:
                        email_data['body'] = clean_text(part.get_content())
                    except:
                        email_data['body'] = clean_text(str(part.get_payload()))
                    break
                elif content_type == 'text/html' and not email_data['body']:
                    try:
                        email_data['body'] = clean_text(part.get_content())
                    except:
                        pass
        else:
            email_data['body'] = clean_text(str(msg.get_payload()))
        
        # Extract attachments
        attachments = []
        for part in msg.walk():
            content_disposition = part.get_content_disposition()
            if content_disposition and 'attachment' in content_disposition:
                filename = part.get_filename()
                if filename:
                    # Save attachment to temp dir
                    att_id = f"{filepath.stem}_{len(attachments)}"
                    att_filename = f"{att_id}_{filename}"
                    att_path = os.path.join(temp_dir, att_filename)
                    
                    try:
                        payload = part.get_payload(decode=True)
                        if payload:
                            with open(att_path, 'wb') as f:
                                f.write(payload)
                            attachments.append({
                                'id': att_id,
                                'filename': filename,
                                'path': att_path,
                                'size': len(payload)
                            })
                    except Exception as e:
                        print(f"    Failed to extract {filename}: {e}")
        
        email_data['attachments'] = attachments
        
        # Classify
        email_data['category'], email_data['keywords'], email_data['confidence'] = classify_text(
            email_data['subject'], email_data['body'], email_data['sender']
        )
        
        items.append(email_data)
        
        # Create separate items for parsed attachments
        for att in attachments:
            att_text = parse_attachment(att['path'])
            att_data = {
                'id': att['id'],
                'type': f'attachment_{Path(att["filename"]).suffix.lower().replace(".", "")}',
                'folder': str(filepath.parent.name),
                'sender': email_data['sender'],
                'recipient': email_data['recipient'],
                'subject': f"[{att['filename']}] {email_data['subject']}",
                'date': email_data['date'],
                'body': att_text,
                'category': email_data['category'],  # Inherit email category
                'keywords': email_data['keywords'],
                'confidence': email_data['confidence'],
                'source': str(filepath),
                'parent_email': filepath.stem,
                'attachments': []
            }
            items.append(att_data)
        
        return items
        
    except Exception as e:
        print(f"Error parsing {filepath}: {e}")
        return []

def parse_doc(filepath: Path) -> Optional[Dict]:
    """Parse standalone document (PDF, DOCX, etc.)"""
    try:
        ext = filepath.suffix.lower()
        text = ""
        
        parsers = {
            '.pdf': parse_pdf,
            '.docx': parse_docx,
            '.doc': parse_docx,
            '.xlsx': parse_xlsx,
            '.xls': parse_xlsx,
            '.pptx': parse_pptx,
            '.ppt': parse_pptx,
        }
        
        for ext_pattern, parser in parsers.items():
            if ext == ext_pattern:
                text = parser(str(filepath))
                break
        
        if not text:
            if ext in ['.txt', '.md', '.csv']:
                with open(filepath, 'r', errors='ignore') as f:
                    text = f.read()
        
        if not text.strip():
            return None
        
        return {
            'id': filepath.stem,
            'type': f'document_{ext.replace(".", "")}',
            'folder': str(filepath.parent.name),
            'sender': '',
            'recipient': '',
            'subject': filepath.name,
            'date': '',
            'body': clean_text(text),
            'category': 'uncategorized',
            'keywords': [],
            'confidence': 0.0,
            'source': str(filepath),
            'attachments': []
        }
    except Exception as e:
        return None

# ============== MAIN ==============

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Business RAG Pipeline v2')
    parser.add_argument('--emails', help='Path to .eml folder')
    parser.add_argument('--docs', help='Path to document folder')
    parser.add_argument('--output', default='/output', help='Output directory')
    parser.add_argument('--batch-size', type=int, default=128)
    args = parser.parse_args()
    
    global BATCH_SIZE, OUTPUT_DIR, CHROMA_PATH
    BATCH_SIZE = args.batch_size
    OUTPUT_DIR = args.output
    CHROMA_PATH = os.path.join(OUTPUT_DIR, 'chroma')
    
    print("=" * 60)
    print("BUSINESS RAG PIPELINE v2 - WITH ATTACHMENTS")
    print("=" * 60)
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    temp_dir = os.path.join(OUTPUT_DIR, 'temp_attachments')
    os.makedirs(temp_dir, exist_ok=True)
    
    # ========== PARSE ==========
    print("\n[1/4] Parsing files...")
    items = []
    stats = {'emails': 0, 'attachments': 0, 'documents': 0}
    
    if args.emails:
        eml_files = list(Path(args.emails).rglob('*.eml'))
        print(f"  Found {len(eml_files)} .eml files")
        
        for i, f in enumerate(eml_files):
            if i % 1000 == 0:
                print(f"  Processing [{i}/{len(eml_files)}]...")
            
            parsed = parse_eml_with_attachments(f, temp_dir)
            for item in parsed:
                if item['type'] == 'email':
                    stats['emails'] += 1
                else:
                    stats['attachments'] += 1
            items.extend(parsed)
    
    if args.docs:
        doc_files = []
        for ext in ['.pdf', '.docx', '.doc', '.xlsx', '.xls', '.pptx', '.ppt', '.txt', '.md', '.csv']:
            doc_files.extend(Path(args.docs).rglob(f'*{ext}'))
        print(f"  Found {len(doc_files)} standalone documents")
        
        for f in doc_files:
            item = parse_doc(f)
            if item:
                stats['documents'] += 1
                items.append(item)
    
    print(f"\n  Parsed: {stats['emails']} emails, {stats['attachments']} attachments, {stats['documents']} docs")
    print(f"  Total items: {len(items)}")
    
    # Save JSONL
    jsonl_path = os.path.join(OUTPUT_DIR, "business_data.jsonl")
    with open(jsonl_path, 'w') as f:
        for item in items:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    print(f"  Saved to {jsonl_path}")
    
    # ========== EMBED + INDEX ==========
    print("\n[2/4] Loading embedding model...")
    os.environ.setdefault('OMP_NUM_THREADS', '4')
    os.environ.setdefault('MKL_NUM_THREADS', '4')
    
    from sentence_transformers import SentenceTransformer
    import chromadb
    
    model = SentenceTransformer(EMBEDDING_MODEL, device='cuda')
    print(f"  Model: {EMBEDDING_MODEL}")
    print(f"  Device: CUDA (GPU)")
    
    # ChromaDB
    os.makedirs(CHROMA_PATH, exist_ok=True)
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    
    try:
        collection = client.get_collection("business")
        existing_ids = set(collection.get(limit=100000)['ids'])
        print(f"  Existing collection: {len(existing_ids)} items")
    except:
        collection = client.create_collection(name="business", metadata={"description": "Business emails, attachments, documents"})
        existing_ids = set()
        print("  Created new collection")
    
    to_index = [x for x in items if x['id'] not in existing_ids]
    print(f"  Indexing {len(to_index)} new items...")
    
    # ========== BATCH EMBED ==========
    print("\n[3/4] Embedding & indexing...")
    
    batch_ids, batch_texts, batch_metas = [], [], []
    
    for i, item in enumerate(to_index):
        # Build searchable text
        if item['type'] == 'email':
            text = f"From: {item['sender']}\nTo: {item['recipient']}\nSubject: {item['subject']}\n{item['body'][:700]}"
        else:
            text = f"Document: {item['subject']}\nType: {item['type']}\n{item['body'][:1000]}"
        
        batch_ids.append(item['id'])
        batch_texts.append(text)
        batch_metas.append({
            'type': item['type'],
            'category': item['category'],
            'sender': item['sender'],
            'subject': item['subject'],
            'date': item['date'],
            'confidence': item['confidence'],
            'source': item['source']
        })
        
        if len(batch_ids) >= BATCH_SIZE:
            print(f"  Batch {i+1}/{len(to_index)}...")
            emb = model.encode(batch_texts, batch_size=BATCH_SIZE)
            collection.add(documents=batch_texts, metadatas=batch_metas, ids=batch_ids, embeddings=emb.tolist())
            batch_ids, batch_texts, batch_metas = [], [], []
    
    # Final batch
    if batch_ids:
        print(f"  Final batch: {len(batch_ids)}")
        emb = model.encode(batch_texts, batch_size=BATCH_SIZE)
        collection.add(documents=batch_texts, metadatas=batch_metas, ids=batch_ids, embeddings=emb.tolist())
    
    # ========== SUMMARY ==========
    print("\n[4/4] Done!")
    
    cats = {}
    for item in items:
        cats[item['category']] = cats.get(item['category'], 0) + 1
    
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"Total items indexed: {len(items)}")
    print(f"  Emails: {stats['emails']}")
    print(f"  Attachments extracted: {stats['attachments']}")
    print(f"  Standalone docs: {stats['documents']}")
    print(f"\nCategories:")
    for cat, count in sorted(cats.items(), key=lambda x: -x[1]):
        print(f"  {cat}: {count}")
    print(f"\nIndexed to: {CHROMA_PATH}")
    print(f"Temp attachments: {temp_dir}")

if __name__ == "__main__":
    main()
