Skip to content

Commit 4c4a533

Browse files
committed
Initial commit
0 parents  commit 4c4a533

File tree

7 files changed

+383
-0
lines changed

7 files changed

+383
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/__pycache__
2+
arial.ttf
3+
demo.mp4

agent.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import torch
2+
import random
3+
import numpy as np
4+
from game import SnakeGameAI, Direction, Point
5+
from collections import deque
6+
from model import Linear_QNet, QTrainer
7+
from helper import plot
8+
9+
MAX_MEMORY = 100_000
10+
BATCH_SIZE = 1000
11+
LR = 0.001
12+
13+
class Agent:
14+
def __init__(self):
15+
self.n_games = 0
16+
self.epsilon = 0 # randomness
17+
self.gamma = 0.9 # discount rate, must be < 1
18+
self.memory = deque(maxlen=MAX_MEMORY)
19+
self.model = Linear_QNet(11, 256, 3)
20+
self.model.load_state_dict(torch.load("model\model.pth"))
21+
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
22+
23+
def get_state(self, game):
24+
head = game.snake[0]
25+
point_l = Point(head.x - 20, head.y)
26+
point_r = Point(head.x + 20, head.y)
27+
point_u = Point(head.x, head.y - 20)
28+
point_d = Point(head.x, head.y + 20)
29+
30+
dir_l = game.direction == Direction.LEFT
31+
dir_r = game.direction == Direction.RIGHT
32+
dir_u = game.direction == Direction.UP
33+
dir_d = game.direction == Direction.DOWN
34+
35+
state = [
36+
# Danger straight
37+
(dir_r and game.is_collision(point_r)) or
38+
(dir_l and game.is_collision(point_l)) or
39+
(dir_u and game.is_collision(point_u)) or
40+
(dir_d and game.is_collision(point_d)),
41+
42+
# Danger right
43+
(dir_u and game.is_collision(point_r)) or
44+
(dir_d and game.is_collision(point_l)) or
45+
(dir_l and game.is_collision(point_u)) or
46+
(dir_r and game.is_collision(point_d)),
47+
48+
# Danger left
49+
(dir_d and game.is_collision(point_r)) or
50+
(dir_u and game.is_collision(point_l)) or
51+
(dir_r and game.is_collision(point_u)) or
52+
(dir_l and game.is_collision(point_d)),
53+
54+
# Move direction
55+
dir_l,
56+
dir_r,
57+
dir_u,
58+
dir_d,
59+
60+
# Food location
61+
game.food.x < game.head.x, # food left
62+
game.food.x > game.head.x, # food right
63+
game.food.y < game.head.y, # food up
64+
game.food.y > game.head.y # food down
65+
]
66+
67+
return np.array(state, dtype=int)
68+
69+
def remember(self, state, action, reward, next_state, done):
70+
self.memory.append((state, action, reward, next_state, done))
71+
72+
def train_long_memory(self):
73+
if len(self.memory) > BATCH_SIZE:
74+
mini_sample = random.sample(self.memory, BATCH_SIZE)
75+
else:
76+
mini_sample = self.memory
77+
78+
states, actions, rewards, next_states, dones = zip(*mini_sample)
79+
self.trainer.train_step(states, actions, rewards, next_states, dones)
80+
81+
82+
83+
def train_short_memory(self, state, action, reward, next_state, done):
84+
self.trainer.train_step(state, action, reward, next_state, done)
85+
86+
def get_action(self, state):
87+
# random moves: tradeoff exploration / exploitation
88+
# self.epsilon = 85 - self.n_games
89+
final_move = [0, 0, 0]
90+
# if random.randint(0,200) < self.epsilon:
91+
# move = random.randint(0, 2)
92+
# final_move[move] = 1
93+
# else:
94+
state0 = torch.tensor(state, dtype=torch.float)
95+
prediction = self.model(state0)
96+
move = torch.argmax(prediction).item()
97+
final_move[move] = 1
98+
99+
return final_move
100+
101+
def train():
102+
plot_scores = []
103+
plot_mean_scores = []
104+
total_score = 0
105+
record = 0 # best score so far
106+
agent = Agent()
107+
game = SnakeGameAI()
108+
while True:
109+
# get old state
110+
state_old = agent.get_state(game)
111+
112+
# get move
113+
final_move = agent.get_action(state_old)
114+
115+
# perform move and get new state
116+
reward, done, score = game.play_step(final_move)
117+
state_new = agent.get_state(game)
118+
119+
# train short memory
120+
agent.train_short_memory(state_old, final_move, reward, state_new, done)
121+
122+
# remember
123+
agent.remember(state_old, final_move, reward, state_new, done)
124+
125+
if done:
126+
# train long memory
127+
game.reset()
128+
agent.n_games += 1
129+
agent.train_long_memory()
130+
131+
if score > record:
132+
record = score
133+
agent.model.save()
134+
135+
print('Game', agent.n_games, 'Score', score, 'Record:', record)
136+
137+
plot_scores.append(score)
138+
total_score += score
139+
mean_score = total_score / agent.n_games
140+
plot_mean_scores.append(mean_score)
141+
plot(plot_scores, plot_mean_scores)
142+
143+
if __name__ == '__main__':
144+
train()

