#!/usr/bin/env python3
"""PDF 解析工具 - 支持本地文件/URL、分批处理、图片内嵌/独立模式和翻译"""

import argparse
import requests
import subprocess
import math
import time
import os
import re
import base64
import zipfile
import shutil

API_URL = "http://llm.bnuzh.edu.cn:8880/file_parse"
BATCH_SIZE = 50
MAX_RETRIES = 3
RETRY_DELAY = 5

# 翻译 API 配置
TRANSLATION_API_URL = "http://llm.bnuzh.edu.cn:9180/v1/chat/completions"
TRANSLATION_API_KEY = "sk-AeGYLnhYWUEifWq7fedjCioDZssxEVraGzEcyhej9gQ1um6s"

def download_pdf_from_url(url, output_path=None):
    """从 URL 下载 PDF 文件"""
    print(f"📥 正在下载: {url}")
    
    try:
        response = requests.get(url, timeout=120, stream=True)
        response.raise_for_status()
        
        content_disposition = response.headers.get('Content-Disposition')
        if content_disposition and 'filename=' in content_disposition:
            filename = re.findall(r'filename="?([^";]+)"?', content_disposition)[0]
        else:
            filename = os.path.basename(url.split('?')[0])
            if not filename.endswith('.pdf'):
                filename = 'download.pdf'
        
        if not output_path:
            output_path = filename
        
        with open(output_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        
        print(f"✅ 下载完成: {output_path}")
        return output_path
        
    except Exception as e:
        raise Exception(f"下载失败: {e}")

def get_pdf_page_count(file_path):
    """获取 PDF 页数"""
    result = subprocess.run(
        ['pdfinfo', file_path],
        capture_output=True, text=True
    )
    for line in result.stdout.split('\n'):
        if line.startswith('Pages:'):
            return int(line.split(':')[1].strip())
    return 0

def detect_language(text):
    """检测文本是否为中文"""
    chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
    total_chars = len(text)
    
    if total_chars == 0:
        return 'en'
    
    chinese_ratio = chinese_chars / total_chars
    return 'zh' if chinese_ratio > 0.3 else 'en'

def translate_to_chinese(text):
    """使用 LLM API 将英文翻译为中文"""
    print("🔄 正在翻译为中文...")
    
    chunk_size = 2000
    chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
    translated_chunks = []
    
    for i, chunk in enumerate(chunks):
        print(f"  翻译段落 {i+1}/{len(chunks)}...")
        
        prompt = f"""请将以下英文内容翻译成中文，保持 Markdown 格式：

{chunk}

要求：
- 保持原有的 Markdown 格式（标题、列表、代码块等）
- 翻译准确、流畅
- 专有名词保持英文但加括号说明中文
- 不要添加解释性内容"""

        try:
            response = requests.post(
                TRANSLATION_API_URL,
                headers={
                    "Authorization": f"Bearer {TRANSLATION_API_KEY}",
                    "Content-Type": "application/json"
                },
                json={
                    "model": "GLM-5",
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0.7
                },
                timeout=120
            )
            
            if response.status_code == 200:
                result = response.json()
                translated = result['choices'][0]['message']['content']
                translated_chunks.append(translated)
            else:
                print(f"  ⚠️ 翻译失败: {response.status_code}")
                translated_chunks.append(chunk)
                
        except Exception as e:
            print(f"  ⚠️ 翻译错误: {e}")
            translated_chunks.append(chunk)
    
    return '\n\n'.join(translated_chunks)

def parse_pdf_with_retry(file_path, start_page=None, end_page=None, 
                          return_images=True, retries=MAX_RETRIES):
    """带重试的 PDF 解析"""
    is_url = file_path.startswith('http://') or file_path.startswith('https://')
    
    temp_file = None
    try:
        if is_url:
            temp_file = download_pdf_from_url(file_path)
            file_path = temp_file
        
        for attempt in range(retries):
            try:
                with open(file_path, 'rb') as f:
                    files = {'files': f}
                    data = {'return_md': 'true'}
                    
                    if start_page is not None:
                        data['start_page_id'] = str(start_page)
                    if end_page is not None:
                        data['end_page_id'] = str(end_page)
                    if return_images:
                        data['return_images'] = 'true'
                    
                    response = requests.post(API_URL, files=files, data=data, timeout=600)
                    result = response.json()
                    
                    if result.get('status') == 'completed':
                        file_results = list(result['results'].values())[0]
                        md_content = file_results.get('md_content', '')
                        images = file_results.get('images', {}) if return_images else {}
                        return md_content, images
                        
                    elif result.get('status') == 'pending':
                        status_url = result.get('status_url')
                        for _ in range(30):
                            time.sleep(2)
                            status_resp = requests.get(status_url).json()
                            if status_resp['status'] == 'completed':
                                file_results = list(status_resp['results'].values())[0]
                                md_content = file_results.get('md_content', '')
                                images = file_results.get('images', {}) if return_images else {}
                                return md_content, images
                        raise Exception("轮询超时")
                        
            except Exception as e:
                print(f"  尝试 {attempt + 1}/{retries} 失败: {e}")
                if attempt < retries - 1:
                    time.sleep(RETRY_DELAY * (2 ** attempt))
                else:
                    raise
        return None, None
        
    finally:
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)

