Skip to content

Commit 9aed6c8

Browse files
Add jsrun launcher on top of latest DS
1 parent 64d6c5a commit 9aed6c8

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

Diff for: deepspeed/launcher/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
PDSH_MAX_FAN_OUT = 1024
55

66
OPENMPI_LAUNCHER = 'openmpi'
7+
JSRUN_LAUNCHER = 'jsrun'
78
MPICH_LAUNCHER = 'mpich'
89
SLURM_LAUNCHER = 'slurm'
910
MVAPICH_LAUNCHER = 'mvapich'

Diff for: deepspeed/launcher/multinode_runner.py

+54
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,60 @@ def get_cmd(self, environment, active_resources):
169169
return mpirun_cmd + export_cmd + python_exec + [self.user_script
170170
] + self.user_arguments
171171

172+
class JSRunner(MultiNodeRunner):
173+
def __init__(self, args, world_info_base64, resource_pool):
174+
super().__init__(args, world_info_base64)
175+
self.resource_pool = resource_pool
176+
# Hard coded for Summit
177+
self.add_export('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5')
178+
179+
def backend_exists(self):
180+
#TODO: if IB is available we should suggestion mvapich
181+
#This ompi check will still work for jsrun since spectrum-mpi is based on ompi
182+
return shutil.which('ompi_info')
183+
184+
@property
185+
def name(self):
186+
return "jsrun"
187+
188+
def validate_args(self):
189+
super().validate_args()
190+
#TODO: Allow for include/exclude at node-level but not gpu-level
191+
if self.args.include != "" or self.args.exclude != "":
192+
raise ValueError(
193+
f"{self.name} backend does not support worker include/exclusion")
194+
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
195+
raise ValueError(
196+
f"{self.name} backend does not support limiting num nodes/gpus")
197+
198+
def get_cmd(self, environment, active_resources):
199+
total_process_count = sum(self.resource_pool.values())
200+
201+
jsrun_cmd = [
202+
'jsrun',
203+
'-n',
204+
f'{total_process_count}',
205+
'-c',
206+
f'{7}',
207+
'-g',
208+
f'{1}',
209+
'-a',
210+
f'{1}',
211+
212+
] + split(self.args.launcher_args)
213+
214+
export_cmd = []
215+
for k, v in self.exports.items():
216+
export_cmd += ['-E', "{}={}".format(k, v)]
217+
218+
python_exec = []
219+
if not self.args.no_python:
220+
python_exec = [sys.executable, "-u"]
221+
if self.args.module:
222+
python_exec.append("-m")
223+
224+
return jsrun_cmd + export_cmd + python_exec + [self.user_script
225+
] + self.user_arguments
172226

173227
class MPICHRunner(MultiNodeRunner):
174228
def __init__(self, args, world_info_base64, resource_pool):

Diff for: deepspeed/launcher/runner.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import signal
1919
import time
2020

21-
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner
22-
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER
21+
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, JSRunner
22+
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, JSRUN_LAUNCHER
2323
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
2424
from ..nebula.constants import NEBULA_EXPORT_ENVS
2525
from ..utils import logger
@@ -511,6 +511,8 @@ def main(args=None):
511511
runner = PDSHRunner(args, world_info_base64)
512512
elif args.launcher == OPENMPI_LAUNCHER:
513513
runner = OpenMPIRunner(args, world_info_base64, resource_pool)
514+
elif args.launcher == JSRUN_LAUNCHER:
515+
runner = JSRunner(args, world_info_base64, resource_pool)
514516
elif args.launcher == MPICH_LAUNCHER:
515517
runner = MPICHRunner(args, world_info_base64, resource_pool)
516518
elif args.launcher == MVAPICH_LAUNCHER:

0 commit comments

Comments
 (0)