game.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import pygame
2+
import random
3+
from enum import Enum
4+
from collections import namedtuple
5+
import numpy as np
6+
7+
pygame.init()
8+
font = pygame.font.Font('arial.ttf', 25)
9+
#font = pygame.font.SysFont('arial', 25)
10+
11+
12+
# reset
13+
# reward
14+
# play(action) -> direction
15+
# game_iteration
16+
#is_collision
17+
class Direction(Enum):
18+
RIGHT = 1
19+
LEFT = 2
20+
UP = 3
21+
DOWN = 4
22+
23+
Point = namedtuple('Point', 'x, y')
24+
25+
# rgb colors
26+
WHITE = (255, 255, 255)
27+
RED = (200,0,0)
28+
BLUE1 = (0, 0, 255)
29+
BLUE2 = (0, 100, 255)
30+
BLACK = (0,0,0)
31+
32+
BLOCK_SIZE = 20
33+
SPEED = 120
34+
35+
class SnakeGameAI:
36+
37+
def __init__(self, w=640, h=480):
38+
self.w = w
39+
self.h = h
40+
# init display
41+
self.display = pygame.display.set_mode((self.w, self.h))
42+
pygame.display.set_caption('Snake')
43+
self.clock = pygame.time.Clock()
44+
self.reset()
45+
46+
def reset(self):
47+
# init game state
48+
self.direction = Direction.RIGHT
49+
50+
self.head = Point(self.w/2, self.h/2)
51+
self.snake = [self.head,
52+
Point(self.head.x-BLOCK_SIZE, self.head.y),
53+
Point(self.head.x-(2*BLOCK_SIZE), self.head.y)]
54+
55+
self.score = 0
56+
self.food = None
57+
self._place_food()
58+
self.frame_iteration = 0
59+
60+
def _place_food(self):
61+
x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
62+
y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
63+
self.food = Point(x, y)
64+
if self.food in self.snake:
65+
self._place_food()
66+
67+
def play_step(self, action):
68+
self.frame_iteration += 1
69+
# 1. collect user input
70+
for event in pygame.event.get(pump=False):
71+
if event.type == pygame.QUIT:
72+
pygame.quit()
73+
quit()
74+
75+
# 2. move
76+
self._move(action) # update the head
77+
self.snake.insert(0, self.head)
78+
79+
# 3. check if game over
80+
reward = -0
81+
game_over = False
82+
if self.is_collision() or self.frame_iteration > 100*len(self.snake):
83+
game_over = True
84+
reward = -20
85+
return reward, game_over, self.score
86+
87+
# 4. place new food or just move
88+
if self.head == self.food:
89+
self.score += 1
90+
reward = 30
91+
self._place_food()
92+
else:
93+
self.snake.pop()
94+
95+
# 5. update ui and clock
96+
self._update_ui()
97+
self.clock.tick(SPEED)
98+
# 6. return game over and score
99+
return reward, game_over, self.score
100+
101+
def is_collision(self, pt=None):
102+
if pt is None:
103+
pt = self.head
104+
# hits boundary
105+
if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0:
106+
return True
107+
# hits itself
108+
if pt in self.snake[1:]:
109+
return True
110+
111+
return False
112+
113+
def _update_ui(self):
114+
self.display.fill(BLACK)
115+
116+
for pt in self.snake:
117+
pygame.draw.rect(self.display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
118+
pygame.draw.rect(self.display, BLUE2, pygame.Rect(pt.x+4, pt.y+4, 12, 12))
119+
120+
pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))
121+
122+
text = font.render("Score: " + str(self.score), True, WHITE)
123+
self.display.blit(text, [0, 0])
124+
pygame.display.flip()
125+
126+
def _move(self, action):
127+
# [straight, right, left]
128+
129+
clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
130+
idx = clock_wise.index(self.direction)
131+
132+
if np.array_equal(action, [1,0,0]):
133+
new_dir = clock_wise[idx]
134+
elif np.array_equal(action, [0,1,0]):
135+
next_idx = (idx + 1)%4
136+
new_dir = clock_wise[next_idx]
137+
else:
138+
next_idx = (idx - 1)%4
139+
new_dir = clock_wise[next_idx]
140+
self.direction = new_dir
141+
142+
x = self.head.x
143+
y = self.head.y
144+
if self.direction == Direction.RIGHT:
145+
x += BLOCK_SIZE
146+
elif self.direction == Direction.LEFT:
147+
x -= BLOCK_SIZE
148+
elif self.direction == Direction.DOWN:
149+
y += BLOCK_SIZE
150+
elif self.direction == Direction.UP:
151+
y -= BLOCK_SIZE
152+
153+
self.head = Point(x, y)

