"""
Example: Using ContentAutoGenerator with Celery task queue.

This demonstrates how the refactored OOP design integrates with Celery.
"""

import os
import json
import logging
import psycopg2
from celery import Celery
from ContentAutoGenerator import ContentAutoGenerator

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize Celery app
celery_app = Celery(
    'content_generation',
    broker=os.getenv('CELERY_BROKER_URL', 'redis://redis:6379/0'),
    backend=os.getenv('CELERY_RESULT_BACKEND', 'redis://redis:6379/0')
)
app = celery_app

# Configure Celery
celery_app.conf.update(
    task_serializer='json',
    accept_content=['json'],
    result_serializer='json',
    timezone='UTC',
    enable_utc=True,
)

# DB Config
DB_HOST = os.getenv('DB_HOST', 'db')
DB_PORT = os.getenv('DB_PORT', '5432')
DB_USER = os.getenv('DB_USER', 'wuyuxuan')
DB_PASS = os.getenv('DB_PASSWORD', '1234567890')
DB_NAME = os.getenv('DB_NAME', 'postgres')

def get_db_connection():
    try:
        return psycopg2.connect(
            host=DB_HOST,
            port=DB_PORT,
            user=DB_USER,
            password=DB_PASS,
            dbname=DB_NAME
        )
    except Exception as e:
        logger.error(f"Error connecting to database: {e}")
        return None

def update_publication_db(pub_id: str, results: dict, task_status: str, error_msg: str = None):
    conn = get_db_connection()
    if not conn:
        logger.error("Could not update DB: Connection failed")
        return

    try:
        cur = conn.cursor()
        
        # Build update query dynamically based on successful results
        update_fields = []
        update_values = []
        
        # Map resource types to DB columns
        type_map = {
            'audio': 'audio_url',
            'video': 'video_url',
            'mental_map': 'mental_map_url',
            'report': 'report_url',
            'flashcard': 'flashcard_url',
            'quiz': 'quiz_url',
            'infographic': 'infografica_url', # Note: DB column is infografica_url
            'presentation': 'presentation_url',
            'datatable': 'datatable_url'
        }
        
        # Extract filenames from results
        # results structure: {'audio': {'status': 'success', 'files': {...}}, ...}
        if results:
            for r_type, result in results.items():
                if result.get('status') == 'success' and r_type in type_map:
                    files = result.get('files')
                    filename = None
                    
                    if isinstance(files, str):
                        filename = os.path.basename(files)
                    elif isinstance(files, dict):
                        # Prioritize specific keys
                        priority_keys = [
                            'audio_path', 'video_path', 'html_path', 'image_path', 
                            'pdf_path', 'pptx_path', 'xlsx_path'
                        ]
                        for key in priority_keys:
                            if key in files and files[key]:
                                filename = os.path.basename(files[key])
                                break
                        # Fallback: take first value
                        if not filename and files:
                            first_val = next(iter(files.values()))
                            if first_val:
                                filename = os.path.basename(first_val)
                    
                    if filename:
                        col_name = type_map[r_type]
                        update_fields.append(f"{col_name} = %s")
                        update_values.append(filename)
        
        # Update status
        update_fields.append("task_status = %s")
        update_values.append(task_status)
        
        if error_msg:
            update_fields.append("task_error = %s")
            update_values.append(error_msg)
        
        # Execute update
        if update_fields:
            query = f"UPDATE publication_schema.publications SET {', '.join(update_fields)} WHERE id = %s"
            update_values.append(pub_id)
            
            cur.execute(query, tuple(update_values))
            conn.commit()
            logger.info(f"Updated DB for publication {pub_id}")
            
    except Exception as e:
        logger.error(f"Error updating DB: {e}")
        conn.rollback()
    finally:
        conn.close()

@celery_app.task(bind=True, name='generate_content')
def generate_content_task(
    self,
    pdf_path: str,
    resource_types: list = None,
    output_dir: str = None,
    llm_provider: str = None,
    publication_id: str = None # New argument
):
    """
    Celery task for generating educational content from PDF.
    """
    try:
        self.update_state(state='PARSING', meta={'status': 'Parsing PDF...'})
        
        generator = ContentAutoGenerator(
            pdf_path=pdf_path,
            resource_types=resource_types,
            output_dir=output_dir, # This should be mapped to ./assets in container
            llm_provider=llm_provider
        )
        
        paper_info = generator.parse_pdf()
        
        self.update_state(
            state='GENERATING',
            meta={
                'status': 'Generating resources...',
                'title': paper_info.get('title', 'Unknown')
            }
        )
        
        # Generate resources
        results = generator.generate_all()
        summary = generator.get_summary()
        
        # Determine final status
        if summary['failed'] == 0:
            final_status = 'completed'
            error_msg = None
        elif summary['successful'] > 0:
            final_status = 'partial'
            error_msg = f"Failed: {summary['failed']} resources"
        else:
            final_status = 'failed'
            error_msg = "All resources failed"

        # Update Database
        if publication_id:
            update_publication_db(publication_id, generator.results, final_status, error_msg)
        
        return {
            'status': final_status,
            'summary': summary,
            'paper_info': {
                'title': paper_info.get('title'),
                'page_count': paper_info.get('page_count')
            }
        }
        
    except Exception as e:
        logger.error(f"Task failed: {e}")
        if publication_id:
            update_publication_db(publication_id, {}, 'failed', str(e))
            
        self.update_state(state='FAILURE', meta={'status': f'Error: {str(e)}'})
        raise





