11import torch
2- from all .experiments import DummyWriter
2+ from all .logging import DummyWriter
33from ._agent import Agent
44
55class SAC (Agent ):
@@ -9,7 +9,9 @@ def __init__(self,
99 q_2 ,
1010 v ,
1111 replay_buffer ,
12- entropy_regularizer = 0.01 ,
12+ entropy_target = - 2. , # usually -action_space.size[0]
13+ temperature_initial = 0.1 ,
14+ lr_temperature = 1e-4 ,
1315 discount_factor = 0.99 ,
1416 minibatch_size = 32 ,
1517 replay_start_size = 5000 ,
@@ -28,7 +30,10 @@ def __init__(self,
2830 self .update_frequency = update_frequency
2931 self .minibatch_size = minibatch_size
3032 self .discount_factor = discount_factor
31- self .entropy_regularizer = entropy_regularizer
33+ # vars for learning the temperature
34+ self .entropy_target = entropy_target
35+ self .temperature = temperature_initial
36+ self .lr_temperature = lr_temperature
3237 # data
3338 self .env = None
3439 self .state = None
@@ -39,8 +44,7 @@ def act(self, state, reward):
3944 self ._store_transition (state , reward )
4045 self ._train ()
4146 self .state = state
42- with torch .no_grad ():
43- self .action = self .policy (state )
47+ self .action = self .policy .eval (state )
4448 return self .action
4549
4650 def _store_transition (self , state , reward ):
@@ -58,14 +62,17 @@ def _train(self):
5862 # compute targets for Q and V
5963 with torch .no_grad ():
6064 _actions , _log_probs = self .policy (states , log_prob = True )
61- q_targets = rewards + self .discount_factor * self .v .eval (next_states )
65+ q_targets = rewards + self .discount_factor * self .v .target (next_states )
6266 v_targets = torch .min (
63- self .q_1 .eval (states , _actions ),
64- self .q_2 .eval (states , _actions ),
65- ) - self .entropy_regularizer * _log_probs
67+ self .q_1 .target (states , _actions ),
68+ self .q_2 .target (states , _actions ),
69+ ) - self .temperature * _log_probs
70+ temperature_loss = ((_log_probs + self .entropy_target ).detach ().mean ())
6671 self .writer .add_loss ('entropy' , - _log_probs .mean ())
6772 self .writer .add_loss ('v_mean' , v_targets .mean ())
6873 self .writer .add_loss ('r_mean' , rewards .mean ())
74+ self .writer .add_loss ('temperature_loss' , temperature_loss )
75+ self .writer .add_loss ('temperature' , self .temperature )
6976
7077 # update Q-functions
7178 q_1_errors = q_targets - self .q_1 (states , actions )
@@ -79,15 +86,17 @@ def _train(self):
7986
8087 # train policy
8188 _actions , _log_probs = self .policy (states , log_prob = True )
82-
8389 loss = - (
8490 self .q_1 (states , _actions , detach = False )
85- - self .entropy_regularizer * _log_probs
91+ - self .temperature * _log_probs
8692 ).mean ()
8793 loss .backward ()
8894 self .policy .step ()
8995 self .q_1 .zero_grad ()
9096
97+ # adjust temperature
98+ self .temperature += self .lr_temperature * temperature_loss
99+
91100 def _should_train (self ):
92101 return (self .frames_seen > self .replay_start_size and
93102 self .frames_seen % self .update_frequency == 0 )
0 commit comments