1+ import torch
2+ from all .memory import GeneralizedAdvantageBuffer
3+ from ._agent import Agent
4+
5+
6+ class PPO (Agent ):
7+ def __init__ (
8+ self ,
9+ features ,
10+ v ,
11+ policy ,
12+ epsilon = 0.2 ,
13+ epochs = 4 ,
14+ minibatches = 4 ,
15+ n_envs = None ,
16+ n_steps = 4 ,
17+ discount_factor = 0.99 ,
18+ lam = 0.95
19+ ):
20+ if n_envs is None :
21+ raise RuntimeError ("Must specify n_envs." )
22+ self .features = features
23+ self .v = v
24+ self .policy = policy
25+ self .n_envs = n_envs
26+ self .n_steps = n_steps
27+ self .discount_factor = discount_factor
28+ self .lam = lam
29+ self ._epsilon = epsilon
30+ self ._epochs = epochs
31+ self ._batch_size = n_envs * n_steps
32+ self ._minibatches = minibatches
33+ self ._buffer = self ._make_buffer ()
34+ self ._features = []
35+
36+ def act (self , states , rewards ):
37+ self ._train ()
38+ actions = self .policy .eval (self .features .eval (states ))
39+ self ._buffer .store (states , actions , rewards )
40+ return actions
41+
42+ def _train (self ):
43+ if len (self ._buffer ) >= self ._batch_size :
44+ states , actions , advantages = self ._buffer .sample (self ._batch_size )
45+ with torch .no_grad ():
46+ features = self .features .eval (states )
47+ pi_0 = self .policy .eval (features , actions )
48+ targets = self .v .eval (features ) + advantages
49+ for _ in range (self ._epochs ):
50+ self ._train_epoch (states , actions , pi_0 , advantages , targets )
51+
52+ def _train_epoch (self , states , actions , pi_0 , advantages , targets ):
53+ minibatch_size = int (self ._batch_size / self ._minibatches )
54+ indexes = torch .randperm (self ._batch_size )
55+ for n in range (self ._minibatches ):
56+ first = n * minibatch_size
57+ last = first + minibatch_size
58+ i = indexes [first :last ]
59+ self ._train_minibatch (states [i ], actions [i ], pi_0 [i ], advantages [i ], targets [i ])
60+
61+ def _train_minibatch (self , states , actions , pi_0 , advantages , targets ):
62+ features = self .features (states )
63+ self .policy (features , actions )
64+ self .policy .reinforce (self ._compute_policy_loss (pi_0 , advantages ))
65+ self .v .reinforce (targets - self .v (features ))
66+ self .features .reinforce ()
67+
68+ def _compute_targets (self , returns , next_states , lengths ):
69+ return (
70+ returns +
71+ (self .discount_factor ** lengths )
72+ * self .v .eval (self .features .eval (next_states ))
73+ )
74+
75+ def _compute_policy_loss (self , pi_0 , advantages ):
76+ def _policy_loss (pi_i ):
77+ ratios = torch .exp (pi_i - pi_0 )
78+ surr1 = ratios * advantages
79+ surr2 = torch .clamp (ratios , 1.0 - self ._epsilon , 1.0 + self ._epsilon ) * advantages
80+ return - torch .min (surr1 , surr2 ).mean ()
81+ return _policy_loss
82+
83+ def _make_buffer (self ):
84+ return GeneralizedAdvantageBuffer (
85+ self .v ,
86+ self .features ,
87+ self .n_steps ,
88+ self .n_envs ,
89+ discount_factor = self .discount_factor ,
90+ lam = self .lam
91+ )
92+
0 commit comments