Skip to content

Commit f5a716d

Browse files
committed
Merge branch 'master' into lowering_precision
2 parents b5c608b + 4c894b7 commit f5a716d

File tree

20 files changed

+2308
-2021
lines changed

20 files changed

+2308
-2021
lines changed

.github/workflows/python-app.yml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,20 @@ jobs:
1616

1717
steps:
1818
- uses: actions/checkout@v3
19-
- name: Set up Python 3.9
19+
- name: Set up Python 3.13
2020
uses: actions/setup-python@v4
2121
with:
22-
python-version: 3.9
22+
<<<<<<< HEAD
23+
python-version: 3.10
2324
- name: Install Poetry
2425
env:
25-
POETRY_VERSION: 1.5.1
26+
POETRY_VERSION: 2.0.0
27+
=======
28+
python-version: 3.13
29+
- name: Install Poetry
30+
env:
31+
POETRY_VERSION: 2.1.2
32+
>>>>>>> origin
2633
run: |
2734
curl -sSL https://install.python-poetry.org | python - -y &&\
2835
poetry config virtualenvs.create false

.github/workflows/pythonpackage.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ jobs:
99
strategy:
1010
max-parallel: 4
1111
matrix:
12-
python-version: ["3.9", "3.10", "3.11"]
12+
<<<<<<< HEAD
13+
python-version: ["3.10", "3.11", "3.12"]
14+
=======
15+
python-version: ["3.10", "3.11", "3.12", "3.13"]
16+
>>>>>>> origin
1317

1418
steps:
1519
- uses: actions/checkout@v3
@@ -19,7 +23,11 @@ jobs:
1923
python-version: ${{ matrix.python-version }}
2024
- name: Install Poetry
2125
env:
22-
POETRY_VERSION: 1.5.1
26+
<<<<<<< HEAD
27+
POETRY_VERSION: 2.0.0
28+
=======
29+
POETRY_VERSION: 2.1.2
30+
>>>>>>> origin
2331
run: |
2432
curl -sSL https://install.python-poetry.org | python - -y &&\
2533
poetry config virtualenvs.create false

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ This package is used as part of ongoing research on applying SNNs, machine learn
99

