Skip to content

There might be error in the train_maml function. #6

@lishanwu135

Description

@lishanwu135

The loss used for calculating the gradient to perform the meta update is only from one task. However, it should be the sum of all sampled tasks according to the original paper. Please look at step 8 in the code (below). The test_loss is inside the loop for i, t in enumerate(random.sample(dataset, len(dataset))), indicating the test_loss is only for one task.

# Step 2: instead of checking for convergence, we train for a number
# of epochs
for _ in range(epochs):
    total_loss = 0
    losses = []
    start = time.time()
    # Step 3 and 4
    for i, t in enumerate(random.sample(dataset, len(dataset))):
        **x, y = np_to_tensor(t.batch())**
        model.forward(x)  # run forward pass to initialize weights
        with tf.GradientTape() as test_tape:
            # test_tape.watch(model.trainable_variables)
            # Step 5
            with tf.GradientTape() as train_tape:
                train_loss, _ = compute_loss(model, x, y)
            # Step 6
            gradients = train_tape.gradient(train_loss, model.trainable_variables)
            k = 0
            model_copy = copy_model(model, x)
            for j in range(len(model_copy.layers)):
                model_copy.layers[j].kernel = tf.subtract(model.layers[j].kernel,
                            tf.multiply(lr_inner, gradients[k]))
                model_copy.layers[j].bias = tf.subtract(model.layers[j].bias,
                            tf.multiply(lr_inner, gradients[k+1]))
                k += 2
            # Step 8
            **test_loss**, logits = compute_loss(model_copy, **x, y**)
        # Step 8
        gradients = test_tape.gradient(**test_loss**, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions