@@ -25,6 +25,13 @@ public NeuralNetworkForm(TrackSelectionManager tsm, PartitionsForm pf)
2525 this . tsm = tsm ;
2626 this . pf = pf ;
2727 InitializeComponent ( ) ;
28+
29+ if ( File . Exists ( savePath ) && File . Exists ( savePath + ".config" ) )
30+ {
31+ var sizes = JsonConvert . DeserializeObject < int [ ] > ( File . ReadAllText ( NeuralNetworkForm . savePath + ".config" ) ) ;
32+ textBoxSizes . Enabled = false ;
33+ textBoxSizes . Text = string . Join ( "," , sizes ) ;
34+ }
2835 }
2936
3037 public static string savePath = Path . Combine ( LuteBotForm . lutebotPath , "CustomNetwork" ) ;
@@ -33,20 +40,30 @@ private async void buttonTrain_Click(object sender, EventArgs e)
3340 {
3441 try
3542 {
36- //Invoke((MethodInvoker)delegate
37- //{
38- textBoxSizes . Enabled = false ;
39- buttonTrain . Enabled = false ;
40- textBox1 . Enabled = false ;
41- textBox2 . Enabled = false ;
42- textBoxParallel . Enabled = false ;
43- richTextBox1 . AppendText ( "\n \n \n \n Loading song data for training..." ) ;
44- //});
45- var sizes = textBoxSizes . Text . Split ( ',' ) . Select ( s => int . Parse ( s ) ) . ToArray ( ) ;
46- int numPerfect = int . Parse ( textBox1 . Text ) ;
47- float percentForSuccess = float . Parse ( textBox2 . Text ) ;
48- int parallel = int . Parse ( textBoxParallel . Text ) ;
49- await TrainNetwork ( numPerfect , percentForSuccess , parallel , sizes ) ;
43+ if ( buttonTrain . Text == "Train" )
44+ {
45+ //Invoke((MethodInvoker)delegate
46+ //{
47+ buttonTrain . Text = "Pause" ;
48+
49+ textBoxSizes . Enabled = false ;
50+ buttonTrain . Enabled = false ;
51+ textBox1 . Enabled = false ;
52+ textBox2 . Enabled = false ;
53+ textBoxParallel . Enabled = false ;
54+ richTextBox1 . AppendText ( "\n \n \n \n Loading song data for training..." ) ;
55+ //});
56+ var sizes = textBoxSizes . Text . Split ( ',' ) . Select ( s => int . Parse ( s ) ) . ToArray ( ) ;
57+ int numPerfect = int . Parse ( textBox1 . Text ) ;
58+ float percentForSuccess = float . Parse ( textBox2 . Text ) ;
59+ int parallel = int . Parse ( textBoxParallel . Text ) ;
60+ cancelled = false ;
61+ await TrainNetwork ( numPerfect , percentForSuccess , parallel , sizes ) . ConfigureAwait ( false ) ;
62+ }
63+ else
64+ {
65+ cancelled = true ;
66+ }
5067 }
5168 catch ( Exception ex )
5269 {
@@ -277,7 +294,7 @@ private async Task TestMusicMaker()
277294 Console . WriteLine ( "Done" ) ;
278295 }
279296
280-
297+ private bool cancelled = false ;
281298 private async Task TrainNetwork ( int numPerfect , float percentForSuccess , int parallelism , params int [ ] sizes )
282299 {
283300 try
@@ -297,7 +314,12 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
297314 for ( int n = 1 ; n < parameters . Length - 1 ; n ++ )
298315 parameters [ n ] = sizes [ n - 1 ] ;
299316 parameters [ parameters . Length - 1 ] = 1 ;
317+
300318 tsm . neural = new NeuralNetwork ( parameters , activation ) ;
319+ if ( File . Exists ( NeuralNetworkForm . savePath ) && File . Exists ( NeuralNetworkForm . savePath + ".config" ) )
320+ {
321+ tsm . neural . Load ( savePath ) ;
322+ }
301323
302324 // These below work great and are the settings for 'v2Neural'
303325 //string[] activation = new string[] { "tanh", "softmax" };
@@ -364,6 +386,7 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
364386 }
365387 }
366388 }
389+ buttonTrain . Enabled = true ;
367390 Console . WriteLine ( $ "Training with { candidates . Count } candidates...") ;
368391 var trainingTarget = new TrainingTarget < MidiChannelItem > ( candidates , ( ( a , b ) => a . Id == b . Id ) ) ;
369392
@@ -413,7 +436,7 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
413436
414437 int i = 0 ;
415438 int numSuccesses = 0 ;
416- while ( numSuccesses < numPerfect )
439+ while ( numSuccesses < numPerfect && ! cancelled )
417440 //while(costTotal > 0.1f && i < trainCount)
418441 //for (int i = 0; i < trainCount; i++)
419442 {
@@ -482,126 +505,132 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
482505 numActualTestsCorrect = 0 ;
483506 Parallel . ForEach ( neuralTrainingCandidates . OrderBy ( n => random . NextDouble ( ) ) , new ParallelOptions ( ) { MaxDegreeOfParallelism = parallelism } , song =>
484507 {
485- float maxAvgNoteLength = song . Max ( c => c . Id == 9 ? 0 : c . avgNoteLength ) ;
486- float maxNoteLength = song . Max ( c => c . Id == 9 ? 0 : c . totalNoteLength ) ;
487- float maxNumNotes = song . Max ( c => c . Id == 9 ? 0 : c . numNotes ) ;
508+ if ( ! cancelled )
509+ {
510+ float maxAvgNoteLength = song . Max ( c => c . Id == 9 ? 0 : c . avgNoteLength ) ;
511+ float maxNoteLength = song . Max ( c => c . Id == 9 ? 0 : c . totalNoteLength ) ;
512+ float maxNumNotes = song . Max ( c => c . Id == 9 ? 0 : c . numNotes ) ;
488513
489- float maxTickNumber = song . SelectMany ( c => c . tickNotes ) . SelectMany ( kvp => kvp . Value ) . Max ( n => n . tickNumber ) ;
514+ float maxTickNumber = song . SelectMany ( c => c . tickNotes ) . SelectMany ( kvp => kvp . Value ) . Max ( n => n . tickNumber ) ;
490515
491- Dictionary < MidiChannelItem , float > channelResults = new Dictionary < MidiChannelItem , float > ( ) ;
492- foreach ( var channel in song )
493- {
494- var inputs = channel . GetNeuralInputs ( maxAvgNoteLength , maxNumNotes , maxNoteLength ) ;
495- //var inputs = channel.GetRecurrentInput(noteParams, maxTickNumber);
496- //var neuralResults = tsm.neural.FeedForwardRecurrent(inputs);
497- var neuralResults = tsm . neural . FeedForward ( inputs ) ;
498- channelResults [ channel ] = neuralResults [ 0 ] ;
499- }
516+ Dictionary < MidiChannelItem , float > channelResults = new Dictionary < MidiChannelItem , float > ( ) ;
517+ foreach ( var channel in song )
518+ {
519+ var inputs = channel . GetNeuralInputs ( maxAvgNoteLength , maxNumNotes , maxNoteLength ) ;
520+ //var inputs = channel.GetRecurrentInput(noteParams, maxTickNumber);
521+ //var neuralResults = tsm.neural.FeedForwardRecurrent(inputs);
522+ var neuralResults = tsm . neural . FeedForward ( inputs ) ;
523+ channelResults [ channel ] = neuralResults [ 0 ] ;
524+ }
500525
501- var orderedResults = channelResults . OrderByDescending ( kvp => kvp . Value ) ;
526+ var orderedResults = channelResults . OrderByDescending ( kvp => kvp . Value ) ;
502527
503528
504- /*
505- var inputs = song.GetNeuralInput();
529+ /*
530+ var inputs = song.GetNeuralInput();
506531
507- var neuralResults = tsm.neural.FeedForward(inputs);
508- // The output here is a channel ID to confidence map...
509- // I need to find the ID of the best one... and rank the rest...
510- // So let's just build a quick dictionary I guess
511- var results = new Dictionary<int, float>();
532+ var neuralResults = tsm.neural.FeedForward(inputs);
533+ // The output here is a channel ID to confidence map...
534+ // I need to find the ID of the best one... and rank the rest...
535+ // So let's just build a quick dictionary I guess
536+ var results = new Dictionary<int, float>();
512537
513- for (int j = 0; j < neuralResults.Length; j++)
514- {
515- results[j] = neuralResults[j];
516- }
538+ for (int j = 0; j < neuralResults.Length; j++)
539+ {
540+ results[j] = neuralResults[j];
541+ }
517542
518- var orderedResults = results.OrderByDescending(kvp => kvp.Value);
519- */
520- // Check the number of active flute channels...
521- var numFlute = song . Where ( s => s . Active ) . Count ( ) ;
522- bool correct = true ;
543+ var orderedResults = results.OrderByDescending(kvp => kvp.Value);
544+ */
545+ // Check the number of active flute channels...
546+ var numFlute = song . Where ( s => s . Active ) . Count ( ) ;
547+ bool correct = true ;
523548
524- for ( int j = 0 ; j < numFlute ; j ++ )
525- {
526- bool ? existsAndCorrect = song . Where ( s => s . Id == orderedResults . ElementAt ( j ) . Key . Id ) . SingleOrDefault ( ) ? . Active ;
527- if ( ! existsAndCorrect . HasValue || ! existsAndCorrect . Value )
528- correct = false ;
529- }
549+ for ( int j = 0 ; j < numFlute ; j ++ )
550+ {
551+ bool ? existsAndCorrect = song . Where ( s => s . Id == orderedResults . ElementAt ( j ) . Key . Id ) . SingleOrDefault ( ) ? . Active ;
552+ if ( ! existsAndCorrect . HasValue || ! existsAndCorrect . Value )
553+ correct = false ;
554+ }
530555
531- if ( correct )
532- Interlocked . Increment ( ref numTestsCorrect ) ;
556+ if ( correct )
557+ Interlocked . Increment ( ref numTestsCorrect ) ;
533558
559+ }
534560 //await Task.Delay(0); // Let the form live between iterations
535561 } ) ;
536562 Parallel . ForEach ( neuralTestCandidates . OrderBy ( n => random . NextDouble ( ) ) , new ParallelOptions ( ) { MaxDegreeOfParallelism = parallelism } , song =>
537563 {
538- float maxAvgNoteLength = song . Max ( c => c . Id == 9 ? 0 : c . avgNoteLength ) ;
539- float maxNoteLength = song . Max ( c => c . Id == 9 ? 0 : c . totalNoteLength ) ;
540- float maxNumNotes = song . Max ( c => c . Id == 9 ? 0 : c . numNotes ) ;
541-
542- float maxTickNumber = song . SelectMany ( c => c . tickNotes ) . SelectMany ( kvp => kvp . Value ) . Max ( n => n . tickNumber ) ;
543-
544- Dictionary < MidiChannelItem , float > channelResults = new Dictionary < MidiChannelItem , float > ( ) ;
545- foreach ( var channel in song )
564+ if ( ! cancelled )
546565 {
547- //var inputs = channel.GetRecurrentInput(noteParams, maxTickNumber);
548- //var neuralResults = tsm.neural.FeedForwardRecurrent(inputs);
549- var inputs = channel . GetNeuralInputs ( maxAvgNoteLength , maxNumNotes , maxNoteLength ) ;
550- var neuralResults = tsm . neural . FeedForward ( inputs ) ;
551- channelResults [ channel ] = neuralResults [ 0 ] ;
552- }
566+ float maxAvgNoteLength = song . Max ( c => c . Id == 9 ? 0 : c . avgNoteLength ) ;
567+ float maxNoteLength = song . Max ( c => c . Id == 9 ? 0 : c . totalNoteLength ) ;
568+ float maxNumNotes = song . Max ( c => c . Id == 9 ? 0 : c . numNotes ) ;
553569
554- var orderedResults = channelResults . OrderByDescending ( kvp => kvp . Value ) ;
570+ float maxTickNumber = song . SelectMany ( c => c . tickNotes ) . SelectMany ( kvp => kvp . Value ) . Max ( n => n . tickNumber ) ;
555571
572+ Dictionary < MidiChannelItem , float > channelResults = new Dictionary < MidiChannelItem , float > ( ) ;
573+ foreach ( var channel in song )
574+ {
575+ //var inputs = channel.GetRecurrentInput(noteParams, maxTickNumber);
576+ //var neuralResults = tsm.neural.FeedForwardRecurrent(inputs);
577+ var inputs = channel . GetNeuralInputs ( maxAvgNoteLength , maxNumNotes , maxNoteLength ) ;
578+ var neuralResults = tsm . neural . FeedForward ( inputs ) ;
579+ channelResults [ channel ] = neuralResults [ 0 ] ;
580+ }
556581
557- /*
558- var inputs = song.GetNeuralInput();
582+ var orderedResults = channelResults . OrderByDescending ( kvp => kvp . Value ) ;
559583
560- var neuralResults = tsm.neural.FeedForward(inputs);
561- // The output here is a channel ID to confidence map...
562- // I need to find the ID of the best one... and rank the rest...
563- // So let's just build a quick dictionary I guess
564- var results = new Dictionary<int, float>();
565584
566- for (int j = 0; j < neuralResults.Length; j++)
567- {
568- results[j] = neuralResults[j];
569- }
585+ /*
586+ var inputs = song.GetNeuralInput();
570587
571- var orderedResults = results.OrderByDescending(kvp => kvp.Value );
572- */
573- // Check the number of active flute channels ...
574- var numFlute = song . Where ( s => s . Active ) . Count ( ) ;
575- bool correct = true ;
588+ var neuralResults = tsm.neural.FeedForward(inputs );
589+ // The output here is a channel ID to confidence map...
590+ // I need to find the ID of the best one... and rank the rest ...
591+ // So let's just build a quick dictionary I guess
592+ var results = new Dictionary<int, float>() ;
576593
577- for ( int j = 0 ; j < numFlute ; j ++ )
578- {
579- bool ? existsAndCorrect = song . Where ( s => s . Id == orderedResults . ElementAt ( j ) . Key . Id ) . SingleOrDefault ( ) ? . Active ;
580- if ( ! existsAndCorrect . HasValue || ! existsAndCorrect . Value )
581- correct = false ;
582- }
594+ for (int j = 0; j < neuralResults.Length; j++)
595+ {
596+ results[j] = neuralResults[j];
597+ }
583598
584- if ( correct )
585- {
586- Interlocked . Increment ( ref numTestsCorrect ) ;
587- Interlocked . Increment ( ref numActualTestsCorrect ) ;
588- }
599+ var orderedResults = results.OrderByDescending(kvp => kvp.Value);
600+ */
601+ // Check the number of active flute channels...
602+ var numFlute = song . Where ( s => s . Active ) . Count ( ) ;
603+ bool correct = true ;
589604
590- //int fluteCount = 0;
591- //foreach (var channel in channelResults.Keys)
592- ////foreach (var channel in activeChannels)
593- //{
594- // //var channel = activeChannels.Where(c => c.Id == orderedResults.ElementAt(i).Key).SingleOrDefault();
595- //
596- // if (channel != null)
597- // {
598- // Console.WriteLine($"{channel.Name} ({channel.Id}) - Neural Score: {channelResults[channel]}");
599- // //Console.WriteLine($"{channel.Name} ({channel.Id}) - Neural Score: {neuralResults[channel.Id]}");
600- // //channel.Name += $"(Flute Rank {++fluteCount} - {Math.Round(channelResults[channel], 2)}%)";
601- // }
602- //}
605+ for ( int j = 0 ; j < numFlute ; j ++ )
606+ {
607+ bool ? existsAndCorrect = song . Where ( s => s . Id == orderedResults . ElementAt ( j ) . Key . Id ) . SingleOrDefault ( ) ? . Active ;
608+ if ( ! existsAndCorrect . HasValue || ! existsAndCorrect . Value )
609+ correct = false ;
610+ }
603611
604- //await Task.Delay(0); // Let the form live between iterations
612+ if ( correct )
613+ {
614+ Interlocked . Increment ( ref numTestsCorrect ) ;
615+ Interlocked . Increment ( ref numActualTestsCorrect ) ;
616+ }
617+
618+ //int fluteCount = 0;
619+ //foreach (var channel in channelResults.Keys)
620+ ////foreach (var channel in activeChannels)
621+ //{
622+ // //var channel = activeChannels.Where(c => c.Id == orderedResults.ElementAt(i).Key).SingleOrDefault();
623+ //
624+ // if (channel != null)
625+ // {
626+ // Console.WriteLine($"{channel.Name} ({channel.Id}) - Neural Score: {channelResults[channel]}");
627+ // //Console.WriteLine($"{channel.Name} ({channel.Id}) - Neural Score: {neuralResults[channel.Id]}");
628+ // //channel.Name += $"(Flute Rank {++fluteCount} - {Math.Round(channelResults[channel], 2)}%)";
629+ // }
630+ //}
631+
632+ //await Task.Delay(0); // Let the form live between iterations
633+ }
605634 } ) ;
606635 BeginInvoke ( ( MethodInvoker ) delegate
607636 {
@@ -626,11 +655,11 @@ private async Task TrainNetwork(int numPerfect, float percentForSuccess, int par
626655 {
627656 richTextBox1 . AppendText ( $ "\n \n ----Training Complete----\n \n Network saved to CustomNetwork file to { savePath } \n You may now load midis and check it for yourself, and this will be your new default network. \n \n You can revert at any time by deleting this file, or train again to replace it") ;
628657 richTextBox1 . ScrollToCaret ( ) ;
629- textBoxSizes . Enabled = true ;
630658 buttonTrain . Enabled = true ;
631659 textBox1 . Enabled = true ;
632660 textBoxParallel . Enabled = true ;
633661 textBox2 . Enabled = true ;
662+ buttonTrain . Text = "Train" ;
634663 } ) ;
635664 if ( File . Exists ( savePath ) )
636665 File . Delete ( savePath ) ;
0 commit comments