@@ -46,15 +46,16 @@ def fit_angle_in_range(angles, min_angle=-np.pi, max_angle=np.pi):
46
46
return output .reshape (output_shape )
47
47
48
48
49
- def update_state_with_Runge_Kutta (state , u , functions , dt = 0.01 ):
49
+ def update_state_with_Runge_Kutta (state , u , functions , dt = 0.01 , batch = True ):
50
50
""" update state in Runge Kutta methods
51
51
Args:
52
52
state (array-like): state of system
53
53
u (array-like): input of system
54
54
functions (list): update function of each state,
55
- each function will be called like func(* state, * u)
55
+ each function will be called like func(state, u)
56
56
We expect that this function returns differential of each state
57
57
dt (float): float in seconds
58
+ batch (bool): state and u is given by batch or not
58
59
59
60
Returns:
60
61
next_state (np.array): next state of system
@@ -68,36 +69,50 @@ def func_x(self, x_1, x_2, u):
68
69
69
70
Note that the function return x_dot.
70
71
"""
71
- state_size = len (state )
72
- assert state_size == len (functions ), \
73
- "Invalid functions length, You need to give the state size functions"
72
+ if not batch :
73
+ state_size = len (state )
74
+ assert state_size == len (functions ), \
75
+ "Invalid functions length, You need to give the state size functions"
74
76
75
- k0 = np .zeros (state_size )
76
- k1 = np .zeros (state_size )
77
- k2 = np .zeros (state_size )
78
- k3 = np .zeros (state_size )
77
+ k0 = np .zeros (state_size )
78
+ k1 = np .zeros (state_size )
79
+ k2 = np .zeros (state_size )
80
+ k3 = np .zeros (state_size )
79
81
80
- inputs = np .concatenate ([state , u ])
82
+ for i , func in enumerate (functions ):
83
+ k0 [i ] = dt * func (state , u )
81
84
82
- for i , func in enumerate (functions ):
83
- k0 [i ] = dt * func (* inputs )
85
+ for i , func in enumerate (functions ):
86
+ k1 [i ] = dt * func (state + k0 / 2. , u )
84
87
85
- add_state = state + k0 / 2.
86
- inputs = np . concatenate ([ add_state , u ] )
88
+ for i , func in enumerate ( functions ):
89
+ k2 [ i ] = dt * func ( state + k1 / 2. , u )
87
90
88
- for i , func in enumerate (functions ):
89
- k1 [i ] = dt * func (* inputs )
91
+ for i , func in enumerate (functions ):
92
+ k3 [i ] = dt * func (state + k2 , u )
90
93
91
- add_state = state + k1 / 2.
92
- inputs = np .concatenate ([add_state , u ])
94
+ return (k0 + 2. * k1 + 2. * k2 + k3 ) / 6.
93
95
94
- for i , func in enumerate (functions ):
95
- k2 [i ] = dt * func (* inputs )
96
+ else :
97
+ batch_size , state_size = state .shape
98
+ assert state_size == len (functions ), \
99
+ "Invalid functions length, You need to give the state size functions"
96
100
97
- add_state = state + k2
98
- inputs = np .concatenate ([add_state , u ])
101
+ k0 = np .zeros (batch_size , state_size )
102
+ k1 = np .zeros (batch_size , state_size )
103
+ k2 = np .zeros (batch_size , state_size )
104
+ k3 = np .zeros (batch_size , state_size )
99
105
100
- for i , func in enumerate (functions ):
101
- k3 [ i ] = dt * func (* inputs )
106
+ for i , func in enumerate (functions ):
107
+ k0 [:, i ] = dt * func (state , u )
102
108
103
- return (k0 + 2. * k1 + 2. * k2 + k3 ) / 6.
109
+ for i , func in enumerate (functions ):
110
+ k1 [:, i ] = dt * func (state + k0 / 2. , u )
111
+
112
+ for i , func in enumerate (functions ):
113
+ k2 [:, i ] = dt * func (state + k1 / 2. , u )
114
+
115
+ for i , func in enumerate (functions ):
116
+ k3 [:, i ] = dt * func (state + k2 , u )
117
+
118
+ return (k0 + 2. * k1 + 2. * k2 + k3 ) / 6.
0 commit comments