-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
The loss used for calculating the gradient to perform the meta update is only from one task. However, it should be the sum of all sampled tasks according to the original paper. Please look at step 8 in the code (below). The test_loss is inside the loop for i, t in enumerate(random.sample(dataset, len(dataset))), indicating the test_loss is only for one task.
# Step 2: instead of checking for convergence, we train for a number
# of epochs
for _ in range(epochs):
total_loss = 0
losses = []
start = time.time()
# Step 3 and 4
for i, t in enumerate(random.sample(dataset, len(dataset))):
**x, y = np_to_tensor(t.batch())**
model.forward(x) # run forward pass to initialize weights
with tf.GradientTape() as test_tape:
# test_tape.watch(model.trainable_variables)
# Step 5
with tf.GradientTape() as train_tape:
train_loss, _ = compute_loss(model, x, y)
# Step 6
gradients = train_tape.gradient(train_loss, model.trainable_variables)
k = 0
model_copy = copy_model(model, x)
for j in range(len(model_copy.layers)):
model_copy.layers[j].kernel = tf.subtract(model.layers[j].kernel,
tf.multiply(lr_inner, gradients[k]))
model_copy.layers[j].bias = tf.subtract(model.layers[j].bias,
tf.multiply(lr_inner, gradients[k+1]))
k += 2
# Step 8
**test_loss**, logits = compute_loss(model_copy, **x, y**)
# Step 8
gradients = test_tape.gradient(**test_loss**, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels