Source code for sankey

import sys

import plotly.graph_objects as go

from graphpack.utils import *


MARGIN = 50  # Margin for the plot
MAX_MARGIN = 500  # Maximum margin for the plot (increase if last layer nodes' labels are cut off)


[docs] def lighten_color(color, amount=0.5): """ Lighten the given color by the specified amount. Args: color (str): A color name or RGBA string. amount (float, optional): The amount to lighten the color (0.0 to 1.0). Returns: str: The lightened color as an RGBA string. """ # Check if the color is already in RGBA format rgba_pattern = re.compile(r'rgba\((\d+),\s*(\d+),\s*(\d+),\s*(\d*\.?\d+)\)') match = rgba_pattern.match(color) if match: # Extract RGBA components from the string r, g, b, a = map(float, match.groups()) r, g, b = int(r), int(g), int(b) # Apply the amount to the alpha component return f'rgba({r},{g},{b},{amount})' else: # Convert named color to RGBA c = mcolors.to_rgba(color) return f'rgba({c[0] * 255:.0f},{c[1] * 255:.0f},{c[2] * 255:.0f},{amount})'
[docs] def load_data(input_path, graph, method, parameter, parameters): """ Load compression mappings and group labels for each parameter. Args: input_path (str): Path to the input files. graph (str): Input graph identifier. method (str): Clustering method used. parameter (str): Clustering parameter name. parameters (list): List of parameters to be analyzed. Returns: tuple: Two dictionaries, one for compression mappings and one for group labels, and the (eventually updated) list of parameter's values. """ # Initialize dictionaries to store compression mappings and groups for each parameter compression_mappings = {} groups = {} # Create a copy of parameters to iterate over parameters_copy = parameters[:] # Load compression mappings and group labels for each parameter for param in parameters_copy: try: with open(f'{input_path}/{graph}/{method}_{parameter}_{param}/compression_mapping.json', 'r') as f: compression_mappings[param] = json.load(f) labels_file = f'{input_path}/{graph}/{method}_{parameter}_{param}/labels_mapping.json' if os.path.exists(labels_file): with open(labels_file, 'r') as f: groups[param] = json.load(f) else: # If labels_mapping.json doesn't exist, use names from compression_mapping.json as group labels compression_data = compression_mappings[param] groups[param] = {key: key for key in compression_data} except FileNotFoundError: print(f"Files for parameter {param} not found. Skipping this value.") parameters.remove(param) return compression_mappings, groups, parameters
[docs] def map_transitions(compression_mappings, groups, parameters, min_size): """ Map transitions between consecutive parameters. Args: compression_mappings (dict): Compression mappings for each parameter. groups (dict): Group labels for each parameter. parameters (list): List of parameters to be analyzed. min_size (int): Minimum cluster size to be considered significant. Returns: list: List of transitions between clusters. """ # Initialize a list to store the gene transitions between resolutions transitions = [] small_cluster_label = "small clusters" # Iterate through each pair of consecutive resolutions to map transitions for i in range(len(parameters) - 1): param_from = parameters[i] param_to = parameters[i + 1] if param_from in compression_mappings and param_to in compression_mappings: mapping_from = compression_mappings[param_from] mapping_to = compression_mappings[param_to] group_from = groups[param_from] group_to = groups[param_to] # Create a reverse mapping from gene to cluster for the target resolution reverse_mapping_to = {gene: cluster for cluster, genes in mapping_to.items() for gene in genes} # Map transitions from the current resolution to the next for cluster_from, genes_from in mapping_from.items(): source_label = f'{param_from} - {group_from.get(str(cluster_from), small_cluster_label)}' if len(genes_from) > min_size: # Main cluster handling for gene in genes_from: if gene in reverse_mapping_to: cluster_to = reverse_mapping_to[gene] target_label = f'{param_to} - {group_to.get(str(cluster_to), small_cluster_label)}' if len(mapping_to[cluster_to]) > min_size: transitions.append([source_label, target_label, gene]) else: transitions.append([source_label, f'{param_to} - {small_cluster_label}', gene]) else: # Handle small clusters at the source resolution for gene in genes_from: if gene in reverse_mapping_to: cluster_to = reverse_mapping_to[gene] target_label = f'{param_to} - {group_to.get(str(cluster_to), small_cluster_label)}' if len(mapping_to[cluster_to]) > min_size: transitions.append([f'{param_from} - {small_cluster_label}', target_label, gene]) else: transitions.append( [f'{param_from} - {small_cluster_label}', f'{param_to} - {small_cluster_label}', gene]) else: print(f"Skipping transition from resolution {param_from} to {param_to}.") return transitions
[docs] def create_sankey_plot(transitions, min_size, method, input_graph, output_folder): """ Create and save the Sankey plot. Args: transitions (list): List of transitions between clusters. min_size (int): Minimum cluster size to be considered significant. parameter (str): Clustering parameter name. method (str): Clustering method used. input_graph (str): Knowledge graph identifier. output_folder (str): Path to the output folder. Returns: None """ # Create a DataFrame from the transitions list transitions_df = pd.DataFrame(transitions, columns=['Source', 'Target', 'Gene']) # Extract unique stages from Source and Target columns stages = sorted(list(set(transitions_df['Source'].str.split(' - ').str[0]).union( set(transitions_df['Target'].str.split(' - ').str[0])))) # Initialize a dictionary to store nodes grouped by stage stage_nodes = {stage: [] for stage in stages} # Populate stage_nodes dictionary with nodes, sorted alphabetically within each stage for stage in stage_nodes: stage_nodes[stage].extend(transitions_df[transitions_df['Source'].str.contains(stage)]['Source'].unique()) stage_nodes[stage].extend(transitions_df[transitions_df['Target'].str.contains(stage)]['Target'].unique()) stage_nodes[stage] = sorted(list(set(stage_nodes[stage]))) # Flatten stage_nodes dictionary into a sorted list of nodes nodes = [node for stage in stages for node in stage_nodes[stage]] # Create a dictionary to map node labels to their indices node_indices = {node: idx for idx, node in enumerate(nodes)} # Initialize lists to store links with updated indices links = [] # Iterate through each row in transitions_df to create links with updated indices for index, row in transitions_df.iterrows(): source_index = node_indices[row['Source']] target_index = node_indices[row['Target']] links.append({ 'source': source_index, 'target': target_index, 'value': 1 }) # Prepare custom data for each node with limited gene list node_customdata = [] gene_display_limit = 10 # Limit the number of genes displayed in hover info for node in nodes: node_stage = node.split(' - ')[0] group = node.split(' - ')[1] genes = transitions_df[(transitions_df['Source'] == node) | (transitions_df['Target'] == node)]['Gene'].unique() gene_list = ', '.join(genes[:gene_display_limit]) # Limit the number of genes displayed if len(genes) > gene_display_limit: gene_list += f', ... (+{len(genes) - gene_display_limit} more)' custom_data = f"Parameter: {node_stage}<br>Community: {group}<br>Genes: {gene_list}" node_customdata.append(custom_data) # Extract unique community names communities = sorted(list(set(node.split(' - ')[1] for node in nodes))) # Generate a larger color palette using matplotlib colors = [mcolors.CSS4_COLORS[name] for name in COLORS] color_map = {community: colors[i % len(colors)] for i, community in enumerate(communities)} # Assign colors to nodes based on their community node_colors = [] for node in nodes: community = node.split(' - ')[1] if node.endswith('small clusters'): node_colors.append('rgba(0, 0, 0, 0)') # Transparent color for small clusters else: node_colors.append(color_map[community]) # Create a list to store line colors based on whether the target node is a small cluster line_colors = [] for link in links: target_node = nodes[link['target']] source_color = node_colors[link['source']] target_color = node_colors[link['target']] if target_node.endswith('small clusters'): line_colors.append(mcolors.CSS4_COLORS['whitesmoke']) # Lighter gray for links going into small clusters else: # Lighter version of the source node color, if big communities, else lightgray for other links line_colors.append(lighten_color(source_color, 0.2)) if min_size >= 100 else line_colors.append( mcolors.CSS4_COLORS['lightgray']) # Create Sankey plot with custom hover information, node colors, and link colors fig = go.Figure(data=[go.Sankey( node=dict( pad=15, # Padding between nodes thickness=20, # Thickness of the links line=dict(color=mcolors.CSS4_COLORS['lightgray'], width=0.5), label=nodes, # Node labels color=node_colors, # Assign node colors customdata=node_customdata, # Use the prepared custom data hovertemplate='%{customdata}<extra></extra>', # Use the custom data for hover text hoverinfo='all', # Enable hover information for nodes ), link=dict( source=[link['source'] for link in links], # Indices of source nodes target=[link['target'] for link in links], # Indices of target nodes value=[link['value'] for link in links], # Link values hoverinfo='none', # Disable hover interaction for links line=dict(width=0.0005), color=line_colors, # Specify link colors based on conditions ) )]) # Update the layout of the plot fig.update_layout( title_text=f"Genes' community membership for {input_graph} - {method}", margin=dict(t=MARGIN, l=MARGIN, r=MAX_MARGIN, b=MARGIN), # Adjusted margins font_size=10, # Adjusted font size for better visibility width=2000, # Increased width for better visibility height=2000, # Increased height for better visibility plot_bgcolor='white', paper_bgcolor='white', font_color='black' ) # JavaScript for aligning node labels to the right js = ''' const TEXTPAD = 3; // constant used by Plotly.js function sankeyNodeLabelsAlign(position, forcePos) { const textAnchor = {left: 'end', right: 'start', center: 'middle'}[position]; const nodes = gd.getElementsByClassName('sankey-node'); for (const node of nodes) { const d = node.__data__; const label = node.getElementsByClassName('node-label').item(0); // Ensure to reset any previous modifications label.setAttribute('x', 0); if (!d.horizontal) continue; // This is how Plotly's default text positioning is computed (coordinates // are relative to that of the corresponding node). const padX = d.nodeLineWidth / 2 + TEXTPAD; const posX = padX + d.visibleWidth; let x; switch (position) { case 'left': if (d.left || (d.node.originalLayer === 0 && !forcePos)) continue; x = -posX - padX; break; case 'right': if (!d.left || !forcePos) continue; x = posX + padX; break; case 'center': if (!forcePos && (d.left || d.node.originalLayer === 0)) continue; x = (d.nodeLineWidth + d.visibleWidth) / 2 + (d.left ? padX : -posX); break; } // Ensure last layer nodes' labels are inside the plot area if (d.node.originalLayer === d.layerLength - 1) { x = Math.min(x, gd.layout.width - label.getBBox().width - padX); } label.setAttribute('x', x); label.setAttribute('text-anchor', textAnchor); } } const gd = document.getElementById('{plot_id}'); const position = 'right'; // Set position to 'right', 'left', or 'center' const forcePos = true; gd.on('plotly_afterplot', sankeyNodeLabelsAlign.bind(gd, position, forcePos)); gd.emit('plotly_afterplot'); ''' # Create output folder if it does not exist if not os.path.exists(output_folder): os.makedirs(output_folder) fig.write_html(f"{output_folder}/{input_graph}_sankey_{method}_s_{min_size}.html", post_script=js)
[docs] def produce_sankey(graph, input_path='data/output', output_folder='sankey', min_size=100, method='louvain', parameter='r', values=[1.25, 3.0, 5.0, 10.0, 20.0, 30.0]): """ Main function for the GraphPack tool Sankey plot script. This script generates a Sankey plot to visualize gene community membership transitions across different clustering resolutions for a given network. Args: graph (str): Input graph identifier. Required parameter. input_path (str, optional): Path to the input files. Default is "data/output". output_folder (str, optional): Path to the output folder. Default is "sankey". min_size (int, optional): Minimum cluster size to be considered significant. Default is 100. method (str, optional): Clustering method used. Default is "louvain". parameter (str, optional): Clustering parameter name, as it appears in the subfolders' names. Default is "r". values (list of float, optional): List of parameters to be analyzed. Default is [1.25, 3.0, 5.0, 10.0, 20.0, 30.0]. Examples: >>> from graphpack.demo.sankey import * >>> produce_sankey(input_path="./", output_folder="results", min_size=50, graph='simple_graph', method='hclust', parameter='k', values=[10, 50, 100, 250]) """ # Print information about the arguments print(f"\n{'=' * 80}") print(f"{'GraphPack Tool Sankey plot script':^80}") print(f"{'=' * 80}") print(f"\n▶ Input graph: {graph}") print(f"▶ Method: {method}") print(f"▶ Parameters: {parameter} in [ {', '.join(map(str, values))} ]") print(f"\n▶ Min cluster size: {min_size}") print(f"\n▶ Input folder: {input_path}") print(f"▶ Output folder: {output_folder}") print("\n📑 Loading data...") compression_mappings, groups, values = load_data(input_path, graph, method, parameter, values) print("⚙️ Mapping transitions...") transitions = map_transitions(compression_mappings, groups, values, min_size) print("📊 Creating Sankey plot...") create_sankey_plot(transitions, min_size, method, graph, output_folder) print("\n✅ Done!") print(f"Sankey plot for gene community membership transitions has been saved in '{output_folder}'.")
[docs] def parse_args(): """ Parse command-line arguments for Sankey plot script. Command-line arguments: Args: --graph (str): Input graph identifier. Required argument. --input-path (str): Path to the input files. Default is "data/output". --output-folder (str): Path to the output folder. Default is "sankey". --min-size (int): Minimum cluster size to be considered significant. Default is 100. --method (str): Clustering method used. Default is "louvain". --parameter (str): Clustering parameter name. Default is "r". --parameters (list of float): List of parameters to be analyzed. Default is [1.25, 3.0, 5.0, 10.0, 20.0, 30.0]. Returns: args (argparse.Namespace): Parsed command-line arguments. """ from graphpack import __version__ # Create the custom parser parser = CustomArgumentParser( description="Generate a Sankey plot for gene community membership transitions.", epilog="For more information, please refer to the documentation.", add_help=False ) parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, help='Show this help message and exit.') parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}', help='Show the version of the program.') parser.add_argument('-g', '--graph', type=str, required=True, help='Input graph identifier.') parser.add_argument('-i', '--input-path', type=str, default='data/output', help='Path to ihe input files. Default is "data/output".') parser.add_argument('-o', '--output-folder', type=str, default='sankey', help='Path to the output folder. Default is "sankey".') parser.add_argument('-s', '--min-size', type=int, default=100, help='Minimum cluster size to be considered significant. Default is 100.') parser.add_argument('-m', '--method', type=str, default='louvain', help='Clustering method used. Default is "louvain".') parser.add_argument('-p', '--parameter', type=str, default='r', help='Clustering parameter name, as it appears in the output subfolder. Default is "r".') parser.add_argument('-V', '--values', type=float, nargs='+', default=[1.25, 3.0, 5.0, 10.0, 20.0, 30.0], help='List of parameters to be analyzed. Default is [1.25, 3.0, 5.0, 10.0, 20.0, 30.0].') args = parser.parse_args() if len(sys.argv) - 1 == 0: print(LONG_DESCR) print(f"{ORANGE_BOLD}Warning: no parameters provided. To display the help, run the script with --help{RESET}") sys.exit(0) # Validate input path if not os.path.exists(args.input_path): print(f"{RED_BOLD}Error: Input file '{args.input_path}' does not exist.") sys.exit(1) # Validate output folder path output_path = os.path.abspath(args.output_path) if not os.path.isdir(output_path): print(f"{ORANGE_BOLD}Warning: Output folder '{output_path}' does not exist.") print(f"{BOLD}Creating it now.{RESET}") try: os.makedirs(output_path, exist_ok=True) except Exception as e: print(f"{RED_BOLD}Error: Failed to create output folder '{output_path}'. {str(e)}{RESET}") sys.exit(1) # Validate the method argument if args.method not in METHODS: print(f"{RED_BOLD}Error: Unsupported compression method '{args.method}'.{RESET}") print(f"{BOLD}Supported methods are: {', '.join(METHODS)}{RESET}") sys.exit(1) # Cast to integer the parameter values if the method is not Louvain (k must be integer) if args.method != 'louvain': args.values = [int(param) for param in args.values] return args
[docs] def main(): args = parse_args() produce_sankey(**vars(args))
if __name__ == "__main__": main()