Skip to content

Commit 9340770

Browse files
committed
Modify training loop
1 parent 1bae4bb commit 9340770

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

examples/beat_detect.psh

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ fun training_actor(detection_actor_id) {
429429
let var positive_examples_count = 0;
430430
let var negative_examples_count = 0;
431431

432-
for (let var training_itr = 1;; ++training_itr) {
432+
for (let var training_itr = 1;;) {
433433
let msg = $actor_poll();
434434
if (msg != nil) {
435435
training_data.push(msg);
@@ -441,44 +441,47 @@ fun training_actor(detection_actor_id) {
441441
$println("Positive samples: " + positive_examples_count.to_s() + ", negative samples: " + negative_examples_count.to_s());
442442
}
443443

444-
if (training_data.len > 0) {
445-
// Create a perturbed clone
446-
let perturbed_rnn = rnn.clone();
447-
perturbed_rnn.perturb(PERTURBATION_AMOUNT);
444+
if (training_data.len == 0) {
445+
$actor_sleep(1);
446+
continue;
447+
}
448448

449-
// Pick a random sample
450-
let idx = rand_int(0, training_data.len - 1);
451-
let sample = training_data[idx];
452-
let frames = sample[0];
453-
let label = sample[1];
449+
// Create a perturbed clone
450+
let perturbed_rnn = rnn.clone();
451+
perturbed_rnn.perturb(PERTURBATION_AMOUNT);
454452

455-
// Calculate loss for both models on the same sample
456-
let current_loss = calculate_loss(rnn, frames, label);
457-
let perturbed_loss = calculate_loss(perturbed_rnn, frames, label);
453+
// Pick a random sample
454+
let idx = rand_int(0, training_data.len - 1);
455+
let sample = training_data[idx];
456+
let frames = sample[0];
457+
let label = sample[1];
458458

459-
if (perturbed_loss < current_loss) {
460-
rnn = perturbed_rnn;
461-
best_loss = perturbed_loss;
462-
}
459+
// Calculate loss for both models on the same sample
460+
let current_loss = calculate_loss(rnn, frames, label);
461+
let perturbed_loss = calculate_loss(perturbed_rnn, frames, label);
463462

464-
if (iteration_count % 20 == 0) {
465-
$println(
466-
"Itr#" + training_itr.to_s() +
467-
", loss: " + best_loss.format_decimals(9) +
468-
", pos: " +
469-
positive_examples_count.to_s() +
470-
", neg: " +
471-
negative_examples_count.to_s()
472-
);
473-
}
463+
if (perturbed_loss < current_loss) {
464+
rnn = perturbed_rnn;
465+
best_loss = perturbed_loss;
466+
}
474467

475-
iteration_count = iteration_count + 1;
476-
if (iteration_count % 100 == 0) {
477-
$actor_send(detection_actor_id, rnn);
478-
}
468+
if (iteration_count % 20 == 0) {
469+
$println(
470+
"Itr#" + training_itr.to_s() +
471+
", loss: " + best_loss.format_decimals(1) +
472+
", pos: " +
473+
positive_examples_count.to_s() +
474+
", neg: " +
475+
negative_examples_count.to_s()
476+
);
477+
}
478+
479+
iteration_count = iteration_count + 1;
480+
if (iteration_count % 100 == 0) {
481+
$actor_send(detection_actor_id, rnn);
479482
}
480483

481-
$actor_sleep(1);
484+
++training_itr;
482485
}
483486
}
484487

0 commit comments

Comments
 (0)