horovod/common/mpi/mpi_context.h (41 lines of code) (raw):
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2019 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef HOROVOD_MPI_CONTEXT_H
#define HOROVOD_MPI_CONTEXT_H
#include <iostream>
#include <memory>
#include <vector>
#include "../common.h"
#include "../half.h"
#include "../logging.h"
namespace horovod {
namespace common {
// Base class for managing MPI environment. Can be derived if other frameworks
// (like DDL) are able to manage MPI environment.
class MPIContextManager {
public:
// Initialize MPI environment with required multi-threads support level.
virtual void EnvInitialize(int mpi_threads_required);
// Finalize MPI environment.
virtual void EnvFinalize();
};
struct MPIContext {
void Enable() {
enabled_ = true;
LOG(DEBUG) << "MPI context enabled.";
};
bool IsEnabled() { return enabled_; }
// Take an argument of context manager pointer that will take care of
// initialization of MPI environment.
void Initialize(const std::vector<int>& ranks,
MPIContextManager& ctx_manager);
// Take an argument of context manager pointer that will take care of
// finalization of MPI environment.
void Finalize(MPIContextManager& ctx_manager);
MPI_Datatype GetMPIDataType(std::shared_ptr<Tensor> tensor);
MPI_Datatype GetMPIDataType(DataType dtype);
MPI_Op GetMPISumOp(DataType dtype);
MPI_Comm GetMPICommunicator(Communicator comm);
int GetMPITypeSize(DataType dtype);
// Flag indicating whether mpi is enabled.
bool enabled_ = false;
// MPI custom data type for float16.
MPI_Datatype mpi_float16_t;
MPI_Op mpi_float16_sum;
// Private MPI communicator for Horovod to ensure no collisions with other
// threads using MPI.
MPI_Comm mpi_comm;
// Node-local communicator.
MPI_Comm local_comm;
// Cross-node communicator for hierarchical allreduce.
MPI_Comm cross_comm;
// MPI Window used for shared memory allgather
MPI_Win window;
// Whether mpi context should be finalize.
bool should_finalize = false;
};
} // namespace common
} // namespace horovod
#endif // HOROVOD_MPI_CONTEXT_H