Skip to content

Commit 0a80e06

Browse files
committed
AI with alpha beta pruning + iterative deepening added
1 parent 52905a9 commit 0a80e06

File tree

5 files changed

+302
-20
lines changed

5 files changed

+302
-20
lines changed

src/main/java/Engine.java

+149-11
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
1+
import commons.Color;
12
import game.Board;
23
import game.Move;
34

4-
import java.util.Arrays;
5+
import java.util.Comparator;
56
import java.util.List;
7+
import java.util.concurrent.CompletableFuture;
8+
import java.util.concurrent.ExecutionException;
9+
import java.util.concurrent.TimeUnit;
10+
import java.util.concurrent.TimeoutException;
611
import java.util.stream.Collectors;
712

813
public class Engine {
14+
public static int nodesEvaluated;
915

1016
public int countAllMoves(final Board board, final int depth) {
1117
return countAllMoves(board, depth, 1000);
1218
}
1319

1420
public int countAllMoves(final Board board, final int depth, final int printAt) {
1521
final List<Move> legalMoves = board.getLegalMoves();
16-
// if (board.fenRepresentation().equals("8/2p5/K2p4/1P4kr/1R6/6p1/4P3/8 w - -")) {
17-
// System.out.println("HERE");
18-
// System.out.println(board);
19-
// System.out.println(Arrays.deepToString(board.canCastle));
20-
// System.out.println(legalMoves.stream().map(Move::toString).collect(Collectors.joining("\n")));
21-
// final List<Move> legalMovezzz = board.getLegalMoves();
22-
// System.out.println(legalMovezzz);
23-
// }
2422
if (legalMoves.isEmpty()) {
2523
return 0;
2624
}
@@ -31,20 +29,160 @@ public int countAllMoves(final Board board, final int depth, final int printAt)
3129
final Board copy = board.copy();
3230
copy.makeMove(move);
3331
final int countAllMoves = countAllMoves(copy, depth - 1, printAt);
34-
if (depth == printAt) {// && board.fenRepresentation().equals("8/8/2pp4/1P4kr/K4pP1/8/4P3/1R6 b - -")) {
32+
if (depth == printAt) {
3533
System.out.println(getString(move) + ": " + countAllMoves + " " + move +
36-
// Arrays.deepToString(copy.canCastle) + " \n" + copy.getKing(Color.BLACK).getMoveList(copy) + " \n" +
3734
" " + copy.fenRepresentation());
3835
}
3936
return countAllMoves;
4037
}).sum();
4138
}
4239

40+
public Evaluation minMax(final Board board, final int depth, final int printAt) {
41+
final List<Move> legalMoves = board.getLegalMoves();
42+
nodesEvaluated++;
43+
if (legalMoves.isEmpty() || depth == 0) {
44+
return new Evaluation(null, -board.evaluation(legalMoves.size()));
45+
}
46+
Evaluation bestMove = null;
47+
for (Move move : legalMoves) {
48+
final Board copy = board.copy();
49+
copy.makeMove(move);
50+
final Evaluation eval = minMax(copy, depth - 1, printAt);
51+
if (depth == printAt) {
52+
System.out.println(getString(move) + ": " + eval.getScore() + " " + move +
53+
" " + copy.fenRepresentation());
54+
}
55+
if (bestMove == null || bestMove.getScore() > -eval.getScore()) {
56+
bestMove = new Evaluation(move, -eval.getScore());
57+
if (bestMove.getScore() < 0 && bestMove.getScore() + Integer.MAX_VALUE < 0.0001) {
58+
return bestMove;
59+
}
60+
}
61+
}
62+
return bestMove;
63+
}
64+
65+
public OutCome alphaBeta(final Board board, final int depth, double alpha, double beta, final int printAt) {
66+
final List<Move> legalMoves = board.getLegalMoves();
67+
nodesEvaluated++;
68+
if (legalMoves.isEmpty() || depth == 0) {
69+
return new OutCome(board, null, -board.evaluation(legalMoves.size()));
70+
}
71+
final List<OutCome> outComes = legalMoves.stream().map(move -> {
72+
final Board changedBoard = board.copy();
73+
changedBoard.makeMove(move);
74+
return new OutCome(changedBoard, move, changedBoard.evaluation());
75+
}).sorted(Comparator.comparingDouble(outCome -> -outCome.getScore())).collect(Collectors.toList());
76+
OutCome bestMove = null;
77+
for (final OutCome outCome : outComes) {
78+
final OutCome eval = alphaBeta(outCome.getBoard(), depth - 1, alpha, beta, printAt);
79+
if (depth == printAt) {
80+
System.out.println(getString(outCome.getMove()) + ": " + eval.getScore() + " " + outCome +
81+
" " + outCome.getBoard().fenRepresentation());
82+
}
83+
if (bestMove == null || bestMove.getScore() > -eval.getScore()) {
84+
bestMove = new OutCome(board, outCome.getMove(), -eval.getScore());
85+
if (bestMove.getScore() < 0 && bestMove.getScore() + Integer.MAX_VALUE < 0.0001) {
86+
return bestMove;
87+
}
88+
}
89+
if (board.playerToMove.equals(Color.WHITE)) {
90+
if (alpha < -bestMove.getScore()) {
91+
alpha = -bestMove.getScore();
92+
}
93+
} else {
94+
if (beta > bestMove.getScore()) {
95+
beta = bestMove.getScore();
96+
}
97+
}
98+
if (alpha > beta) {
99+
break;
100+
}
101+
}
102+
return bestMove;
103+
}
104+
105+
public OutCome iterativeDeepening(final Board board, final long time) {
106+
final long start = System.currentTimeMillis();
107+
int depth = 1;
108+
OutCome evaluation = alphaBeta(board, depth, Integer.MIN_VALUE, Integer.MAX_VALUE, 1000);
109+
while (System.currentTimeMillis() - start < time * 1000 && Math.abs(evaluation.getScore()) - Integer.MAX_VALUE < 0.0001) {
110+
depth++;
111+
System.out.println("DEPTH: " + depth + " EVAL: " + evaluation + " TIME: " + ((System.currentTimeMillis() - start) / 1000));
112+
try {
113+
int finalDepth = depth;
114+
evaluation = CompletableFuture.supplyAsync(() -> alphaBeta(board, finalDepth, Integer.MIN_VALUE, Integer.MAX_VALUE, 1000)).get(time * 1000 - (System.currentTimeMillis() - start), TimeUnit.MILLISECONDS);
115+
} catch (InterruptedException | ExecutionException | TimeoutException e) {
116+
System.out.println("TIMEOUT: " + depth);
117+
break;
118+
}
119+
}
120+
return evaluation;
121+
}
122+
43123
private String getString(Move move) {
44124
return move.piece.position.notation() + move.target.notation();
45125
}
46126
}
47127

