Skip to content

Commit faa3b92

Browse files
author
Shunichi09
committed
Add model of nonlinear sample system
1 parent d64a799 commit faa3b92

File tree

3 files changed

+59
-32
lines changed

3 files changed

+59
-32
lines changed

Environments.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
| Two wheeled System (Constant Goal) | x || 3 | 2 |
77
| Two wheeled System (Moving Goal) (Coming soon) | x || 3 | 2 |
88
| Cartpole (Swing up) | x || 4 | 1 |
9+
| Nonlinear Sample System Env | x || 2 | 1 |
10+
911

1012
## [FistOrderLagEnv](PythonLinearNonlinearControl/envs/first_order_lag.py)
1113

@@ -53,4 +55,14 @@ mc = 1, mp = 0.2, l = 0.5, g = 9.81
5355

5456
### Cost.
5557

56-
<img src="assets/cartpole_score.png" width="300">
58+
<img src="assets/cartpole_score.png" width="300">
59+
60+
## [Nonlinear Sample System Env](PythonLinearNonlinearControl/envs/nonlinear_sample_system.py)
61+
62+
## System equation.
63+
64+
<img src="assets/nonlinear_sample_system.png" width="400">
65+
66+
### Cost.
67+
68+
<img src="assets/nonlinear_sample_system_score.png" width="400">

PythonLinearNonlinearControl/common/utils.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,16 @@ def fit_angle_in_range(angles, min_angle=-np.pi, max_angle=np.pi):
4646
return output.reshape(output_shape)
4747

4848

49-
def update_state_with_Runge_Kutta(state, u, functions, dt=0.01):
49+
def update_state_with_Runge_Kutta(state, u, functions, dt=0.01, batch=True):
5050
""" update state in Runge Kutta methods
5151
Args:
5252
state (array-like): state of system
5353
u (array-like): input of system
5454
functions (list): update function of each state,
55-
each function will be called like func(*state, *u)
55+
each function will be called like func(state, u)
5656
We expect that this function returns differential of each state
5757
dt (float): float in seconds
58+
batch (bool): state and u is given by batch or not
5859
5960
Returns:
6061
next_state (np.array): next state of system
@@ -68,36 +69,50 @@ def func_x(self, x_1, x_2, u):
6869
6970
Note that the function return x_dot.
7071
"""
71-
state_size = len(state)
72-
assert state_size == len(functions), \
73-
"Invalid functions length, You need to give the state size functions"
72+
if not batch:
73+
state_size = len(state)
74+
assert state_size == len(functions), \
75+
"Invalid functions length, You need to give the state size functions"
7476

75-
k0 = np.zeros(state_size)
76-
k1 = np.zeros(state_size)
77-
k2 = np.zeros(state_size)
78-
k3 = np.zeros(state_size)
77+
k0 = np.zeros(state_size)
78+
k1 = np.zeros(state_size)
79+
k2 = np.zeros(state_size)
80+
k3 = np.zeros(state_size)
7981

80-
inputs = np.concatenate([state, u])
82+
for i, func in enumerate(functions):
83+
k0[i] = dt * func(state, u)
8184

82-
for i, func in enumerate(functions):
83-
k0[i] = dt * func(*inputs)
85+
for i, func in enumerate(functions):
86+
k1[i] = dt * func(state + k0 / 2., u)
8487

85-
add_state = state + k0 / 2.
86-
inputs = np.concatenate([add_state, u])
88+
for i, func in enumerate(functions):
89+
k2[i] = dt * func(state + k1 / 2., u)
8790

88-
for i, func in enumerate(functions):
89-
k1[i] = dt * func(*inputs)
91+
for i, func in enumerate(functions):
92+
k3[i] = dt * func(state + k2, u)
9093

91-
add_state = state + k1 / 2.
92-
inputs = np.concatenate([add_state, u])
94+
return (k0 + 2. * k1 + 2. * k2 + k3) / 6.
9395

94-
for i, func in enumerate(functions):
95-
k2[i] = dt * func(*inputs)
96+
else:
97+
batch_size, state_size = state.shape
98+
assert state_size == len(functions), \
99+
"Invalid functions length, You need to give the state size functions"
96100

97-
add_state = state + k2
98-
inputs = np.concatenate([add_state, u])
101+
k0 = np.zeros(batch_size, state_size)
102+
k1 = np.zeros(batch_size, state_size)
103+
k2 = np.zeros(batch_size, state_size)
104+
k3 = np.zeros(batch_size, state_size)
99105

100-
for i, func in enumerate(functions):
101-
k3[i] = dt * func(*inputs)
106+
for i, func in enumerate(functions):
107+
k0[:, i] = dt * func(state, u)
102108

103-
return (k0 + 2. * k1 + 2. * k2 + k3) / 6.
109+
for i, func in enumerate(functions):
110+
k1[:, i] = dt * func(state + k0 / 2., u)
111+
112+
for i, func in enumerate(functions):
113+
k2[:, i] = dt * func(state + k1 / 2., u)
114+
115+
for i, func in enumerate(functions):
116+
k3[:, i] = dt * func(state + k2, u)
117+
118+
return (k0 + 2. * k1 + 2. * k2 + k3) / 6.

PythonLinearNonlinearControl/envs/nonlinear_sample_system.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def step(self, u):
5959
self.config["input_lower_bound"],
6060
self.config["input_upper_bound"])
6161

62-
funtions = [self._func_x_1, self._func_x_2]
62+
functions = [self._func_x_1, self._func_x_2]
6363

64-
next_x = update_state_with_Runge_Kutta(self._curr_x, u,
64+
next_x = update_state_with_Runge_Kutta(self.curr_x, u,
6565
functions, self.config["dt"])
6666

6767
# cost
@@ -82,16 +82,16 @@ def step(self, u):
8282
self.step_count > self.config["max_step"], \
8383
{"goal_state": self.g_x}
8484

85-
def _func_x_1(self, x_1, x_2, u):
85+
def _func_x_1(self, x, u):
8686
"""
8787
"""
88-
x_dot = x_2
88+
x_dot = x[1]
8989
return x_dot
9090

91-
def _func_x_2(self, x_1, x_2, u):
91+
def _func_x_2(self, x, u):
9292
"""
9393
"""
94-
x_dot = (1. - x_1**2 - x_2**2) * x_2 - x_1 + u
94+
x_dot = (1. - x[0]**2 - x[1]**2) * x[1] - x[0] + u
9595
return x_dot
9696

9797
def plot_func(self, to_plot, i=None, history_x=None, history_g_x=None):

0 commit comments

Comments
 (0)