diff --git a/damus/ContentView.swift b/damus/ContentView.swift index f9411c9f..6d72fbdf 100644 --- a/damus/ContentView.swift +++ b/damus/ContentView.swift @@ -512,6 +512,7 @@ struct ContentView: View { case .background: print("txn: 📙 DAMUS BACKGROUNDED") Task { @MainActor in + await damus_state.nostrNetwork.close() // Close ndb streaming tasks before closing ndb to avoid memory errors damus_state.ndb.close() } break diff --git a/damus/Core/Networking/NostrNetworkManager/NostrNetworkManager.swift b/damus/Core/Networking/NostrNetworkManager/NostrNetworkManager.swift index 0e435f6d..744513b5 100644 --- a/damus/Core/Networking/NostrNetworkManager/NostrNetworkManager.swift +++ b/damus/Core/Networking/NostrNetworkManager/NostrNetworkManager.swift @@ -234,7 +234,8 @@ class NostrNetworkManager { // MARK: - App lifecycle functions - func close() { + func close() async { + await self.reader.cancelAllTasks() pool.close() } } diff --git a/damus/Core/Networking/NostrNetworkManager/SubscriptionManager.swift b/damus/Core/Networking/NostrNetworkManager/SubscriptionManager.swift index 4202c116..08226f4c 100644 --- a/damus/Core/Networking/NostrNetworkManager/SubscriptionManager.swift +++ b/damus/Core/Networking/NostrNetworkManager/SubscriptionManager.swift @@ -4,6 +4,7 @@ // // Created by Daniel D’Aquino on 2025-03-25. // +import Foundation extension NostrNetworkManager { /// Reads or fetches information from RelayPool and NostrDB, and provides an easier and unified higher-level interface. @@ -14,10 +15,12 @@ extension NostrNetworkManager { class SubscriptionManager { private let pool: RelayPool private var ndb: Ndb + private var taskManager: TaskManager init(pool: RelayPool, ndb: Ndb) { self.pool = pool self.ndb = ndb + self.taskManager = TaskManager() } // MARK: - Reading data from Nostr @@ -35,6 +38,7 @@ extension NostrNetworkManager { let ndbStreamTask = Task { do { for await item in try self.ndb.subscribe(filters: try filters.map({ try NdbFilter(from: $0) })) { + try Task.checkCancellation() switch item { case .eose: continuation.yield(.eose) @@ -48,24 +52,71 @@ extension NostrNetworkManager { } lend(unownedNote) } + try Task.checkCancellation() continuation.yield(.event(borrow: lender)) } } } catch { - Log.error("NDB streaming error: %s", for: .ndb, error.localizedDescription) + Log.error("NDB streaming error: %s", for: .subscription_manager, error.localizedDescription) } + continuation.finish() } let streamTask = Task { - for await _ in self.pool.subscribe(filters: filters, to: desiredRelays) { - // NO-OP. Notes will be automatically ingested by NostrDB - // TODO: Improve efficiency of subscriptions? + do { + for await _ in self.pool.subscribe(filters: filters, to: desiredRelays) { + // NO-OP. Notes will be automatically ingested by NostrDB + // TODO: Improve efficiency of subscriptions? + try Task.checkCancellation() + } + } + catch { + Log.error("Network streaming error: %s", for: .subscription_manager, error.localizedDescription) + } + continuation.finish() + } + + Task { + let ndbStreamTaskId = await self.taskManager.add(task: ndbStreamTask) + let streamTaskId = await self.taskManager.add(task: streamTask) + + continuation.onTermination = { @Sendable _ in + Task { + await self.taskManager.cancelAndCleanUp(taskId: ndbStreamTaskId) + await self.taskManager.cancelAndCleanUp(taskId: streamTaskId) + } } } - continuation.onTermination = { @Sendable _ in - streamTask.cancel() // Close the RelayPool stream when caller stops streaming - ndbStreamTask.cancel() + } + } + + func cancelAllTasks() async { + await self.taskManager.cancelAllTasks() + } + + actor TaskManager { + private var tasks: [UUID: Task] = [:] + + func add(task: Task) -> UUID { + let taskId = UUID() + self.tasks[taskId] = task + return taskId + } + + func cancelAndCleanUp(taskId: UUID) async { + self.tasks[taskId]?.cancel() + await self.tasks[taskId]?.value + self.tasks[taskId] = nil + return + } + + func cancelAllTasks() async { + Log.info("Cancelling all SubscriptionManager tasks", for: .subscription_manager) + for (taskId, _) in self.tasks { + Log.info("Cancelling SubscriptionManager task %s", for: .subscription_manager, taskId.uuidString) + await cancelAndCleanUp(taskId: taskId) } + Log.info("Cancelled all SubscriptionManager tasks", for: .subscription_manager) } } } diff --git a/damus/Core/Storage/DamusState.swift b/damus/Core/Storage/DamusState.swift index 1155799e..a1821e6c 100644 --- a/damus/Core/Storage/DamusState.swift +++ b/damus/Core/Storage/DamusState.swift @@ -164,8 +164,10 @@ class DamusState: HeadlessDamusState { try await self.push_notification_client.revoke_token() } wallet.disconnect() - nostrNetwork.close() - ndb.close() + Task { + await nostrNetwork.close() // Close ndb streaming tasks before closing ndb to avoid memory errors + ndb.close() + } } static var empty: DamusState { diff --git a/damus/Shared/Utilities/Log.swift b/damus/Shared/Utilities/Log.swift index f5fbc0e9..79b9955a 100644 --- a/damus/Shared/Utilities/Log.swift +++ b/damus/Shared/Utilities/Log.swift @@ -14,6 +14,7 @@ enum LogCategory: String { case render case storage case networking + case subscription_manager case timeline /// Logs related to Nostr Wallet Connect components case nwc diff --git a/nostrdb/Ndb.swift b/nostrdb/Ndb.swift index d9668483..e1754d56 100644 --- a/nostrdb/Ndb.swift +++ b/nostrdb/Ndb.swift @@ -698,9 +698,13 @@ class Ndb { // Fetch initial results guard let txn = NdbTxn(ndb: self) else { throw .cannotOpenTransaction } + do { try Task.checkCancellation() } catch { throw .cancelled } + // Use our safe wrapper instead of direct C function call let noteIds = try query(with: txn, filters: filters, maxResults: maxSimultaneousResults) + do { try Task.checkCancellation() } catch { throw .cancelled } + // Create a subscription for new events let newEventsStream = ndbSubscribe(filters: filters) @@ -717,6 +721,7 @@ class Ndb { // Create a task to forward events from the subscription stream let forwardingTask = Task { for await item in newEventsStream { + try Task.checkCancellation() continuation.yield(item) } continuation.finish() @@ -876,6 +881,7 @@ extension Ndb { case cannotConvertFilter(any Error) case initialQueryFailed case timeout + case cancelled } /// An error that may happen when looking something up