import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from scipy.stats import fisher_exact
from IPython.display import display
from matplotlib.legend import Legend
from matplotlib.transforms import Bbox
from brokenaxes import brokenaxes
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.spatial.distance import cosine
from scipy.stats import pearsonr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_absolute_error
from matplotlib.gridspec import GridSpec
[docs]
def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',point_size=50, figsize=20, threshold=0,save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 20]]):
# Dictionary mapping compartment to color
colors = {'micronemes':'black',
'rhoptries 1':'darkviolet',
'rhoptries 2':'darkviolet',
'nucleus - chromatin':'blue',
'nucleus - non-chromatin':'blue',
'dense granules':'teal',
'ER 1':'pink',
'ER 2':'pink',
'unknown':'black',
'tubulin cytoskeleton':'slategray',
'IMC':'slategray',
'PM - peripheral 1':'slategray',
'PM - peripheral 2':'slategray',
'cytosol':'turquoise',
'mitochondrion - soluble':'red',
'mitochondrion - membranes':'red',
'apicoplast':'slategray',
'Golgi':'green',
'PM - integral':'slategray',
'apical 1':'orange',
'apical 2':'orange',
'19S proteasome':'slategray',
'20S proteasome':'slategray',
'60S ribosome':'slategray',
'40S ribosome':'slategray',
}
# Increase font size for better readability
fontsize = 18
plt.rcParams.update({'font.size': fontsize})
# --- Load data ---
if isinstance(data_path, pd.DataFrame):
data = data_path
else:
data = pd.read_csv(data_path)
# Extract ‘variable’ and ‘gene_nr’ from your feature notation
data['variable'] = data['feature'].str.extract(r'\[(.*?)\]')
data['variable'].fillna(data['feature'], inplace=True)
data['gene_nr'] = data['variable'].str.split('_').str[0]
data = data[data['variable'] != 'Intercept']
# --- Load metadata ---
if isinstance(metadata_path, pd.DataFrame):
metadata = metadata_path
else:
metadata = pd.read_csv(metadata_path)
metadata['gene_nr'] = metadata['gene_nr'].astype(str)
data['gene_nr'] = data['gene_nr'].astype(str)
# Merge data and metadata
merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]],
on='gene_nr', how='left')
merged_data[metadata_column].fillna('unknown', inplace=True)
# --- Create figure with "upper" and "lower" subplots sharing the x-axis ---
fig = plt.figure(figsize=(figsize, figsize))
gs = GridSpec(2, 1, height_ratios=[1, 3], hspace=0.05)
ax_upper = fig.add_subplot(gs[0])
ax_lower = fig.add_subplot(gs[1], sharex=ax_upper)
# Hide x-axis labels on the upper plot
ax_upper.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
# List to collect the variables (hits) that meet threshold criteria
hit_list = []
# --- Scatter plot on both axes ---
for _, row in merged_data.iterrows():
y_val = -np.log10(row['p_value'])
# Decide which axis to draw on based on the p-value
ax = ax_upper if y_val > y_lims[1][0] else ax_lower
# Here is the main change: color by the colors dict
ax.scatter(
row['coefficient'],
y_val,
color=colors.get(row[metadata_column], 'gray'), # <-- Use your color dict
marker='o', # You can fix a single marker if desired
s=point_size,
edgecolor='black',
alpha=0.6
)
# Check significance thresholds
if (row['p_value'] <= 0.05) and (abs(row['coefficient']) >= abs(threshold)):
hit_list.append(row['variable'])
# --- Adjust axis limits ---
ax_upper.set_ylim(y_lims[1])
ax_lower.set_ylim(y_lims[0])
ax_lower.set_xlim(x_lim)
# Hide top spines
ax_lower.spines['top'].set_visible(False)
ax_upper.spines['top'].set_visible(False)
ax_upper.spines['bottom'].set_visible(False)
# Set x-axis and y-axis labels
ax_lower.set_xlabel('Coefficient')
ax_lower.set_ylabel('-log10(p-value)')
ax_upper.set_ylabel('-log10(p-value)')
for ax in [ax_upper, ax_lower]:
ax.spines['right'].set_visible(False)
# --- Add threshold lines to both axes ---
for ax in [ax_upper, ax_lower]:
ax.axvline(x=-abs(threshold), linestyle='--', color='black')
ax.axvline(x=abs(threshold), linestyle='--', color='black')
ax_lower.axhline(y=-np.log10(0.05), linestyle='--', color='black')
# --- Annotate significant points ---
texts_upper, texts_lower = [], []
for _, row in merged_data.iterrows():
y_val = -np.log10(row['p_value'])
if row['p_value'] > 0.05 or abs(row['coefficient']) < abs(threshold):
continue
ax = ax_upper if y_val > y_lims[1][0] else ax_lower
text = ax.text(
row['coefficient'],
y_val,
row['variable'],
fontsize=fontsize,
ha='center',
va='bottom'
)
if ax == ax_upper:
texts_upper.append(text)
else:
texts_lower.append(text)
# Attempt to keep text labels from overlapping
adjust_text(texts_upper, ax=ax_upper, arrowprops=dict(arrowstyle='-', color='black'))
adjust_text(texts_lower, ax=ax_lower, arrowprops=dict(arrowstyle='-', color='black'))
# --- Add a legend keyed by color (optional) ---
# If you'd like a legend that shows what each compartment color represents:
legend_handles = []
for comp, comp_color in colors.items():
# Create a “dummy” scatter for legend
legend_handles.append(
plt.Line2D([0], [0], marker='o', color=comp_color,
label=comp, linewidth=0, markersize=8)
)
# You can adjust the location and styling of the legend to taste:
ax_lower.legend(
handles=legend_handles,
bbox_to_anchor=(1.05, 1),
loc='upper left',
borderaxespad=0.25,
labelspacing=2,
handletextpad=0.25,
markerscale=1.5,
prop={'size': fontsize}
)
# --- Save and show ---
if save_path:
plt.savefig(save_path, format='pdf', bbox_inches='tight')
plt.show()
return hit_list
[docs]
def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']):
"""
Perform GO term enrichment analysis for each GO term column and generate plots.
Parameters:
- significant_df: DataFrame containing the significant genes from the screen.
- metadata_path: Path to the metadata file containing GO terms.
- go_term_columns: List of columns in the metadata corresponding to GO terms.
For each GO term column, this function will:
- Split the GO terms by semicolons.
- Count the occurrences of GO terms in the hits and in the background.
- Perform Fisher's exact test for enrichment.
- Plot the enrichment score vs -log10(p-value).
"""
#significant_df['variable'].fillna(significant_df['feature'], inplace=True)
#split_columns = significant_df['variable'].str.split('_', expand=True)
#significant_df['gene_nr'] = split_columns[0]
#gene_list = significant_df['gene_nr'].to_list()
significant_df = significant_df.dropna(subset=['n_gene'])
significant_df = significant_df[significant_df['n_gene'] != None]
gene_list = significant_df['n_gene'].to_list()
# Load metadata
metadata = pd.read_csv(metadata_path)
split_columns = metadata['Gene ID'].str.split('_', expand=True)
metadata['gene_nr'] = split_columns[1]
# Create a subset of metadata with only the rows that contain genes in gene_list (hits)
hits_metadata = metadata[metadata['gene_nr'].isin(gene_list)]
# Create a list to hold results from all columns
combined_results = []
for go_term_column in go_term_columns:
# Initialize lists to store results
go_terms = []
enrichment_scores = []
p_values = []
# Split the GO terms in the entire metadata and hits
metadata[go_term_column] = metadata[go_term_column].fillna('')
hits_metadata[go_term_column] = hits_metadata[go_term_column].fillna('')
all_go_terms = metadata[go_term_column].str.split(';').explode()
hit_go_terms = hits_metadata[go_term_column].str.split(';').explode()
# Count occurrences of each GO term in hits and total metadata
all_go_term_counts = all_go_terms.value_counts()
hit_go_term_counts = hit_go_terms.value_counts()
# Perform enrichment analysis for each GO term
for go_term in all_go_term_counts.index:
total_with_go_term = all_go_term_counts.get(go_term, 0)
hits_with_go_term = hit_go_term_counts.get(go_term, 0)
# Calculate the total number of genes and hits
total_genes = len(metadata)
total_hits = len(hits_metadata)
# Perform Fisher's exact test
contingency_table = [[hits_with_go_term, total_hits - hits_with_go_term],
[total_with_go_term - hits_with_go_term, total_genes - total_hits - (total_with_go_term - hits_with_go_term)]]
_, p_value = fisher_exact(contingency_table)
# Calculate enrichment score (hits with GO term / total hits with GO term)
if total_with_go_term > 0 and total_hits > 0:
enrichment_score = (hits_with_go_term / total_hits) / (total_with_go_term / total_genes)
else:
enrichment_score = 0.0
# Store the results only if enrichment score is non-zero
if enrichment_score > 0.0:
go_terms.append(go_term)
enrichment_scores.append(enrichment_score)
p_values.append(p_value)
# Create a results DataFrame for this GO term column
results_df = pd.DataFrame({
'GO Term': go_terms,
'Enrichment Score': enrichment_scores,
'P-value': p_values,
'GO Column': go_term_column # Track the GO term column for final combined plot
})
# Sort by enrichment score
results_df = results_df.sort_values(by='Enrichment Score', ascending=False)
# Append this DataFrame to the combined list
combined_results.append(results_df)
# Plot the enrichment results for each individual column
plt.figure(figsize=(10, 6))
# Create a scatter plot of Enrichment Score vs -log10(p-value)
sns.scatterplot(data=results_df, x='Enrichment Score', y=-np.log10(results_df['P-value']), hue='GO Term', size='Enrichment Score', sizes=(50, 200))
# Set plot labels and title
plt.title(f'GO Term Enrichment Analysis for {go_term_column}')
plt.xlabel('Enrichment Score')
plt.ylabel('-log10(P-value)')
# Move the legend to the right of the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
# Show the plot
plt.tight_layout() # Ensure everything fits in the figure area
plt.show()
# Optionally return or save the results for each column
print(f'Results for {go_term_column}')
# Combine results from all columns into a single DataFrame
combined_df = pd.concat(combined_results)
# Plot the combined results with text labels
plt.figure(figsize=(12, 8))
sns.scatterplot(data=combined_df, x='Enrichment Score', y=-np.log10(combined_df['P-value']),
style='GO Column', size='Enrichment Score', sizes=(50, 200))
# Set plot labels and title for the combined graph
plt.title('Combined GO Term Enrichment Analysis')
plt.xlabel('Enrichment Score')
plt.ylabel('-log10(P-value)')
# Annotate the points with labels and connecting lines
texts = []
for i, row in combined_df.iterrows():
texts.append(plt.text(row['Enrichment Score'], -np.log10(row['P-value']), row['GO Term'], fontsize=9))
# Adjust text to avoid overlap
adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))
# Show the combined plot
plt.tight_layout()
plt.show()
[docs]
def plot_gene_phenotypes(data, gene_list, x_column='Gene ID', data_column='T.gondii GT1 CRISPR Phenotype - Mean Phenotype',error_column='T.gondii GT1 CRISPR Phenotype - Standard Error', save_path=None):
"""
Plot a line graph for the mean phenotype with standard error shading and highlighted genes.
Args:
data (pd.DataFrame): The input DataFrame containing gene data.
gene_list (list): A list of gene names to highlight on the plot.
"""
# Ensure x_column is properly processed
def extract_gene_id(gene):
if isinstance(gene, str) and '_' in gene:
return gene.split('_')[1]
return str(gene)
data.loc[:, data_column] = pd.to_numeric(data[data_column], errors='coerce')
data = data.dropna(subset=[data_column])
data.loc[:, error_column] = pd.to_numeric(data[error_column], errors='coerce')
data = data.dropna(subset=[error_column])
data['x'] = data[x_column].apply(extract_gene_id)
# Sort by the data_column and assign ranks
data = data.sort_values(by=data_column).reset_index(drop=True)
data['rank'] = range(1, len(data) + 1)
# Prepare the x, y, and error values for plotting
x = data['rank']
y = data[data_column]
yerr = data[error_column]
# Create the plot
plt.figure(figsize=(10, 10))
# Plot the mean phenotype with standard error shading
plt.plot(x, y, label='Mean Phenotype', color=(0/255, 155/255, 155/255), linewidth=2)
plt.fill_between(
x, y - yerr, y + yerr,
color=(0/255, 155/255, 155/255), alpha=0.1, label='Standard Error'
)
# Prepare for adjustText
texts = [] # Store text objects for adjustment
# Highlight the genes in the gene_list
for gene in gene_list:
gene_id = extract_gene_id(gene)
gene_data = data[data['x'] == gene_id]
if not gene_data.empty:
# Scatter the highlighted points in purple and add labels for adjustment
plt.scatter(
gene_data['rank'],
gene_data[data_column],
color=(155/255, 55/255, 155/255),
s=200,
alpha=0.6,
label=f'Highlighted Gene: {gene}',
zorder=3 # Ensure the points are on top
)
# Add the text label next to the highlighted gene
texts.append(
plt.text(
gene_data['rank'].values[0],
gene_data[data_column].values[0],
gene,
fontsize=18,
ha='right'
)
)
# Adjust text to avoid overlap with lines drawn from points to text
adjust_text(texts, arrowprops=dict(arrowstyle='-', color='gray'))
# Label the plot
plt.xlabel('Rank')
plt.ylabel('Mean Phenotype')
#plt.xticks(rotation=90) # Rotate x-axis labels for readability
plt.legend().remove() # Remove the legend if not needed
plt.tight_layout()
# Save the plot if a path is provided
if save_path:
plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
print(f"Figure saved to {save_path}")
plt.show()
[docs]
def plot_gene_heatmaps(data, gene_list, columns, x_column='Gene ID', normalize=False, save_path=None):
"""
Generate a teal-to-white heatmap with the specified columns and genes.
Args:
data (pd.DataFrame): The input DataFrame containing gene data.
gene_list (list): A list of genes to include in the heatmap.
columns (list): A list of column names to visualize as heatmaps.
normalize (bool): If True, normalize the values for each gene between 0 and 1.
save_path (str): Optional. If provided, the plot will be saved to this path.
"""
# Ensure x_column is properly processed
def extract_gene_id(gene):
if isinstance(gene, str) and '_' in gene:
return gene.split('_')[1]
return str(gene)
data['x'] = data[x_column].apply(extract_gene_id)
# Filter the data to only include the specified genes
filtered_data = data[data['x'].isin(gene_list)].set_index('x')[columns]
# Normalize each gene's values between 0 and 1 if normalize=True
if normalize:
filtered_data = filtered_data.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=1)
# Define the figure size dynamically based on the number of genes and columns
width = len(columns) * 4
height = len(gene_list) * 1
# Create the heatmap
plt.figure(figsize=(width, height))
cmap = sns.color_palette("viridis", as_cmap=True)
# Plot the heatmap with genes on the y-axis and columns on the x-axis
sns.heatmap(
filtered_data,
cmap=cmap,
cbar=True,
annot=False,
linewidths=0.5,
square=True
)
# Set the labels
plt.xticks(rotation=90, ha='center') # Rotate x-axis labels for better readability
plt.yticks(rotation=0) # Keep y-axis labels horizontal
plt.xlabel('')
plt.ylabel('')
# Adjust layout to ensure the plot fits well
plt.tight_layout()
# Save the plot if a path is provided
if save_path:
plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
print(f"Figure saved to {save_path}")
plt.show()
[docs]
def generate_score_heatmap(settings):
def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
df = pd.read_csv(csv)
if 'column_name' in df.columns:
df = df[df['column_name']==column]
elif 'column' in df.columns:
df['columnID'] = df['column']
df = df[df['column_name']==column]
if not plate is None:
df['plateID'] = f"plate{plate}"
grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])[data_column].mean().reset_index()
grouped_df['prc'] = grouped_df['plateID'].astype(str) + '_' + grouped_df['rowID'].astype(str) + '_' + grouped_df['column_name'].astype(str)
return grouped_df
def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
df = pd.read_csv(csv)
df = df[df['column_name']==column]
if plate not in df.columns:
df['plateID'] = f"plate{plate}"
df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])['count'].sum().reset_index()
grouped_df = grouped_df.rename(columns={'count': 'total_count'})
merged_df = pd.merge(df, grouped_df, on=['plateID', 'rowID', 'column_name'])
merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
merged_df['prc'] = merged_df['plateID'].astype(str) + '_' + merged_df['rowID'].astype(str) + '_' + merged_df['column_name'].astype(str)
return merged_df
def plot_multi_channel_heatmap(df, column='c3'):
"""
Plot a heatmap with multiple channels as columns.
Parameters:
- df: DataFrame with scores for different channels.
- column: Column to filter by (default is 'c3').
"""
# Extract row number and convert to integer for sorting
df['row_num'] = df['rowID'].str.extract(r'(\d+)').astype(int)
# Filter and sort by plate, row, and column
df = df[df['column_name'] == column]
df = df.sort_values(by=['plateID', 'row_num', 'column_name'])
# Drop temporary 'row_num' column after sorting
df = df.drop('row_num', axis=1)
# Create a new column combining plate, row, and column for the index
df['plate_row_col'] = df['plateID'] + '-' + df['rowID'] + '-' + df['column_name']
# Set 'plate_row_col' as the index
df.set_index('plate_row_col', inplace=True)
# Extract only numeric data for the heatmap
heatmap_data = df.select_dtypes(include=[float, int])
# Plot heatmap with square boxes, no annotations, and 'viridis' colormap
plt.figure(figsize=(12, 8))
sns.heatmap(
heatmap_data,
cmap="viridis",
cbar=True,
square=True,
annot=False
)
plt.title("Heatmap of Prediction Scores for All Channels")
plt.xlabel("Channels")
plt.ylabel("Plate-Row-Column")
plt.tight_layout()
# Save the figure object and return it
fig = plt.gcf()
plt.show()
return fig
def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
# Ensure `folders` is a list
if isinstance(folders, str):
folders = [folders]
ls = [] # Initialize ls to store found CSV file paths
# Iterate over the provided folders
for folder in folders:
sub_folders = os.listdir(folder) # Get sub-folder list
for sub_folder in sub_folders: # Iterate through sub-folders
path = os.path.join(folder, sub_folder) # Join the full path
if os.path.isdir(path): # Check if it’s a directory
csv = os.path.join(path, csv_name) # Join path to the CSV file
if os.path.exists(csv): # If CSV exists, add to list
ls.append(csv)
else:
print(f'No such file: {csv}')
# Initialize combined DataFrame
combined_df = None
print(f'Found {len(ls)} CSV files')
# Loop through all collected CSV files and process them
for csv_file in ls:
df = pd.read_csv(csv_file) # Read CSV into DataFrame
df = df[df['column_name']==column]
if not plate is None:
df['plateID'] = f"plate{plate}"
# Group the data by 'plateID', 'rowID', and 'column_name'
grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])[data_column].mean().reset_index()
# Use the CSV filename to create a new column name
folder_name = os.path.dirname(csv_file).replace(".csv", "")
new_column_name = os.path.basename(f"{folder_name}_{data_column}")
print(new_column_name)
grouped_df = grouped_df.rename(columns={data_column: new_column_name})
# Merge into the combined DataFrame
if combined_df is None:
combined_df = grouped_df
else:
combined_df = pd.merge(combined_df, grouped_df, on=['plateID', 'rowID', 'column_name'], how='outer')
combined_df['prc'] = combined_df['plateID'].astype(str) + '_' + combined_df['rowID'].astype(str) + '_' + combined_df['column_name'].astype(str)
return combined_df
def calculate_mae(df):
"""
Calculate the MAE between each channel's predictions and the fraction column for all rows.
"""
# Extract numeric columns excluding 'fraction' and 'prc'
channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
mae_data = []
# Compute MAE for each channel with 'fraction' for all rows
for column in channels.columns:
for index, row in df.iterrows():
mae = mean_absolute_error([row['fraction']], [row[column]])
mae_data.append({'Channel': column, 'MAE': mae, 'Row': row['prc']})
# Convert the list of dictionaries to a DataFrame
mae_df = pd.DataFrame(mae_data)
return mae_df
result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plateID'], settings['columnID'], )
df = calculate_fraction_mixed_condition(settings['csv'], settings['plateID'], settings['columnID'], settings['control_sgrnas'])
df = df[df['grna_name']==settings['fraction_grna']]
fraction_df = df[['fraction', 'prc']]
merged_df = pd.merge(fraction_df, result_df, on=['prc'])
cv_df = group_cv_score(settings['cv_csv'], settings['plateID'], settings['columnID'], settings['data_column_cv'])
cv_df = cv_df[[settings['data_column_cv'], 'prc']]
merged_df = pd.merge(merged_df, cv_df, on=['prc'])
fig = plot_multi_channel_heatmap(merged_df, settings['columnID'])
if 'row_number' in merged_df.columns:
merged_df = merged_df.drop('row_num', axis=1)
mae_df = calculate_mae(merged_df)
if 'row_number' in mae_df.columns:
mae_df = mae_df.drop('row_num', axis=1)
if not settings['dst'] is None:
mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plateID']}.csv")
merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}_data.csv")
heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}.pdf")
mae_df.to_csv(mae_dst, index=False)
merged_df.to_csv(merged_dst, index=False)
fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')
return merged_df