optimum/habana/distributed/strategy.py (83 lines of code) (raw):
# Copyright 2024 The Foundation Model Stack Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified from its original version.
# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack
from abc import abstractmethod
from typing import List
import torch
import torch.distributed
from torch import nn
class DistributedStrategy:
def __init__(self, from_meta=False):
self.from_meta = from_meta
def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module:
"""
Optionally a distributed strategy may distribute modules that are not
numbered layers
"""
return module
@abstractmethod
def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module:
"""
Distribute each layer as-appropriate
"""
pass
class NotDistributed(DistributedStrategy):
def __init__(self, from_meta=False):
super().__init__(from_meta)
def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module:
return module
def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module:
return block
NoOpStrategy = NotDistributed()
class DeviceMover(nn.Module):
def __init__(self, module: nn.Module, device):
super().__init__()
self.device = device
# make this wrapper module behave as if it was the wrapped module.
attr = module.__dict__
attr["module"] = module.to(device)
attr["device"] = device
self.__dict__ = attr
def forward(self, *args, **kwargs):
device = self.device
args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: (kwargs[k].to(device) if isinstance(kwargs[k], torch.Tensor) else kwargs[k]) for k in kwargs}
return self.module(*args, **kwargs)
class UniformModelParallelStrategy(DistributedStrategy):
def __init__(self, devices: List[int], num_layers: int, from_meta=False):
super().__init__(from_meta)
num_dev = len(devices)
layers_per_dev = num_layers // num_dev
remainder = num_layers - (layers_per_dev * num_dev)
self.layer_to_device = [0] * num_layers
layer_id = 0
for dev_idx in range(len(devices)):
for i in range(layers_per_dev):
self.layer_to_device[layer_id] = devices[dev_idx]
layer_id = layer_id + 1
if remainder > 0:
self.layer_to_device[layer_id] = devices[dev_idx]
layer_id = layer_id + 1
remainder -= 1
def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module:
device = self.layer_to_device[layer]
if self.from_meta:
block.to_empty(device=device) # type: ignore[arg-type]
wrapped = DeviceMover(block, device)
return wrapped
def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module:
if final_layers:
device = self.layer_to_device[len(self.layer_to_device) - 1]
else:
device = self.layer_to_device[0]
if self.from_meta:
return module.to_empty(device=device) # type: ignore[arg-type]
wrapped = DeviceMover(module, device)
return wrapped
class TensorParallelStrategy(DistributedStrategy):
def __init__(self, group=None, from_meta=False):
super().__init__(from_meta)
assert torch.distributed.is_initialized(), "must initialize a process group"
self.group = group if group is not None else torch.distributed.GroupMember.WORLD
def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module:
from . import tp_wrapping
return tp_wrapping.apply_tp(module, self.group)
def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module:
from . import tp_wrapping
return tp_wrapping.apply_tp(block, layer, self.group)
def __getstate__(self):
state = self.__dict__.copy()
state["group"] = None # Remove ProcessGroup from state
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.group = None # Restore to default state or reinitialize