helper.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import matplotlib.pyplot as plt
2+
from IPython import display
3+
4+
plt.ion()
5+
6+
def plot(scores, mean_scores):
7+
display.clear_output(wait=True)
8+
display.display(plt.gcf())
9+
plt.clf()
10+
plt.title('Training...')
11+
plt.xlabel('Number of Games')
12+
plt.ylabel('Score')
13+
plt.plot(scores)
14+
plt.plot(mean_scores)
15+
plt.ylim(ymin=0)
16+
plt.text(len(scores)-1, scores[-1], str(scores[-1]))
17+
plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))
18+
plt.show(block=False)
19+
plt.pause(.1)

model.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import torch.nn.functional as F
5+
import os
6+
7+
class Linear_QNet(nn.Module):
8+
def __init__(self, input_size, hidden_size, output_size):
9+
super().__init__()
10+
self.linear1 = nn.Linear(input_size, hidden_size)
11+
self.linear2 = nn.Linear(hidden_size, output_size)
12+
13+
14+
def forward(self, x):
15+
x = F.relu(self.linear1(x))
16+
x = self.linear2(x)
17+
return x
18+
19+
def save(self, file_name="model2.pth"):
20+
model_folder_path = "./model"
21+
if not os.path.exists(model_folder_path):
22+
os.makedirs(model_folder_path)
23+
24+
file_name = os.path.join(model_folder_path, file_name)
25+
torch.save(self.state_dict(), file_name)
26+
27+
28+
class QTrainer:
29+
def __init__(self, model, lr, gamma):
30+
self.lr = lr
31+
self.gamma = gamma
32+
self.model = model
33+
self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
34+
self.criterion = nn.MSELoss()
35+
36+
def train_step(self, state, action, reward, next_state, done):
37+
state = torch.tensor(state, dtype=torch.float)
38+
next_state = torch.tensor(next_state, dtype = torch.float)
39+
action = torch.tensor(action, dtype = torch.long)
40+
reward = torch.tensor(reward, dtype = torch.float)
41+
42+
if len(state.shape) == 1:
43+
state = torch.unsqueeze(state,0)
44+
next_state = torch.unsqueeze(next_state, 0)
45+
action = torch.unsqueeze(action, 0)
46+
reward = torch.unsqueeze(reward, 0)
47+
done = (done, )
48+
49+
# 1: get predicted Q values with current state
50+
pred = self.model(state)
51+
52+
target = pred.clone()
53+
for idx in range(len(done)):
54+
Q_new = reward[idx]
55+
if not done[idx]:
56+
Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
57+
58+
target[idx][torch.argmax(action).item()] = Q_new
59+
60+
self.optimizer.zero_grad()
61+
loss = self.criterion(target, pred)
62+
loss.backward()
63+
64+
self.optimizer.step()

model/model.pth

16.4 KB
Binary file not shown.

model/model2.pth

16.4 KB
Binary file not shown.

0 commit comments

Comments
 (0)