in workload_generator/generate_deepspeed_stage3_workload.py [0:0]
def __init__(self, args, model) -> None:
super().__init__(args, model)
self.name = "deepspeed_stage3"
self.amp_enabled = args.amp_enabled
self.dp_world_size = args.dp_num
self.batch_size = args.micro_batch
self.seq_len = args.seq_length
self.compute_enable = args.computation_enable
self.reduce_bucket, self.reduce_bucket_size = 0, args.reduce_bucket_size
self.prefetch_bucket_size = args.prefetch_bucket_size
self.max_live_parameters, self.current_live_parameters = (
args.max_live_parameters,
0,
)
self.stage, self._param_queue, self.all_params = (
"init",
deque(),
list(self.model.parameters()),
)
self.__param_order = [
(param, step_id)
for step_id, param in enumerate(self.all_params + self.all_params[::-1])
]
self.__most_recent_step_id_param_fetched_for = defaultdict(lambda: -1)
self._mark_persistent_parameters(
args.param_persistence_threshold, args.model_persistence_threshold
)