#!/usr/bin/env python3
"""
Docling Document Converter Script (Enhanced)

Converts documents (PDF, DOCX, PPTX, HTML, etc.) to Markdown/JSON/HTML.
Features:
- Extracts figures as PNG to images directory
- Extracts photos as JPEG to images directory
- Tables output as Markdown format
- Large file handling (script generation or page range processing)
"""

import argparse
import json
import os
import re
import shutil
import sys
from pathlib import Path
from typing import Optional

try:
    from docling.document_converter import DocumentConverter
    from docling.datamodel.base_models import InputFormat
    from docling.datamodel.pipeline_options import (
        PdfPipelineOptions,
        EasyOcrOptions,
        TesseractOcrOptions,
        TableFormerMode,
    )
    from docling.document_converter import PdfFormatOption
    from docling_core.types.doc import ImageRefMode
    from docling_core.types.doc.document import DocItemLabel, PictureClassificationClass
except ImportError:
    print("Error: docling is not installed. Install with: pip install docling")
    sys.exit(1)


# Constants for large file detection
LARGE_FILE_SIZE_MB = 50
LARGE_PAGE_COUNT = 50


def is_large_file(source: str, page_threshold: int = LARGE_PAGE_COUNT) -> tuple[bool, int]:
    """Check if file is considered large (by page count)."""
    try:
        import fitz  # PyMuPDF for quick page count

        if source.lower().endswith('.pdf'):
            doc = fitz.open(source)
            page_count = len(doc)
            doc.close()
            return page_count > page_threshold, page_count
    except ImportError:
        pass
    except Exception:
        pass
    return False, 0


def get_file_size_mb(source: str) -> float:
    """Get file size in MB."""
    try:
        return Path(source).stat().st_size / (1024 * 1024)
    except Exception:
        return 0.0


def create_converter(
    ocr: bool = False,
    ocr_engine: str = "easyocr",
    languages: Optional[list[str]] = None,
    table_mode: str = "fast",
    generate_images: bool = True,
) -> DocumentConverter:
    """Create a DocumentConverter with specified options."""

    if languages is None:
        languages = ["en", "ja"]

    pipeline_options = PdfPipelineOptions()
    pipeline_options.do_ocr = ocr
    pipeline_options.do_table_structure = True
    pipeline_options.generate_page_images = False
    pipeline_options.generate_picture_images = generate_images

    # Configure OCR engine
    if ocr:
        if ocr_engine == "tesseract":
            pipeline_options.ocr_options = TesseractOcrOptions()
        else:
            pipeline_options.ocr_options = EasyOcrOptions(
                lang=languages,
                confidence_threshold=0.5
            )

    # Configure table extraction mode
    if table_mode == "accurate":
        pipeline_options.table_structure_options.mode = TableFormerMode.ACCURATE
    else:
        pipeline_options.table_structure_options.mode = TableFormerMode.FAST

    pipeline_options.table_structure_options.do_cell_matching = True

    return DocumentConverter(
        format_options={
            InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options),
        }
    )


def classify_picture(picture) -> str:
    """Classify picture as 'figure' or 'photo'."""
    if hasattr(picture, 'classification') and picture.classification:
        cls = picture.classification
        if cls == PictureClassificationClass.PHOTO:
            return 'photo'
    return 'figure'


