11import torch
2+ from all .environments import State
23from .abstract import Agent
34
45class VPG (Agent ):
@@ -16,40 +17,52 @@ def __init__(
1617 self .gamma = gamma
1718 self .n_episodes = n_episodes
1819 self ._trajectories = []
19- self ._states = None
20- self ._rewards = None
20+ self ._features = []
21+ self ._rewards = []
22+
23+ def act (self , state , reward ):
24+ if not self ._features :
25+ return self ._initial (state )
26+ if not state .done :
27+ return self ._act (state , reward )
28+ return self ._terminal (reward )
2129
22- def initial (self , state , info = None ):
30+ def _initial (self , state ):
2331 features = self .features (state )
24- self ._states = [features ]
25- self ._rewards = []
32+ self ._features = [features .features ]
2633 return self .policy (features )
2734
28- def act (self , state , reward , info = None ):
35+ def _act (self , state , reward ):
2936 features = self .features (state )
30- self ._states .append (features )
37+ self ._features .append (features . features )
3138 self ._rewards .append (reward )
3239 return self .policy (features )
3340
34- def terminal (self , reward , info = None ):
41+ def _terminal (self , reward ):
3542 self ._rewards .append (reward )
36- states = torch .cat (self ._states )
37- rewards = torch .tensor (self ._rewards , device = states .device )
38- self ._trajectories .append ((states , rewards ))
43+ features = torch .cat (self ._features )
44+ rewards = torch .tensor (self ._rewards , device = features .device )
45+ self ._trajectories .append ((features , rewards ))
46+ self ._features = []
47+ self ._rewards = []
48+
3949 if len (self ._trajectories ) >= self .n_episodes :
40- advantages = torch .cat ([
41- self ._compute_advantages (states , rewards )
42- for (states , rewards )
43- in self ._trajectories
44- ])
45- self .v .reinforce (advantages , retain_graph = True )
46- self .policy .reinforce (advantages )
47- self .features .reinforce ()
48- self ._trajectories = []
50+ self ._train ()
51+
52+ def _train (self ):
53+ advantages = torch .cat ([
54+ self ._compute_advantages (features , rewards )
55+ for (features , rewards )
56+ in self ._trajectories
57+ ])
58+ self .v .reinforce (advantages , retain_graph = True )
59+ self .policy .reinforce (advantages )
60+ self .features .reinforce ()
61+ self ._trajectories = []
4962
5063 def _compute_advantages (self , features , rewards ):
5164 returns = self ._compute_discounted_returns (rewards )
52- values = self .v (features )
65+ values = self .v (State ( features ) )
5366 return returns - values
5467
5568 def _compute_discounted_returns (self , rewards ):
0 commit comments