Skip to content

Commit 934a28e

Browse files
committed
code qa
1 parent 9d3321b commit 934a28e

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

river/anomaly/memstream.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,10 @@ def __process_x__(self, x):
194194
loss_values = np.sort(
195195
norms,
196196
)[: self.k]
197-
loss_value = np.sum(loss_values * self.exp) / (
198-
np.sum(self.exp) + self.eps
199-
)
197+
loss_value = np.sum(loss_values * self.exp) / (np.sum(self.exp) + self.eps)
200198
if self.replace_strategy == ReplaceStrategy.LRU:
201199
memory_indeces = np.argsort(norms)[: self.k]
202-
(
203-
self.__reorder_memory__(memory_index)
204-
for memory_index in memory_indeces
205-
)
200+
(self.__reorder_memory__(memory_index) for memory_index in memory_indeces)
206201
return loss_value, encode_x, x
207202

208203
def score_one(self, x, y=None):
@@ -223,9 +218,7 @@ def __manage_non_encoded__(self, x, y):
223218
if (y is not None and y != 1) or y is None:
224219
self.__update_memory__(0, np.zeros((1, self.out_dim)), x)
225220
elif self.count >= self.grace_period:
226-
self.__define_encoder__(
227-
[(self.mem_data[i], 0) for i in range(len(self.mem_data))]
228-
)
221+
self.__define_encoder__([(self.mem_data[i], 0) for i in range(len(self.mem_data))])
229222
self.initialized = True
230223

231224
def learn_one(self, x, y=None):
@@ -235,9 +228,7 @@ def learn_one(self, x, y=None):
235228
loss_value, encode_x, x = self.__process_x__(x)
236229
if y is not None and y == 1:
237230
return # Do not learn from anomalies
238-
self.__update_memory__(
239-
0 if self.count < self.grace_period else loss_value, encode_x, x
240-
)
231+
self.__update_memory__(0 if self.count < self.grace_period else loss_value, encode_x, x)
241232

242233

243234
class MemStreamPCA(MemStream):

0 commit comments

Comments
 (0)