def convert_document(
    source: str,
    output_dir: str | Path,
    output_format: str = "markdown",
    ocr: bool = False,
    ocr_engine: str = "easyocr",
    languages: Optional[list[str]] = None,
    table_mode: str = "fast",
    page_range: Optional[tuple[int, int]] = None,
) -> dict:
    """
    Convert a document to the specified format with enhanced output.

    Args:
        source: Path or URL to the source document
        output_dir: Directory to save output files
        output_format: Output format (markdown, json, html, all)
        ocr: Enable OCR for scanned documents
        ocr_engine: OCR engine to use (easyocr, tesseract)
        languages: List of language codes for OCR
        table_mode: Table extraction mode (fast, accurate)
        page_range: Optional (start, end) page range (1-indexed)

    Returns:
        Dictionary with conversion results and file paths
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    converter = create_converter(
        ocr=ocr,
        ocr_engine=ocr_engine,
        languages=languages,
        table_mode=table_mode,
        generate_images=True,
    )

    # Handle page range for PDFs
    source_path = Path(source)
    actual_source = source

    if page_range and source.lower().endswith('.pdf'):
        try:
            import fitz
            pdf_doc = fitz.open(source)
            start_page, end_page = page_range
            start_page = max(0, start_page - 1)  # Convert to 0-indexed
            end_page = min(len(pdf_doc), end_page)

            # Create a temporary PDF with selected pages
            temp_pdf = output_dir / "_temp_pages.pdf"
            new_doc = fitz.open()
            new_doc.insert_pdf(pdf_doc, from_page=start_page, to_page=end_page - 1)
            new_doc.save(str(temp_pdf))
            new_doc.close()
            pdf_doc.close()
            actual_source = str(temp_pdf)
        except ImportError:
            print("Warning: PyMuPDF not installed, processing full document", file=sys.stderr)
        except Exception as e:
            print(f"Warning: Page range extraction failed: {e}", file=sys.stderr)

    # Convert document
    result = converter.convert(actual_source)
    doc = result.document

    # Clean up temp file
    temp_pdf = output_dir / "_temp_pages.pdf"
    if temp_pdf.exists():
        temp_pdf.unlink()

    # Get source file stem for output naming
    source_stem = source_path.stem

    # Generate output based on format
    results = {
        'source': str(source),
        'output_dir': str(output_dir),
        'files': [],
    }

    # Use docling's built-in save methods with REFERENCED image mode
    # This automatically saves images and creates proper links

    if output_format in ("markdown", "all"):
        md_file = output_dir / f"{source_stem}.md"

        # save_as_markdown with REFERENCED mode creates artifacts directory
        # and embeds image links
        doc.save_as_markdown(
            md_file,
            image_mode=ImageRefMode.REFERENCED
        )

        results['files'].append(str(md_file))
        results['markdown'] = str(md_file)

        # Process images: move from artifacts to images directory with proper naming
        artifacts_dir = output_dir / f"{source_stem}_artifacts"
        target_images_dir = output_dir / "images"

        if artifacts_dir.exists():
            target_images_dir.mkdir(exist_ok=True)

            # Get all image files sorted by name
            image_files = sorted([
                f for f in artifacts_dir.iterdir()
                if f.is_file() and f.suffix.lower() in ('.png', '.jpg', '.jpeg', '.gif', '.webp')
            ])

            # Build rename map
            rename_map = {}
            figure_count = 0
            photo_count = 0

            for idx, img_file in enumerate(image_files):
                # Try to classify based on document pictures
                pic_type = 'figure'
                if idx < len(doc.pictures):
                    pic_type = classify_picture(doc.pictures[idx])

                if pic_type == 'photo':
                    photo_count += 1
                    new_name = f"photo_{photo_count:03d}.jpg"
                    new_file = target_images_dir / new_name
                    try:
                        from PIL import Image
                        img = Image.open(img_file)
                        if img.mode in ('RGBA', 'LA', 'P'):
                            img = img.convert('RGB')
                        img.save(new_file, 'JPEG', quality=85)
                    except Exception:
                        shutil.copy2(str(img_file), str(new_file))
                else:
                    figure_count += 1
                    new_name = f"figure_{figure_count:03d}.png"
                    new_file = target_images_dir / new_name
                    shutil.copy2(str(img_file), str(new_file))

                rename_map[img_file.name] = new_name

            # Update markdown file with new image paths (relative paths)
            if rename_map:
                md_content = md_file.read_text(encoding='utf-8')

                # Replace absolute paths with relative paths
                # Pattern: ![...](absolute/path/to/artifacts/filename.png)
                for old_name, new_name in rename_map.items():
                    # Handle various path patterns
                    # Absolute path pattern
                    abs_pattern = re.escape(str(artifacts_dir / old_name))
                    md_content = re.sub(
                        rf'!\[([^\]]*)\]\({abs_pattern}\)',
                        rf'![\1](images/{new_name})',
                        md_content
                    )

                    # Relative path patterns
                    rel_artifacts = f"{source_stem}_artifacts"
                    md_content = md_content.replace(
                        f"]({rel_artifacts}/{old_name})",
                        f"](images/{new_name})"
                    )
                    md_content = md_content.replace(
                        f"](./{rel_artifacts}/{old_name})",
                        f"](images/{new_name})"
                    )

                md_file.write_text(md_content, encoding='utf-8')

            # Remove artifacts directory after processing
            shutil.rmtree(artifacts_dir, ignore_errors=True)

    if output_format in ("json", "all"):
        json_file = output_dir / f"{source_stem}.json"
        doc.save_as_json(
            json_file,
            image_mode=ImageRefMode.REFERENCED
        )
        results['files'].append(str(json_file))
        results['json'] = str(json_file)

    if output_format in ("html", "all"):
        html_file = output_dir / f"{source_stem}.html"
        doc.save_as_html(
            html_file,
            image_mode=ImageRefMode.REFERENCED
        )
        results['files'].append(str(html_file))
        results['html'] = str(html_file)

    # Count images in the images directory
    images_dir = output_dir / "images"
    if images_dir.exists():
        image_files = list(images_dir.glob("*.*"))
        results['images'] = [str(f) for f in image_files]
        results['image_count'] = len(image_files)
    else:
        results['images'] = []
        results['image_count'] = 0

    # Add table count
    results['table_count'] = len(list(doc.tables))

    return results


def generate_batch_script(
    source: str,
    output_dir: str,
    page_count: int,
    batch_size: int = 20,
    **kwargs
) -> str:
    """Generate a Python script for batch processing large documents."""

    script_content = f'''#!/usr/bin/env python3
"""
Auto-generated batch processing script for: {source}
Total pages: {page_count}
Batch size: {batch_size}
"""

import sys
sys.path.insert(0, "{Path(__file__).parent}")

from convert_document import convert_document
from pathlib import Path

source = "{source}"
output_base = Path("{output_dir}")
total_pages = {page_count}
batch_size = {batch_size}

# Processing parameters
ocr = {kwargs.get('ocr', False)}
ocr_engine = "{kwargs.get('ocr_engine', 'easyocr')}"
languages = {kwargs.get('languages', ['en', 'ja'])}
table_mode = "{kwargs.get('table_mode', 'fast')}"

def main():
    all_results = []

    for start in range(1, total_pages + 1, batch_size):
        end = min(start + batch_size - 1, total_pages)
        batch_dir = output_base / f"pages_{{start:03d}}_{{end:03d}}"

        print(f"Processing pages {{start}}-{{end}}...")

        try:
            result = convert_document(
                source=source,
                output_dir=str(batch_dir),
                output_format="all",
                ocr=ocr,
                ocr_engine=ocr_engine,
                languages=languages,
                table_mode=table_mode,
                page_range=(start, end),
            )
            all_results.append(result)
            print(f"  Saved to: {{batch_dir}}")
        except Exception as e:
            print(f"  Error: {{e}}")

    print(f"\\nProcessing complete. {{len(all_results)}} batches processed.")

if __name__ == "__main__":
    main()
'''
    return script_content


def main():
    parser = argparse.ArgumentParser(
        description="Convert documents using Docling (Enhanced)",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Basic conversion
  %(prog)s document.pdf -o ./output

  # With OCR for scanned documents
  %(prog)s scanned.pdf -o ./output --ocr

  # JSON and Markdown output
  %(prog)s document.pdf -o ./output -f all

  # Process specific page range
  %(prog)s large.pdf -o ./output --pages 1-20

  # Generate batch script for large file
  %(prog)s large.pdf -o ./output --generate-script
        """,
    )

    parser.add_argument("source", help="Path or URL to the source document")
    parser.add_argument(
        "-o", "--output",
        required=True,
        help="Output directory path",
    )
    parser.add_argument(
        "-f", "--format",
        choices=["markdown", "json", "html", "all"],
        default="markdown",
        help="Output format (default: markdown)",
    )
    parser.add_argument("--ocr", action="store_true", help="Enable OCR for scanned documents")
    parser.add_argument(
        "--ocr-engine",
        choices=["easyocr", "tesseract"],
        default="easyocr",
        help="OCR engine to use (default: easyocr)",
    )
    parser.add_argument(
        "--languages",
        nargs="+",
        default=["en", "ja"],
        help="Languages for OCR (default: en ja)",
    )
    parser.add_argument(
        "--table-mode",
        choices=["fast", "accurate"],
        default="fast",
        help="Table extraction mode (default: fast)",
    )
    parser.add_argument(
        "--pages",
        help="Page range to process (e.g., 1-20)",
    )
    parser.add_argument(
        "--generate-script",
        action="store_true",
        help="Generate batch processing script for large files",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=20,
        help="Pages per batch for script generation (default: 20)",
    )

    args = parser.parse_args()

    # Check if file is large
    is_large, page_count = is_large_file(args.source)
    file_size_mb = get_file_size_mb(args.source)

    if is_large or file_size_mb > LARGE_FILE_SIZE_MB:
        print(f"Large file detected: {page_count} pages, {file_size_mb:.1f} MB")

        if args.generate_script:
            script = generate_batch_script(
                source=args.source,
                output_dir=args.output,
                page_count=page_count,
                batch_size=args.batch_size,
                ocr=args.ocr,
                ocr_engine=args.ocr_engine,
                languages=args.languages,
                table_mode=args.table_mode,
            )
            script_path = Path(args.output) / "batch_process.py"
            script_path.parent.mkdir(parents=True, exist_ok=True)
            script_path.write_text(script, encoding="utf-8")
            print(f"Batch script generated: {script_path}")
            print(f"Run with: python {script_path}")
            return

        if not args.pages:
            print("Tip: Use --pages 1-20 to process specific pages,")
            print("     or --generate-script to create a batch processing script.")

    # Parse page range
    page_range = None
    if args.pages:
        try:
            start, end = map(int, args.pages.split('-'))
            page_range = (start, end)
        except ValueError:
            print(f"Error: Invalid page range format: {args.pages}", file=sys.stderr)
            sys.exit(1)

    try:
        result = convert_document(
            source=args.source,
            output_dir=args.output,
            output_format=args.format,
            ocr=args.ocr,
            ocr_engine=args.ocr_engine,
            languages=args.languages,
            table_mode=args.table_mode,
            page_range=page_range,
        )

        print("\n=== Conversion Complete ===")
        print(f"Source: {result['source']}")
        print(f"Output directory: {result['output_dir']}")
        print(f"Files created:")
        for f in result['files']:
            print(f"  - {f}")
        print(f"Images extracted: {result['image_count']}")
        print(f"Tables found: {result['table_count']}")

    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()
