-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy path0b_preprocess_training_data.py
More file actions
95 lines (75 loc) · 3.02 KB
/
0b_preprocess_training_data.py
File metadata and controls
95 lines (75 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Translate data from HuMoR-style npz format to an hdf5-based one.
Due to AMASS licensing, we unfortunately can't re-distribute our preprocessed dataset. If you have questions
or run into issues, please reach out.
"""
import queue
import threading
import time
from pathlib import Path
import h5py
import torch
import torch.cuda
import tyro
from egoallo import fncsmpl
from egoallo.data.amass import EgoTrainingData
def main(
smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"),
data_npz_dir: Path = Path("./data/processed_30fps_no_skating/"),
output_file: Path = Path("./data/egoalgo_no_skating_dataset.hdf5"),
output_list_file: Path = Path("./data/egoalgo_no_skating_dataset_files.txt"),
include_hands: bool = True,
) -> None:
body_model = fncsmpl.SmplhModel.load(smplh_npz_path)
assert torch.cuda.is_available()
task_queue = queue.Queue[Path]()
for path in list(data_npz_dir.glob("**/*.npz")):
task_queue.put_nowait(path)
total_count = task_queue.qsize()
start_time = time.time()
output_hdf5 = h5py.File(output_file, "w")
file_list: list[str] = []
def worker(device_idx: int) -> None:
device_body_model = body_model.to("cuda:" + str(device_idx))
while True:
try:
npz_path = task_queue.get_nowait()
except queue.Empty:
break
print(f"Processing {npz_path} on device {device_idx}...")
train_data = EgoTrainingData.load_from_npz(
device_body_model, npz_path, include_hands=include_hands
)
assert "neutral" in str(npz_path)
group_name = str(npz_path).rpartition("neutral/")[2]
print(f"Writing to group {group_name} on {device_idx}...")
group = output_hdf5.create_group(group_name)
file_list.append(group_name)
for k, v in vars(train_data).items():
# No need to write the mask, which will always be ones when we
# load from the npz file!
if k == "mask":
continue
# Chunk into 32 timesteps at a time.
assert v.dtype == torch.float32
if v.shape[0] == train_data.T_world_cpf.shape[0]:
chunks = (min(32, v.shape[0]),) + v.shape[1:]
else:
assert v.shape[0] == 1
chunks = v.shape
group.create_dataset(k, data=v.numpy(force=True), chunks=chunks)
print(
f"Finished ~{total_count - task_queue.qsize()}/{total_count},",
f"{(total_count - task_queue.qsize()) / total_count * 100:.2f}% in",
f"{time.time() - start_time} seconds",
)
workers = [
threading.Thread(target=worker, args=(i,))
for i in range(torch.cuda.device_count())
]
for w in workers:
w.start()
for w in workers:
w.join()
output_list_file.write_text("\n".join(file_list))
if __name__ == "__main__":
tyro.cli(main)