Skip to content

Commit b149523

Browse files
committed
Added example of single shooting
1 parent 5625283 commit b149523

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

bindings/python/examples/lipm_ms.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pyopensot as pysot
55
import rclpy
66
from collections import deque
7+
from ttictoc import tic, toc
78

89
def plot_trajectory(Ns, x_value, u_value, zmp_refs, dt):
910
plt.figure(figsize=(8, 4))
@@ -120,7 +121,10 @@ def min_x(x, Q=np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
120121
stack.update()
121122
solver = pysot.iHQP(stack)
122123

124+
tic()
123125
w = solver.solve()
126+
elapsed = toc() # End timer and print elapsed time
127+
print(f"Elapsed time: {elapsed:.3f} seconds")
124128

125129
x_value = np.zeros((nx, Ns+1))
126130
u_value = np.zeros((nu, Ns))
@@ -155,6 +159,8 @@ def min_x(x, Q=np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
155159
# --- Main Loop ---
156160
t = 0
157161

162+
t_mpc = 0.
163+
158164
x0 = np.zeros((nx, 1))
159165
try:
160166
while rclpy.ok():
@@ -170,9 +176,11 @@ def min_x(x, Q=np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
170176
zry.append(zmp_tasks[0].getb()[1])
171177
zryplot = zry.popleft()
172178

179+
tic()
173180
stack.update()
174181

175182
w = solver.solve()
183+
t_mpc += toc()
176184

177185
for i in range(Ns):
178186
x_value[:, i] = variables.getVariable(f"x{i}").getValue(w)
@@ -201,6 +209,7 @@ def min_x(x, Q=np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
201209
pass
202210
finally:
203211
print("Stopping the node.")
212+
print("Average mpt time: {:.3f} seconds".format(t_mpc / t))
204213

205214
if rclpy.ok():
206215
rclpy.shutdown()
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from pyopensot import AffineHelper, OptvarHelper, GenericTask, AggregatedTask
4+
import pyopensot as pysot
5+
import rclpy
6+
from collections import deque
7+
from ttictoc import tic, toc
8+
9+
def plot_trajectory(Ns, x_value, u_value, zmp_refs, dt):
10+
plt.figure(figsize=(8, 4))
11+
plt.plot(np.arange(Ns + 1) * dt, x_value[1, :], label='$r_y$ (CoM position)')
12+
plt.plot(np.arange(Ns) * dt, u_value[1, :], label='$z_y$ (ZMP position)')
13+
plt.plot(np.arange(Ns) * dt, zmp_refs[1, :], label='$z_y$ reference')
14+
plt.xlabel('Time [s]')
15+
plt.ylabel('Position [m]')
16+
plt.grid(True)
17+
plt.legend()
18+
plt.tight_layout()
19+
plt.show()
20+
def zmp_pattern(ns):
21+
zref = np.zeros((2, ns))
22+
for i in range(ns):
23+
zref[:, i] = np.zeros((2, 1)).flatten()
24+
if i >= 10 and i < 20:
25+
zref[:, i] = np.array([0.0, 0.1])
26+
elif i >= 20 and i < 30:
27+
zref[:, i] = np.array([0.0, -0.1])
28+
elif i >= 30:
29+
zref[:, i] = np.array([0.0, 0.])
30+
return zref
31+
32+
rclpy.init()
33+
34+
nx = 4 # com position and velocity [x, y, xdot, ydot]
35+
nu = 2 # zmp position [zmp_x, zmp_y]
36+
37+
h = 0.83 # height of the CoM
38+
w = np.sqrt(9.81 / h)
39+
40+
Ns = 40 # number of nodes
41+
tf = 3.0 # final time
42+
43+
dt = tf / Ns # time step
44+
45+
vars = list()
46+
vars.append((f"x0", nx))
47+
for i in range(Ns):
48+
vars.append((f"u{i}", nu))
49+
50+
variables = OptvarHelper(vars)
51+
print(f"variables.getSize(): {variables.getSize()}")
52+
53+
def lipm(r, z, h):
54+
w = np.sqrt(9.81 / h)
55+
return w*w*(r - z)
56+
57+
def euler(x, xdot, dt):
58+
return x + dt * xdot # x1 = x0 + dt * xdot0
59+
60+
61+
def get_state(ns, variables, h, dt):
62+
x = variables.getVariable("x0")
63+
if ns == 0:
64+
return x
65+
else:
66+
for i in range(ns):
67+
r = x[0:2]
68+
rdot = x[2:]
69+
rddot = lipm(r, variables.getVariable(f"u{i}"), h)
70+
71+
xdot = AffineHelper.pile(rdot, rddot)
72+
x = euler(x, xdot, dt)
73+
return x
74+
75+
def initial_state_constraint(x0, value):
76+
tmp = x0 + value
77+
return GenericTask("initial_state", tmp.getM(), tmp.getq())
78+
79+
def min_u(u, R=np.array([[1, 0], [0, 1]])):
80+
T = GenericTask("zmp_tracking", u.getM(), u.getq())
81+
T.setWeight(R)
82+
return T
83+
84+
initial_state = initial_state_constraint(variables.getVariable("x0"), np.array([0., 0., 0, 0.]))
85+
final_state = get_state(Ns, variables, h, dt)
86+
min_rdot_final = GenericTask("min_rdot_final", final_state[2:].getM(), final_state[2:].getq())
87+
88+
zmp_tasks = list()
89+
for i in range(Ns):
90+
zmp_tasks.append(min_u(variables.getVariable(f"u{i}"), R=1e1 * np.array([[1, 0], [0, 1]])))
91+
zmp_tracking_task = AggregatedTask(zmp_tasks, variables.getSize())
92+
93+
# Create the stack
94+
cost = zmp_tracking_task + min_rdot_final
95+
constraints = initial_state
96+
97+
zmp_refs = zmp_pattern(Ns)
98+
for i in range(Ns):
99+
zmp_tasks[i].setb(zmp_refs[:, i])
100+
zmp_tasks[i].update()
101+
102+
# 1. Trajectory Optimization
103+
stack = pysot.AutoStack(cost) << constraints
104+
stack.update()
105+
solver = pysot.iHQP(stack)
106+
107+
tic()
108+
w = solver.solve()
109+
elapsed = toc() # End timer and print elapsed time
110+
print(f"Elapsed time: {elapsed:.3f} seconds")
111+
112+
x_value = np.zeros((nx, Ns+1))
113+
u_value = np.zeros((nu, Ns))
114+
115+
for i in range(Ns):
116+
u_value[:, i] = variables.getVariable(f"u{i}").getValue(w)
117+
x_value[:, i] = get_state(i+1, variables, h, dt).getValue(w)
118+
x_value[:, Ns] = get_state(Ns, variables, h, dt).getValue(w)
119+
120+
# Plot
121+
plot_trajectory(Ns, x_value, u_value, zmp_refs, dt)
122+
123+
# 2. MPC
124+
# zeroing references
125+
for i in range(Ns):
126+
zmp_tasks[i].setb(np.zeros((2, 1)))
127+
128+
stack.update()
129+
130+
# --- Prepare plot ---
131+
ry = deque([0.]*100)
132+
zy = deque([0.]*100)
133+
zry = deque([0.]*100)
134+
135+
plt.ion()
136+
line1, = plt.plot(ry, label='$r_y$ (CoM position)')
137+
line2, = plt.plot(zy, label='$z_y$ (ZMP position)')
138+
line3, = plt.plot(zry, label='$zr_y$ (ZMP ref)')
139+
plt.ylim([-0.5, 0.5])
140+
plt.show()
141+
# --- Main Loop ---
142+
t = 0
143+
144+
t_mpc = 0.
145+
146+
x0 = np.zeros((nx, 1))
147+
try:
148+
while rclpy.ok():
149+
ry.append(x0[1,0])
150+
ryplot = ry.popleft()
151+
152+
initial_state.setb(x0)
153+
154+
# shift reference to left
155+
for j in range(1, Ns):
156+
zmp_tasks[j-1].setb(zmp_tasks[j].getb())
157+
zmp_tasks[Ns-1].setb(zmp_refs[:, t % Ns])
158+
zry.append(zmp_tasks[0].getb()[1])
159+
zryplot = zry.popleft()
160+
161+
tic()
162+
stack.update()
163+
164+
w = solver.solve()
165+
t_mpc += toc()
166+
167+
for i in range(Ns):
168+
u_value[:, i] = variables.getVariable(f"u{i}").getValue(w)
169+
x_value[:, i] = get_state(i + 1, variables, h, dt).getValue(w)
170+
x_value[:, Ns] = get_state(Ns, variables, h, dt).getValue(w)
171+
172+
rdot = x_value[2:4, 0].flatten()
173+
rddot = lipm(x_value[0:2, 0], u_value[:, 0], h)
174+
175+
xdot = np.hstack((rdot, rddot)).reshape((nx, 1))
176+
x0 += dt * xdot
177+
178+
t += 1
179+
zy.append(u_value[1, 0])
180+
zyplot = zy.popleft()
181+
182+
line1.set_ydata(ry)
183+
line2.set_ydata(zy)
184+
line3.set_ydata(zry)
185+
plt.draw()
186+
187+
plt.pause(dt)
188+
189+
except KeyboardInterrupt:
190+
print("KeyboardInterrupt: Stopping the node.")
191+
pass
192+
finally:
193+
print("Stopping the node.")
194+
print("Average mpt time: {:.3f} seconds".format(t_mpc / t))
195+
196+
if rclpy.ok():
197+
rclpy.shutdown()
198+
199+
200+
201+

0 commit comments

Comments
 (0)