Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,3 @@
|
|
| 1 |
-
# app.py
|
| 2 |
-
# =============
|
| 3 |
-
# This is a complete app.py file for an Arkanoid game that a neural network will play and learn using reinforcement learning.
|
| 4 |
-
# The game is built using pygame, and the neural network is trained using stable-baselines3. Gradio is used for the interface.
|
| 5 |
-
|
| 6 |
import os
|
| 7 |
import numpy as np
|
| 8 |
import pygame
|
|
@@ -12,7 +7,6 @@ from stable_baselines3 import DQN
|
|
| 12 |
from stable_baselines3.common.evaluation import evaluate_policy
|
| 13 |
import gradio as gr
|
| 14 |
import cv2
|
| 15 |
-
import imageio
|
| 16 |
|
| 17 |
# Constants
|
| 18 |
SCREEN_WIDTH = 640
|
|
@@ -68,7 +62,7 @@ class Ball:
|
|
| 68 |
|
| 69 |
class Brick:
|
| 70 |
def __init__(self, x, y):
|
| 71 |
-
self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
|
| 72 |
|
| 73 |
class ArkanoidEnv(gym.Env):
|
| 74 |
def __init__(self):
|
|
@@ -85,7 +79,8 @@ class ArkanoidEnv(gym.Env):
|
|
| 85 |
self.seed_value = seed
|
| 86 |
self.paddle = Paddle()
|
| 87 |
self.ball = Ball()
|
| 88 |
-
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
|
|
|
|
| 89 |
self.done = False
|
| 90 |
self.score = 0
|
| 91 |
return self._get_state(), {}
|
|
@@ -172,27 +167,27 @@ def train_and_play():
|
|
| 172 |
|
| 173 |
for i in range(0, total_timesteps, timesteps_per_update):
|
| 174 |
model.learn(total_timesteps=timesteps_per_update)
|
| 175 |
-
obs = env.reset()
|
| 176 |
done = False
|
| 177 |
truncated = False
|
| 178 |
while not done and not truncated:
|
| 179 |
action, _states = model.predict(obs, deterministic=True)
|
| 180 |
-
obs, reward, done, truncated,
|
| 181 |
env.render()
|
| 182 |
# Capture the current frame
|
| 183 |
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
|
|
|
| 184 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 185 |
video_frames.append(frame)
|
| 186 |
|
| 187 |
# Save the video
|
| 188 |
video_path = "arkanoid_training.mp4"
|
| 189 |
-
|
| 190 |
-
video_writer = cv2.VideoWriter(video_path, fourcc, FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
|
| 191 |
for frame in video_frames:
|
| 192 |
video_writer.write(frame)
|
| 193 |
video_writer.release()
|
| 194 |
|
| 195 |
-
#
|
| 196 |
return video_path
|
| 197 |
|
| 198 |
# Main function
|
|
@@ -208,17 +203,3 @@ def main():
|
|
| 208 |
|
| 209 |
if __name__ == "__main__":
|
| 210 |
main()
|
| 211 |
-
|
| 212 |
-
# Dependencies
|
| 213 |
-
# =============
|
| 214 |
-
# The following dependencies are required to run this app:
|
| 215 |
-
# - pygame
|
| 216 |
-
# - stable-baselines3
|
| 217 |
-
# - torch
|
| 218 |
-
# - gradio
|
| 219 |
-
# - gymnasium
|
| 220 |
-
# - opencv-python
|
| 221 |
-
# - imageio
|
| 222 |
-
#
|
| 223 |
-
# You can install these dependencies using pip:
|
| 224 |
-
# pip install pygame stable-baselines3 torch gradio gymnasium opencv-python imageio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import numpy as np
|
| 3 |
import pygame
|
|
|
|
| 7 |
from stable_baselines3.common.evaluation import evaluate_policy
|
| 8 |
import gradio as gr
|
| 9 |
import cv2
|
|
|
|
| 10 |
|
| 11 |
# Constants
|
| 12 |
SCREEN_WIDTH = 640
|
|
|
|
| 62 |
|
| 63 |
class Brick:
|
| 64 |
def __init__(self, x, y):
|
| 65 |
+
self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
|
| 66 |
|
| 67 |
class ArkanoidEnv(gym.Env):
|
| 68 |
def __init__(self):
|
|
|
|
| 79 |
self.seed_value = seed
|
| 80 |
self.paddle = Paddle()
|
| 81 |
self.ball = Ball()
|
| 82 |
+
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
|
| 83 |
+
for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
|
| 84 |
self.done = False
|
| 85 |
self.score = 0
|
| 86 |
return self._get_state(), {}
|
|
|
|
| 167 |
|
| 168 |
for i in range(0, total_timesteps, timesteps_per_update):
|
| 169 |
model.learn(total_timesteps=timesteps_per_update)
|
| 170 |
+
obs, _ = env.reset()
|
| 171 |
done = False
|
| 172 |
truncated = False
|
| 173 |
while not done and not truncated:
|
| 174 |
action, _states = model.predict(obs, deterministic=True)
|
| 175 |
+
obs, reward, done, truncated, _ = env.step(action)
|
| 176 |
env.render()
|
| 177 |
# Capture the current frame
|
| 178 |
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
| 179 |
+
frame = np.rot90(frame) # Fix orientation
|
| 180 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 181 |
video_frames.append(frame)
|
| 182 |
|
| 183 |
# Save the video
|
| 184 |
video_path = "arkanoid_training.mp4"
|
| 185 |
+
video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
|
|
|
|
| 186 |
for frame in video_frames:
|
| 187 |
video_writer.write(frame)
|
| 188 |
video_writer.release()
|
| 189 |
|
| 190 |
+
env.close() # Ensure the environment is properly closed
|
| 191 |
return video_path
|
| 192 |
|
| 193 |
# Main function
|
|
|
|
| 203 |
|
| 204 |
if __name__ == "__main__":
|
| 205 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|