@@ -36,6 +36,7 @@ def __init__(
3636 "gpu" : False ,
3737 "ddp" : False ,
3838 "mpi" : False ,
39+ "mps" : False ,
3940 "device" : "cpu" ,
4041 "openpmd_configuration" : {},
4142 "openpmd_granularity" : 1 ,
@@ -80,6 +81,17 @@ def _update_gpu(self, new_gpu):
8081 """
8182 self ._configuration ["gpu" ] = new_gpu
8283
84+ def _update_mps (self , new_mps ):
85+ """
86+ Propagate new Apple silicon GPU setting to parameter subclasses.
87+
88+ Parameters
89+ ----------
90+ new_mps : bool
91+ New GPU setting.
92+ """
93+ self ._configuration ["mps" ] = new_mps
94+
8395 def _update_ddp (self , new_ddp ):
8496 """
8597 Propagate new DDP setting to parameter subclasses.
@@ -1630,9 +1642,14 @@ def __init__(self):
16301642 self .manual_seed = None
16311643
16321644 # Properties
1645+ # Needed for first initialization, or else the resetting of the device
1646+ # fails. This is because we dynamically set the device based on MPS
1647+ # and GPU values.
1648+ self ._use_mps = False
16331649 self .use_gpu = False
16341650 self .use_ddp = False
16351651 self .use_mpi = False
1652+ self .use_mps = False
16361653 self .verbosity = 1
16371654 self .device = "cpu"
16381655 self .openpmd_configuration = {}
@@ -1700,6 +1717,10 @@ def use_gpu(self, value):
17001717 if value is False :
17011718 self ._use_gpu = False
17021719 else :
1720+ # Cannot use CUDA and MPS at the same time.
1721+ # Also don't think anyone would want that.
1722+ self .use_mps = False
1723+
17031724 if torch .cuda .is_available ():
17041725 self ._use_gpu = True
17051726 else :
@@ -1723,6 +1744,36 @@ def use_gpu(self, value):
17231744 self .running ._update_gpu (self .use_gpu )
17241745 self .hyperparameters ._update_gpu (self .use_gpu )
17251746
1747+ @property
1748+ def use_mps (self ):
1749+ """Control whether an Apple silicon GPU is used."""
1750+ return self ._use_mps
1751+
1752+ @use_mps .setter
1753+ def use_mps (self , value ):
1754+ if value is False :
1755+ self ._use_mps = False
1756+ else :
1757+ # Cannot use CUDA and MPS at the same time.
1758+ # Also don't think anyone would want that.
1759+ self .use_gpu = False
1760+ if torch .mps .is_available ():
1761+ self ._use_mps = True
1762+ else :
1763+ parallel_warn (
1764+ "GPU requested, but no GPU found. MALA will "
1765+ "operate with CPU only."
1766+ )
1767+
1768+ # Invalidate, will be updated in setter.
1769+ self .device = None
1770+ self .network ._update_mps (self .use_mps )
1771+ self .descriptors ._update_mps (self .use_mps )
1772+ self .targets ._update_mps (self .use_mps )
1773+ self .data ._update_mps (self .use_mps )
1774+ self .running ._update_mps (self .use_mps )
1775+ self .hyperparameters ._update_mps (self .use_mps )
1776+
17261777 @property
17271778 def use_ddp (self ):
17281779 """Control whether ddp is used for parallel training."""
@@ -1765,6 +1816,9 @@ def device(self, value):
17651816 device_id = get_local_rank ()
17661817 if self .use_gpu :
17671818 self ._device = "cuda:" f"{ device_id } "
1819+ elif self .use_mps :
1820+ if torch .mps .is_available ():
1821+ self ._device = "mps:" f"{ device_id } "
17681822 else :
17691823 self ._device = "cpu"
17701824 self .network ._update_device (self ._device )
0 commit comments