in src/accumulator.cc [45:128]
void setup() {
rpc->define<bool(std::string_view, uint32_t, std::string_view)>(
"AccumulatorService::requestModel",
[this](std::string_view resName, uint32_t syncId, std::string_view peerName) {
std::shared_ptr<AccumulatorResource> h = resources.findPointer(resName);
if (h) {
std::unique_lock l(h->mutex);
if (syncId != h->syncId) {
log.debug("Got model update for wrong syncId (%#x, should be %#x)\n", syncId, h->syncId);
return false;
}
if (std::find(h->requestedModelUpdate.begin(), h->requestedModelUpdate.end(), peerName) ==
h->requestedModelUpdate.end()) {
h->requestedModelUpdate.push_back(std::string(peerName));
}
return true;
}
return false;
});
rpc->define<bool(
std::string_view, uint32_t, bool, int64_t, std::vector<rpc::Tensor>, std::vector<rpc::Tensor>,
GilWrapper<py::object>)>(
"AccumulatorService::modelUpdate",
[this](
std::string_view resName, uint32_t syncId, bool isRegularUpdate, int64_t modelVersion,
std::vector<rpc::Tensor> parameters, std::vector<rpc::Tensor> buffers, GilWrapper<py::object> userState) {
std::shared_ptr<AccumulatorResource> h = resources.findPointer(resName);
if (h) {
std::unique_lock l(h->mutex);
if (syncId != h->syncId) {
log.debug("Got model update for wrong syncId (%#x, should be %#x)\n", syncId, h->syncId);
return false;
}
if (isRegularUpdate && modelVersion != h->modelVersion) {
log.debug(
"Got regular model update for wrong modelVersion (%#x, should be %#x)\n", modelVersion,
h->modelVersion);
return false;
}
if (parameters.size() != h->modelParameters.size()) {
log.debug(
"Got model update for wrong number of parameters (%d, should be %d)\n", parameters.size(),
h->modelParameters.size());
return false;
}
if (buffers.size() != h->modelBuffers.size()) {
log.debug(
"Got model update for wrong number of buffers (%d, should be %d)\n", buffers.size(),
h->modelBuffers.size());
return false;
}
log.debug("got modelUpdate %d\n", modelVersion);
h->haveNewParameters = true;
h->newModelVersion = modelVersion;
h->newParameters = std::move(parameters);
h->newBuffers = std::move(buffers);
h->newUserState = std::move(*userState);
return true;
} else {
return false;
}
});
rpc->define<bool(std::string_view, uint32_t, std::vector<rpc::Tensor>)>(
"AccumulatorService::buffersUpdate",
[this](std::string_view resName, uint32_t syncId, std::vector<rpc::Tensor> buffers) {
std::shared_ptr<AccumulatorResource> h = resources.findPointer(resName);
if (h) {
std::unique_lock l(h->mutex);
if (syncId != h->syncId) {
log.debug("Got buffers update for wrong syncId (%#x, should be %#x)\n", syncId, h->syncId);
return false;
}
log.debug("Got buffers\n");
h->haveNewBuffers = true;
h->newBuffers = std::move(buffers);
return true;
} else {
return false;
}
});
}