Skip to content

Commit 2688afb

Browse files
committed
Fix NN things
1 parent 7f85de8 commit 2688afb

File tree

2 files changed

+141
-112
lines changed

2 files changed

+141
-112
lines changed

LuteBot/UI/NeuralNetworkForm.Designer.cs

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

LuteBot/UI/NeuralNetworkForm.cs

Lines changed: 140 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ public NeuralNetworkForm(TrackSelectionManager tsm, PartitionsForm pf)
2525
this.tsm = tsm;
2626
this.pf = pf;
2727
InitializeComponent();
28+
29+
if (File.Exists(savePath) && File.Exists(savePath + ".config"))
30+
{
31+
var sizes = JsonConvert.DeserializeObject<int[]>(File.ReadAllText(NeuralNetworkForm.savePath + ".config"));
32+
textBoxSizes.Enabled = false;
33+
textBoxSizes.Text = string.Join(",", sizes);
34+
}
2835
}
2936

3037
public static string savePath = Path.Combine(LuteBotForm.lutebotPath, "CustomNetwork");
@@ -33,20 +40,30 @@ private async void buttonTrain_Click(object sender, EventArgs e)
3340
{
3441
try
3542
{
36-
//Invoke((MethodInvoker)delegate
37-
//{
38-
textBoxSizes.Enabled = false;
39-
buttonTrain.Enabled = false;
40-
textBox1.Enabled = false;
41-
textBox2.Enabled = false;
42-
textBoxParallel.Enabled = false;
43-
richTextBox1.AppendText("\n\n\n\nLoading song data for training...");
44-
//});
45-
var sizes = textBoxSizes.Text.Split(',').Select(s => int.Parse(s)).ToArray();
46-
int numPerfect = int.Parse(textBox1.Text);
47-
float percentForSuccess = float.Parse(textBox2.Text);
48-
int parallel = int.Parse(textBoxParallel.Text);
49-
await TrainNetwork(numPerfect, percentForSuccess, parallel, sizes);
43+
if (buttonTrain.Text == "Train")
44+
{
45+
//Invoke((MethodInvoker)delegate
46+
//{
47+
buttonTrain.Text = "Pause";
48+
49+
textBoxSizes.Enabled = false;
50+
buttonTrain.Enabled = false;
51+
textBox1.Enabled = false;
52+
textBox2.Enabled = false;
53+
textBoxParallel.Enabled = false;
54+
richTextBox1.AppendText("\n\n\n\nLoading song data for training...");
55+
//});
56+
var sizes = textBoxSizes.Text.Split(',').Select(s => int.Parse(s)).ToArray();
57+
int numPerfect = int.Parse(textBox1.Text);
58+
float percentForSuccess = float.Parse(textBox2.Text);
59+
int parallel = int.Parse(textBoxParallel.Text);
60+
cancelled = false;
61+
await TrainNetwork(numPerfect, percentForSuccess, parallel, sizes).ConfigureAwait(false);
62+
}
63+
else
64+
{
65+
cancelled = true;
66+
}
5067
}
5168
catch (Exception ex)
5269
{
@@ -277,7 +294,7 @@ private async Task TestMusicMaker()
277294
Console.WriteLine("Done");
278295
}
279296

280-
297+
private bool cancelled = false;
281298
private async Task TrainNetwork(int numPerfect, float percentForSuccess, int parallelism, params int[] sizes)
282299
{
283300
try
@@ -297,7 +314,12 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
297314
for (int n = 1; n < parameters.Length - 1; n++)
298315
parameters[n] = sizes[n - 1];
299316
parameters[parameters.Length - 1] = 1;
317+
300318
tsm.neural = new NeuralNetwork(parameters, activation);
319+
if (File.Exists(NeuralNetworkForm.savePath) && File.Exists(NeuralNetworkForm.savePath + ".config"))
320+
{
321+
tsm.neural.Load(savePath);
322+
}
301323

302324
// These below work great and are the settings for 'v2Neural'
303325
//string[] activation = new string[] { "tanh", "softmax" };
@@ -364,6 +386,7 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
364386
}
365387
}
366388
}
389+
buttonTrain.Enabled = true;
367390
Console.WriteLine($"Training with {candidates.Count} candidates...");
368391
var trainingTarget = new TrainingTarget<MidiChannelItem>(candidates, ((a, b) => a.Id == b.Id));
369392

