#!/usr/bin/env python """ Created on Wed Jan 24 08:53:37 2024 @author: ian.michael.bollinger@gmail.com/researchconsultants@critical.consulting """ import os import logging logging.basicConfig(level=logging.INFO) import warnings warnings.filterwarnings("ignore", category=FutureWarning) import numpy as np import statsmodels.api as sm import pandas as pd import matplotlib.pyplot as plt import matplotlib.colors as mcolors from matplotlib.ticker import FuncFormatter from matplotlib.ticker import AutoMinorLocator import matplotlib matplotlib.use('Agg') import seaborn as sns sns.set(style="whitegrid") ### TOOLS def scientific_formatter(x, pos): return f"{x/1e7:.0f}x10$^{7}$" if x != 0 else "0" def log10_major_ticks(min_val, max_val): """Generate major tick marks for a log10 scale.""" major_ticks = np.logspace(np.floor(np.log10(min_val)), np.ceil(np.log10(max_val)), num=int(np.ceil(np.log10(max_val)) - np.floor(np.log10(min_val)) + 1)) return major_ticks def log10_minor_ticks(major_ticks): """Generate minor tick marks for a log10 scale.""" minor_ticks = [] for t in major_ticks[:-1]: minor_ticks.extend(np.linspace(t, t * 10, 10, endpoint=False)) return minor_ticks def calculate_N50(lengths): """ Calculate the N50 value from a series of lengths. """ if len(lengths) == 0: return 0 # or another appropriate value indicating no data sorted_lengths = sorted(lengths, reverse=True) cumsum_lengths = np.cumsum(sorted_lengths) total_length = cumsum_lengths[-1] return sorted_lengths[np.searchsorted(cumsum_lengths, total_length / 2)] ### PLOTTING def plot_length_histogram(data, output_path, plot_format, plot_stat, p1m, q): """ Plot a histogram of read lengths in stacked subplots. """ # Create a figure with subplots fig, axes = plt.subplots(2, 1, figsize=(p1m * 960/75, p1m * 960/75)) # Calculate logarithmically spaced bin edges min_length = max(data['sequence_length_template'].min(), 1) max_length = data['sequence_length_template'].max() bins = np.logspace(np.log10(min_length), np.log10(max_length), num=300) # Plot for 'All reads' sns.histplot(data[data['Q_cutoff'] == 'All reads']['sequence_length_template'], ax=axes[0], kde=False, bins=bins, color='#3b528bff') axes[0].set_xscale('log') axes[0].set_xlabel('') axes[0].set_ylabel('Number of Reads', fontsize=p1m * 20) axes[0].set_title('Read Length Distribution - All Reads', fontsize=p1m * 22) axes[0].set_xlim(min_length, max_length) # Set common x-axis range # Plot for 'Q>=7' sns.histplot(data[data['Q_cutoff'] == f'Q>={q}']['sequence_length_template'], ax=axes[1], kde=False, bins=bins, color='#5dc862ff') axes[1].set_xscale('log') axes[1].set_xlabel('Read Length (bases)', fontsize=p1m * 20) axes[1].set_ylabel('Number of Reads', fontsize=p1m * 20) axes[1].set_title(f'Read Length Distribution - Q>={q}', fontsize=p1m * 22) axes[1].set_xlim(min_length, max_length) # Set common x-axis range # Set major and minor ticks for the x-axis major_ticks = log10_major_ticks(min_length, max_length) minor_ticks = log10_minor_ticks(major_ticks) for ax in axes: ax.set_xticks(major_ticks, minor=False) ax.set_xticks(minor_ticks, minor=True) ax.grid(which='minor', axis='x', linestyle=':', linewidth='0.5', color='grey') # Add the N50 lines and labels if applicable if plot_stat: for ax in axes: n50 = calculate_N50(data[data['Q_cutoff'] == ax.get_title().split(" - ")[-1]]['sequence_length_template']) ax.axvline(x=n50, color='black', linestyle='dashed') ax.text(x=n50, y=0.95*ax.get_ylim()[1], s=f'N50: {n50}', rotation=90, verticalalignment='top') # Add the title to the plot fig.suptitle('Sequence Length Histograms', fontsize=p1m * 25) # Adjust layout plt.tight_layout() # Save the plot plt.savefig(f"{output_path}/length_histogram.{plot_format}") plt.close() def plot_qscore_histogram(data, output_path, plot_format, p1m, q): """ Plot a histogram of Q scores in stacked subplots. """ # Determine the common x-axis range min_qscore = min(data['mean_qscore_template'].min(), 0) # Including 0 for safety max_qscore = data['mean_qscore_template'].max() # Create a figure with subplots fig, axes = plt.subplots(2, 1, figsize=(p1m * 960/75, p1m * 960/75), sharex=True) # Plot for 'All reads' sns.histplot(data[data['Q_cutoff'] == 'All reads']['mean_qscore_template'], ax=axes[0], kde=False, bins=300, color='#3b528bff') axes[0].set_ylabel('Number of Reads', fontsize=p1m * 20) axes[0].set_title('Mean Quality (Q) Score Distribution - All Reads', fontsize=p1m * 22) axes[0].set_xlim(min_qscore, max_qscore) # Set common x-axis range # Plot for 'Q>=7' sns.histplot(data[data['Q_cutoff'] == f'Q>={q}']['mean_qscore_template'], ax=axes[1], kde=False, bins=300, color='#5dc862ff') axes[1].set_xlabel('Mean Quality (Q) Score of Read', fontsize=p1m * 20) axes[1].set_ylabel('Number of Reads', fontsize=p1m * 20) axes[1].set_title(f'Mean Quality (Q) Score Distribution - Q>={q}', fontsize=p1m * 22) axes[1].set_xlim(min_qscore, max_qscore) # Set common x-axis range # Apply minor ticks and grids to both histograms for ax in axes: ax.xaxis.set_minor_locator(AutoMinorLocator()) ax.grid(which='minor', axis='x', linestyle=':', linewidth='0.5', color='grey') # Add the title to the plot fig.suptitle('Sequence Quality (Q) Score Histograms', fontsize=p1m * 25) # Adjust layout plt.tight_layout() # Save the plot plt.savefig(f"{output_path}/qscore_histogram.{plot_format}") plt.close() def plot_yield_over_time(data, output_path, muxes, plot_format, p1m, q): """ Plot the yield over time. """ # Creating the plot plt.figure(figsize=(p1m*960/75, p1m*480/75)) # plt.figure(figsize=(38.4, 19.2)) palette = {"All reads": "#3b528bff", f"Q>={q}": "#5dc862ff"} sns.lineplot(data=data, x='hour', y='cumulative.bases.time', hue='Q_cutoff', palette=palette) plt.xlabel('Hours Into Run', size = p1m * 20) plt.ylabel('Total Yield in Gigabases (GB)', size = p1m * 20) plt.title('Gigabase (GB) Yield and Quality (Q) Over Time', size = p1m * 25) # Add the vertical lines for given intervals for interval in muxes: plt.axvline(x=interval, color='red', linestyle='dashed', alpha=0.5) # Adjust layout plt.tight_layout() # Save the plot plt.savefig(f"{output_path}/yield_over_time.{plot_format}") plt.close() def plot_yield_by_length(data, output_path, plot_format, p1m, q): """ Plot the yield over time. """ # Calculate xmax xmax = data.loc[data['cumulative.bases'] > 0.01 * data['cumulative.bases'].max(), 'sequence_length_template'].max() plt.figure(figsize=(p1m*960/75, p1m*480/75)) sns.lineplot(data=data, x='sequence_length_template', y=data['cumulative.bases']/1e9, hue='Q_cutoff', palette="viridis") plt.xlabel('Minimum Read Length (bases)', size = p1m * 20) plt.ylabel('Total Yield in Gigabases (GB)', size = p1m * 20) plt.title('Read Length (bases) per Gigabase (GB) Generated', size = p1m * 25) plt.xlim(0, xmax) plt.tight_layout() plt.savefig(f"{output_path}/yield_by_length.{plot_format}") plt.close() def plot_sequence_length_over_time(data, output_path, muxes, plot_format, p1m, q): # Create the plot plt.figure(figsize=(p1m*960/75, p1m*480/75)) # Filter the data for 'All reads' and 'Q>={q}' df_all_reads = data[data['Q_cutoff'] == 'All reads'] df_q = data[data['Q_cutoff'] == f'Q>={q}'] # Get the count for each unique hour count_hours_all = df_all_reads['hour'].value_counts().to_dict() count_hours_q = df_q['hour'].value_counts().to_dict() # Create a trimmed dataframe with 'hour' values in the dictionary whose value is < 5 df_all_reads_trimmed = df_all_reads[df_all_reads['hour'].map(count_hours_all) > 5] df_q_trimmed = df_q[df_q['hour'].map(count_hours_q) > 5] # Calculate lowess smoothed values with a smaller fraction for less smoothing lowess_all_reads = sm.nonparametric.lowess(df_all_reads_trimmed['sequence_length_template'], df_all_reads_trimmed['hour'], frac=0.25) lowess_q = sm.nonparametric.lowess(df_q_trimmed['sequence_length_template'], df_q_trimmed['hour'], frac=0.25) # Plotting the smoothed curves plt.plot(lowess_all_reads[:, 0], lowess_all_reads[:, 1], color='#3b528bff', lw=2, label='All Reads') plt.plot(lowess_q[:, 0], lowess_q[:, 1], color='#5dc862ff', lw=2, label=f'Q>={q}') # Trim the dataframes df_all_reads_trimmed = df_all_reads_trimmed[['hour', 'sequence_length_template']] df_q_trimmed = df_q_trimmed[['hour', 'sequence_length_template']] # Plotting with seaborn lineplot using trimmed dataframes sns.lineplot(data=df_all_reads_trimmed, x='hour', y='sequence_length_template', label='All Reads Mean', estimator='mean', color='#3b528bff', lw=1, linestyle='dashed') sns.lineplot(data=df_q_trimmed, x='hour', y='sequence_length_template', label=f'Q>={q} Mean', estimator='mean', color='#5dc862ff', lw=1, linestyle='dashed') # Set the plot labels and title plt.xlabel('Hours Into Run', size = p1m * 20) plt.ylabel('Mean Read Length (bases)', size = p1m * 20) plt.yticks(size = p1m * 7) plt.title('Sequence Length Over Time', size = p1m * 25) # Add the vertical lines for given intervals for interval in muxes: plt.axvline(x=interval, color='red', linestyle='dashed', alpha=0.5) # Adjust legend position and font size legend = plt.legend(title='Reads', loc='right', bbox_to_anchor=(1.225, 0.5)) legend.get_title().set_fontsize(p1m * 12) for label in legend.get_texts(): label.set_fontsize(p1m * 10) # Adjust layout plt.tight_layout() # Save the plot plt.savefig(f"{output_path}/length_by_hour.{plot_format}") plt.close() def plot_qscore_over_time(data, output_path, muxes, plot_format, p1m, q): """ Plot Q score over time """ # Create the plot plt.figure(figsize=(p1m*960/75, p1m*480/75)) # Filter out the negative values from 'sequence_length_template' in both subsets df_all_reads = data[(data['Q_cutoff'] == 'All reads') & (data['mean_qscore_template'] > 0)] df_q = data[(data['Q_cutoff'] == f'Q>={q}') & (data['mean_qscore_template'] > 0)] # Get the count for each unique hour count_hours_all = df_all_reads['hour'].value_counts().to_dict() count_hours_q = df_q['hour'].value_counts().to_dict() # Create a trimmed dataframe with 'hour' values in the dictionary whose value is < 5 df_all_reads_trimmed = df_all_reads[df_all_reads['hour'].map(count_hours_all) > 5] df_q_trimmed = df_q[df_q['hour'].map(count_hours_q) > 5] # Calculate lowess smoothed values with a smaller fraction for less smoothing lowess_all_reads = sm.nonparametric.lowess(df_all_reads_trimmed['mean_qscore_template'], df_all_reads_trimmed['hour'], frac=0.25) lowess_q = sm.nonparametric.lowess(df_q_trimmed['mean_qscore_template'], df_q_trimmed['hour'], frac=0.25) # Plotting the smoothed curves plt.plot(lowess_all_reads[:, 0], lowess_all_reads[:, 1], color='#3b528bff', lw=2, label='All Reads') plt.plot(lowess_q[:, 0], lowess_q[:, 1], color='#5dc862ff', lw=2, label=f'Q>={q}') # Trim the dataframes df_all_reads_trimmed = df_all_reads_trimmed[['hour', 'mean_qscore_template']] df_q_trimmed = df_q_trimmed[['hour', 'mean_qscore_template']] # Plotting with seaborn lineplot using trimmed dataframes sns.lineplot(data=df_all_reads_trimmed, x='hour', y='mean_qscore_template', label='All Reads Mean', color='#3b528bff', lw=1, linestyle='dashed') sns.lineplot(data=df_q_trimmed, x='hour', y='mean_qscore_template', label=f'Q>={q} Mean', color='#5dc862ff', lw=1, linestyle='dashed') # Set the plot labels and title plt.xlabel('Hours Into Run', size = p1m * 20) plt.ylabel('Mean Quality (Q) Score', size = p1m * 20) plt.title('Quality (Q) Scores Over Time', size=p1m * 25) # Add the vertical lines for given intervals for interval in muxes: plt.axvline(x=interval, color='red', linestyle='dashed', alpha=0.5) # Adjust legend position and font size legend = plt.legend(title='Reads', loc='right', bbox_to_anchor=(1.225, 0.5)) legend.get_title().set_fontsize(p1m * 12) for label in legend.get_texts(): label.set_fontsize(p1m * 10) # Adjust layout plt.tight_layout() # Save the plot plt.savefig(f"{output_path}/q_by_hour.{plot_format}") plt.close() def plot_reads_per_hour(data, output_path, muxes, plot_format, p1m, q): """ Plot number of reads per hour """ plt.figure(figsize=(p1m*960/75, p1m*480/75)) # plt.figure(figsize=(38.4, 19.2)) palette = {"All reads": "#3b528bff", f"Q>={q}": "#5dc862ff"} min_x_value = int(data['hour'].min()) # Convert to integer max_x_value = int(data['hour'].max()) # Convert to integer # Ensure all hours are represented in the data all_hours = pd.DataFrame({'hour': range(min_x_value, max_x_value + 1)}) data = pd.merge(all_hours, data, on='hour', how='left') data['reads_per_hour'].fillna(0, inplace=True) sns.pointplot(data=data, x='hour', y='reads_per_hour', hue='Q_cutoff', palette=palette) plt.xlabel('Hours Into Run') plt.ylabel('Number of Reads per Hour') plt.title('Reads Generated per Hour', size = p1m * 25) # Add the vertical lines for given intervals for interval in muxes: plt.axvline(x=interval, color='red', linestyle='dashed', alpha=0.5) plt.xlim(0, max_x_value*1.05) # Set X-ticks as the x_list plt.xticks(range(min_x_value, max_x_value + 1)) plt.tight_layout() plt.savefig(f"{output_path}/reads_per_hour.{plot_format}") plt.close() # Function for channel_summary.png def plot_channel_summary_histograms(df_all_reads, df_q_cutoff, output_dir, plot_format, p1m, q): # Rename the columns as per the R script rename_dict = { 'total_bases': 'Number of Bases per Channel', 'total_reads': 'Number of Reads per Channel', 'mean_read_length': 'Mean Read Length per Channel', 'median_read_length': 'Median Read Length per Channel' } df_all_reads['variable'] = df_all_reads['variable'].map(rename_dict).fillna(df_all_reads['variable']) df_q_cutoff['variable'] = df_q_cutoff['variable'].map(rename_dict).fillna(df_q_cutoff['variable']) # Specify the order of variables ordered_variables = [ 'Mean Read Length per Channel', 'Median Read Length per Channel', 'Number of Bases per Channel', 'Number of Reads per Channel' ] # Create a 4x2 grid of subplots (4 variables, 2 dataframes) fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(p1m*960/75, p1m*480/75), sharex='col') # Plot histograms and find the maximum value for y-axis max_value = 0 for i, variable in enumerate(ordered_variables): # Plot for "All reads" dataframe df_var_all = df_all_reads[df_all_reads['variable'] == variable] ax_all = axes[0, i] sns.histplot(df_var_all['value'], ax=ax_all, color='#3b528bff', bins=30) max_value = max(max_value, ax_all.get_ylim()[1]) # Plot for "Q-cutoff" dataframe df_var_cutoff = df_q_cutoff[df_q_cutoff['variable'] == variable] ax_cutoff = axes[1, i] sns.histplot(df_var_cutoff['value'], ax=ax_cutoff, color='#5dc862ff', bins=30) max_value = max(max_value, ax_cutoff.get_ylim()[1]) if variable == "Number of Bases per Channel": ax_all.xaxis.set_major_formatter(FuncFormatter(scientific_formatter)) ax_cutoff.xaxis.set_major_formatter(FuncFormatter(scientific_formatter)) # Adjust y-axis limits max_value *= 1.05 for ax in axes.flatten(): ax.set_ylim(0, max_value) # Set titles and labels for i, variable in enumerate(ordered_variables): axes[0, i].set_title(variable) axes[1, i].set_xlabel('') if i > 0: axes[0, i].set_ylabel("") axes[1, i].set_ylabel("") else: axes[0, i].set_ylabel("Count") axes[1, i].set_ylabel("Count") # Add the title to the plot fig.suptitle('Flowcell Summary Histograms', fontsize=p1m * 25) # Adjust layout and save the plot plt.tight_layout() plt.savefig(f'{output_dir}/channel_summary.{plot_format}', bbox_inches='tight') plt.close() # Function for length_vs_q.png scatterplot def plot_length_vs_q(data, output_dir, plot_format, p1m, q): # Filter for 'All reads' df_filtered = data[data['Q_cutoff'] == 'All reads'] # Create figure and axes for the scatter plot fig, ax = plt.subplots(figsize=(p1m*960/75, p1m*960/75)) # Check if the data is from MinION or PromethION if df_filtered['channel'].max() <= 512: # MinION # Define the normalization range based on 'events_per_base' norm = mcolors.LogNorm(vmin=df_filtered['events_per_base'].min(), vmax=df_filtered['events_per_base'].max()) # Normalize 'events_per_base' values and apply the 'rocket' color map cmap = sns.color_palette("rocket", as_cmap=True) colors = cmap(norm(df_filtered['events_per_base'])) # Create a scatter plot with normalized color values scatter = ax.scatter(df_filtered['sequence_length_template'], df_filtered['mean_qscore_template'], color=colors, alpha=0.05, s=0.4) # Create colorbar in the figure, attached to the axes sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) cbar = fig.colorbar(sm, ax=ax) cbar.set_label('Events per Base (log scale)', rotation=270, labelpad=20) else: # PromethION # Use hexbin for 2D histogram with 'rocket' color map hexbin = ax.hexbin(df_filtered['sequence_length_template'], df_filtered['mean_qscore_template'], gridsize=50, cmap="rocket", bins='log') # Create colorbar in the figure, attached to the axes cbar = fig.colorbar(hexbin, ax=ax) cbar.set_label('Counts in bin', rotation=270, labelpad=20) # Set axis scales and labels ax.set_xscale('log') ax.set_xlabel('Read Length (bases)', size = p1m * 15) ax.set_ylabel('Mean Quality (Q) Score of Read', size = p1m * 15) ax.set_title('Read Length (bases) vs Quality (Q) Score', size = p1m * 25) # Set the y-axis to start at 0 ax.set_ylim(bottom=0) # Set the x-axis to start at 0 ax.set_xlim(left=1) # Enable minor gridlines on the x-axis ax.xaxis.grid(True, which='minor', linewidth=0.5) # Adjust layout to prevent overlap plt.tight_layout(rect=[0, 0, 1, 1]) plt.savefig(f'{output_dir}/length_vs_q.{plot_format}') plt.close() def plot_both_per_channel(df1, df2, output_dir, plot_format, p1m, q): # Function to process and pivot data def process_data(df): aggregated_df = df.groupby(['row', 'col']).agg({'value': 'mean'}).reset_index() return aggregated_df.pivot(index='row', columns='col', values='value') # Process both dataframes pivot_table1 = process_data(df1) pivot_table2 = process_data(df2) # Determine the common color scale for both heatmaps min_value = min(pivot_table1.min().min(), pivot_table2.min().min()) max_value = max(pivot_table1.max().max(), pivot_table2.max().max()) # Set up the plot with adjusted figsize and width ratios fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(p1m * 960/75, p1m * 960/75), # Reduced width gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.01}, sharey=True) # Plot heatmap for df1 sns.heatmap(pivot_table1, annot=False, cmap="viridis", square=True, ax=axes[0], vmin=min_value, vmax=max_value, cbar=False) axes[0].set_ylabel('Channel Row', fontsize=p1m * 25) # Plot heatmap for df2 with a colorbar cbar_ax = fig.add_axes([.93, .3, .02, .4]) # Position for the colorbar sns.heatmap(pivot_table2, annot=False, cmap="viridis", square=True, ax=axes[1], vmin=min_value, vmax=max_value, cbar_ax=cbar_ax) # Calculate the positions based on min and max values tick_positions = [min_value + 1/3 * (max_value - min_value), min_value + 2/3 * (max_value - min_value)] cbar_ax.set_yticks(tick_positions) cbar_ax.set_yticklabels(['0.01', '0.02']) # Add title "GB/Channel" to Legend (Colorbar) cbar_ax.set_title("GB per\nChannel", fontsize=p1m * 15) # Set colorbar text size cbar_ax.tick_params(labelsize= p1m * 10) # Adjust tick label size and set ticks for both plots for ax in axes: ax.set_xlabel('') # Make sure there is no background grid ax.grid(False) # Find the number of rows and columns in the data nrows, ncols = pivot_table1.shape if ax == axes[0] else pivot_table2.shape # Set ticks at specific intervals # For x-axis, every 4 units; for y-axis, every 10 units xticks = np.arange(3.5, ncols, 4) yticks = np.arange(9.5, nrows, 10) ax.set_xticks(xticks, minor=False) ax.set_yticks(yticks, minor=False) # Setting new labels as simple integers ax.set_xticklabels([int(round(x,0)) for x in xticks], fontsize= p1m * 20) # +1 if indexing starts from 1 ax.set_yticklabels([int(round(y,0)) for y in yticks], fontsize= p1m * 20) ax.tick_params(axis='x', rotation=0, labelsize= p1m * 22) ax.tick_params(axis='y', labelsize= p1m * 20) # Draw a vertical white line from x = 1 to x = 15 for x in range(1, 16): ax.axvline(x=x, color='white', linestyle='-', lw=2) # Draw a vertical white line from x = 1 to x = 15 for y in range(1, 32): ax.axhline(y=y, color='white', linestyle='-', lw=2) # Set a centralized x-axis label for the entire figure fig.text(0.5, 0.04, 'Channel Column', ha='center', va='center', fontsize= p1m * 22) # Add light grey box behind each heatmap title for ax, title in zip(axes, ['All Reads', f'Q>={q}']): ax.text(0.5, 1.05, title, transform=ax.transAxes, fontsize= p1m * 22, horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.5')) # Ensure output directory exists if not os.path.exists(output_dir): os.makedirs(output_dir) # Add the title to the plot fig.suptitle('Gigabases (GB) per Channel Overview', fontsize=p1m * 25) # # Save the dual plot with adjusted layout # plt.tight_layout(rect=[0, 0, .9, 1]) plt.savefig(os.path.join(output_dir, f"gb_per_channel_overview.{plot_format}")) plt.close() def plot_flowcell_overview(data, output_dir, plot_format, p1m, q): """ Generate a flowcell overview plot Args: data (DataFrame): The DataFrame containing the data to plot. output_dir (str): Directory to save the output plot. plot_format (str, optional): Format of the saved plot. Defaults to 'png'. """ # Create a large figure fig = plt.figure(figsize=(p1m*80, p1m*76.4)) # fig = plt.figure(figsize=(80, 76.4)) # Define grid size num_columns = 16 num_rows = 32 # Add axes for subplots and colorbar grid_size = (num_rows, num_columns + 1) # +1 to account for the colorbar axs = [plt.subplot2grid(grid_size, (i, j)) for i in range(num_rows) for j in range(num_columns)] # Filter for 'All reads' and create a copy to avoid SettingWithCopyWarning all_reads_data = data[data['Q_cutoff'] == 'All reads'].copy() all_reads_data['start_time_hours'] = all_reads_data['start_time'] / 3600 # Convert start_time to hours for i, ax in enumerate(axs): # Determine the row and column number for this subplot row_number = i // num_columns + 1 col_number = i % num_columns + 1 # Filter data for this specific subplot based on row and col subplot_data = all_reads_data[(all_reads_data['row'] == row_number) & (all_reads_data['col'] == col_number)] if not subplot_data.empty and 'mean_qscore_template' in subplot_data.columns: valid_scores = subplot_data['mean_qscore_template'].dropna() if not valid_scores.empty: # Generate Mini plot mini_plot = sns.scatterplot(x='start_time_hours', y='sequence_length_template', hue='mean_qscore_template', data=subplot_data, ax=ax, palette="viridis", alpha=0.35) # Remove y-axis and x-axis labels of mini-plot if present ax.set_xlabel('') ax.set_ylabel('') # Remove legend from mini_plot if ax.get_legend(): ax.get_legend().remove() else: logging.info(f"No valid 'mean_qscore_template' data for subplot {i}") else: logging.info(f"Empty data or missing 'mean_qscore_template' column for subplot {i}") # Set y-axis to log scale ax.set_yscale('log') # Set y-ticks and label size ax.set_yticks([1e+01, 1e+02, 1e+03, 1e+04, 1e+05]) if i % num_columns != 0: ax.set_yticklabels([]) else: ax.tick_params(axis='y', labelsize= p1m * 25) # Set x-ticks and label size ax.set_xticks([0, 10, 20, 30, 40]) if i < num_columns * (num_rows - 1): ax.set_xticklabels([]) else: ax.tick_params(axis='x', labelsize= p1m * 25) # Add overall Title, y-axis label, and x-axis label fig.text(0.5, 0.95, 'Individual Flowcell Read Length & Quality (Q) Over Time', ha='center', fontsize= p1m * 175) fig.text(0.05, 0.5, 'Read Length (bases)', va='center', rotation='vertical', fontsize= p1m * 150) fig.text(0.5, 0.05, 'Hours Into Run', ha='center', fontsize= p1m * 150) # Create a colorbar for the Viridis palette viridis = plt.get_cmap('viridis') sm = plt.cm.ScalarMappable(cmap=viridis, norm=plt.Normalize(vmin=0, vmax=16)) sm.set_array([]) # Create a new axis for the colorbar with desired position cbar_ax = fig.add_axes([0.92, 0.1, 0.02, 0.8]) cbar = fig.colorbar(sm, cax=cbar_ax, orientation='vertical') cbar.ax.set_title('Q', size= p1m * 100) cbar.ax.tick_params(labelsize= p1m * 75) # Add light grey box behind each column heatmap title first_row_axes = axs[:num_columns] column_label_list = [str(i) for i in range(1, num_columns + 1)] for ax, title in zip(first_row_axes, column_label_list): ax.text(0.5, 1.7, title, transform=ax.transAxes, fontsize= p1m * 75, horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.5')) # Calculate x position for the row labels (to the right of the subplots, but left of the colorbar) label_x_pos = (num_columns) / (num_columns + 1) * 0.9325 # Adjust this as needed # Create a list of integers for row titles row_label_list = [str(i) for i in range(1, num_rows + 1)] # Add row titles directly to the figure for i, ax in zip(range(num_rows), axs[::num_columns]): # Get the bounding box of the subplot in figure coordinates bbox = ax.get_window_extent().transformed(fig.transFigure.inverted()) y_pos = bbox.y0 + bbox.height / 2 # Vertical center of the subplot fig.text(label_x_pos, y_pos, row_label_list[i], fontsize= p1m * 70, horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.5'), transform=fig.transFigure) # Adjust spacing between subplots plt.subplots_adjust(wspace=0.1, hspace=0.1) # Save and show the plot if not os.path.exists(output_dir): os.makedirs(output_dir) plt.savefig(os.path.join(output_dir, f"flowcell_overview.{plot_format}")) plt.close() def plot_multi_flowcell(multi_data, output_dir, plot_format, p1m, q): # Ensure the DataFrame has a 'flowcell' column if 'flowcell' not in multi_data.columns: raise ValueError("DataFrame must contain a 'flowcell' column.") # Set mux intervals (modify as needed) mux_intervals = np.arange(0, multi_data['hour'].max() + 1, 8) # Call functions for each plot type plot_length_distributions(multi_data, output_dir, plot_format, p1m, q) plot_qscore_distributions(multi_data, output_dir, plot_format, p1m, q) plot_yield_over_time_multi(multi_data, output_dir, mux_intervals, plot_format, p1m, q) plot_sequence_length_over_time_multi(multi_data, output_dir, mux_intervals, plot_format, p1m, q) plot_qscore_over_time_multi(multi_data, output_dir, mux_intervals, plot_format, p1m, q) plot_yield_by_length_multi(multi_data, output_dir, plot_format, p1m, q) def plot_length_distributions(multi_data, output_dir, plot_format, p1m, q): # Example: Creating a density plot of read lengths for each flowcell plt.figure(figsize=(p1m*960/75, p1m*480/75)) for flowcell in multi_data['flowcell'].unique(): flowcell_data = multi_data[multi_data['flowcell'] == flowcell] sns.kdeplot(flowcell_data['sequence_length_template'], label=flowcell) plt.xscale('log') plt.xlabel('Read length (bases)', size = p1m * 15) plt.ylabel('Density', size = p1m * 15) plt.title('Length Distributions Across Flowcells', size = p1m * 25) plt.legend(title='Flowcells', fontsize=p1m * 12) plt.savefig(f"{output_dir}/length_distributions.{plot_format}") plt.close() def plot_qscore_distributions(multi_data, output_dir, plot_format, p1m, q): """ Plot Q score distributions for each flowcell. Args: multi_data (DataFrame): The DataFrame containing the data to plot. output_dir (str): Directory to save the output plot. q (float): Q score cutoff. p1m (float): Scaling factor for plot size. plot_format (str): Format of the saved plot. """ plt.figure(figsize=(p1m * 960/75, p1m * 480/75)) # Ensure 'flowcell' and 'mean_qscore_template' columns are present if 'flowcell' not in multi_data.columns or 'mean_qscore_template' not in multi_data.columns: raise ValueError("DataFrame must contain 'flowcell' and 'mean_qscore_template' columns.") # Iterate over each flowcell and plot Q score distributions for flowcell in multi_data['flowcell'].unique(): flowcell_data = multi_data[multi_data['flowcell'] == flowcell] sns.kdeplot(flowcell_data['mean_qscore_template'], label=flowcell) plt.xlabel('Mean Q Score of Read', fontsize=p1m * 15) plt.ylabel('Density', fontsize=p1m * 15) plt.title('Q Score Distributions Across Flowcells', fontsize=p1m * 25) plt.legend(title='Flowcell', fontsize=p1m * 12) # Save the plot plt.tight_layout() plt.savefig(os.path.join(output_dir, f"qscore_distributions.{plot_format}")) plt.close() def plot_yield_over_time_multi(multi_data, output_dir, mux_intervals, plot_format, p1m, q): plt.figure(figsize=(p1m*960/75, p1m*480/75)) for flowcell in multi_data['flowcell'].unique(): df_flowcell = multi_data[multi_data['flowcell'] == flowcell] sns.lineplot(data=df_flowcell, x='hour', y='cumulative.bases.time', label=flowcell) # Add vertical lines for mux intervals for interval in mux_intervals: plt.axvline(x=interval, color='red', linestyle='dashed', alpha=0.5) plt.xlabel('Hours Into Run', size=p1m*15) plt.ylabel('Total Yield in Gigabases (GB)', size=p1m*15) plt.title('Yield Over Time - Multiple Flowcells', size=p1m*25) plt.legend(title='Flowcell', fontsize=p1m*12) plt.tight_layout() plt.savefig(os.path.join(output_dir, f"yield_over_time_multi.{plot_format}")) plt.close() def plot_yield_by_length_multi(multi_data, output_dir, plot_format, p1m, q): plt.figure(figsize=(p1m*960/75, p1m*480/75)) for flowcell in multi_data['flowcell'].unique(): df_flowcell = multi_data[multi_data['flowcell'] == flowcell] sns.lineplot(data=df_flowcell, x='sequence_length_template', y=df_flowcell['cumulative.bases']/1e9, label=flowcell) plt.xlabel('Minimum Read Length (bases)', size=p1m*15) plt.ylabel('Total Yield in Gigabases (GB)', size=p1m*15) plt.title('Yield by Length - Multiple Flowcells', size=p1m*25) plt.legend(title='Flowcell', fontsize=p1m*12) plt.tight_layout() plt.savefig(os.path.join(output_dir, f"yield_by_length_multi.{plot_format}")) plt.close() def plot_sequence_length_over_time_multi(multi_data, output_dir, mux_intervals, plot_format, p1m, q): plt.figure(figsize=(p1m*960/75, p1m*480/75)) for flowcell in multi_data['flowcell'].unique(): df_flowcell = multi_data[multi_data['flowcell'] == flowcell] sns.lineplot(data=df_flowcell, x='hour', y='sequence_length_template', label=flowcell) # Add vertical lines for mux intervals for interval in mux_intervals: plt.axvline(x=interval, color='red', linestyle='dashed', alpha=0.5) plt.xlabel('Hours Into Run', size=p1m*15) plt.ylabel('Sequence Length (bases)', size=p1m*15) plt.title('Sequence Length Over Time - Multiple Flowcells', size=p1m*25) plt.legend(title='Flowcell', fontsize=p1m*12) plt.tight_layout() plt.savefig(os.path.join(output_dir, f"length_over_time_multi.{plot_format}")) plt.close() def plot_qscore_over_time_multi(multi_data, output_dir, mux_intervals, plot_format, p1m, q): plt.figure(figsize=(p1m*960/75, p1m*480/75)) for flowcell in multi_data['flowcell'].unique(): df_flowcell = multi_data[multi_data['flowcell'] == flowcell] sns.lineplot(data=df_flowcell, x='hour', y='mean_qscore_template', label=flowcell) # Add vertical lines for mux intervals for interval in mux_intervals: plt.axvline(x=interval, color='red', linestyle='dashed', alpha=0.5) plt.xlabel('Hours Into Run', size=p1m*15) plt.ylabel('Mean Q Score', size=p1m*15) plt.title('Q Score Over Time - Multiple Flowcells', size=p1m*25) plt.legend(title='Flowcell', fontsize=p1m*12) plt.tight_layout() plt.savefig(os.path.join(output_dir, f"qscore_over_time_multi.{plot_format}")) plt.close()