From 690f8b891e2484287267c9af00738f4980f0acb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20D=E2=80=99Aquino?= Date: Mon, 13 Oct 2025 16:56:37 -0700 Subject: [PATCH] Implement timestamp-based network subscription optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changelog-Changed: Optimized network bandwidth usage and improved timeline performance Signed-off-by: Daniel D’Aquino --- damus.xcodeproj/project.pbxproj | 8 + .../NostrNetworkManager/ProfilesManager.swift | 2 +- .../SubscriptionManager.swift | 86 +++- damus/Core/Nostr/RelayConnection.swift | 9 + damus/Core/Nostr/RelayPool.swift | 2 + .../Search/Models/SearchHomeModel.swift | 2 +- .../Features/Timeline/Models/HomeModel.swift | 4 +- .../Utilities/StreamPipelineDiagnostics.swift | 27 + devtools/visualize_stream_pipeline.py | 475 ++++++++++++++++++ shell.nix | 2 +- 10 files changed, 591 insertions(+), 26 deletions(-) create mode 100644 damus/Shared/Utilities/StreamPipelineDiagnostics.swift create mode 100644 devtools/visualize_stream_pipeline.py diff --git a/damus.xcodeproj/project.pbxproj b/damus.xcodeproj/project.pbxproj index d1792c39..d9a46136 100644 --- a/damus.xcodeproj/project.pbxproj +++ b/damus.xcodeproj/project.pbxproj @@ -1765,6 +1765,9 @@ D7DF58322DFCF18D00E9AD28 /* SendPaymentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7DF58312DFCF18800E9AD28 /* SendPaymentView.swift */; }; D7DF58332DFCF18D00E9AD28 /* SendPaymentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7DF58312DFCF18800E9AD28 /* SendPaymentView.swift */; }; D7DF58342DFCF18D00E9AD28 /* SendPaymentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7DF58312DFCF18800E9AD28 /* SendPaymentView.swift */; }; + D7E5B2D32EA0188200CF47AC /* StreamPipelineDiagnostics.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7E5B2D22EA0187B00CF47AC /* StreamPipelineDiagnostics.swift */; }; + D7E5B2D42EA0188200CF47AC /* StreamPipelineDiagnostics.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7E5B2D22EA0187B00CF47AC /* StreamPipelineDiagnostics.swift */; }; + D7E5B2D52EA0188200CF47AC /* StreamPipelineDiagnostics.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7E5B2D22EA0187B00CF47AC /* StreamPipelineDiagnostics.swift */; }; D7EB00B02CD59C8D00660C07 /* PresentFullScreenItemNotify.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7EB00AF2CD59C8300660C07 /* PresentFullScreenItemNotify.swift */; }; D7EB00B12CD59C8D00660C07 /* PresentFullScreenItemNotify.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7EB00AF2CD59C8300660C07 /* PresentFullScreenItemNotify.swift */; }; D7EBF8BB2E59022A004EAE29 /* NostrNetworkManagerTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7EBF8BA2E5901F7004EAE29 /* NostrNetworkManagerTests.swift */; }; @@ -2707,6 +2710,7 @@ D7DB93092D69485A00DA1EE5 /* NIP65.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NIP65.swift; sourceTree = ""; }; D7DEEF2E2A8C021E00E0C99F /* NostrEventTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NostrEventTests.swift; sourceTree = ""; }; D7DF58312DFCF18800E9AD28 /* SendPaymentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SendPaymentView.swift; sourceTree = ""; }; + D7E5B2D22EA0187B00CF47AC /* StreamPipelineDiagnostics.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StreamPipelineDiagnostics.swift; sourceTree = ""; }; D7EB00AF2CD59C8300660C07 /* PresentFullScreenItemNotify.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PresentFullScreenItemNotify.swift; sourceTree = ""; }; D7EBF8BA2E5901F7004EAE29 /* NostrNetworkManagerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NostrNetworkManagerTests.swift; sourceTree = ""; }; D7EBF8BD2E594708004EAE29 /* test_notes.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = test_notes.jsonl; sourceTree = ""; }; @@ -4687,6 +4691,7 @@ 5C78A7B82E3047DE00CF177D /* Utilities */ = { isa = PBXGroup; children = ( + D7E5B2D22EA0187B00CF47AC /* StreamPipelineDiagnostics.swift */, D77135D22E7B766300E7639F /* DataExtensions.swift */, 4CF0ABEA29844B2F00D66079 /* AnyCodable */, D73B74E02D8365B40067BDBC /* ExtraFonts.swift */, @@ -5811,6 +5816,7 @@ 4CDA128C29EB19C40006FA5A /* LocalNotification.swift in Sources */, 4C3BEFD6281D995700B3DE84 /* ActionBarModel.swift in Sources */, 4C7D09762A0AF19E00943473 /* FillAndStroke.swift in Sources */, + D7E5B2D42EA0188200CF47AC /* StreamPipelineDiagnostics.swift in Sources */, 4CA927612A290E340098A105 /* EventShell.swift in Sources */, D74EC8502E1856B70091DC51 /* NonCopyableLinkedList.swift in Sources */, 4C363AA428296DEE006E126D /* SearchModel.swift in Sources */, @@ -6421,6 +6427,7 @@ 82D6FC0E2CD99F7900C925F4 /* ProfilePicView.swift in Sources */, 82D6FC0F2CD99F7900C925F4 /* ProfileView.swift in Sources */, 82D6FC102CD99F7900C925F4 /* ProfileNameView.swift in Sources */, + D7E5B2D52EA0188200CF47AC /* StreamPipelineDiagnostics.swift in Sources */, 5CB017212D2D985E00A9ED05 /* CoinosButton.swift in Sources */, 82D6FC112CD99F7900C925F4 /* MaybeAnonPfpView.swift in Sources */, 82D6FC122CD99F7900C925F4 /* EventProfileName.swift in Sources */, @@ -7026,6 +7033,7 @@ D703D7752C670BBF00A400EA /* Constants.swift in Sources */, D73E5E172C6A962A007EB227 /* ImageUploadModel.swift in Sources */, D703D76A2C670B2C00A400EA /* Bech32Object.swift in Sources */, + D7E5B2D32EA0188200CF47AC /* StreamPipelineDiagnostics.swift in Sources */, D73E5E162C6A9619007EB227 /* PostView.swift in Sources */, D703D7872C670C7E00A400EA /* DamusPurpleEnvironment.swift in Sources */, D703D7892C670C8600A400EA /* DeepLPlan.swift in Sources */, diff --git a/damus/Core/Networking/NostrNetworkManager/ProfilesManager.swift b/damus/Core/Networking/NostrNetworkManager/ProfilesManager.swift index 43535efb..573c0245 100644 --- a/damus/Core/Networking/NostrNetworkManager/ProfilesManager.swift +++ b/damus/Core/Networking/NostrNetworkManager/ProfilesManager.swift @@ -81,7 +81,7 @@ extension NostrNetworkManager { guard pubkeys.count > 0 else { return } let profileFilter = NostrFilter(kinds: [.metadata], authors: pubkeys) try Task.checkCancellation() - for await ndbLender in self.subscriptionManager.streamIndefinitely(filters: [profileFilter], streamMode: .ndbFirst) { + for await ndbLender in self.subscriptionManager.streamIndefinitely(filters: [profileFilter], streamMode: .ndbFirst(optimizeNetworkFilter: true)) { try Task.checkCancellation() try? ndbLender.borrow { ev in publishProfileUpdates(metadataEvent: ev) diff --git a/damus/Core/Networking/NostrNetworkManager/SubscriptionManager.swift b/damus/Core/Networking/NostrNetworkManager/SubscriptionManager.swift index d08de04f..70c190a3 100644 --- a/damus/Core/Networking/NostrNetworkManager/SubscriptionManager.swift +++ b/damus/Core/Networking/NostrNetworkManager/SubscriptionManager.swift @@ -133,42 +133,75 @@ extension NostrNetworkManager { if canIssueEOSE { Self.logger.debug("Session subscription \(id.uuidString, privacy: .public): Issued EOSE for session. Elapsed: \(CFAbsoluteTimeGetCurrent() - startTime, format: .fixed(precision: 2), privacy: .public) seconds") + logStreamPipelineStats("SubscriptionManager_Advanced_Stream_\(id)", "Consumer_\(id)") continuation.yield(.eose) } } - let streamTask = Task { - while !Task.isCancelled { - for await item in self.multiSessionNetworkStream(filters: filters, to: desiredRelays, streamMode: streamMode, id: id) { - try Task.checkCancellation() - switch item { - case .event(let lender): - continuation.yield(item) - case .eose: - break // Should not happen - case .ndbEose: - break // Should not happen - case .networkEose: - continuation.yield(item) - networkEOSEIssued = true - yieldEOSEIfReady() + var networkStreamTask: Task? = nil + var latestNoteTimestampSeen: UInt32? = nil + + let startNetworkStreamTask = { + networkStreamTask = Task { + while !Task.isCancelled { + let optimizedFilters = filters.map { + var optimizedFilter = $0 + optimizedFilter.since = latestNoteTimestampSeen + return optimizedFilter + } + for await item in self.multiSessionNetworkStream(filters: optimizedFilters, to: desiredRelays, streamMode: streamMode, id: id) { + try Task.checkCancellation() + logStreamPipelineStats("SubscriptionManager_Network_Stream_\(id)", "SubscriptionManager_Advanced_Stream_\(id)") + switch item { + case .event(let lender): + logStreamPipelineStats("SubscriptionManager_Advanced_Stream_\(id)", "Consumer_\(id)") + continuation.yield(item) + case .eose: + break // Should not happen + case .ndbEose: + break // Should not happen + case .networkEose: + logStreamPipelineStats("SubscriptionManager_Advanced_Stream_\(id)", "Consumer_\(id)") + continuation.yield(item) + networkEOSEIssued = true + yieldEOSEIfReady() + } } } } } + if streamMode.optimizeNetworkFilter == false { + // Start streaming from the network straight away + startNetworkStreamTask() + } + let ndbStreamTask = Task { while !Task.isCancelled { for await item in self.multiSessionNdbStream(filters: filters, to: desiredRelays, streamMode: streamMode, id: id) { try Task.checkCancellation() + logStreamPipelineStats("SubscriptionManager_Ndb_MultiSession_Stream_\(id)", "SubscriptionManager_Advanced_Stream_\(id)") switch item { case .event(let lender): + logStreamPipelineStats("SubscriptionManager_Advanced_Stream_\(id)", "Consumer_\(id)") + try? lender.borrow({ event in + if let latestTimestamp = latestNoteTimestampSeen { + latestNoteTimestampSeen = max(latestTimestamp, event.createdAt) + } + else { + latestNoteTimestampSeen = event.createdAt + } + }) continuation.yield(item) case .eose: break // Should not happen case .ndbEose: + logStreamPipelineStats("SubscriptionManager_Advanced_Stream_\(id)", "Consumer_\(id)") continuation.yield(item) ndbEOSEIssued = true + if streamMode.optimizeNetworkFilter { + startNetworkStreamTask() + } yieldEOSEIfReady() case .networkEose: break // Should not happen @@ -178,7 +211,7 @@ extension NostrNetworkManager { } continuation.onTermination = { @Sendable _ in - streamTask.cancel() + networkStreamTask?.cancel() ndbStreamTask.cancel() } } @@ -200,9 +233,8 @@ extension NostrNetworkManager { do { for await item in await self.pool.subscribe(filters: filters, to: desiredRelays, id: id) { - // NO-OP. Notes will be automatically ingested by NostrDB - // TODO: Improve efficiency of subscriptions? try Task.checkCancellation() + logStreamPipelineStats("RelayPool_Handler_\(id)", "SubscriptionManager_Network_Stream_\(id)") switch item { case .event(let event): if EXTRA_VERBOSE_LOGGING { @@ -249,6 +281,7 @@ extension NostrNetworkManager { Self.logger.info("\(subscriptionId.uuidString, privacy: .public): Streaming from NDB.") for await item in self.sessionNdbStream(filters: filters, to: desiredRelays, streamMode: streamMode, id: id) { try Task.checkCancellation() + logStreamPipelineStats("SubscriptionManager_Ndb_Session_Stream_\(id?.uuidString ?? "NoID")", "SubscriptionManager_Ndb_MultiSession_Stream_\(id?.uuidString ?? "NoID")") continuation.yield(item) } Self.logger.info("\(subscriptionId.uuidString, privacy: .public): Session subscription ended. Sleeping for 1 second before resuming.") @@ -318,7 +351,7 @@ extension NostrNetworkManager { // MARK: - Utility functions private func defaultStreamMode() -> StreamMode { - self.experimentalLocalRelayModelSupport ? .ndbFirst : .ndbAndNetworkParallel + self.experimentalLocalRelayModelSupport ? .ndbFirst(optimizeNetworkFilter: false) : .ndbAndNetworkParallel(optimizeNetworkFilter: false) } // MARK: - Finding specific data from Nostr @@ -496,8 +529,19 @@ extension NostrNetworkManager { /// The mode of streaming enum StreamMode { /// Returns notes exclusively through NostrDB, treating it as the only channel for information in the pipeline. Generic EOSE is fired when EOSE is received from NostrDB - case ndbFirst + /// `optimizeNetworkFilter`: Returns notes from ndb, then streams from the network with an added "since" filter set to the latest note stored on ndb. + case ndbFirst(optimizeNetworkFilter: Bool) /// Returns notes from both NostrDB and the network, in parallel, treating it with similar importance against the network relays. Generic EOSE is fired when EOSE is received from both the network and NostrDB - case ndbAndNetworkParallel + /// `optimizeNetworkFilter`: Returns notes from ndb, then streams from the network with an added "since" filter set to the latest note stored on ndb. + case ndbAndNetworkParallel(optimizeNetworkFilter: Bool) + + var optimizeNetworkFilter: Bool { + switch self { + case .ndbFirst(optimizeNetworkFilter: let optimizeNetworkFilter): + return optimizeNetworkFilter + case .ndbAndNetworkParallel(optimizeNetworkFilter: let optimizeNetworkFilter): + return optimizeNetworkFilter + } + } } } diff --git a/damus/Core/Nostr/RelayConnection.swift b/damus/Core/Nostr/RelayConnection.swift index 608c9ea7..1581b018 100644 --- a/damus/Core/Nostr/RelayConnection.swift +++ b/damus/Core/Nostr/RelayConnection.swift @@ -35,6 +35,15 @@ enum NostrConnectionEvent { } } } + + var subId: String? { + switch self { + case .ws_connection_event(_): + return nil + case .nostr_event(let event): + return event.subid + } + } } final class RelayConnection: ObservableObject { diff --git a/damus/Core/Nostr/RelayPool.swift b/damus/Core/Nostr/RelayPool.swift index 8cf8e434..be40b3ed 100644 --- a/damus/Core/Nostr/RelayPool.swift +++ b/damus/Core/Nostr/RelayPool.swift @@ -535,6 +535,8 @@ actor RelayPool { } for handler in handlers { + guard handler.sub_id == event.subId else { continue } + logStreamPipelineStats("RelayPool_\(relay_id.absoluteString)", "RelayPool_Handler_\(handler.sub_id)") handler.handler.yield((relay_id, event)) } } diff --git a/damus/Features/Search/Models/SearchHomeModel.swift b/damus/Features/Search/Models/SearchHomeModel.swift index 4d34eba3..a94c42e4 100644 --- a/damus/Features/Search/Models/SearchHomeModel.swift +++ b/damus/Features/Search/Models/SearchHomeModel.swift @@ -19,7 +19,7 @@ class SearchHomeModel: ObservableObject { let base_subid = UUID().description let follow_pack_subid = UUID().description let profiles_subid = UUID().description - let limit: UInt32 = 500 + let limit: UInt32 = 200 //let multiple_events_per_pubkey: Bool = false init(damus_state: DamusState) { diff --git a/damus/Features/Timeline/Models/HomeModel.swift b/damus/Features/Timeline/Models/HomeModel.swift index 0280ab57..65caa003 100644 --- a/damus/Features/Timeline/Models/HomeModel.swift +++ b/damus/Features/Timeline/Models/HomeModel.swift @@ -524,7 +524,7 @@ class HomeModel: ContactsDelegate, ObservableObject { } self.generalHandlerTask?.cancel() self.generalHandlerTask = Task { - for await item in damus_state.nostrNetwork.reader.advancedStream(filters: dms_filters + contacts_filters) { + for await item in damus_state.nostrNetwork.reader.advancedStream(filters: dms_filters + contacts_filters, streamMode: .ndbAndNetworkParallel(optimizeNetworkFilter: true)) { switch item { case .event(let lender): await lender.justUseACopy({ await process_event(ev: $0, context: .other) }) @@ -602,7 +602,7 @@ class HomeModel: ContactsDelegate, ObservableObject { DispatchQueue.main.async { self.loading = true } - for await item in damus_state.nostrNetwork.reader.advancedStream(filters: home_filters, id: id) { + for await item in damus_state.nostrNetwork.reader.advancedStream(filters: home_filters, streamMode: .ndbAndNetworkParallel(optimizeNetworkFilter: true), id: id) { switch item { case .event(let lender): let currentTime = CFAbsoluteTimeGetCurrent() diff --git a/damus/Shared/Utilities/StreamPipelineDiagnostics.swift b/damus/Shared/Utilities/StreamPipelineDiagnostics.swift new file mode 100644 index 00000000..7e64651f --- /dev/null +++ b/damus/Shared/Utilities/StreamPipelineDiagnostics.swift @@ -0,0 +1,27 @@ +// +// StreamPipelineDiagnostics.swift +// damus +// +// Created by Daniel D’Aquino on 2025-10-15. +// +import Foundation + +let ENABLE_PIPELINE_DIAGNOSTICS = false + +fileprivate func getTimestamp() -> String { + let d = Date() + let df = DateFormatter() + df.dateFormat = "y-MM-dd H:mm:ss.SSSS" + + return df.string(from: d) +} + +/// Logs stream pipeline data in CSV format that can later be used for plotting and analysis +/// See `devtools/visualize_stream_pipeline.py` +/// +/// Implementation note: This function is inlined for performance purposes. +@inline(__always) func logStreamPipelineStats(_ sourceNode: String, _ destinationNode: String) { + if ENABLE_PIPELINE_DIAGNOSTICS { + print("STREAM_PIPELINE: \(getTimestamp()),\(sourceNode),\(destinationNode)") + } +} diff --git a/devtools/visualize_stream_pipeline.py b/devtools/visualize_stream_pipeline.py new file mode 100644 index 00000000..d260d752 --- /dev/null +++ b/devtools/visualize_stream_pipeline.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 +"""Generate interactive Sankey diagram from network CSV data using Plotly.""" + +from __future__ import annotations + +import argparse +import csv +from datetime import datetime +from pathlib import Path +from collections import defaultdict +from typing import Dict, List, Tuple, Optional + +import plotly.graph_objects as go +import plotly.express as px + + +def parse_timestamp(timestamp_str: str) -> float: + """Parse timestamp string and return as milliseconds since epoch.""" + # Strip whitespace + timestamp_str = timestamp_str.strip() + + # Remove any prefix (e.g., "STREAM_PIPELINE: ") + if ": " in timestamp_str: + timestamp_str = timestamp_str.split(": ", 1)[1] + + try: + # Try parsing as ISO format with milliseconds + dt = datetime.fromisoformat(timestamp_str) + return dt.timestamp() * 1000 + except ValueError: + try: + # Try replacing space with 'T' for ISO format (e.g., "2025-10-13 15:36:46.3650") + if " " in timestamp_str and "-" in timestamp_str: + timestamp_str = timestamp_str.replace(" ", "T") + dt = datetime.fromisoformat(timestamp_str) + return dt.timestamp() * 1000 + raise ValueError() + except ValueError: + try: + # Try parsing as float (milliseconds) + return float(timestamp_str) + except ValueError: + raise ValueError(f"Could not parse timestamp: {timestamp_str}") + + +def load_network_data(csv_file: str, start_time: Optional[str] = None, + end_time: Optional[str] = None) -> Dict[Tuple[str, str], int]: + """ + Load network data from CSV and aggregate edge counts. + + Args: + csv_file: Path to CSV file + start_time: Optional start time filter (ISO format) + end_time: Optional end time filter (ISO format) + + Returns: + Dictionary mapping (source, destination) tuples to counts + """ + edge_counts = defaultdict(int) + timestamps = [] + + # Parse time filters if provided + start_ts = parse_timestamp(start_time) if start_time else None + end_ts = parse_timestamp(end_time) if end_time else None + + with open(csv_file, 'r') as f: + reader = csv.reader(f) + + # Skip header if present + first_row = next(reader, None) + if first_row is None: + print("Empty CSV file") + return edge_counts + + # Check if first row is a header + try: + parse_timestamp(first_row[0]) + rows = [first_row] # First row is data + except (ValueError, IndexError): + rows = [] # First row is header, skip it + + # Add remaining rows + rows.extend(reader) + + for row_idx, row in enumerate(rows): + if len(row) < 3: + print(f"Skipping invalid row {row_idx + 1}: {row}") + continue + + try: + timestamp_str = row[0] + source = row[1].strip() + destination = row[2].strip() + + # Parse timestamp + timestamp_ms = parse_timestamp(timestamp_str) + + # Apply time filters + if start_ts and timestamp_ms < start_ts: + continue + if end_ts and timestamp_ms > end_ts: + continue + + timestamps.append(timestamp_ms) + edge_counts[(source, destination)] += 1 + + except (ValueError, IndexError) as e: + print(f"Error processing row {row_idx + 1}: {e}") + continue + + if timestamps: + start_dt = datetime.fromtimestamp(min(timestamps) / 1000.0) + end_dt = datetime.fromtimestamp(max(timestamps) / 1000.0) + print(f"\nLoaded {sum(edge_counts.values())} events") + print(f"Time range: {start_dt} to {end_dt}") + print(f"Unique edges: {len(edge_counts)}") + + return edge_counts + + +def filter_top_edges(edge_counts: Dict[Tuple[str, str], int], + top_n: Optional[int] = None) -> Dict[Tuple[str, str], int]: + """Filter to keep only top N most active edges.""" + if top_n is None or top_n <= 0: + return edge_counts + + # Sort by count and take top N + sorted_edges = sorted(edge_counts.items(), key=lambda x: x[1], reverse=True) + return dict(sorted_edges[:top_n]) + + +def filter_top_nodes(edge_counts: Dict[Tuple[str, str], int], + top_n: Optional[int] = None) -> Dict[Tuple[str, str], int]: + """Filter to keep only edges involving top N most active nodes.""" + if top_n is None or top_n <= 0: + return edge_counts + + # Calculate node activity (both as source and destination) + node_activity = defaultdict(int) + for (source, dest), count in edge_counts.items(): + node_activity[source] += count + node_activity[dest] += count + + # Get top N nodes + top_nodes = set(sorted(node_activity.items(), key=lambda x: x[1], reverse=True)[:top_n]) + top_nodes = {node for node, _ in top_nodes} + + # Filter edges to only include top nodes + filtered = {} + for (source, dest), count in edge_counts.items(): + if source in top_nodes and dest in top_nodes: + filtered[(source, dest)] = count + + return filtered + + +def create_sankey_diagram(edge_counts: Dict[Tuple[str, str], int], + title: str = "Network Flow Sankey Diagram", + color_scheme: str = "Viridis", + show_values: bool = True) -> go.Figure: + """ + Create an interactive Sankey diagram from edge counts. + + Args: + edge_counts: Dictionary mapping (source, destination) to flow count + title: Title for the diagram + color_scheme: Plotly color scheme name + show_values: Whether to show flow values on hover + + Returns: + Plotly Figure object + """ + if not edge_counts: + print("No data to visualize") + return go.Figure() + + # Create node list (unique sources and destinations) + all_nodes = set() + for source, dest in edge_counts.keys(): + all_nodes.add(source) + all_nodes.add(dest) + + # Create node index mapping + node_list = sorted(all_nodes) + node_to_idx = {node: idx for idx, node in enumerate(node_list)} + + # Prepare Sankey data + sources = [] + targets = [] + values = [] + link_colors = [] + + for (source, dest), count in edge_counts.items(): + sources.append(node_to_idx[source]) + targets.append(node_to_idx[dest]) + values.append(count) + + # Calculate node colors based on total flow + node_flow = defaultdict(int) + for (source, dest), count in edge_counts.items(): + node_flow[source] += count + node_flow[dest] += count + + # Get color scale + max_flow = max(node_flow.values()) if node_flow else 1 + colors = px.colors.sample_colorscale( + color_scheme, + [node_flow.get(node, 0) / max_flow for node in node_list] + ) + + # Create link colors (semi-transparent version of source node color) + for source_idx in sources: + color = colors[source_idx] + # Convert to rgba with transparency + if color.startswith('rgb'): + link_colors.append(color.replace('rgb', 'rgba').replace(')', ', 0.4)')) + else: + link_colors.append(color) + + # Create hover text for nodes + node_hover = [] + for node in node_list: + total_flow = node_flow.get(node, 0) + # Calculate in/out flows + inflow = sum(count for (s, d), count in edge_counts.items() if d == node) + outflow = sum(count for (s, d), count in edge_counts.items() if s == node) + hover_text = f"{node}
" + hover_text += f"Total Flow: {total_flow}
" + hover_text += f"Inflow: {inflow}
" + hover_text += f"Outflow: {outflow}" + node_hover.append(hover_text) + + # Create hover text for links + link_hover = [] + for i, ((source, dest), count) in enumerate(edge_counts.items()): + hover_text = f"{source} → {dest}
" + hover_text += f"Flow: {count} events
" + if sum(values) > 0: + percentage = (count / sum(values)) * 100 + hover_text += f"Percentage: {percentage:.1f}%" + link_hover.append(hover_text) + + # Create the Sankey diagram + fig = go.Figure(data=[go.Sankey( + node=dict( + pad=15, + thickness=20, + line=dict(color="black", width=0.5), + label=node_list, + color=colors, + customdata=node_hover, + hovertemplate='%{customdata}' + ), + link=dict( + source=sources, + target=targets, + value=values, + color=link_colors, + customdata=link_hover, + hovertemplate='%{customdata}' + ) + )]) + + # Update layout + fig.update_layout( + title=dict( + text=title, + font=dict(size=20, color='#333') + ), + font=dict(size=12), + plot_bgcolor='white', + paper_bgcolor='white', + height=800, + margin=dict(l=20, r=20, t=80, b=20) + ) + + return fig + + +def print_summary_statistics(edge_counts: Dict[Tuple[str, str], int]) -> None: + """Print summary statistics about the network flows.""" + if not edge_counts: + print("No data to summarize") + return + + print("\n" + "="*70) + print("SANKEY DIAGRAM SUMMARY") + print("="*70) + + # Calculate statistics + total_events = sum(edge_counts.values()) + unique_edges = len(edge_counts) + + all_sources = {source for source, _ in edge_counts.keys()} + all_destinations = {dest for _, dest in edge_counts.keys()} + all_nodes = all_sources | all_destinations + + print(f"\nTotal Events: {total_events}") + print(f"Unique Edges: {unique_edges}") + print(f"Unique Nodes: {len(all_nodes)}") + print(f" - Source nodes: {len(all_sources)}") + print(f" - Destination nodes: {len(all_destinations)}") + + # Node activity + node_activity = defaultdict(lambda: {'in': 0, 'out': 0, 'total': 0}) + for (source, dest), count in edge_counts.items(): + node_activity[source]['out'] += count + node_activity[source]['total'] += count + node_activity[dest]['in'] += count + node_activity[dest]['total'] += count + + print(f"\nTop 10 Most Active Edges:") + sorted_edges = sorted(edge_counts.items(), key=lambda x: x[1], reverse=True) + for i, ((source, dest), count) in enumerate(sorted_edges[:10], 1): + pct = (count / total_events) * 100 + print(f" {i:2d}. {source:<25s} → {dest:<25s} {count:>6d} ({pct:>5.1f}%)") + + print(f"\nTop 10 Most Active Nodes (by total flow):") + sorted_nodes = sorted(node_activity.items(), key=lambda x: x[1]['total'], reverse=True) + for i, (node, flows) in enumerate(sorted_nodes[:10], 1): + print(f" {i:2d}. {node:<30s} Total: {flows['total']:>6d} " + f"(In: {flows['in']:>5d}, Out: {flows['out']:>5d})") + + print("\n" + "="*70 + "\n") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate interactive Sankey diagram from network CSV data.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Generate basic Sankey diagram + %(prog)s data.csv + + # Filter to top 20 edges and save to HTML + %(prog)s data.csv --top-edges 20 --output sankey.html + + # Filter to top 15 nodes with custom title + %(prog)s data.csv --top-nodes 15 --title "My Network Flows" + + # Filter by time range + %(prog)s data.csv --start-time "2025-01-13 10:00:00" --end-time "2025-01-13 12:00:00" + + # Combine filters + %(prog)s data.csv --top-nodes 10 --color-scheme Plasma --output flows.html + """ + ) + + parser.add_argument( + "csv_file", + type=str, + help="Path to CSV file with format: timestamp, source_node, destination_node" + ) + + parser.add_argument( + "--output", + type=str, + default=None, + help="Output HTML file path (if not specified, opens in browser)" + ) + + parser.add_argument( + "--top-edges", + type=int, + default=None, + help="Show only top N most active edges (default: all)" + ) + + parser.add_argument( + "--top-nodes", + type=int, + default=None, + help="Show only edges involving top N most active nodes (default: all)" + ) + + parser.add_argument( + "--start-time", + type=str, + default=None, + help="Start time filter (ISO format, e.g., '2025-01-13 10:00:00')" + ) + + parser.add_argument( + "--end-time", + type=str, + default=None, + help="End time filter (ISO format, e.g., '2025-01-13 12:00:00')" + ) + + parser.add_argument( + "--title", + type=str, + default="Network Flow Sankey Diagram", + help="Title for the diagram (default: 'Network Flow Sankey Diagram')" + ) + + parser.add_argument( + "--color-scheme", + type=str, + default="Viridis", + choices=["Viridis", "Plasma", "Inferno", "Magma", "Cividis", "Turbo", + "Blues", "Greens", "Reds", "Purples", "Rainbow"], + help="Color scheme for nodes (default: Viridis)" + ) + + parser.add_argument( + "--no-summary", + action="store_true", + help="Skip printing summary statistics" + ) + + parser.add_argument( + "--auto-open", + action="store_true", + help="Automatically open in browser (default: True if no output file specified)" + ) + + args = parser.parse_args() + + # Check if CSV file exists + csv_path = Path(args.csv_file) + if not csv_path.exists(): + print(f"Error: CSV file not found: {args.csv_file}") + return + + # Load data + print(f"Loading data from {args.csv_file}...") + edge_counts = load_network_data(args.csv_file, args.start_time, args.end_time) + + if not edge_counts: + print("No data to visualize!") + return + + # Apply filters + if args.top_edges: + print(f"Filtering to top {args.top_edges} edges...") + edge_counts = filter_top_edges(edge_counts, args.top_edges) + + if args.top_nodes: + print(f"Filtering to edges involving top {args.top_nodes} nodes...") + edge_counts = filter_top_nodes(edge_counts, args.top_nodes) + + # Print summary statistics + if not args.no_summary: + print_summary_statistics(edge_counts) + + # Create Sankey diagram + print("Generating Sankey diagram...") + fig = create_sankey_diagram( + edge_counts, + title=args.title, + color_scheme=args.color_scheme + ) + + # Save or show + if args.output: + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.write_html(str(output_path)) + print(f"\nSaved Sankey diagram to: {output_path}") + print(f"Open the file in a web browser to view the interactive diagram.") + + if args.auto_open: + import webbrowser + webbrowser.open(f"file://{output_path.absolute()}") + else: + print("\nOpening Sankey diagram in browser...") + fig.show() + + print("\nDone!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/shell.nix b/shell.nix index a5360642..ae2425cd 100644 --- a/shell.nix +++ b/shell.nix @@ -1,5 +1,5 @@ { pkgs ? import {} }: with pkgs; mkShell { - buildInputs = with python3Packages; [ Mako requests wabt todo-txt-cli pyyaml ]; + buildInputs = with python3Packages; [ Mako requests wabt todo-txt-cli pyyaml plotly numpy ]; }