Diffusion/Common/Downloader.swift (82 lines of code) (raw):
//
// Downloader.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Foundation
import Combine
class Downloader: NSObject, ObservableObject {
private(set) var destination: URL
enum DownloadState {
case notStarted
case downloading(Double)
case completed(URL)
case failed(Error)
}
private(set) lazy var downloadState: CurrentValueSubject<DownloadState, Never> = CurrentValueSubject(.notStarted)
private var stateSubscriber: Cancellable?
private var urlSession: URLSession? = nil
init(from url: URL, to destination: URL, using authToken: String? = nil) {
self.destination = destination
super.init()
var config = URLSessionConfiguration.default
#if !os(macOS)
// .background allows downloads to proceed in the background
// helpful for devices that may not keep the app in the foreground for the download duration
config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download")
config.isDiscretionary = false
config.sessionSendsLaunchEvents = true
#endif
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue())
downloadState.value = .downloading(0)
urlSession?.getAllTasks { tasks in
// If there's an existing pending background task with the same URL, let it proceed.
guard tasks.filter({ $0.originalRequest?.url == url }).isEmpty else {
print("Already downloading \(url)")
return
}
print("Starting download of \(url)")
var request = URLRequest(url: url)
if let authToken = authToken {
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
}
self.urlSession?.downloadTask(with: request).resume()
}
}
@discardableResult
func waitUntilDone() throws -> URL {
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
let semaphore = DispatchSemaphore(value: 0)
stateSubscriber = downloadState.sink { state in
switch state {
case .completed: semaphore.signal()
case .failed: semaphore.signal()
default: break
}
}
semaphore.wait()
switch downloadState.value {
case .completed(let url): return url
case .failed(let error): throw error
default: throw("Should never happen, lol")
}
}
func cancel() {
urlSession?.invalidateAndCancel()
}
}
extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))
}
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
guard FileManager.default.fileExists(atPath: location.path) else {
downloadState.value = .failed("Invalid download location received: \(location)")
return
}
do {
try FileManager.default.moveItem(at: location, to: destination)
downloadState.value = .completed(destination)
} catch {
downloadState.value = .failed(error)
}
}
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
downloadState.value = .failed(error)
} else if let response = task.response as? HTTPURLResponse {
print("HTTP response status code: \(response.statusCode)")
// let headers = response.allHeaderFields
// print("HTTP response headers: \(headers)")
}
}
}