Skip to content

Commit d52c815

Browse files
committed
Simplify training loop
1 parent 711299e commit d52c815

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

alignit/models/alignnet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,5 @@ def forward(self, rgb_images, vector_inputs=None, depth_images=None):
127127
features.append(vec_feats)
128128

129129
fused = torch.cat(features, dim=1)
130-
print("Fused shape:", fused.shape)
131130

132131
return self.head(fused) # (B, output_dim)

alignit/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def main(cfg: TrainConfig):
5454
net.train()
5555

5656
for epoch in range(cfg.epochs):
57+
total_loss = 0
5758
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
5859
images = batch["images"]
5960
depth_images_pil = batch["depth_images"]
@@ -106,8 +107,9 @@ def main(cfg: TrainConfig):
106107
loss.backward()
107108
optimizer.step()
108109

109-
tqdm.write(f"Loss: {loss.item():.4f}")
110+
total_loss += loss.item()
110111

112+
tqdm.write(f"Loss: {total_loss / len(train_loader):.4f}")
111113
torch.save(net.state_dict(), cfg.model.path)
112114
tqdm.write(f"Model saved as {cfg.model.path}")
113115

0 commit comments

Comments
 (0)