import os
import sys
import time
import math
import pygame
import pickle
import random
import numpy as np
from datetime import datetime
from copy import deepcopy

# Constants
refresh_rate = 0.05
grid_width, grid_height = 121, 71
screen_width, screen_height = grid_width * 10, grid_height * 10

class Game:
    def __init__(self):
        self.elements = {}
        self.event_log = []
        self.init_grid()
        self.init_walls()
        self.init_paddles()
        self.init_ball()
        self.event_pending = []  # (not used in this pong implementation)

    def init_grid(self):
        self.grid = np.zeros((grid_width, grid_height), dtype=int)
        self.r = np.zeros((grid_width, grid_height), dtype=int)
        self.g = np.zeros((grid_width, grid_height), dtype=int)
        self.b = np.zeros((grid_width, grid_height), dtype=int)
        # Initialize with black background
        self.grid[:, :] = 0
        self.r[:, :] = 0
        self.g[:, :] = 0
        self.b[:, :] = 0

    def init_walls(self):
        # Top wall (ID 7)
        self.grid[3:grid_width - 3, 0:3] = 7
        self.r[3:grid_width - 3, 0:3] = 0
        self.g[3:grid_width - 3, 0:3] = 255
        self.b[3:grid_width - 3, 0:3] = 150
        self.elements['wall_top'] = {
            'id': 7,
            'pos_x': grid_width // 2,
            'pos_y': 1,
            'shape_x': grid_width // 2,
            'shape_y': 1,
            'hitbox_tl_x': 3,
            'hitbox_tl_y': 0,
            'hitbox_br_x': grid_width - 4,
            'hitbox_br_y': 2,
            'color_r': 0,
            'color_g': 255,
            'color_b': 150,
            'color_state': 0,
            'never_hit': True,
            'existence': True,
        }
        # Bottom wall (ID 8)
        self.grid[3:grid_width - 3, grid_height - 2:grid_height] = 8
        self.r[3:grid_width - 3, grid_height - 2:grid_height] = 0
        self.g[3:grid_width - 3, grid_height - 2:grid_height] = 255
        self.b[3:grid_width - 3, grid_height - 2:grid_height] = 200
        self.elements['wall_bottom'] = {
            'id': 8,
            'pos_x': grid_width // 2,
            'pos_y': grid_height - 2,
            'shape_x': grid_width // 2,
            'shape_y': 1,
            'hitbox_tl_x': 3,
            'hitbox_tl_y': grid_height - 3,
            'hitbox_br_x': grid_width - 4,
            'hitbox_br_y': grid_height - 1,
            'color_r': 0,
            'color_g': 255,
            'color_b': 200,
            'color_state': 0,
            'never_hit': True,
            'existence': True,
        }

    def init_paddles(self):
        # Paddle dimensions and speed
        self.paddle_halfwidth = 1
        self.paddle_halfheight = 4
        self.paddle_base_speed = 2

        # Left paddle (ID 2) – controlled by W/S keys; placed near left edge
        self.left_paddle_x = 5
        self.left_paddle_y = grid_height // 2
        self.left_paddle_speed = 0
        self.left_paddle_old_y = self.left_paddle_y

        # Right paddle (ID 3) – controlled by Up/Down arrow keys; placed near right edge
        self.right_paddle_x = grid_width - 6
        self.right_paddle_y = grid_height // 2
        self.right_paddle_speed = 0
        self.right_paddle_old_y = self.right_paddle_y

        # Draw initial paddles
        self.draw_left_paddle()
        self.draw_right_paddle()

        self.elements['paddle_left'] = {
            'id': 2,
            'pos_x': self.left_paddle_x,
            'pos_y': self.left_paddle_y,
            'shape_x': self.paddle_halfwidth,
            'shape_y': self.paddle_halfheight,
            'hitbox_tl_x': self.left_paddle_x - self.paddle_halfwidth,
            'hitbox_tl_y': self.left_paddle_y - self.paddle_halfheight,
            'hitbox_br_x': self.left_paddle_x + self.paddle_halfwidth,
            'hitbox_br_y': self.left_paddle_y + self.paddle_halfheight,
            'color_r': 0,
            'color_g': 0,
            'color_b': 255,  # Blue
            'color_state': 0,
            'never_hit': True,
            'existence': True,
        }
        self.elements['paddle_right'] = {
            'id': 3,
            'pos_x': self.right_paddle_x,
            'pos_y': self.right_paddle_y,
            'shape_x': self.paddle_halfwidth,
            'shape_y': self.paddle_halfheight,
            'hitbox_tl_x': self.right_paddle_x - self.paddle_halfwidth,
            'hitbox_tl_y': self.right_paddle_y - self.paddle_halfheight,
            'hitbox_br_x': self.right_paddle_x + self.paddle_halfwidth,
            'hitbox_br_y': self.right_paddle_y + self.paddle_halfheight,
            'color_r': 255,
            'color_g': 0,
            'color_b': 0,  # Red
            'color_state': 0,
            'never_hit': True,
            'existence': True,
        }

    def set_left_paddle_speed(self, value):
        self.left_paddle_speed = value * self.paddle_base_speed

    def set_right_paddle_speed(self, value):
        self.right_paddle_speed = value * self.paddle_base_speed

    def update_paddles(self):
        # Update left paddle (vertical movement)
        new_left_y = self.left_paddle_y + self.left_paddle_speed
        if (new_left_y - self.paddle_halfheight) >= 3 and (new_left_y + self.paddle_halfheight) <= (grid_height - 3):
            self.left_paddle_old_y = self.left_paddle_y
            self.left_paddle_y = new_left_y

        self.elements['paddle_left']['pos_y'] = self.left_paddle_y
        self.elements['paddle_left']['hitbox_tl_y'] = self.left_paddle_y - self.paddle_halfheight
        self.elements['paddle_left']['hitbox_br_y'] = self.left_paddle_y + self.paddle_halfheight

        # Update right paddle (vertical movement)
        new_right_y = self.right_paddle_y + self.right_paddle_speed
        if (new_right_y - self.paddle_halfheight) >= 3 and (new_right_y + self.paddle_halfheight) <= (grid_height - 3):
            self.right_paddle_old_y = self.right_paddle_y
            self.right_paddle_y = new_right_y

        self.elements['paddle_right']['pos_y'] = self.right_paddle_y
        self.elements['paddle_right']['hitbox_tl_y'] = self.right_paddle_y - self.paddle_halfheight
        self.elements['paddle_right']['hitbox_br_y'] = self.right_paddle_y + self.paddle_halfheight

    def draw_left_paddle(self):
        # Erase previous left paddle drawing
        self.grid[self.left_paddle_x - self.paddle_halfwidth:self.left_paddle_x + self.paddle_halfwidth + 1,
                  self.left_paddle_old_y - self.paddle_halfheight:self.left_paddle_old_y + self.paddle_halfheight + 1] = 0
        self.r[self.left_paddle_x - self.paddle_halfwidth:self.left_paddle_x + self.paddle_halfwidth + 1,
                  self.left_paddle_old_y - self.paddle_halfheight:self.left_paddle_old_y + self.paddle_halfheight + 1] = 0
        self.g[self.left_paddle_x - self.paddle_halfwidth:self.left_paddle_x + self.paddle_halfwidth + 1,
                  self.left_paddle_old_y - self.paddle_halfheight:self.left_paddle_old_y + self.paddle_halfheight + 1] = 0
        self.b[self.left_paddle_x - self.paddle_halfwidth:self.left_paddle_x + self.paddle_halfwidth + 1,
                  self.left_paddle_old_y - self.paddle_halfheight:self.left_paddle_old_y + self.paddle_halfheight + 1] = 0

        # Draw left paddle (blue, ID 2)
        self.grid[self.left_paddle_x - self.paddle_halfwidth:self.left_paddle_x + self.paddle_halfwidth + 1,
                  self.left_paddle_y - self.paddle_halfheight:self.left_paddle_y + self.paddle_halfheight + 1] = 2
        self.r[self.left_paddle_x - self.paddle_halfwidth:self.left_paddle_x + self.paddle_halfwidth + 1,
                  self.left_paddle_y - self.paddle_halfheight:self.left_paddle_y + self.paddle_halfheight + 1] = 0
        self.g[self.left_paddle_x - self.paddle_halfwidth:self.left_paddle_x + self.paddle_halfwidth + 1,
                  self.left_paddle_y - self.paddle_halfheight:self.left_paddle_y + self.paddle_halfheight + 1] = 0
        self.b[self.left_paddle_x - self.paddle_halfwidth:self.left_paddle_x + self.paddle_halfwidth + 1,
                  self.left_paddle_y - self.paddle_halfheight:self.left_paddle_y + self.paddle_halfheight + 1] = 255

    def draw_right_paddle(self):
        # Erase previous right paddle drawing
        self.grid[self.right_paddle_x - self.paddle_halfwidth:self.right_paddle_x + self.paddle_halfwidth + 1,
                  self.right_paddle_old_y - self.paddle_halfheight:self.right_paddle_old_y + self.paddle_halfheight + 1] = 0
        self.r[self.right_paddle_x - self.paddle_halfwidth:self.right_paddle_x + self.paddle_halfwidth + 1,
                  self.right_paddle_old_y - self.paddle_halfheight:self.right_paddle_old_y + self.paddle_halfheight + 1] = 0
        self.g[self.right_paddle_x - self.paddle_halfwidth:self.right_paddle_x + self.paddle_halfwidth + 1,
                  self.right_paddle_old_y - self.paddle_halfheight:self.right_paddle_old_y + self.paddle_halfheight + 1] = 0
        self.b[self.right_paddle_x - self.paddle_halfwidth:self.right_paddle_x + self.paddle_halfwidth + 1,
                  self.right_paddle_old_y - self.paddle_halfheight:self.right_paddle_old_y + self.paddle_halfheight + 1] = 0

        # Draw right paddle (red, ID 3)
        self.grid[self.right_paddle_x - self.paddle_halfwidth:self.right_paddle_x + self.paddle_halfwidth + 1,
                  self.right_paddle_y - self.paddle_halfheight:self.right_paddle_y + self.paddle_halfheight + 1] = 3
        self.r[self.right_paddle_x - self.paddle_halfwidth:self.right_paddle_x + self.paddle_halfwidth + 1,
                  self.right_paddle_y - self.paddle_halfheight:self.right_paddle_y + self.paddle_halfheight + 1] = 255
        self.g[self.right_paddle_x - self.paddle_halfwidth:self.right_paddle_x + self.paddle_halfwidth + 1,
                  self.right_paddle_y - self.paddle_halfheight:self.right_paddle_y + self.paddle_halfheight + 1] = 0
        self.b[self.right_paddle_x - self.paddle_halfwidth:self.right_paddle_x + self.paddle_halfwidth + 1,
                  self.right_paddle_y - self.paddle_halfheight:self.right_paddle_y + self.paddle_halfheight + 1] = 0

    def init_ball(self):
        # Initialize ball near center with a slight random offset
        self.ball_x = grid_width // 2 + random.randint(-3, 3)
        self.ball_y = grid_height // 2 + random.randint(-3, 3)
        self.ball_radius = 1
        # Initial ball speed: random horizontal direction and vertical component
        self.ball_speed_x = random.choice([-1, 1])
        self.ball_speed_y = random.choice([-1, 1])
        self.ball_old_x = self.ball_x
        self.ball_old_y = self.ball_y
        self.draw_ball()
        self.elements['ball'] = {
            'id': 1,
            'pos_x': self.ball_x,
            'pos_y': self.ball_y,
            'shape_x': self.ball_radius,
            'shape_y': self.ball_radius,
            'hitbox_tl_x': self.ball_x - self.ball_radius,
            'hitbox_tl_y': self.ball_y - self.ball_radius,
            'hitbox_br_x': self.ball_x + self.ball_radius,
            'hitbox_br_y': self.ball_y + self.ball_radius,
            'color_r': 255,
            'color_g': 255,
            'color_b': 255,
            'color_state': 0,
            'never_hit': True,
            'existence': True,
        }

    def update_ball(self):
        invert_speed_x = False
        invert_speed_y = False

        # Predict new ball position
        ball_new_x = self.ball_x + self.ball_speed_x
        ball_new_y = self.ball_y + self.ball_speed_y

        # Bounce off top wall
        if ball_new_y - self.ball_radius < 3:
            invert_speed_y = True
        # Bounce off bottom wall
        if ball_new_y + self.ball_radius > grid_height - 3:
            invert_speed_y = True

        # Bounce off left paddle (ID 2)
        if (ball_new_x - self.ball_radius <= self.left_paddle_x + self.paddle_halfwidth and
            ball_new_x - self.ball_radius >= self.left_paddle_x - self.paddle_halfwidth):
            if (ball_new_y >= self.left_paddle_y - self.paddle_halfheight and
                ball_new_y <= self.left_paddle_y + self.paddle_halfheight):
                invert_speed_x = True

        # Bounce off right paddle (ID 3)
        if (ball_new_x + self.ball_radius >= self.right_paddle_x - self.paddle_halfwidth and
            ball_new_x + self.ball_radius <= self.right_paddle_x + self.paddle_halfwidth):
            if (ball_new_y >= self.right_paddle_y - self.paddle_halfheight and
                ball_new_y <= self.right_paddle_y + self.paddle_halfheight):
                invert_speed_x = True

        # Update ball position; if a bounce is due, invert the corresponding speed
        self.ball_old_x = self.ball_x
        self.ball_old_y = self.ball_y

        if invert_speed_x:
            self.ball_speed_x = -self.ball_speed_x
        if invert_speed_y:
            self.ball_speed_y = -self.ball_speed_y

        self.ball_x += self.ball_speed_x
        self.ball_y += self.ball_speed_y

        self.elements['ball']['pos_x'] = self.ball_x
        self.elements['ball']['pos_y'] = self.ball_y
        self.elements['ball']['hitbox_tl_x'] = self.ball_x - self.ball_radius
        self.elements['ball']['hitbox_tl_y'] = self.ball_y - self.ball_radius
        self.elements['ball']['hitbox_br_x'] = self.ball_x + self.ball_radius
        self.elements['ball']['hitbox_br_y'] = self.ball_y + self.ball_radius

        # End the game if the ball goes past the left or right boundaries.
        if self.ball_x - self.ball_radius < 0 or self.ball_x + self.ball_radius > grid_width - 1:
            self.event_log.append({'description': 'ball_out', 'subject': 1})
            return True  # signal game end
        return False

    def draw_ball(self):
        # Erase old ball drawing
        self.r[self.ball_old_x - self.ball_radius:self.ball_old_x + self.ball_radius + 1,
               self.ball_old_y - self.ball_radius:self.ball_old_y + self.ball_radius + 1] = 0
        self.g[self.ball_old_x - self.ball_radius:self.ball_old_x + self.ball_radius + 1,
               self.ball_old_y - self.ball_radius:self.ball_old_y + self.ball_radius + 1] = 0
        self.b[self.ball_old_x - self.ball_radius:self.ball_old_x + self.ball_radius + 1,
               self.ball_old_y - self.ball_radius:self.ball_old_y + self.ball_radius + 1] = 0

        # Draw ball in white
        self.r[self.ball_x - self.ball_radius:self.ball_x + self.ball_radius + 1,
               self.ball_y - self.ball_radius:self.ball_y + self.ball_radius + 1] = 255
        self.g[self.ball_x - self.ball_radius:self.ball_x + self.ball_radius + 1,
               self.ball_y - self.ball_radius:self.ball_y + self.ball_radius + 1] = 255
        self.b[self.ball_x - self.ball_radius:self.ball_x + self.ball_radius + 1,
               self.ball_y - self.ball_radius:self.ball_y + self.ball_radius + 1] = 255

    def update(self):
        # Update paddle positions and redraw them
        self.update_paddles()
        self.draw_left_paddle()
        self.draw_right_paddle()
        # Update ball and redraw it
        game_end = self.update_ball()
        self.draw_ball()
        event_log = self.event_log
        self.event_log = []
        return self.elements, event_log, game_end

    def get_log(self):
        return self.elements, self.event_log

    def get_grid(self):
        return np.transpose(np.stack([self.r, self.g, self.b]), (1, 2, 0))


# Main loop and event handling
save_log = True
save_frames = False

pygame.init()
window = pygame.display.set_mode((screen_width, screen_height))
pygame.display.set_caption("Basic Pong")

game = Game()

grid = pygame.surfarray.make_surface(game.get_grid())
screen = pygame.transform.scale(grid, (screen_width, screen_height))
window.blit(screen, (0, 0))

t = time.time()
keys_down = []
keys_up = []

# Track keys held for paddle control:
# Right paddle uses UP/DOWN arrow keys.
upArrowHeld = False
downArrowHeld = False
# Left paddle uses W/S keys.
wHeld = False
sHeld = False

first_time_run = True
screen_running = True
game_running = False

element_log, event_log = game.get_log()

frame_id = 0
frames = []
frames.append({
    'frame_id': frame_id,
    'commands': [],
    'elements': deepcopy(element_log),
    'events': [{'description': 'game_start', 'subject': 0}]
})
frame_id += 1

date_time = datetime.now().strftime("_%Y_%m_%d_%H_%M_%S")
frame_folder = f"saved_frames/episode{date_time}"
if save_frames:
    os.makedirs(frame_folder, exist_ok=True)

while screen_running:
    for e in pygame.event.get():
        if e == pygame.QUIT:
            screen_running = False

        if e.type == pygame.KEYDOWN:
            keys_down.append(e.key)
            if e.key == pygame.K_UP:
                upArrowHeld = True
            elif e.key == pygame.K_DOWN:
                downArrowHeld = True
            elif e.key == pygame.K_w:
                wHeld = True
            elif e.key == pygame.K_s:
                sHeld = True

        if e.type == pygame.KEYUP:
            keys_up.append(e.key)
            if e.key == pygame.K_UP:
                upArrowHeld = False
            elif e.key == pygame.K_DOWN:
                downArrowHeld = False
            elif e.key == pygame.K_w:
                wHeld = False
            elif e.key == pygame.K_s:
                sHeld = False

    new_t = time.time()
    if new_t - t > refresh_rate:
        command_log = []
        if first_time_run:
            if not game_running and (len(keys_down) > 0):
                game_running = True
                first_time_run = False

        if pygame.K_q in keys_down:
            screen_running = False

        # Toggle pause/play with SPACE if desired
        if pygame.K_SPACE in keys_down:
            game_running = not game_running

        # Set right paddle speed (controlled by UP/DOWN arrows)
        if upArrowHeld and not downArrowHeld:
            game.set_right_paddle_speed(-1)
            command_log.append(('up_arrow_pressed'))
        elif downArrowHeld and not upArrowHeld:
            game.set_right_paddle_speed(1)
            command_log.append(('down_arrow_pressed'))
        else:
            game.set_right_paddle_speed(0)

        # Set left paddle speed (controlled by W/S keys)
        if wHeld and not sHeld:
            game.set_left_paddle_speed(-1)
            command_log.append(('w_key_pressed'))
        elif sHeld and not wHeld:
            game.set_left_paddle_speed(1)
            command_log.append(('s_key_pressed'))
        else:
            game.set_left_paddle_speed(0)

        if game_running:
            element_log, event_log, end_game = game.update()
            if frame_id > 0:
                frames[-1]['commands'].extend(deepcopy(command_log))
            frames.append({
                'frame_id': frame_id,
                'commands': [],
                'elements': deepcopy(element_log),
                'events': event_log
            })
            frame_id += 1

            grid = pygame.surfarray.make_surface(game.get_grid())
            screen = pygame.transform.scale(grid, (screen_width, screen_height))
            window.blit(screen, (0, 0))

            if save_frames:
                pygame.image.save(screen, f"{frame_folder}/frame_{frame_id}.png")

            if end_game:
                game_running = False
                screen_running = False

        t = new_t
        keys_down = []
        keys_up = []

    pygame.display.flip()

if save_log:

    os.makedirs('logs/pong_logs', exist_ok=True)

    with open(f'logs/pong_logs/pong_log{date_time}.pkl', 'wb') as logfile:
        pickle.dump(frames, logfile)

pygame.quit()
sys.exit()