1010
Check out the [BindsNET examples](https://github.com/BindsNET/bindsnet/tree/master/examples) for a collection of experiments, functions for the analysis of results, plots of experiment outcomes, and more. Documentation for the package can be found [here](https://bindsnet-docs.readthedocs.io).
1111

12-
![Build Status](https://github.com/BindsNET/bindsnet/actions/workflows/python-app.yml/badge.svg?branch=master)
12+
[![CodeQL](https://github.com/BindsNET/bindsnet/actions/workflows/github-code-scanning/codeql/badge.svg)](https://github.com/BindsNET/bindsnet/actions/workflows/github-code-scanning/codeql)
1313
[![Documentation Status](https://readthedocs.org/projects/bindsnet-docs/badge/?version=latest)](https://bindsnet-docs.readthedocs.io/?badge=latest)
14-
[![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/bindsnet_/community)
14+
[![Neuromorphic Computing](https://img.shields.io/badge/Collaboration_Network-Open_Neuromorphic-blue)](https://open-neuromorphic.org/neuromorphic-computing/)
1515

1616
## Requirements
1717

bindsnet/evaluation/evaluation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ def assign_labels(
4444
indices = torch.nonzero(labels == i).view(-1)
4545

4646
# Compute average firing rates for this label.
47+
selected_spikes = torch.index_select(
48+
spikes, dim=0, index=torch.tensor(indices)
49+
)
4750
rates[:, i] = alpha * rates[:, i] + (
48-
torch.sum(spikes[indices], 0) / n_labeled
51+
torch.sum(selected_spikes, 0) / n_labeled
4952
)
5053

5154
# Compute proportions of spike activity per class.
@@ -111,6 +114,8 @@ def all_activity(
111114

112115
# Sum over time dimension (spike ordering doesn't matter).
113116
spikes = spikes.sum(1)
117+
if spikes.is_sparse:
118+
spikes = spikes.to_dense()
114119

115120
rates = torch.zeros((n_samples, n_labels), device=spikes.device)
116121
for i in range(n_labels):
@@ -152,6 +157,8 @@ def proportion_weighting(
152157

153158
# Sum over time dimension (spike ordering doesn't matter).
154159
spikes = spikes.sum(1)
160+
if spikes.is_sparse:
161+
spikes = spikes.to_dense()
155162

156163
rates = torch.zeros((n_samples, n_labels), device=spikes.device)
157164
for i in range(n_labels):

bindsnet/learning/MCC_learning.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ def update(self, **kwargs) -> None:
102102
if ((self.min is not None) or (self.max is not None)) and not isinstance(
103103
self, NoOp
104104
):
105-
self.feature_value.clamp_(self.min, self.max)
105+
if self.feature_value.is_sparse:
106+
self.feature_value = (
107+
self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse()
108+
)
109+
else:
110+
self.feature_value.clamp_(self.min, self.max)
106111

107112
@abstractmethod
108113
def reset_state_variables(self) -> None:
@@ -247,10 +252,15 @@ def _connection_update(self, **kwargs) -> None:
247252
torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt
248253
)
249254
else:
250-
self.feature_value -= (
251-
self.reduction(torch.bmm(source_s, target_x), dim=0)
252-
* self.connection.dt
253-
)
255+
if self.feature_value.is_sparse:
256+
self.feature_value -= (
257+
torch.bmm(source_s, target_x) * self.connection.dt
258+
).to_sparse()
259+
else:
260+
self.feature_value -= (
261+
self.reduction(torch.bmm(source_s, target_x), dim=0)
262+
* self.connection.dt
263+
)
254264
del source_s, target_x
255265

256266
# Post-synaptic update.
@@ -278,10 +288,15 @@ def _connection_update(self, **kwargs) -> None:
278288
torch.mean(self.average_buffer_post, dim=0) * self.connection.dt
279289
)
280290
else:
281-
self.feature_value += (
282-
self.reduction(torch.bmm(source_x, target_s), dim=0)
283-
* self.connection.dt
284-
)
291+
if self.feature_value.is_sparse:
292+
self.feature_value += (
293+
torch.bmm(source_x, target_s) * self.connection.dt
294+
).to_sparse()
295+
else:
296+
self.feature_value += (
297+
self.reduction(torch.bmm(source_x, target_s), dim=0)
298+
* self.connection.dt
299+
)
285300
del source_x, target_s
286301

287302
super().update()
@@ -508,16 +523,16 @@ def _connection_update(self, **kwargs) -> None:
508523
self.average_buffer_index + 1
509524
) % self.average_update
510525

511-
if self.continues_update:
512-
self.feature_value += self.nu[0] * torch.mean(
513-
self.average_buffer, dim=0
514-
)
515-
elif self.average_buffer_index == 0:
516-
self.feature_value += self.nu[0] * torch.mean(
517-
self.average_buffer, dim=0
518-
)
526+
if self.continues_update or self.average_buffer_index == 0:
527+
update = self.nu[0] * torch.mean(self.average_buffer, dim=0)
528+
if self.feature_value.is_sparse:
529+
update = update.to_sparse()
530+
self.feature_value += update
519531
else:
520-
self.feature_value += self.nu[0] * self.reduction(update, dim=0)
532+
update = self.nu[0] * self.reduction(update, dim=0)
533+
if self.feature_value.is_sparse:
534+
update = update.to_sparse()
535+
self.feature_value += update
521536

522537
# Update P^+ and P^- values.
523538
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -686,14 +701,16 @@ def _connection_update(self, **kwargs) -> None:
686701
self.average_buffer_index + 1
687702
) % self.average_update
688703

689-
if self.continues_update:
690-
self.feature_value += torch.mean(self.average_buffer, dim=0)
691-
elif self.average_buffer_index == 0:
692-
self.feature_value += torch.mean(self.average_buffer, dim=0)
704+
if self.continues_update or self.average_buffer_index == 0:
705+
update = torch.mean(self.average_buffer, dim=0)
706+
if self.feature_value.is_sparse:
707+
update = update.to_sparse()
708+
self.feature_value += update
693709
else:
694-
self.feature_value += (
695-
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
696-
)
710+
update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
711+
if self.feature_value.is_sparse:
712+
update = update.to_sparse()
713+
self.feature_value += update
697714

698715
# Update P^+ and P^- values.
699716
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) # Decay

bindsnet/learning/learning.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def update(self) -> None:
9898
(self.connection.wmin != -np.inf).any()
9999
or (self.connection.wmax != np.inf).any()
100100
) and not isinstance(self, NoOp):
101-
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
101+
if self.connection.w.is_sparse:
102+
raise Exception("SparseConnection isn't supported for wmin\\wmax")
103+
else:
104+
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
102105

103106

104107
class NoOp(LearningRule):
@@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None:
396399
if self.nu[0].any():
397400
source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
398401
target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
399-
self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
402+
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
403+
if self.connection.w.is_sparse:
404+
update = update.to_sparse()
405+
self.connection.w -= update
400406
del source_s, target_x
401407

402408
# Post-synaptic update.
@@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None:
405411
self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
406412
)
407413
source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
408-
self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
414+
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
415+
if self.connection.w.is_sparse:
416+
update = update.to_sparse()
417+
self.connection.w += update
409418
del source_x, target_s
410419

411420
super().update()
@@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None:
11131122

11141123
# Pre-synaptic update.
11151124
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
1125+
if self.connection.w.is_sparse:
1126+
update = update.to_sparse()
11161127
self.connection.w += self.nu[0] * update
11171128

11181129
# Post-synaptic update.
11191130
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
1131+
if self.connection.w.is_sparse:
1132+
update = update.to_sparse()
11201133
self.connection.w += self.nu[1] * update
11211134

11221135
super().update()
@@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None:
15421555
a_minus = torch.tensor(a_minus, device=self.connection.w.device)
15431556

15441557
# Compute weight update based on the eligibility value of the past timestep.
1545-
update = reward * self.eligibility
1546-
self.connection.w += self.nu[0] * self.reduction(update, dim=0)
1558+
update = self.reduction(reward * self.eligibility, dim=0)
1559+
if self.connection.w.is_sparse:
1560+
update = update.to_sparse()
1561+
self.connection.w += self.nu[0] * update
15471562

15481563
# Update P^+ and P^- values.
15491564
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None:
22142229
self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
22152230
self.eligibility_trace += self.eligibility / self.tc_e_trace
22162231

2232+
update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
2233+
if self.connection.w.is_sparse:
2234+
update = update.to_sparse()
22172235
# Compute weight update.
2218-
self.connection.w += (
2219-
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
2220-
)
2236+
self.connection.w += update
22212237

22222238
# Update P^+ and P^- values.
22232239
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None:
29362952
) * source_x[:, None]
29372953

29382954
# Compute weight update.
2939-
self.connection.w += self.nu[0] * reward * self.eligibility_trace
2955+
update = self.nu[0] * reward * self.eligibility_trace
2956+
if self.connection.w.is_sparse:
2957+
update = update.to_sparse()
2958+
self.connection.w += update
29402959

29412960
super().update()

bindsnet/models/models.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from scipy.spatial.distance import euclidean
66
from torch.nn.modules.utils import _pair
7+
from torch import device
78

89
from bindsnet.learning import PostPre
910
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
@@ -101,6 +102,8 @@ def __init__(
101102
self,
102103
n_inpt: int,
103104
device: str = "cpu",
105+
batch_size: int = None,
106+
sparse: bool = False,
104107
n_neurons: int = 100,
105108
exc: float = 22.5,
106109
inh: float = 17.5,
@@ -193,25 +196,31 @@ def __init__(
193196
reduction=reduction,
194197
nu=nu,
195198
learning_rule=MMCPostPre,
199+
sparse=sparse,
200+
batch_size=batch_size,
196201
)
197202
],
198203
)
199204
w = self.exc * torch.diag(torch.ones(self.n_neurons))
205+
if sparse:
206+
w = w.unsqueeze(0).expand(batch_size, -1, -1)
200207
exc_inh_conn = MulticompartmentConnection(
201208
source=exc_layer,
202209
target=inh_layer,
203210
device=device,
204-
pipeline=[Weight("weight", w, value_dtype=w_dtype, range=[0, self.exc])],
211+
pipeline=[Weight("weight", w, value_dtype=w_dtype, range=[0, self.exc], sparse=sparse)],
205212
)
206213
w = -self.inh * (
207214
torch.ones(self.n_neurons, self.n_neurons)
208215
- torch.diag(torch.ones(self.n_neurons))
209216
)
217+
if sparse:
218+
w = w.unsqueeze(0).expand(batch_size, -1, -1)
210219
inh_exc_conn = MulticompartmentConnection(
211220
source=inh_layer,
212221
target=exc_layer,
213222
device=device,
214-
pipeline=[Weight("weight", w, value_dtype=w_dtype, range=[-self.inh, 0])],
223+
pipeline=[Weight("weight", w, value_dtype=w_dtype, range=[-self.inh, 0], sparse=sparse)],
215224
)
216225

217226
# Add to network

bindsnet/network/monitors.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
time: Optional[int] = None,
4646
batch_size: int = 1,
4747
device: str = "cpu",
48+
sparse: Optional[bool] = False,
4849
):
4950
# language=rst
5051
"""
@@ -62,6 +63,7 @@ def __init__(
6263
self.time = time
6364
self.batch_size = batch_size
6465
self.device = device
66+
self.sparse = sparse
6567

6668
# if time is not specified the monitor variable accumulate the logs
6769
if self.time is None:
@@ -98,11 +100,12 @@ def record(self) -> None:
98100
for v in self.state_vars:
99101
data = getattr(self.obj, v).unsqueeze(0)
100102
# self.recording[v].append(data.detach().clone().to(self.device))
101-
self.recording[v].append(
102-
torch.empty_like(data, device=self.device, requires_grad=False).copy_(
103-
data, non_blocking=True
104-
)
105-
)
103+
record = torch.empty_like(
104+
data, device=self.device, requires_grad=False
105+
).copy_(data, non_blocking=True)
106+
if self.sparse:
107+
record = record.to_sparse()
108+
self.recording[v].append(record)
106109
# remove the oldest element (first in the list)
107110
if self.time is not None:
108111
self.recording[v].pop(0)

bindsnet/network/network.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "N
1919
:param learning: Whether to load with learning enabled. Default loads value from
2020
disk.
2121
"""
22-
network = torch.load(open(file_name, "rb"), map_location=map_location)
22+
network = torch.load(
23+
open(file_name, "rb"), map_location=map_location, weights_only=False
24+
)
2325
if learning is not None and "learning" in vars(network):
2426
network.learning = learning
2527

@@ -191,6 +193,7 @@ def save(self, file_name: str) -> None:
191193
# Save the network to disk.
192194
network.save(str(Path.home()) + '/network.pt')
193195
"""
196+
torch.serialization.add_safe_globals([self])
194197
torch.save(self, open(file_name, "wb"))
195198

196199
def clone(self) -> "Network":

0 commit comments

Comments
 (0)