def embed_images_in_markdown(md_content, images_dict):
    """将图片 base64 内嵌到 Markdown 中（不保存独立图片文件）
    
    Args:
        md_content: Markdown 内容
        images_dict: {"图片名.jpg": "data:image/jpeg;base64,xxx", ...}
    
    Returns:
        处理后的 md_content，图片引用已替换为 base64 内嵌格式
    """
    if not images_dict:
        return md_content
    
    replaced_count = 0
    
    for img_name, img_data in images_dict.items():
        if not img_data.startswith('data:image/'):
            continue
        
        # 替换 MD 中的引用: images/xxx.jpg → base64 内嵌
        old_ref = f"images/{img_name}"
        
        if old_ref in md_content:
            md_content = md_content.replace(old_ref, img_data)
            replaced_count += 1
        elif img_name in md_content:
            md_content = md_content.replace(img_name, img_data)
            replaced_count += 1
    
    if replaced_count > 0:
        print(f"  🔗 已内嵌 {replaced_count} 张图片到 Markdown")
    
    return md_content

def extract_and_save_images(md_content, images_dict, output_dir="images"):
    """从 API 返回的 images 字典提取图片并保存到独立文件夹
    
    Args:
        md_content: Markdown 内容
        images_dict: {"图片名.jpg": "data:image/jpeg;base64,xxx", ...}
        output_dir: 图片输出目录
    
    Returns:
        处理后的 md_content，图片引用替换为本地路径
    """
    if not images_dict:
        return md_content
    
    os.makedirs(output_dir, exist_ok=True)
    saved_count = 0
    
    for img_name, img_data in images_dict.items():
        if not img_data.startswith('data:image/'):
            continue
        
        match = re.match(r'data:image/([^;]+);base64,(.+)', img_data)
        if not match:
            continue
        
        b64_data = match.group(2)
        filepath = os.path.join(output_dir, img_name)
        
        try:
            with open(filepath, 'wb') as f:
                f.write(base64.b64decode(b64_data))
            saved_count += 1
        except Exception as e:
            print(f"  保存图片失败: {img_name} - {e}")
            continue
        
        # 替换引用
        old_ref = f"images/{img_name}"
        new_ref = f"{output_dir}/{img_name}"
        
        if old_ref in md_content:
            md_content = md_content.replace(old_ref, new_ref)
        elif img_name in md_content:
            md_content = md_content.replace(img_name, new_ref)
    
    if saved_count > 0:
        print(f"  📁 已保存 {saved_count} 张图片到 {output_dir}/")
    
    return md_content

