in FirebaseMLModelDownloader/Sources/ModelDownloader.swift [411:523]
func downloadInfoAndModel(modelName: String,
modelInfoRetriever: ModelInfoRetriever,
downloader: FileDownloader,
conditions: ModelDownloadConditions,
progressHandler: ((Float) -> Void)? = nil,
completion: @escaping (Result<CustomModel, DownloadError>)
-> Void) {
modelInfoRetriever.downloadModelInfo { result in
switch result {
case let .success(downloadModelInfoResult):
switch downloadModelInfoResult {
// New model info was downloaded from server.
case let .modelInfo(remoteModelInfo):
// Progress handler for model file download.
let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
if let progressHandler = progressHandler {
self.asyncOnMainQueue(progressHandler(progress))
}
}
// Completion handler for model file download.
let taskCompletion: ModelDownloadTask.Completion = { result in
switch result {
case let .success(model):
self.asyncOnMainQueue(completion(.success(model)))
case let .failure(error):
switch error {
case .notFound:
self.asyncOnMainQueue(completion(.failure(.notFound)))
case .invalidArgument:
self.asyncOnMainQueue(completion(.failure(.invalidArgument)))
case .permissionDenied:
self.asyncOnMainQueue(completion(.failure(.permissionDenied)))
// This is the error returned when model download URL has expired.
case .expiredDownloadURL:
// Retry model info and model file download, if allowed.
guard self.numberOfRetries > 0 else {
self
.asyncOnMainQueue(
completion(.failure(.internalError(description: ModelDownloader
.ErrorDescription
.expiredModelInfo)))
)
return
}
self.numberOfRetries -= 1
DeviceLogger.logEvent(level: .debug,
message: ModelDownloader.DebugDescription.retryDownload,
messageCode: .retryDownload)
self.downloadInfoAndModel(
modelName: modelName,
modelInfoRetriever: modelInfoRetriever,
downloader: downloader,
conditions: conditions,
progressHandler: progressHandler,
completion: completion
)
default:
self.asyncOnMainQueue(completion(.failure(error)))
}
}
self.taskSerialQueue.async {
// Stop keeping track of current download task.
self.currentDownloadTask.removeValue(forKey: modelName)
}
}
self.taskSerialQueue.sync {
// Merge duplicate requests if there is already a download in progress for the same model.
if let downloadTask = self.currentDownloadTask[modelName],
downloadTask.canMergeRequests() {
downloadTask.merge(
newProgressHandler: taskProgressHandler,
newCompletion: taskCompletion
)
DeviceLogger.logEvent(level: .debug,
message: ModelDownloader.DebugDescription.mergingRequests,
messageCode: .mergeRequests)
if downloadTask.canResume() {
downloadTask.resume()
}
// TODO: Handle else.
} else {
// Create download task for model file download.
let downloadTask = ModelDownloadTask(
remoteModelInfo: remoteModelInfo,
appName: self.appName,
defaults: self.userDefaults,
downloader: downloader,
progressHandler: taskProgressHandler,
completion: taskCompletion,
telemetryLogger: self.telemetryLogger
)
// Keep track of current download task to allow for merging duplicate requests.
self.currentDownloadTask[modelName] = downloadTask
downloadTask.resume()
}
}
/// Local model info is the latest model info.
case .notModified:
guard let localModel = self.getLocalModel(modelName: modelName) else {
// This can only happen if either local model info or the model file was wiped out after model info request but before server response.
self
.asyncOnMainQueue(completion(.failure(.internalError(description: ModelDownloader
.ErrorDescription.deletedLocalModelInfoOrFile))))
return
}
self.asyncOnMainQueue(completion(.success(localModel)))
}
// Error retrieving model info.
case let .failure(error):
self.asyncOnMainQueue(completion(.failure(error)))
}
}
}