Skip to content

Commit e4fbe6e

Browse files
authored
Merge pull request #97 from cpnota/release/0.3.0
Release/0.3.0
2 parents c9c85ef + 28f1e81 commit e4fbe6e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+924
-576
lines changed

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
install:
2-
pip install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl
3-
pip install torchvision tensorflow
2+
pip install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl
3+
pip install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp37-cp37m-linux_x86_64.whl
4+
pip install tensorflow
45
pip install -e .
56

67
lint:

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ We provide out-of-the-box modules for:
3434
- [x] Generalized Advantage Estimation (GAE)
3535
- [x] Target networks
3636
- [x] Polyak averaging
37+
- [x] Easy parameter and learning rate scheduling
3738
- [x] An enhanced `nn` module (includes dueling layers, noisy layers, action bounds, and the coveted `nn.Flatten`)
3839
- [x] `gym` to `pytorch` wrappers
3940
- [x] Atari wrappers

all/agents/_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
2+
from all.optim import Schedulable
23

3-
class Agent(ABC):
4+
class Agent(ABC, Schedulable):
45
"""
56
A reinforcement learning agent.
67

all/agents/a2c.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import torch
2-
from all.environments import State
31
from all.memory import NStepAdvantageBuffer
42
from ._agent import Agent
53

@@ -22,26 +20,34 @@ def __init__(
2220
self.n_envs = n_envs
2321
self.n_steps = n_steps
2422
self.discount_factor = discount_factor
23+
self._states = None
24+
self._actions = None
2525
self._batch_size = n_envs * n_steps
2626
self._buffer = self._make_buffer()
2727
self._features = []
2828

2929
def act(self, states, rewards):
30-
self._buffer.store(states, torch.zeros(self.n_envs), rewards)
31-
self._train()
32-
features = self.features(states)
33-
self._features.append(features)
34-
return self.policy(features)
30+
self._store_transitions(rewards)
31+
self._train(states)
32+
self._states = states
33+
self._actions = self.policy.eval(self.features.eval(states))
34+
return self._actions
3535

36-
def _train(self):
36+
def _store_transitions(self, rewards):
37+
if self._states:
38+
self._buffer.store(self._states, self._actions, rewards)
39+
40+
def _train(self, states):
3741
if len(self._buffer) >= self._batch_size:
38-
states = State.from_list(self._features)
39-
_, _, advantages = self._buffer.sample(self._batch_size)
40-
self.v(states)
42+
states, actions, advantages = self._buffer.advantages(states)
43+
# forward pass
44+
features = self.features(states)
45+
self.v(features)
46+
self.policy(features, actions)
47+
# backward pass
4148
self.v.reinforce(advantages)
4249
self.policy.reinforce(advantages)
4350
self.features.reinforce()
44-
self._features = []
4551

4652
def _make_buffer(self):
4753
return NStepAdvantageBuffer(

all/agents/ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _train(self):
4646
# train q function
4747
td_errors = (
4848
rewards +
49-
self.discount_factor * self.q.eval(next_states, self.policy.eval(next_states)) -
49+
self.discount_factor * self.q.target(next_states, self.policy.target(next_states)) -
5050
self.q(states, torch.cat(actions))
5151
)
5252
self.q.reinforce(weights * td_errors)

all/agents/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _train(self):
4545
self.minibatch_size)
4646
td_errors = (
4747
rewards +
48-
self.discount_factor * torch.max(self.q.eval(next_states), dim=1)[0] -
48+
self.discount_factor * torch.max(self.q.target(next_states), dim=1)[0] -
4949
self.q(states, actions)
5050
)
5151
self.q.reinforce(weights * td_errors)

all/agents/evaluation/greedy_agent.py

Whitespace-only changes.

all/agents/ppo.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __init__(
2626
self.n_steps = n_steps
2727
self.discount_factor = discount_factor
2828
self.lam = lam
29+
self._states = None
30+
self._actions = None
2931
self._epsilon = epsilon
3032
self._epochs = epochs
3133
self._batch_size = n_envs * n_steps
@@ -34,14 +36,19 @@ def __init__(
3436
self._features = []
3537

3638
def act(self, states, rewards):
37-
self._train()
38-
actions = self.policy.eval(self.features.eval(states))
39-
self._buffer.store(states, actions, rewards)
40-
return actions
39+
self._store_transitions(rewards)
40+
self._train(states)
41+
self._states = states
42+
self._actions = self.policy.eval(self.features.eval(states))
43+
return self._actions
4144

42-
def _train(self):
45+
def _store_transitions(self, rewards):
46+
if self._states:
47+
self._buffer.store(self._states, self._actions, rewards)
48+
49+
def _train(self, _states):
4350
if len(self._buffer) >= self._batch_size:
44-
states, actions, advantages = self._buffer.sample(self._batch_size)
51+
states, actions, advantages = self._buffer.advantages(_states)
4552
with torch.no_grad():
4653
features = self.features.eval(states)
4754
pi_0 = self.policy.eval(features, actions)
@@ -65,18 +72,12 @@ def _train_minibatch(self, states, actions, pi_0, advantages, targets):
6572
self.v.reinforce(targets - self.v(features))
6673
self.features.reinforce()
6774

68-
def _compute_targets(self, returns, next_states, lengths):
69-
return (
70-
returns +
71-
(self.discount_factor ** lengths)
72-
* self.v.eval(self.features.eval(next_states))
73-
)
74-
7575
def _compute_policy_loss(self, pi_0, advantages):
7676
def _policy_loss(pi_i):
7777
ratios = torch.exp(pi_i - pi_0)
7878
surr1 = ratios * advantages
79-
surr2 = torch.clamp(ratios, 1.0 - self._epsilon, 1.0 + self._epsilon) * advantages
79+
epsilon = self._epsilon
80+
surr2 = torch.clamp(ratios, 1.0 - epsilon, 1.0 + epsilon) * advantages
8081
return -torch.min(surr1, surr2).mean()
8182
return _policy_loss
8283

all/agents/sac.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from all.experiments import DummyWriter
2+
from all.logging import DummyWriter
33
from ._agent import Agent
44

55
class SAC(Agent):
@@ -9,7 +9,9 @@ def __init__(self,
99
q_2,
1010
v,
1111
replay_buffer,
12-
entropy_regularizer=0.01,
12+
entropy_target=-2., # usually -action_space.size[0]
13+
temperature_initial=0.1,
14+
lr_temperature=1e-4,
1315
discount_factor=0.99,
1416
minibatch_size=32,
1517
replay_start_size=5000,
@@ -28,7 +30,10 @@ def __init__(self,
2830
self.update_frequency = update_frequency
2931
self.minibatch_size = minibatch_size
3032
self.discount_factor = discount_factor
31-
self.entropy_regularizer = entropy_regularizer
33+
# vars for learning the temperature
34+
self.entropy_target = entropy_target
35+
self.temperature = temperature_initial
36+
self.lr_temperature = lr_temperature
3237
# data
3338
self.env = None
3439
self.state = None
@@ -39,8 +44,7 @@ def act(self, state, reward):
3944
self._store_transition(state, reward)
4045
self._train()
4146
self.state = state
42-
with torch.no_grad():
43-
self.action = self.policy(state)
47+
self.action = self.policy.eval(state)
4448
return self.action
4549

4650
def _store_transition(self, state, reward):
@@ -58,14 +62,17 @@ def _train(self):
5862
# compute targets for Q and V
5963
with torch.no_grad():
6064
_actions, _log_probs = self.policy(states, log_prob=True)
61-
q_targets = rewards + self.discount_factor * self.v.eval(next_states)
65+
q_targets = rewards + self.discount_factor * self.v.target(next_states)
6266
v_targets = torch.min(
63-
self.q_1.eval(states, _actions),
64-
self.q_2.eval(states, _actions),
65-
) - self.entropy_regularizer * _log_probs
67+
self.q_1.target(states, _actions),
68+
self.q_2.target(states, _actions),
69+
) - self.temperature * _log_probs
70+
temperature_loss = ((_log_probs + self.entropy_target).detach().mean())
6671
self.writer.add_loss('entropy', -_log_probs.mean())
6772
self.writer.add_loss('v_mean', v_targets.mean())
6873
self.writer.add_loss('r_mean', rewards.mean())
74+
self.writer.add_loss('temperature_loss', temperature_loss)
75+
self.writer.add_loss('temperature', self.temperature)
6976

7077
# update Q-functions
7178
q_1_errors = q_targets - self.q_1(states, actions)
@@ -79,15 +86,17 @@ def _train(self):
7986

8087
# train policy
8188
_actions, _log_probs = self.policy(states, log_prob=True)
82-
8389
loss = -(
8490
self.q_1(states, _actions, detach=False)
85-
- self.entropy_regularizer * _log_probs
91+
- self.temperature * _log_probs
8692
).mean()
8793
loss.backward()
8894
self.policy.step()
8995
self.q_1.zero_grad()
9096

97+
# adjust temperature
98+
self.temperature += self.lr_temperature * temperature_loss
99+
91100
def _should_train(self):
92101
return (self.frames_seen > self.replay_start_size and
93102
self.frames_seen % self.update_frequency == 0)

all/agents/vac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def act(self, state, reward):
1313
if self._previous_features:
1414
td_error = (
1515
reward
16-
+ self.gamma * self.v.eval(self.features.eval(state))
16+
+ self.gamma * self.v.target(self.features.eval(state))
1717
- self.v(self._previous_features)
1818
)
1919
self.v.reinforce(td_error)

0 commit comments

Comments
 (0)