@@ -166,23 +166,18 @@ export class Link {
166166 accErrorDer = 0
167167 /** Number of accumulated derivatives since the last update. */
168168 numAccumulatedDers = 0
169- regularization : RegularizationFunction
170169
171170 /**
172171 * Constructs a link in the neural network initialized with random weight.
173172 *
174173 * @param source The source node.
175174 * @param dest The destination node.
176- * @param regularization The regularization function that computes the
177- * penalty for this weight. If null, there will be no regularization.
178175 * @param initZero
179176 */
180- constructor ( source : Node , dest : Node ,
181- regularization : RegularizationFunction , initZero ?: boolean ) {
177+ constructor ( source : Node , dest : Node , initZero ?: boolean ) {
182178 this . id = source . id + '-' + dest . id
183179 this . source = source
184180 this . dest = dest
185- this . regularization = regularization
186181 if ( initZero ) {
187182 this . weight = 0
188183 }
@@ -197,16 +192,12 @@ export class Link {
197192 * 3 nodes in second hidden layer and 1 output node.
198193 * @param activation The activation function of every hidden node.
199194 * @param outputActivation The activation function for the output nodes.
200- * @param regularization The regularization function that computes a penalty
201- * for a given weight (parameter) in the network. If null, there will be
202- * no regularization.
203195 * @param inputIds List of ids for the input nodes.
204196 * @param initZero
205197 */
206198export function buildNetwork (
207199 networkShape : number [ ] , activation : ActivationFunction ,
208200 outputActivation : ActivationFunction ,
209- regularization : RegularizationFunction ,
210201 inputIds : string [ ] , initZero ?: boolean ) : Node [ ] [ ] {
211202 let numLayers = networkShape . length
212203 let id = 1
@@ -232,7 +223,7 @@ export function buildNetwork (
232223 // Add links from nodes in the previous layer to this node.
233224 for ( let j = 0 ; j < network [ layerIdx - 1 ] . length ; j ++ ) {
234225 let prevNode = network [ layerIdx - 1 ] [ j ]
235- let link = new Link ( prevNode , node , regularization , initZero )
226+ let link = new Link ( prevNode , node , initZero )
236227 prevNode . outputs . push ( link )
237228 node . inputLinks . push ( link )
238229 }
@@ -330,12 +321,25 @@ export function backProp (network: Node[][], target: number,
330321 }
331322}
332323
324+ type UpdateWeights = {
325+ network : Node [ ] [ ] ,
326+ learningRate : number ,
327+ regularization : RegularizationFunction ,
328+ regularizationRate : number ,
329+ }
330+
333331/**
334332 * Updates the weights of the network using the previously accumulated error
335333 * derivatives.
336334 */
337- export function updateWeights ( network : Node [ ] [ ] , learningRate : number ,
338- regularizationRate : number ) {
335+ export function updateWeights (
336+ {
337+ network,
338+ learningRate,
339+ regularization,
340+ regularizationRate,
341+ } : UpdateWeights ,
342+ ) {
339343 for ( let layerIdx = 1 ; layerIdx < network . length ; layerIdx ++ ) {
340344 let currentLayer = network [ layerIdx ]
341345 for ( let i = 0 ; i < currentLayer . length ; i ++ ) {
@@ -352,16 +356,16 @@ export function updateWeights (network: Node[][], learningRate: number,
352356 if ( link . isDead ) {
353357 continue
354358 }
355- let regulDer = link . regularization ?
356- link . regularization . der ( link . weight ) : 0
359+ let regulDer = regularization ?
360+ regularization . der ( link . weight ) : 0
357361 if ( link . numAccumulatedDers > 0 ) {
358362 // Update the weight based on dE/dw.
359363 link . weight = link . weight -
360364 ( learningRate / link . numAccumulatedDers ) * link . accErrorDer
361365 // Further update the weight based on regularization.
362366 let newLinkWeight = link . weight -
363367 ( learningRate * regularizationRate ) * regulDer
364- if ( link . regularization === RegularizationFunction . L1 &&
368+ if ( regularization === RegularizationFunction . L1 &&
365369 link . weight * newLinkWeight < 0 ) {
366370 // The weight crossed 0 due to the regularization term. Set it to 0.
367371 link . weight = 0
0 commit comments