in distributed/rpc/parameter_server/rpc_parameter_server.py [0:0]
def __init__(self, num_gpus=0):
super(Net, self).__init__()
print(f"Using {num_gpus} GPUs to train")
self.num_gpus = num_gpus
device = torch.device(
"cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
print(f"Putting first 2 convs on {str(device)}")
# Put conv layers on the first cuda device
self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
# Put rest of the network on the 2nd cuda device, if there is one
if "cuda" in str(device) and num_gpus > 1:
device = torch.device("cuda:1")
print(f"Putting rest of layers on {str(device)}")
self.dropout1 = nn.Dropout2d(0.25).to(device)
self.dropout2 = nn.Dropout2d(0.5).to(device)
self.fc1 = nn.Linear(9216, 128).to(device)
self.fc2 = nn.Linear(128, 10).to(device)