func downloadInfoAndModel()

in FirebaseMLModelDownloader/Sources/ModelDownloader.swift [411:523]


  func downloadInfoAndModel(modelName: String,
                            modelInfoRetriever: ModelInfoRetriever,
                            downloader: FileDownloader,
                            conditions: ModelDownloadConditions,
                            progressHandler: ((Float) -> Void)? = nil,
                            completion: @escaping (Result<CustomModel, DownloadError>)
                              -> Void) {
    modelInfoRetriever.downloadModelInfo { result in
      switch result {
      case let .success(downloadModelInfoResult):
        switch downloadModelInfoResult {
        // New model info was downloaded from server.
        case let .modelInfo(remoteModelInfo):
          // Progress handler for model file download.
          let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
            if let progressHandler = progressHandler {
              self.asyncOnMainQueue(progressHandler(progress))
            }
          }
          // Completion handler for model file download.
          let taskCompletion: ModelDownloadTask.Completion = { result in
            switch result {
            case let .success(model):
              self.asyncOnMainQueue(completion(.success(model)))
            case let .failure(error):
              switch error {
              case .notFound:
                self.asyncOnMainQueue(completion(.failure(.notFound)))
              case .invalidArgument:
                self.asyncOnMainQueue(completion(.failure(.invalidArgument)))
              case .permissionDenied:
                self.asyncOnMainQueue(completion(.failure(.permissionDenied)))
              // This is the error returned when model download URL has expired.
              case .expiredDownloadURL:
                // Retry model info and model file download, if allowed.
                guard self.numberOfRetries > 0 else {
                  self
                    .asyncOnMainQueue(
                      completion(.failure(.internalError(description: ModelDownloader
                          .ErrorDescription
                          .expiredModelInfo)))
                    )
                  return
                }
                self.numberOfRetries -= 1
                DeviceLogger.logEvent(level: .debug,
                                      message: ModelDownloader.DebugDescription.retryDownload,
                                      messageCode: .retryDownload)
                self.downloadInfoAndModel(
                  modelName: modelName,
                  modelInfoRetriever: modelInfoRetriever,
                  downloader: downloader,
                  conditions: conditions,
                  progressHandler: progressHandler,
                  completion: completion
                )
              default:
                self.asyncOnMainQueue(completion(.failure(error)))
              }
            }
            self.taskSerialQueue.async {
              // Stop keeping track of current download task.
              self.currentDownloadTask.removeValue(forKey: modelName)
            }
          }
          self.taskSerialQueue.sync {
            // Merge duplicate requests if there is already a download in progress for the same model.
            if let downloadTask = self.currentDownloadTask[modelName],
               downloadTask.canMergeRequests() {
              downloadTask.merge(
                newProgressHandler: taskProgressHandler,
                newCompletion: taskCompletion
              )
              DeviceLogger.logEvent(level: .debug,
                                    message: ModelDownloader.DebugDescription.mergingRequests,
                                    messageCode: .mergeRequests)
              if downloadTask.canResume() {
                downloadTask.resume()
              }
              // TODO: Handle else.
            } else {
              // Create download task for model file download.
              let downloadTask = ModelDownloadTask(
                remoteModelInfo: remoteModelInfo,
                appName: self.appName,
                defaults: self.userDefaults,
                downloader: downloader,
                progressHandler: taskProgressHandler,
                completion: taskCompletion,
                telemetryLogger: self.telemetryLogger
              )
              // Keep track of current download task to allow for merging duplicate requests.
              self.currentDownloadTask[modelName] = downloadTask
              downloadTask.resume()
            }
          }
        /// Local model info is the latest model info.
        case .notModified:
          guard let localModel = self.getLocalModel(modelName: modelName) else {
            // This can only happen if either local model info or the model file was wiped out after model info request but before server response.
            self
              .asyncOnMainQueue(completion(.failure(.internalError(description: ModelDownloader
                  .ErrorDescription.deletedLocalModelInfoOrFile))))
            return
          }
          self.asyncOnMainQueue(completion(.success(localModel)))
        }
      // Error retrieving model info.
      case let .failure(error):
        self.asyncOnMainQueue(completion(.failure(error)))
      }
    }
  }