Skip to content

Commit cc3e48e

Browse files
author
Lojze Žust
committed
SLR training script
1 parent 1036ef8 commit cc3e48e

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

README.md

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ For additional details please refer to the [paper](https://arxiv.org/abs/2108.00
6060

6161
### SLR Training
6262

63+
Use the utility script `tools/train_slr.sh` to train a model using the entire SLR pipeline.
64+
65+
```bash
66+
chmod +x tools/train_slr.sh
67+
tools/train_slr.sh
68+
```
69+
70+
The script contains the following variables, that can be changed to achieve the desired results.
71+
72+
- `MASTR_DIR`: Location of the dataset used for training.
73+
- `ARCHITECTURE`: Which architecture to use (use `python tools/train.py warmup --help` for more info).
74+
- `MODEL_NAME`: Name of the model. Used for saving logs and weights.
75+
- `BATCH_SIZE`: Batch size per gpu.
76+
- `WARMUP_EPOCHS`: Number of epochs for the warm-up phase.
77+
- `FINETUNE_EPOCHS`: Number of epochs for the fine-tuning phase.
78+
- `NUM_ITER`: Number of iterations of the SLR pseudo label estimation and fine-tuning.
79+
80+
Individual steps of the SLR pipeline can also be executed separately, with the following python scripts.
81+
6382
#### Step I: Feature warm-up
6483

6584
Train an initial model on partial labels generated from weak annotations and IMU. Uses additional object-wise losses.
@@ -71,6 +90,7 @@ python tools/train.py warmup \
7190
--batch-size 4
7291
```
7392

93+
Use the `--help` switch for more details on all possible arguments and settings.
7494

7595
#### Step II: Generate pseudo labels
7696

@@ -85,6 +105,9 @@ python tools/generate_pseudo_labels.py \
85105

86106
This creates the pseudo-labels and stores them into `output/pseudo_labels/wasr_slr_warmup_v0`.
87107

108+
109+
Use the `--help` switch for more details on all possible arguments and settings.
110+
88111
#### Step III: Fine-tune model
89112

90113
Fine-tune the initial model on the estimated pseudo-labels from the previous step.
@@ -99,6 +122,9 @@ python tools/train.py finetune \
99122
--pretrained-weights output/logs/wasr_slr_warmup/version_0/checkpoints/last.ckpt \
100123
--mask-dir output/pseudo_labels/wasr_slr_warmup_v0
101124
```
125+
126+
Use the `--help` switch for more details on all possible arguments and settings.
127+
102128
### Inference
103129

104130
#### General inference
@@ -109,19 +135,23 @@ Run inference using a trained model. `tools/general_inference.py` script is able
109135
export CUDA_VISIBLE_DEVICES=0,1
110136
python tools/general_inference.py \
111137
--architecture wasr_resnet101 \
112-
--weights-file output/logs/wasr_slr/version_0/checkpoints/last.ckpt \
138+
--weights-file output/logs/wasr_slr_v2_it1/version_0/checkpoints/last.ckpt \
113139
--image-dir data/example_dir \
114140
--output-dir output/predictions/test_predictions
115141
```
116142

117143
Additionally, `--imu-dir` can be used to supply a directory with corresponding IMU horizon masks. The directory structure should match the one of image dir.
118144

119-
**NOTE**: The IMU dir has to be provided for models architectures relying on IMU data (i.e. WaSR).
145+
**NOTE**: The IMU dir has to be provided for models architectures relying on IMU data (i.e. WaSR with IMU).
146+
147+
Use the `--help` switch for more details on all possible arguments and settings.
120148

121149
#### MODS inference
122150

123151
`tools/mods_inference.py` can be used in a similar fashion to run inference on the MODS benchmark.
124152

153+
Use the `--help` switch for more details on all possible arguments and settings.
154+
125155
## Pretrained models
126156

127157
Currently available pretrained model weights. All models are trained on the MaSTr1325 dataset using SLR and weak annotations.

tools/train_slr.sh

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/bin/bash
2+
3+
# Arguments:
4+
MASTR_DIR=data/mastr1325
5+
ARCHITECTURE=wasr_resnet101_imu
6+
MODEL_NAME=wasr_slr_v2
7+
BATCH_SIZE=3
8+
WARMUP_EPOCHS=25
9+
FINETUNE_EPOCHS=50
10+
NUM_ITER=2
11+
12+
# 1. Warm-up model
13+
echo "1. Warm-up model"
14+
python tools/train.py warmup \
15+
--architecture $ARCHITECTURE \
16+
--model-name ${MODEL_NAME}_it0 \
17+
--batch-size $BATCH_SIZE \
18+
--epochs $WARMUP_EPOCHS
19+
20+
for i in $(seq 1 $NUM_ITER)
21+
do
22+
let "PREV_I = $i - 1"
23+
CUR_I=$i
24+
PREV_VERSION=$(ls -t1 output/logs/${MODEL_NAME}_it${PREV_I} | head -n 1 | cut -d _ -f 2)
25+
26+
PREV_MODEL_FILENAME="${MODEL_NAME}_it${PREV_I}_v${PREV_VERSION}"
27+
PREV_MODEL_WEIGHTS=output/logs/${MODEL_NAME}_it${PREV_I}/version_${PREV_VERSION}/checkpoints/last.ckpt
28+
FILLED_MASKS_DIR=output/pseudo_labels/${PREV_MODEL_FILENAME}
29+
30+
echo "-------------------------"
31+
echo "Fine-tuning, iteration $CUR_I"
32+
echo "-------------------------"
33+
34+
# 2. Estimate pseudo labels
35+
echo "2. Estimate pseudo labels"
36+
python tools/generate_pseudo_labels.py \
37+
--architecture $ARCHITECTURE \
38+
--weights-file $PREV_MODEL_WEIGHTS \
39+
--output-dir $FILLED_MASKS_DIR
40+
41+
# 3. Re-train model
42+
echo "3. Retrain model"
43+
python tools/train.py finetune \
44+
--architecture $ARCHITECTURE \
45+
--model-name ${MODEL_NAME}_it${CUR_I} \
46+
--batch-size $BATCH_SIZE \
47+
--pretrained-weights $PREV_MODEL_WEIGHTS \
48+
--mask-dir $FILLED_MASKS_DIR \
49+
--epochs $FINETUNE_EPOCHS
50+
51+
done

0 commit comments

Comments
 (0)