128+
class Evaluation {
129+
private final Move move;
130+
private final double score;
131+
132+
public Evaluation(Move move, double score) {
133+
this.move = move;
134+
this.score = score;
135+
}
136+
137+
public double getScore() {
138+
return score;
139+
}
140+
141+
public Move getMove() {
142+
return move;
143+
}
144+
145+
@Override
146+
public String toString() {
147+
return "{" +
148+
"move=" + move +
149+
", score=" + score +
150+
'}';
151+
}
152+
}
153+
154+
class OutCome {
155+
private final Board board;
156+
private final Move move;
157+
private final double score;
158+
159+
public OutCome(Board board, Move move, double score) {
160+
this.move = move;
161+
this.score = score;
162+
this.board = board;
163+
}
164+
165+
public double getScore() {
166+
return score;
167+
}
168+
169+
public Move getMove() {
170+
return move;
171+
}
172+
173+
public Board getBoard() {
174+
return board;
175+
}
176+
177+
@Override
178+
public String toString() {
179+
return "{" +
180+
"board=\n" + board +
181+
", move=" + move +
182+
", score=" + score +
183+
'}';
184+
}
185+
}
48186
/*
49187
50188

src/main/java/game/Board.java

+9-5
Original file line numberDiff line numberDiff line change
@@ -639,10 +639,14 @@ public Board copy() {
639639
return board;
640640
}
641641

642+
public double evaluation() {
643+
return evaluation(getLegalMoves().size());
644+
}
645+
642646
//todo: idea: Should we compare a list of positions and choose the best? We do not evaluate positions in isolation,
643647
// but rather rank them by comparing the top 20 positions possible
644-
public int evaluation() {
645-
if (getLegalMoves().size() == 0) {
648+
public double evaluation(int availableMoves) {
649+
if (availableMoves == 0) {
646650
if (inCheck) {
647651
return Integer.MIN_VALUE;
648652
} else {
@@ -652,12 +656,12 @@ public int evaluation() {
652656
if (isDraw()) {
653657
return 0;
654658
}
655-
return heuristic();
659+
return heuristic() + availableMoves / 100.0;
656660
}
657661

658662
private int heuristic() {
659-
return playerPieces.get(Color.WHITE).stream().mapToInt(piece -> approxValue.get(piece.pieceType)).sum()
660-
- playerPieces.get(Color.BLACK).stream().mapToInt(piece -> approxValue.get(piece.pieceType)).sum();
663+
return playerPieces.get(playerToMove).stream().mapToInt(piece -> approxValue.get(piece.pieceType)).sum()
664+
- playerPieces.get(Color.opponent(playerToMove)).stream().mapToInt(piece -> approxValue.get(piece.pieceType)).sum();
661665
}
662666

663667
public boolean isDraw() {

src/test/java/BoardTest.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public void captureAndReturnCheck() {
4343
System.out.println(board.getLegalMoves());
4444
Assert.assertEquals(6, board.getLegalMoves().size());
4545
Assert.assertFalse(board.isDraw());
46-
Assert.assertEquals(-2, board.evaluation());
46+
Assert.assertEquals(-2, board.evaluation(board.getLegalMoves().size()));
4747
board.makeMove(board.getLegalMoves().stream().filter(c -> c.captureMove).findAny().get());
4848
System.out.println(board);
4949
System.out.println(board.getLegalMoves());
@@ -64,7 +64,7 @@ public void checkMate() {
6464
System.out.println(board);
6565
System.out.println(board.getLegalMoves());
6666
Assert.assertTrue(board.getLegalMoves().isEmpty());
67-
Assert.assertEquals(Integer.MIN_VALUE, board.evaluation());
67+
Assert.assertEquals(Integer.MIN_VALUE, board.evaluation(board.getLegalMoves().size()));
6868
}
6969

7070
@Test
@@ -77,7 +77,7 @@ public void staleMate() {
7777
System.out.println(board);
7878
System.out.println(board.getLegalMoves());
7979
Assert.assertTrue(board.getLegalMoves().isEmpty());
80-
Assert.assertEquals(0, board.evaluation());
80+
Assert.assertEquals(0, board.evaluation(board.getLegalMoves().size()));
8181
}
8282

8383
@Test
@@ -91,7 +91,7 @@ public void checkMateIn1() {
9191
System.out.println(board);
9292
System.out.println(board.getLegalMoves());
9393
Assert.assertEquals(21, board.getLegalMoves().size());
94-
Assert.assertEquals(-1, board.evaluation());
94+
Assert.assertEquals(-1, board.evaluation(board.getLegalMoves().size()));
9595
}
9696

9797
@Test(expected = NoSuchElementException.class)

src/test/java/IntenseTest.java

+22
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,28 @@ public void countMovesAtPosition4() {
1717
Assert.assertEquals(706045033, engine.countAllMoves(board, 6));
1818
}
1919

20+
@Test
21+
public void countMovesAtPosition4MinMax() {
22+
Board board = Board.getBoard("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1");
23+
board.inCheck = true;
24+
System.out.println(board);
25+
Engine engine = new Engine();
26+
Engine.nodesEvaluated = 0;
27+
final Evaluation evaluation = engine.minMax(board, 4, 4);
28+
System.out.println(evaluation + " \nNODES: " + Engine.nodesEvaluated);
29+
}
30+
31+
@Test
32+
public void countMovesAtPosition4AlphaBeta() {
33+
Board board = Board.getBoard("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1");
34+
board.inCheck = true;
35+
System.out.println(board);
36+
Engine engine = new Engine();
37+
Engine.nodesEvaluated = 0;
38+
final OutCome evaluation = engine.alphaBeta(board, 4, Integer.MIN_VALUE, Integer.MAX_VALUE, 4);
39+
System.out.println(evaluation + " \nNODES: " + Engine.nodesEvaluated);
40+
}
41+
2042
@Test
2143
public void countMovesAtPosition5() {
2244
Board board = Board.getBoard("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8");

0 commit comments

Comments
 (0)