def convert_html_tables_to_markdown(md_content):
    """将 Markdown 中的 HTML 表格转换为 Markdown 表格格式
    
   MinerU 解析的 PDF 有时会在 Markdown 中输出 HTML 表格（如 <table>, <tr>, <td> 等），
    此函数将这些 HTML 表格转换为标准的 Markdown 表格格式。
    
    Args:
        md_content: Markdown 内容（可能包含 HTML 表格）
    
    Returns:
        转换后的 md_content，HTML 表格已替换为 Markdown 格式
    """
    import xml.etree.ElementTree as ET
    
    def parse_html_table(table_html):
        """解析 HTML 表格为 Markdown 表格"""
        try:
            # 清理 HTML，添加根元素以便解析
            table_html = table_html.strip()
            if not table_html.startswith('<'):
                return table_html
            
            # 移除可能存在的 markdown 标记
            table_html = table_html.replace('\\', '')
            
            # 尝试解析
            root = ET.fromstring(table_html)
            
            rows = []
            
            # 获取所有行 (tr)
            for tr in root.findall('.//tr'):
                cells = []
                
                # 获取表头单元格 (th)
                ths = tr.findall('th')
                if ths:
                    for th in ths:
                        text = th.text or ''
                        text = text.strip().replace('\n', ' ').replace('|', '\\|')
                        cells.append(text)
                
                # 获取普通单元格 (td)
                tds = tr.findall('td')
                if tds:
                    for td in tds:
                        text = td.text or ''
                        text = text.strip().replace('\n', ' ').replace('|', '\\|')
                        cells.append(text)
                
                if cells:
                    rows.append(cells)
            
            if not rows:
                return table_html
            
            # 构建 Markdown 表格
            md_lines = []
            
            # 表头
            header = rows[0]
            md_lines.append('| ' + ' | '.join(header) + ' |')
            
            # 分隔行
            md_lines.append('| ' + ' | '.join(['---'] * len(header)) + ' |')
            
            # 数据行
            for row in rows[1:]:
                # 补齐空单元格
                while len(row) < len(header):
                    row.append('')
                md_lines.append('| ' + ' | '.join(row[:len(header)]) + ' |')
            
            return '\n'.join(md_lines)
            
        except Exception as e:
            # 解析失败，返回原始内容
            return table_html
    
    # 匹配 <table>...</table> 块
    table_pattern = re.compile(r'<table[^>]*>.*?</table>', re.DOTALL | re.IGNORECASE)
    
    def replace_table(match):
        table_html = match.group(0)
        return parse_html_table(table_html)
    
    # 转换所有 HTML 表格
    converted_content = table_pattern.sub(replace_table, md_content)
    
    # 额外清理：移除剩余的 HTML 标签
    converted_content = re.sub(r'</?tr[^>]*>', '', converted_content)
    converted_content = re.sub(r'</?th[^>]*>', '', converted_content)
    converted_content = re.sub(r'</?td[^>]*>', '', converted_content)
    converted_content = re.sub(r'</?tbody[^>]*>', '', converted_content)
    converted_content = re.sub(r'</?thead[^>]*>', '', converted_content)
    converted_content = re.sub(r'</?colgroup[^>]*>', '', converted_content)
    converted_content = re.sub(r'</?col[^>]*>', '', converted_content)
    converted_content = re.sub(r'\s*colspan="\d+"', '', converted_content)
    converted_content = re.sub(r'\s*rowspan="\d+"', '', converted_content)
    
    return converted_content


def create_zip_package(md_file, images_dir, output_zip=None):
    """将 MD 文件和图片文件夹打包为 ZIP"""
    if output_zip is None:
        output_zip = md_file.replace('.md', '.zip')
    
    with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zf:
        # 添加 MD 文件
        zf.write(md_file, os.path.basename(md_file))
        
        # 添加图片文件夹
        if images_dir and os.path.exists(images_dir):
            for root, dirs, files in os.walk(images_dir):
                for file in files:
                    file_path = os.path.join(root, file)
                    arc_name = os.path.join(os.path.basename(images_dir), file)
                    zf.write(file_path, arc_name)
    
    return output_zip