@@ -413,7 +436,7 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
413436

414437
int i = 0;
415438
int numSuccesses = 0;
416-
while (numSuccesses < numPerfect)
439+
while (numSuccesses < numPerfect && !cancelled)
417440
//while(costTotal > 0.1f && i < trainCount)
418441
//for (int i = 0; i < trainCount; i++)
419442
{
@@ -482,126 +505,132 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
482505
numActualTestsCorrect = 0;
483506
Parallel.ForEach(neuralTrainingCandidates.OrderBy(n => random.NextDouble()), new ParallelOptions() { MaxDegreeOfParallelism = parallelism }, song =>
484507
{
485-
float maxAvgNoteLength = song.Max(c => c.Id == 9 ? 0 : c.avgNoteLength);
486-
float maxNoteLength = song.Max(c => c.Id == 9 ? 0 : c.totalNoteLength);
487-
float maxNumNotes = song.Max(c => c.Id == 9 ? 0 : c.numNotes);
508+
if (!cancelled)
509+
{
510+
float maxAvgNoteLength = song.Max(c => c.Id == 9 ? 0 : c.avgNoteLength);
511+
float maxNoteLength = song.Max(c => c.Id == 9 ? 0 : c.totalNoteLength);
512+
float maxNumNotes = song.Max(c => c.Id == 9 ? 0 : c.numNotes);
488513

489-
float maxTickNumber = song.SelectMany(c => c.tickNotes).SelectMany(kvp => kvp.Value).Max(n => n.tickNumber);
514+
float maxTickNumber = song.SelectMany(c => c.tickNotes).SelectMany(kvp => kvp.Value).Max(n => n.tickNumber);
490515

491-
Dictionary<MidiChannelItem, float> channelResults = new Dictionary<MidiChannelItem, float>();
492-
foreach (var channel in song)
493-
{
494-
var inputs = channel.GetNeuralInputs(maxAvgNoteLength, maxNumNotes, maxNoteLength);
495-
//var inputs = channel.GetRecurrentInput(noteParams, maxTickNumber);
496-
//var neuralResults = tsm.neural.FeedForwardRecurrent(inputs);
497-
var neuralResults = tsm.neural.FeedForward(inputs);
498-
channelResults[channel] = neuralResults[0];
499-
}
516+
Dictionary<MidiChannelItem, float> channelResults = new Dictionary<MidiChannelItem, float>();
517+
foreach (var channel in song)
518+
{
519+
var inputs = channel.GetNeuralInputs(maxAvgNoteLength, maxNumNotes, maxNoteLength);
520+
//var inputs = channel.GetRecurrentInput(noteParams, maxTickNumber);
521+
//var neuralResults = tsm.neural.FeedForwardRecurrent(inputs);
522+
var neuralResults = tsm.neural.FeedForward(inputs);
523+
channelResults[channel] = neuralResults[0];
524+
}
500525

501-
var orderedResults = channelResults.OrderByDescending(kvp => kvp.Value);
526+
var orderedResults = channelResults.OrderByDescending(kvp => kvp.Value);
502527

503528

504-
/*
505-
var inputs = song.GetNeuralInput();
529+
/*
530+
var inputs = song.GetNeuralInput();
506531
507-
var neuralResults = tsm.neural.FeedForward(inputs);
508-
// The output here is a channel ID to confidence map...
509-
// I need to find the ID of the best one... and rank the rest...
510-
// So let's just build a quick dictionary I guess
511-
var results = new Dictionary<int, float>();
532+
var neuralResults = tsm.neural.FeedForward(inputs);
533+
// The output here is a channel ID to confidence map...
534+
// I need to find the ID of the best one... and rank the rest...
535+
// So let's just build a quick dictionary I guess
536+
var results = new Dictionary<int, float>();
512537
513-
for (int j = 0; j < neuralResults.Length; j++)
514-
{
515-
results[j] = neuralResults[j];
516-
}
538+
for (int j = 0; j < neuralResults.Length; j++)
539+
{
540+
results[j] = neuralResults[j];
541+
}
517542
518-
var orderedResults = results.OrderByDescending(kvp => kvp.Value);
519-
*/
520-
// Check the number of active flute channels...
521-
var numFlute = song.Where(s => s.Active).Count();
522-
bool correct = true;
543+
var orderedResults = results.OrderByDescending(kvp => kvp.Value);
544+
*/
545+
// Check the number of active flute channels...
546+
var numFlute = song.Where(s => s.Active).Count();
547+
bool correct = true;
523548

524-
for (int j = 0; j < numFlute; j++)
525-
{
526-
bool? existsAndCorrect = song.Where(s => s.Id == orderedResults.ElementAt(j).Key.Id).SingleOrDefault()?.Active;
527-
if (!existsAndCorrect.HasValue || !existsAndCorrect.Value)
528-
correct = false;
529-
}
549+
for (int j = 0; j < numFlute; j++)
550+
{
551+
bool? existsAndCorrect = song.Where(s => s.Id == orderedResults.ElementAt(j).Key.Id).SingleOrDefault()?.Active;
552+
if (!existsAndCorrect.HasValue || !existsAndCorrect.Value)
553+
correct = false;
554+
}
530555

531-
if (correct)
532-
Interlocked.Increment(ref numTestsCorrect);
556+
if (correct)
557+
Interlocked.Increment(ref numTestsCorrect);
533558

559+
}
534560
//await Task.Delay(0); // Let the form live between iterations
535561
});
536562
Parallel.ForEach(neuralTestCandidates.OrderBy(n => random.NextDouble()), new ParallelOptions() { MaxDegreeOfParallelism = parallelism }, song =>
537563
{
538-
float maxAvgNoteLength = song.Max(c => c.Id == 9 ? 0 : c.avgNoteLength);
539-
float maxNoteLength = song.Max(c => c.Id == 9 ? 0 : c.totalNoteLength);
540-
float maxNumNotes = song.Max(c => c.Id == 9 ? 0 : c.numNotes);
541-
542-
float maxTickNumber = song.SelectMany(c => c.tickNotes).SelectMany(kvp => kvp.Value).Max(n => n.tickNumber);
543-
544-
Dictionary<MidiChannelItem, float> channelResults = new Dictionary<MidiChannelItem, float>();
545-
foreach (var channel in song)
564+
if (!cancelled)
546565
{
547-
//var inputs = channel.GetRecurrentInput(noteParams, maxTickNumber);
548-
//var neuralResults = tsm.neural.FeedForwardRecurrent(inputs);
549-
var inputs = channel.GetNeuralInputs(maxAvgNoteLength, maxNumNotes, maxNoteLength);
550-
var neuralResults = tsm.neural.FeedForward(inputs);
551-
channelResults[channel] = neuralResults[0];
552-
}
566+
float maxAvgNoteLength = song.Max(c => c.Id == 9 ? 0 : c.avgNoteLength);
567+
float maxNoteLength = song.Max(c => c.Id == 9 ? 0 : c.totalNoteLength);
568+
float maxNumNotes = song.Max(c => c.Id == 9 ? 0 : c.numNotes);
553569

554-
var orderedResults = channelResults.OrderByDescending(kvp => kvp.Value);
570+
float maxTickNumber = song.SelectMany(c => c.tickNotes).SelectMany(kvp => kvp.Value).Max(n => n.tickNumber);
555571

572+
Dictionary<MidiChannelItem, float> channelResults = new Dictionary<MidiChannelItem, float>();
573+
foreach (var channel in song)
574+
{
575+
//var inputs = channel.GetRecurrentInput(noteParams, maxTickNumber);
576+
//var neuralResults = tsm.neural.FeedForwardRecurrent(inputs);
577+
var inputs = channel.GetNeuralInputs(maxAvgNoteLength, maxNumNotes, maxNoteLength);
578+
var neuralResults = tsm.neural.FeedForward(inputs);
579+
channelResults[channel] = neuralResults[0];
580+
}
556581

557-
/*
558-
var inputs = song.GetNeuralInput();
582+
var orderedResults = channelResults.OrderByDescending(kvp => kvp.Value);
559583

560-
var neuralResults = tsm.neural.FeedForward(inputs);
561-
// The output here is a channel ID to confidence map...
562-
// I need to find the ID of the best one... and rank the rest...
563-
// So let's just build a quick dictionary I guess
564-
var results = new Dictionary<int, float>();
565584

566-
for (int j = 0; j < neuralResults.Length; j++)
567-
{
568-
results[j] = neuralResults[j];
569-
}
585+
/*
586+
var inputs = song.GetNeuralInput();
570587
571-
var orderedResults = results.OrderByDescending(kvp => kvp.Value);
572-
*/
573-
// Check the number of active flute channels...
574-
var numFlute = song.Where(s => s.Active).Count();
575-
bool correct = true;
588+
var neuralResults = tsm.neural.FeedForward(inputs);
589+
// The output here is a channel ID to confidence map...
590+
// I need to find the ID of the best one... and rank the rest...
591+
// So let's just build a quick dictionary I guess
592+
var results = new Dictionary<int, float>();
576593
577-
for (int j = 0; j < numFlute; j++)
578-
{
579-
bool? existsAndCorrect = song.Where(s => s.Id == orderedResults.ElementAt(j).Key.Id).SingleOrDefault()?.Active;
580-
if (!existsAndCorrect.HasValue || !existsAndCorrect.Value)
581-
correct = false;
582-
}
594+
for (int j = 0; j < neuralResults.Length; j++)
595+
{
596+
results[j] = neuralResults[j];
597+
}
583598
584-
if (correct)
585-
{
586-
Interlocked.Increment(ref numTestsCorrect);
587-
Interlocked.Increment(ref numActualTestsCorrect);
588-
}
599+
var orderedResults = results.OrderByDescending(kvp => kvp.Value);
600+
*/
601+
// Check the number of active flute channels...
602+
var numFlute = song.Where(s => s.Active).Count();
603+
bool correct = true;
589604

590-
//int fluteCount = 0;
591-
//foreach (var channel in channelResults.Keys)
592-
////foreach (var channel in activeChannels)
593-
//{
594-
// //var channel = activeChannels.Where(c => c.Id == orderedResults.ElementAt(i).Key).SingleOrDefault();
595-
//
596-
// if (channel != null)
597-
// {
598-
// Console.WriteLine($"{channel.Name} ({channel.Id}) - Neural Score: {channelResults[channel]}");
599-
// //Console.WriteLine($"{channel.Name} ({channel.Id}) - Neural Score: {neuralResults[channel.Id]}");
600-
// //channel.Name += $"(Flute Rank {++fluteCount} - {Math.Round(channelResults[channel], 2)}%)";
601-
// }
602-
//}
605+
for (int j = 0; j < numFlute; j++)
606+
{
607+
bool? existsAndCorrect = song.Where(s => s.Id == orderedResults.ElementAt(j).Key.Id).SingleOrDefault()?.Active;
608+
if (!existsAndCorrect.HasValue || !existsAndCorrect.Value)
609+
correct = false;
610+
}
603611

604-
//await Task.Delay(0); // Let the form live between iterations
612+
if (correct)
613+
{
614+
Interlocked.Increment(ref numTestsCorrect);
615+
Interlocked.Increment(ref numActualTestsCorrect);
616+
}
617+
618+
//int fluteCount = 0;
619+
//foreach (var channel in channelResults.Keys)
620+
////foreach (var channel in activeChannels)
621+
//{
622+
// //var channel = activeChannels.Where(c => c.Id == orderedResults.ElementAt(i).Key).SingleOrDefault();
623+
//
624+
// if (channel != null)
625+
// {
626+
// Console.WriteLine($"{channel.Name} ({channel.Id}) - Neural Score: {channelResults[channel]}");
627+
// //Console.WriteLine($"{channel.Name} ({channel.Id}) - Neural Score: {neuralResults[channel.Id]}");
628+
// //channel.Name += $"(Flute Rank {++fluteCount} - {Math.Round(channelResults[channel], 2)}%)";
629+
// }
630+
//}
631+
632+
//await Task.Delay(0); // Let the form live between iterations
633+
}
605634
});
606635
BeginInvoke((MethodInvoker)delegate
607636
{
@@ -626,11 +655,11 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
626655
{
627656
richTextBox1.AppendText($"\n\n----Training Complete----\n\nNetwork saved to CustomNetwork file to {savePath}\nYou may now load midis and check it for yourself, and this will be your new default network. \n\nYou can revert at any time by deleting this file, or train again to replace it");
628657
richTextBox1.ScrollToCaret();
629-
textBoxSizes.Enabled = true;
630658
buttonTrain.Enabled = true;
631659
textBox1.Enabled = true;
632660
textBoxParallel.Enabled = true;
633661
textBox2.Enabled = true;
662+
buttonTrain.Text = "Train";
634663
});
635664
if (File.Exists(savePath))
636665
File.Delete(savePath);

0 commit comments

Comments
 (0)