Skip to content

Commit 4e5d671

Browse files
Training works
1 parent 4a46f67 commit 4e5d671

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

mala/common/parameters.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
"gpu": False,
3737
"ddp": False,
3838
"mpi": False,
39+
"mps" : False,
3940
"device": "cpu",
4041
"openpmd_configuration": {},
4142
"openpmd_granularity": 1,
@@ -80,6 +81,17 @@ def _update_gpu(self, new_gpu):
8081
"""
8182
self._configuration["gpu"] = new_gpu
8283

84+
def _update_mps(self, new_mps):
85+
"""
86+
Propagate new Apple silicon GPU setting to parameter subclasses.
87+
88+
Parameters
89+
----------
90+
new_mps : bool
91+
New GPU setting.
92+
"""
93+
self._configuration["mps"] = new_mps
94+
8395
def _update_ddp(self, new_ddp):
8496
"""
8597
Propagate new DDP setting to parameter subclasses.
@@ -1630,9 +1642,14 @@ def __init__(self):
16301642
self.manual_seed = None
16311643

16321644
# Properties
1645+
# Needed for first initialization, or else the resetting of the device
1646+
# fails. This is because we dynamically set the device based on MPS
1647+
# and GPU values.
1648+
self._use_mps = False
16331649
self.use_gpu = False
16341650
self.use_ddp = False
16351651
self.use_mpi = False
1652+
self.use_mps = False
16361653
self.verbosity = 1
16371654
self.device = "cpu"
16381655
self.openpmd_configuration = {}
@@ -1700,6 +1717,10 @@ def use_gpu(self, value):
17001717
if value is False:
17011718
self._use_gpu = False
17021719
else:
1720+
# Cannot use CUDA and MPS at the same time.
1721+
# Also don't think anyone would want that.
1722+
self.use_mps = False
1723+
17031724
if torch.cuda.is_available():
17041725
self._use_gpu = True
17051726
else:
@@ -1723,6 +1744,36 @@ def use_gpu(self, value):
17231744
self.running._update_gpu(self.use_gpu)
17241745
self.hyperparameters._update_gpu(self.use_gpu)
17251746

1747+
@property
1748+
def use_mps(self):
1749+
"""Control whether an Apple silicon GPU is used."""
1750+
return self._use_mps
1751+
1752+
@use_mps.setter
1753+
def use_mps(self, value):
1754+
if value is False:
1755+
self._use_mps = False
1756+
else:
1757+
# Cannot use CUDA and MPS at the same time.
1758+
# Also don't think anyone would want that.
1759+
self.use_gpu = False
1760+
if torch.mps.is_available():
1761+
self._use_mps = True
1762+
else:
1763+
parallel_warn(
1764+
"GPU requested, but no GPU found. MALA will "
1765+
"operate with CPU only."
1766+
)
1767+
1768+
# Invalidate, will be updated in setter.
1769+
self.device = None
1770+
self.network._update_mps(self.use_mps)
1771+
self.descriptors._update_mps(self.use_mps)
1772+
self.targets._update_mps(self.use_mps)
1773+
self.data._update_mps(self.use_mps)
1774+
self.running._update_mps(self.use_mps)
1775+
self.hyperparameters._update_mps(self.use_mps)
1776+
17261777
@property
17271778
def use_ddp(self):
17281779
"""Control whether ddp is used for parallel training."""
@@ -1765,6 +1816,9 @@ def device(self, value):
17651816
device_id = get_local_rank()
17661817
if self.use_gpu:
17671818
self._device = "cuda:" f"{device_id}"
1819+
elif self.use_mps:
1820+
if torch.mps.is_available():
1821+
self._device = "mps:" f"{device_id}"
17681822
else:
17691823
self._device = "cpu"
17701824
self.network._update_device(self._device)

0 commit comments

Comments
 (0)