def parse_pdf(input_source, return_images=True, batch_size=BATCH_SIZE, 
              embed_mode=True, create_zip=False, translate=False, auto_yes=False):
    """解析 PDF，支持本地文件或 URL
    
    Args:
        input_source: PDF 文件路径或 URL
        return_images: 固定为 True，始终返回图片（内嵌到 MD）
        batch_size: 每批页数
        embed_mode: True=内嵌模式, False=独立模式
        create_zip: 是否打包为 ZIP
        translate: 是否翻译为中文
        auto_yes: 是否自动确认翻译
    
    Returns:
        处理后的 md_content
    """
    # 固定 return_images=True，始终提取图片
    return_images = True
    
    is_url = input_source.startswith('http://') or input_source.startswith('https://')
    is_url = input_source.startswith('http://') or input_source.startswith('https://')
    
    if is_url:
        base_name = os.path.splitext(os.path.basename(input_source.split('?')[0]))[0]
    else:
        base_name = os.path.splitext(os.path.basename(input_source))[0]
    
    md_output = f"{base_name}.md"
    images_dir = f"{base_name}_images"
    
    # 解析 PDF
    if is_url:
        print(f"🌐 解析 URL: {input_source}")
        md_content, images = parse_pdf_with_retry(input_source, return_images=return_images)
    else:
        page_count = get_pdf_page_count(input_source)
        print(f"PDF 总页数: {page_count}")
        
        if page_count <= batch_size or page_count == 0:
            md_content, images = parse_pdf_with_retry(input_source, return_images=return_images)
        else:
            batches = math.ceil(page_count / batch_size)
            print(f"分 {batches} 批处理")
            
            all_content = []
            all_images = {}
            
            for i in range(batches):
                start = i * batch_size
                end = min((i + 1) * batch_size - 1, page_count - 1)
                print(f"处理第 {i+1}/{batches} 批: 页 {start}-{end}")
                
                md, images = parse_pdf_with_retry(input_source, start, end, return_images)
                if md:
                    all_content.append(md)
                    if images:
                        all_images.update(images)
                    print(f"  批次 {i+1} 完成")
                else:
                    print(f"  批次 {i+1} 返回空")
            
            md_content = '\n\n'.join(all_content)
            images = all_images
    
    if not md_content:
        raise Exception("解析失败，未获取到内容")
    
    # 检测语言
    lang = detect_language(md_content)
    print(f"📝 检测语言: {'中文' if lang == 'zh' else '英文/其他'}")
    
    # 处理图片（先处理，保存原语言版本）
    if return_images and images:
        if embed_mode:
            md_content_with_images = embed_images_in_markdown(md_content, images)
        else:
            md_content_with_images = extract_and_save_images(md_content, images, images_dir)
    else:
        md_content_with_images = md_content
    
    # 保存原语言版本
    original_md_output = f"{base_name}.md"
    # 转换 HTML 表格为 Markdown 表格
    md_content_converted = convert_html_tables_to_markdown(md_content_with_images)
    with open(original_md_output, 'w', encoding='utf-8') as f:
        f.write(md_content_converted)
    print(f"✅ 原语言版本: {original_md_output}")
    
    # 翻译处理
    should_translate = False
    
    if lang == 'zh':
        # 中文不需要翻译
        pass
    elif translate or auto_yes:
        # 命令行指定了翻译
        should_translate = True
    else:
        # 询问用户
        try:
            response = input("📢 文章是英文，需要翻译成中文吗？(y/n): ").strip().lower()
            if response in ['y', 'yes', '是', '1']:
                should_translate = True
        except EOFError:
            # 非交互模式，不翻译
            pass
    
    # 生成中文版本
    if should_translate:
        print("🔄 正在翻译为中文...")
        md_content_translated = translate_to_chinese(md_content)
        
        # 处理中文版本的图片
        if return_images and images:
            if embed_mode:
                md_content_translated = embed_images_in_markdown(md_content_translated, images)
            else:
                zh_images_dir = f"{base_name}_zh_images"
                md_content_translated = extract_and_save_images(md_content_translated, images, zh_images_dir)
        
        # 保存中文版本
        zh_md_output = f"{base_name}_zh.md"
        # 转换 HTML 表格为 Markdown 表格
        md_content_zh_converted = convert_html_tables_to_markdown(md_content_translated)
        with open(zh_md_output, 'w', encoding='utf-8') as f:
            f.write(md_content_zh_converted)
        print(f"✅ 中文版本: {zh_md_output}")
    
    # 打包 ZIP
    if create_zip:
        # 原语言版本打包
        if not embed_mode and os.path.exists(images_dir):
            zip_path = create_zip_package(original_md_output, images_dir)
            print(f"📦 打包完成: {zip_path}")
            if os.path.exists(images_dir):
                shutil.rmtree(images_dir)
        elif embed_mode:
            zip_path = create_zip_package(original_md_output, None)
            print(f"📦 打包完成: {zip_path}")
        
        # 中文版本打包（如果存在）
        if should_translate and not embed_mode:
            zh_images_dir = f"{base_name}_zh_images"
            if os.path.exists(zh_images_dir):
                zh_zip = create_zip_package(zh_md_output, zh_images_dir)
                print(f"📦 中文版打包完成: {zh_zip}")
                shutil.rmtree(zh_images_dir)
        elif should_translate and embed_mode:
            zh_zip = create_zip_package(zh_md_output, None)
            print(f"📦 中文版打包完成: {zh_zip}")
    
    return md_content

def main():
    parser = argparse.ArgumentParser(description='PDF 解析工具 - MinerU API')
    parser.add_argument('input', help='PDF 文件路径或 URL')
    parser.add_argument('-o', '--output', help='输出文件路径')
    parser.add_argument('-b', '--batch-size', type=int, default=BATCH_SIZE, help='每批页数')
    
    # 图片模式选择
    img_group = parser.add_mutually_exclusive_group()
    img_group.add_argument('-e', '--embed', action='store_true', default=True,
                          help='内嵌模式（默认）：图片 base64 内嵌到 MD，不打包')
    img_group.add_argument('-s', '--separate', action='store_true',
                          help='独立模式：图片保存到文件夹，需要 -z 打包')
    
    parser.add_argument('-z', '--zip', action='store_true', help='打包为 ZIP')
    parser.add_argument('-t', '--translate', action='store_true', help='自动翻译为中文')
    parser.add_argument('-y', '--yes', action='store_true', help='自动确认翻译')
    
    args = parser.parse_args()
    
    # 独立模式需要 -z 打包
    embed_mode = not args.separate
    
    parse_pdf(args.input, True, args.batch_size, 
              embed_mode=embed_mode,
              create_zip=args.zip, translate=args.translate, auto_yes=args.yes)

if __name__ == '__main__':
    main()
