Skip to content

Commit 20b8739

Browse files
committed
use ThreadPoolExecutor to eliminate CUDA re-initialization in data modification during training.
1 parent 432be58 commit 20b8739

File tree

6 files changed

+43
-20
lines changed

6 files changed

+43
-20
lines changed

deepmd/pd/utils/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
Dataset,
66
)
77

8+
from deepmd.pd.utils.env import (
9+
NUM_WORKERS,
10+
)
811
from deepmd.utils.data import (
912
DataRequirementItem,
1013
DeepmdData,
@@ -32,7 +35,7 @@ def __len__(self):
3235

3336
def __getitem__(self, index):
3437
"""Get a frame from the selected system."""
35-
b_data = self._data_system.get_item_paddle(index)
38+
b_data = self._data_system.get_item_paddle(index, NUM_WORKERS)
3639
b_data["natoms"] = self._natoms_vec
3740
return b_data
3841

deepmd/pt/train/training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,9 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp:
339339
if validation_data is not None:
340340
validation_data.add_data_requirement(data_requirement)
341341
# Preload and apply modifiers to all data before computing statistics
342-
training_data.preload_and_modify_all_data()
342+
training_data.preload_and_modify_all_data_torch()
343343
if validation_data is not None:
344-
validation_data.preload_and_modify_all_data()
344+
validation_data.preload_and_modify_all_data_torch()
345345
self.get_sample_func = single_model_stat(
346346
self.model,
347347
model_params.get("data_stat_nbatch", 10),
@@ -385,9 +385,9 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp:
385385
if validation_data[model_key] is not None:
386386
validation_data[model_key].add_data_requirement(data_requirement)
387387
# Preload and apply modifiers to all data before computing statistics
388-
training_data[model_key].preload_and_modify_all_data()
388+
training_data[model_key].preload_and_modify_all_data_torch()
389389
if validation_data[model_key] is not None:
390-
validation_data[model_key].preload_and_modify_all_data()
390+
validation_data[model_key].preload_and_modify_all_data_torch()
391391
self.get_sample_func[model_key] = single_model_stat(
392392
self.model[model_key],
393393
model_params["model_dict"][model_key].get("data_stat_nbatch", 10),

deepmd/pt/utils/dataloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,9 @@ def print_summary(
238238
[ss._data_system.pbc for ss in self.systems],
239239
)
240240

241-
def preload_and_modify_all_data(self) -> None:
241+
def preload_and_modify_all_data_torch(self) -> None:
242242
for system in self.systems:
243-
system.preload_and_modify_all_data()
243+
system.preload_and_modify_all_data_torch()
244244

245245

246246
def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]:

deepmd/pt/utils/dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from deepmd.pt.modifier import (
1313
BaseModifier,
1414
)
15+
from deepmd.pt.utils.env import (
16+
NUM_WORKERS,
17+
)
1518
from deepmd.utils.data import (
1619
DataRequirementItem,
1720
DeepmdData,
@@ -48,7 +51,7 @@ def __len__(self) -> int:
4851

4952
def __getitem__(self, index: int) -> dict[str, Any]:
5053
"""Get a frame from the selected system."""
51-
b_data = self._data_system.get_item_torch(index)
54+
b_data = self._data_system.get_item_torch(index, NUM_WORKERS)
5255
b_data["natoms"] = self._natoms_vec
5356
return b_data
5457

@@ -68,5 +71,5 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
6871
output_natoms_for_type_sel=data_item["output_natoms_for_type_sel"],
6972
)
7073

71-
def preload_and_modify_all_data(self) -> None:
72-
self._data_system.preload_and_modify_all_data()
74+
def preload_and_modify_all_data_torch(self) -> None:
75+
self._data_system.preload_and_modify_all_data_torch(NUM_WORKERS)

deepmd/utils/data.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,26 +253,38 @@ def check_test_size(self, test_size: int) -> bool:
253253
"""Check if the system can get a test dataset with `test_size` frames."""
254254
return self.check_batch_size(test_size)
255255

256-
def get_item_torch(self, index: int) -> dict:
256+
def get_item_torch(
257+
self,
258+
index: int,
259+
num_worker: int,
260+
) -> dict:
257261
"""Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets.
258262
259263
Parameters
260264
----------
261265
index
262266
index of the frame
267+
num_worker
268+
number of workers for parallel data modification
263269
"""
264-
return self.get_single_frame(index)
270+
return self.get_single_frame(index, num_worker)
265271

266-
def get_item_paddle(self, index: int) -> dict:
272+
def get_item_paddle(
273+
self,
274+
index: int,
275+
num_worker: int,
276+
) -> dict:
267277
"""Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets.
268278
Same with PyTorch backend.
269279
270280
Parameters
271281
----------
272282
index
273283
index of the frame
284+
num_worker
285+
number of workers for parallel data modification
274286
"""
275-
return self.get_single_frame(index)
287+
return self.get_single_frame(index, num_worker)
276288

277289
def get_batch(self, batch_size: int) -> dict:
278290
"""Get a batch of data with `batch_size` frames. The frames are randomly picked from the data system.
@@ -383,7 +395,7 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray:
383395
tmp = np.append(tmp, natoms_vec)
384396
return tmp.astype(np.int32)
385397

386-
def get_single_frame(self, index: int) -> dict:
398+
def get_single_frame(self, index: int, num_worker: int) -> dict:
387399
"""Orchestrates loading a single frame efficiently using memmap."""
388400
# Check if we have a cached modified frame and use_modifier_cache is True
389401
if (
@@ -488,14 +500,19 @@ def get_single_frame(self, index: int) -> dict:
488500
frame_data["fid"] = index
489501

490502
if self.modifier is not None:
491-
# Apply modifier if it exists
492-
self.modifier.modify_data(frame_data, self)
503+
with ThreadPoolExecutor(max_workers=num_worker) as executor:
504+
# Apply modifier if it exists
505+
executor.submit(
506+
self.modifier.modify_data,
507+
frame_data,
508+
self,
509+
)
493510
if self.use_modifier_cache:
494511
# Cache the modified frame to avoid recomputation
495512
self._modified_frame_cache[index] = copy.deepcopy(frame_data)
496513
return frame_data
497514

498-
def preload_and_modify_all_data(self) -> None:
515+
def preload_and_modify_all_data_torch(self, num_worker: int) -> None:
499516
"""Preload all frames and apply modifier to cache them.
500517
501518
This method is useful when use_modifier_cache is True and you want to
@@ -507,7 +524,7 @@ def preload_and_modify_all_data(self) -> None:
507524
log.info("Preloading and modifying all data frames...")
508525
for i in range(self.nframes):
509526
if i not in self._modified_frame_cache:
510-
self.get_single_frame(i)
527+
self.get_single_frame(i, num_worker)
511528
if (i + 1) % 100 == 0:
512529
log.info(f"Processed {i + 1}/{self.nframes} frames")
513530
log.info("All frames preloaded and modified.")

source/tests/pt/test_data_modifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def test_inference(self):
325325
"type": "scaling_tester",
326326
"model_name": "frozen_model_dm.pth",
327327
"sfactor": sfactor,
328-
"use_cache": True,
328+
"use_cache": self.param[2],
329329
}
330330

331331
trainer = get_trainer(tmp_config)

0 commit comments

Comments
 (0)