"""
Production Batch Image Restyler with Dynamic Parameters
- Accepts prompt, quality, CSV path, and timestamp as parameters
- Parallel execution (configurable)
- Automatic retry on errors
- Auto crop + resize to 1000×1130
- Respects API rate limits
"""

import os, csv, base64, io, time, threading, sys, json, argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import backoff
from openai import OpenAI
from PIL import Image

# ----------------------------
# CONFIGURATION
# ----------------------------
API_KEY   = "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"
MAX_WORKERS = 1               # safe range: 3–5
SIZE = "1024x1024"            # valid DALL·E dimension
RATE_LIMIT = 12

client = OpenAI(api_key=API_KEY)

# ----------------------------
# RATE LIMITER (thread-safe)
# ----------------------------
last_call = 0
lock = threading.Lock()

def rate_limited_call():
    """Ensure we send ≤5 requests per minute total"""
    global last_call
    with lock:
        elapsed = time.time() - last_call
        wait = max(0, RATE_LIMIT - elapsed)
        if wait > 0:
            time.sleep(wait)
        last_call = time.time()

# ----------------------------
# CROPPING & RESIZING
# ----------------------------
def crop_resize_to_target(img_bytes, path_out):
    """Crop and resize image to 1000×1130 while preserving aspect"""
    img = Image.open(io.BytesIO(img_bytes))
    w, h = img.size
    target_w, target_h = 1000, 1130

    # Center-crop to roughly match aspect ratio
    aspect_target = target_w / target_h
    aspect_img = w / h
    if aspect_img > aspect_target:
        new_w = int(h * aspect_target)
        left = (w - new_w) // 2
        box = (left, 0, left + new_w, h)
    else:
        new_h = int(w / aspect_target)
        top = (h - new_h) // 2
        box = (0, top, w, top + new_h)
    cropped = img.crop(box)
    resized = cropped.resize((target_w, target_h), Image.LANCZOS)
    resized.save(path_out, quality=95)

# ----------------------------
# RESTYLE FUNCTION WITH RETRIES
# ----------------------------
def restyle_image(path_in, path_out, prompt, quality):
    """Send one image to OpenAI and save the restyled output"""
    rate_limited_call()
    with open(path_in, "rb") as img:
        result = client.images.edit(
            model="gpt-image-1",
            image=img,
            prompt=prompt,
            size=SIZE,
            n=1,
            quality=quality
        )
    b64 = result.data[0].b64_json
    img_bytes = base64.b64decode(b64)
    crop_resize_to_target(img_bytes, path_out)

@backoff.on_exception(backoff.expo, Exception, max_tries=3, jitter=None)
def restyle_image_with_retry(path_in, path_out, prompt, quality):
    """Wrapper with retry logic"""
    restyle_image(path_in, path_out, prompt, quality)

# ----------------------------
# PER-IMAGE WRAPPER
# ----------------------------
def process_row(row, input_dir, output_dir, prompt, quality):
    """Process a single image row"""
    filename = row["filename"].strip()
    base_name = os.path.splitext(filename)[0]

    inp = os.path.join(input_dir, filename)
    out = os.path.join(output_dir, f"{base_name}-restyled.jpg")

    if not os.path.exists(inp):
        return f"{filename}: MISSING INPUT"
    if os.path.exists(out):
        return f"{filename}: SKIP"

    try:
        restyle_image_with_retry(inp, out, prompt, quality)
        return f"{filename}: OK"
    except Exception as e:
        return f"{filename}: ERR {e}"

# ----------------------------
# STATUS FILE MANAGEMENT
# ----------------------------
def update_status(status_file, status, processed=0, total=0, message=""):
    """Update job status file"""
    status_data = {
        "status": status,
        "processed": processed,
        "total": total,
        "message": message,
        "updated_at": time.strftime("%Y-%m-%d %H:%M:%S")
    }
    with open(status_file, 'w') as f:
        json.dump(status_data, f)

# ----------------------------
# MAIN EXECUTION
# ----------------------------
def main(prompt, quality, csv_path, timestamp):
    """Main processing function"""
    try:
        # Setup paths
        input_dir = os.path.join("images_in", timestamp)
        output_dir = os.path.join("images_out", timestamp)
        status_file = os.path.join("jobs", f"job_{timestamp}.json")

        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs("jobs", exist_ok=True)

        # Update status: started
        update_status(status_file, "processing", 0, 0, "Job started")

        # Read CSV file
        with open(csv_path) as f:
            rows = list(csv.DictReader(f))

        total = len(rows)
        print(f"Processing {total} images using {MAX_WORKERS} threads...")
        print(f"Input: {input_dir}")
        print(f"Output: {output_dir}")
        print(f"Prompt: {prompt[:50]}...")
        print(f"Quality: {quality}\n")

        # Update status: processing
        update_status(status_file, "processing", 0, total, f"Processing {total} images")

        start = time.time()
        processed = 0

        # Process images
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            futures = {
                executor.submit(process_row, row, input_dir, output_dir, prompt, quality): row
                for row in rows
            }

            for i, future in enumerate(as_completed(futures), 1):
                result = future.result()
                print(f"[{i}/{total}] {result}")
                processed = i

                # Update status periodically
                if processed % 5 == 0 or processed == total:
                    update_status(status_file, "processing", processed, total,
                                f"Processed {processed}/{total} images")

        mins = (time.time() - start) / 60
        completion_msg = f"Completed {total} images in {mins:.1f} minutes"
        print(f"\n✅ {completion_msg}")

        # Update status: completed
        update_status(status_file, "completed", total, total, completion_msg)

    except Exception as e:
        error_msg = f"Error: {str(e)}"
        print(f"\n❌ {error_msg}")
        update_status(status_file, "failed", processed, total, error_msg)
        raise

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process images with dynamic parameters')
    parser.add_argument('--prompt', required=True, help='Restyling prompt')
    parser.add_argument('--quality', required=True, choices=['high', 'medium'], help='Image quality')
    parser.add_argument('--csv', required=True, help='CSV file path')
    parser.add_argument('--timestamp', required=True, help='Job timestamp')

    args = parser.parse_args()

    main(args.prompt, args.quality, args.csv, args.timestamp)
