1
+ import copy
2
+ import test_function
3
+ import matplotlib .pyplot as plt
4
+ import random
5
+
6
+ class GA ():
7
+
8
+ def __init__ (self ,iterations ,n_dim ,lb ,ub ,pop_size ,target_function ,retain_rate ,random_select_rate ,mutation_probability ):
9
+ self .iterations = iterations
10
+ self .n_dim = n_dim
11
+ self .lb = lb
12
+ self .ub = ub
13
+ self .pop_size = pop_size
14
+ self .target_function = target_function
15
+ self .retain_rate = retain_rate
16
+ self .random_select_rate = random_select_rate
17
+ self .mutation_probability = mutation_probability
18
+
19
+ def particle_init (self ):
20
+ self .particle = [[] for i in range (self .pop_size )]
21
+ self .g_best = [0 for i in range (self .n_dim )]
22
+ self .g_best .append (float ('inf' ))
23
+ for i in range (self .pop_size ):
24
+ for j in range (self .n_dim ):
25
+ self .particle [i ].append (random .uniform (self .lb ,self .ub ))
26
+ self .particle [i ].append (self .target_function (self .particle [i ]))
27
+ if self .g_best [- 1 ]> self .particle [i ][- 1 ]:
28
+ self .g_best = copy .deepcopy (self .particle [i ])
29
+ self .chosen_probability = [0 for i in range (self .pop_size )]
30
+ self .calculate_chosen_probability ()
31
+
32
+ def calculate_chosen_probability (self ):
33
+ fitness = 0
34
+
35
+ for i in range (self .pop_size ):
36
+ fitness += self .particle [i ][- 1 ]
37
+ for i in range (self .pop_size ):
38
+ self .chosen_probability [i ]= self .particle [i ][- 1 ]/ fitness
39
+
40
+ def selction (self ):
41
+ fitness = []
42
+ for i in self .particle :
43
+ fitness .append (i [- 1 ])
44
+ fitness .sort (reverse = True )
45
+ retain_criteria = fitness [int (self .pop_size * self .retain_rate )]
46
+ parents = []
47
+ for i in range (self .pop_size ):
48
+ if self .particle [i ][- 1 ]<= retain_criteria or random .random ()< self .random_select_rate :
49
+ parents .append (self .particle [i ])
50
+ self .particle = copy .deepcopy (parents )
51
+
52
+ def cross (self ):
53
+ count_parents = len (self .particle )
54
+ count_cross = self .pop_size - count_parents
55
+ for i in range (count_cross ):
56
+ parent_1 = self .particle [random .randint (0 , count_parents - 1 )]
57
+ parent_2 = self .particle [random .randint (0 , count_parents - 1 )]
58
+ child = []
59
+ for j in range (self .n_dim ):
60
+ child .append (random .uniform (min (parent_1 [j ], parent_2 [j ]), max (parent_1 [j ], parent_2 [j ])))
61
+ child .append (self .target_function (child ))
62
+ self .particle .append (child )
63
+ if self .g_best [- 1 ]> self .particle [- 1 ][- 1 ]:
64
+ self .g_best = copy .deepcopy (self .particle [- 1 ])
65
+
66
+
67
+ def mutate (self ,iteration ):
68
+ for i in range (int (self .pop_size * self .retain_rate ),self .pop_size ):
69
+ if random .random ()< self .mutation_probability :
70
+ for j in range (self .n_dim ):
71
+ self .particle [i ][j ]+= random .uniform (- 0.1 ,0.1 )* (self .ub - self .lb )* (1 - iteration / self .iterations )
72
+ if self .particle [i ][j ]> self .ub :
73
+ self .particle [i ][j ]= self .ub
74
+ if self .particle [i ][j ]< self .lb :
75
+ self .particle [i ][j ]= self .lb
76
+ self .particle [i ][- 1 ]= self .target_function (self .particle [i ][:self .n_dim ])
77
+ if self .g_best [- 1 ]> self .particle [i ][- 1 ]:
78
+ self .g_best = copy .deepcopy (self .particle [i ])
79
+
80
+ def run (self ):
81
+ self .particle_init ()
82
+ self .g_best_hist = [self .g_best [- 1 ]]
83
+ for i in range (self .iterations ):
84
+ self .calculate_chosen_probability ()
85
+ self .selction ()
86
+ self .cross ()
87
+ self .mutate (i )
88
+ self .g_best_hist .append (self .g_best [- 1 ])
89
+
90
+ def result (self ):
91
+ return self .g_best
92
+
93
+ def convergence_curve (self ):
94
+ plt .plot (self .g_best_hist )
95
+ plt .yscale ('log' )
96
+ plt .show ()
97
+
98
+ if __name__ == '__main__' :
99
+
100
+ test = GA (iterations = 1000 ,n_dim = 1 ,lb = - 500 ,ub = 500 ,pop_size = 50 ,target_function = function ,retain_rate = 0.3 ,random_select_rate = 0.2 ,mutation_probability = 0.8 )
101
+ test .run ()
102
+ result = test .result ()
103
+ test .convergence_curve ()
104
+ print (result )
0 commit comments