Concept of Parallel Context#
Parallel Context is distributed process group manager in OSLO. You can easily create and manage distributed process groups by putting the desired parallelization sizes into Parallel Context class.
1. Create Parallel Context object#
There are three methods to create Parallel Context: from_torch
, from_slurm
, from_openmpi
.
1.1. from_torch
#
If you are using PyTorch distributed launcher, you can use from_torch
function to create it.
from oslo.torch import ParallelContext
parallel_context = ParallelContext.from_torch(
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
sequence_parallel_size=1,
expert_parallel_size=1,
)
1.2. from_slurm
#
If you are using Slurm launcher, you can use from_slurm
function to create it.
In this case, you must input host
and port
together.
from oslo.torch import ParallelContext
YOUR_HOST = ...
YOUR_PORT = ...
parallel_context = ParallelContext.from_slurm(
host=YOUR_HOST,
port=YOUR_PORT,
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
sequence_parallel_size=1,
expert_parallel_size=1,
)
1.3. from_openmpi
#
If you are using OpenMPI launcher, you can use from_openmpi
function to create it.
Similar with from_slurm
, you must input host
and port
together.
from oslo.torch import ParallelContext
YOUR_HOST = ...
YOUR_PORT = ...
parallel_context = ParallelContext.from_openmpi(
host=YOUR_HOST,
port=YOUR_PORT,
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
sequence_parallel_size=1,
expert_parallel_size=1,
)
2. Check device ranks and world sizes easily#
There is an enum class named ParallelMode
, you can easily check device ranks and world sizes easily with it.
2.1. Ranks#
from oslo.torch import ParallelMode
# create parallel context object
parallel_context = ...
global_rank = parallel_context.get_local_rank(ParallelMode.GLOBAL)
data_parallel_rank = parallel_context.get_local_rank(ParallelMode.DATA)
tensor_parallel_rank = parallel_context.get_local_rank(ParallelMode.TENSOR)
pipeline_parallel_rank = parallel_context.get_local_rank(ParallelMode.PIPELINE)
sequence_parallel_rank = parallel_context.get_local_rank(ParallelMode.SEQUENCE)
expoert_parallel_rank = parallel_context.get_local_rank(ParallelMode.EXPERT)
2.2. World sizes#
from oslo.torch import ParallelMode
# create parallel context object
parallel_context = ...
global_size = parallel_context.get_world_size(ParallelMode.GLOBAL)
data_parallel_size = parallel_context.get_world_size(ParallelMode.DATA)
tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR)
pipeline_parallel_size = parallel_context.get_world_size(ParallelMode.PIPELINE)
sequence_parallel_size = parallel_context.get_world_size(ParallelMode.SEQUENCE)
expoert_parallel_size = parallel_context.get_world_size(ParallelMode.EXPERT)