#Cell 1;

import os
import glob
import shutil
import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.patches import Ellipse, Polygon
from matplotlib.colors import LogNorm
from PIL import Image, ImageChops
from scipy.ndimage import binary_dilation, label
from skimage.measure import label, regionprops
from skimage.morphology import remove_small_objects, remove_small_holes
from scipy.signal import convolve as scipy_convolve
from shapely.geometry import Polygon, Point
from astropy import units, wcs, stats
from astropy.io import ascii, fits
from astropy.nddata import Cutout2D, block_reduce
from astropy.wcs import WCS
from astropy import units as u

from astropy.utils.data import download_file
from astropy.visualization import (astropy_mpl_style, simple_norm, AsinhStretch,
ImageNormalize, MinMaxInterval, SqrtStretch)
from astropy.visualization.wcsaxes import WCSAxes
from astropy.stats import mad_std, sigma_clip, sigma_clipped_stats
from astropy.convolution import Gaussian2DKernel, convolve
from radio_beam import Beam
from pyavm import AVM
from photutils.segmentation import detect_threshold, detect_sources
from photutils.background import MedianBackground
from astrodendro import Dendrogram, pp_catalogue
from astrodendro.analysis import PPStatistic

#Cell 2;
pwd = 'C:/Users/andy-/OneDrive/Python/' # Edit this for current directory
def process_images(directory=pwd+'Images/', output_directory=pwd+'Images/Ratio'):
os.makedirs(output_directory, exist_ok=True)

    for filename in filter(lambda f: f.endswith('I3.fits'), os.listdir(directory)):
        filename1 = filename.replace('I3', 'I1')
        filename2 = filename.replace('I3', 'I2')
        filepath1, filepath2 = os.path.join(directory, filename1), os.path.join(directory, filename2)
        if os.path.exists(filepath2):
            hdul1, hdul2 = fits.open(filepath1), fits.open(filepath2)
            img1, img2 = hdul1[0].data, hdul2[0].data
            ratio_image = img2 / img1
            hdu = fits.PrimaryHDU(ratio_image, header=hdul2[0].header)
            output_path = os.path.join(output_directory, filename1.replace('I1', 'R2_1'))
            fits.HDUList([hdu]).writeto(output_path, overwrite=True)
            hdul1.close()
            hdul2.close()

process_images()

#Cell 3;
def check_color_requirements(region, i2_data, i3_data, i4_data, i2_filepath): # Calculate the mean flux for the region in each image
mean_i2 = np.mean(i2_data[region])
mean_i3 = np.mean(i3_data[region])
mean_i4 = np.mean(i4_data[region]) # Calculate the flux ratios
ratio_2_3 = mean_i2 / mean_i3 if mean_i3 != 0 else np.nan
ratio_2_4 = mean_i2 / mean_i4 if mean_i4 != 0 else np.nan
return mean_i2,mean_i3,mean_i4,ratio_2_3,ratio_2_4

    # Define the background calculation and subtraction functions

def calculate_background(region, data): # Create a 2-pixel wide region around the green fuzzy

    sr = binary_dilation(region, iterations=2)

    surrounding_region = np.logical_xor(sr,region)
    # Get the pixel values in the surrounding region
    surrounding_values = data[surrounding_region]
    sorted_values = np.sort(surrounding_values)
    # Get the value at which 30% of the values are lower
    background_value = np.median(surrounding_values)
    return background_value

def subtract_background(region, data): # Calculate the background value
background_value = calculate_background(region, data) # Subtract the background value from the image data
bg_subtracted_data = data - background_value
return bg_subtracted_data

    # Define the significant detection function

def check_significant_detection(region, i2_data, sigma=0.252): # Calculate the mean background subtracted value for the region
mean_value = np.mean(i2_data[region]) # Check if the mean value is greater than or equal to 5σ
return mean_value/sigma

    # Function for displaying data

def display_data(data, header, fig, title=None):
wcs = WCS(header)
ax=plt.subplot(projection=wcs)
finite_data = data[np.isfinite(data)]
if finite_data.size > 0: # Check if the array is not empty
vmax = np.percentile(finite_data, 97)
else:
print("Warning: All data values are non-finite.")
vmax = 1
ax.imshow(data, origin='lower')
lon=ax.coords[0]
lat=ax.coords[1]
lon.set_axislabel('Galactic Longitude')
lat.set_axislabel('Galactic Latitude')
if title:
plt.title(title)

    # Function to generate mask

def generate_mask(d, idx):
structure = d[idx]
mask = structure.get_mask()
return mask.astype('short')

    # Function to generate grid for mask

def generate_grid(nx,ny):
x, y = np.meshgrid(np.arange(nx), np.arange(ny))
x, y = x.flatten(), y.flatten()
points = np.vstack((x,y)).T
return points

    # Function to apply mask to data

def apply_mask(mask, data):
return mask\*data

    # Function to save masked image

def save_masked_image(data, header, output_filepath):
hdu = fits.PrimaryHDU(data, header)
hdu.writeto(output_filepath, overwrite=True)

def process_image(i2_filepath, i3_filepath, i4_filepath, r21_filepath, dendro_dir,metadata): # Algorithm parameters # Dendrogram
min_value = 2.0
min_npix = 3 # Colour requirements
r23Lim = 0.4
r24Lim = 0.45 # Signal to noise limit
SNRLim = 5

    # Open the provided image files
    r21 = fits.open(r21_filepath)[0]
    i2 = fits.open(i2_filepath)[0]
    i3 = fits.open(i3_filepath)[0]
    i4 = fits.open(i4_filepath)[0]
    wcs = WCS(r21.header)

    # Replace NaN values in the data with 0
    i2.data[np.isnan(i2.data)] = 0
    i3.data[np.isnan(i3.data)] = 0
    i4.data[np.isnan(i4.data)] = 0

    # Generate and visualize the dendrogram
    d = Dendrogram.compute(r21.data, min_value=min_value, min_npix=min_npix)

    # Save dendrogram
    os.makedirs(dendro_dir, exist_ok=True)
    fieldName = i2_filepath.split('/')[-1].replace('I2','').replace('.fits','')
    dendroFile = dendro_dir+'/'+fieldName+'dendro2mv.fits'
    d.save_to(dendroFile, format='fits')
    print(f'Dendrogram saved to {dendroFile}c')

    # Set up log file to record EGOs and rejects
    egoFile = dendro_dir+'/'+fieldName+'egos2mv.dat'
    egoF = open(egoFile,'w')
    print(f'EGO catalogue {egoFile}')
    rejectFile = dendro_dir+'/'+fieldName+'rejects2mv.dat'
    rejectF = open(rejectFile,'w')
    print(f'Reject catalogue {rejectFile}')
    # Write out the parameter values used
    hdr0 = '# Parameter values\n' + f'# {r23Lim=} {r24Lim=} {SNRLim=}' +'\n'+\
        f'# Dendrogram values {min_value=} {min_npix=}'+'\n'
    egoF.write(hdr0)
    rejectF.write(hdr0)

    # Write out some headers to the catalogues to identify the entries
    hdr1 = '# idx SNR ratio_2_3 ratio_2_4 egoStat.x_cen(deg) egoStat.y_cen(deg) egoStat.area_exact(arcsec^2)\n'
    egoF.write(hdr1)
    rejectF.write(hdr1)

    structures = list(d.all_structures)
    if len(structures) == 0:
        print(f"No structures found in {r21_filepath}.")
        return None
        # Loop through each structure
    else:
        print(f"{len(structures)} structures found in {r21_filepath}.")
    # Step through structures
    for idx, structure in enumerate(structures):
        mask = generate_mask(d, idx)
        # Check for NaN values in region before calculating background
        if np.any(np.isnan(i2.data[mask])):
            print(f"Warning: NaN values found in i2 data for structure {idx}.")
        if np.any(np.isnan(i3.data[mask])):
            print(f"Warning: NaN values found in i3 data for structure {idx}.")
        if np.any(np.isnan(i4.data[mask])):
            print(f"Warning: NaN values found in i4 data for structure {idx}.")
        # Calculate the background values
        i2_bck = subtract_background(mask, i2.data)
        i3_bck = subtract_background(mask, i3.data)
        i4_bck = subtract_background(mask, i4.data)
        # Check for NaN values after background subtraction
        if np.any(np.isnan(i2_bck[mask])):
            print(f"Warning: NaN values found in i2_bck for structure {idx}.")
        if np.any(np.isnan(i3_bck[mask])):
            print(f"Warning: NaN values found in i3_bck for structure {idx}.")
        if np.any(np.isnan(i4_bck[mask])):
            print(f"Warning: NaN values found in i4_bck for structure {idx}.")


        # Check colour requirements
        mean_i2,mean_i3,mean_i4,ratio_2_3,ratio_2_4 = check_color_requirements(mask, i2.data-i2_bck, i3.data-i3_bck, i4.data-i4_bck, i2_filepath)
        SNR=check_significant_detection(mask, i2.data-i2_bck)
        egoStat=PPStatistic(d[idx],metadata)

        if (SNR >= SNRLim) & (ratio_2_3>=r23Lim) & (ratio_2_4>=r24Lim):
            # It's an EGO, log it.
            egoF.write(f'{idx} {SNR} {ratio_2_3} {ratio_2_4} {egoStat.x_cen} {egoStat.y_cen} {egoStat.area_exact.value}'+'\n')
        else:
            rejectF.write(f'{idx} {SNR} {ratio_2_3} {ratio_2_4} {egoStat.x_cen} {egoStat.y_cen} {egoStat.area_exact.value}'+'\n')


# Close EGO and reject log files

    egoF.close()
    rejectF.close()
    return

#Cell 4;
metadata = {}
metadata['spatial_scale'] = 1.2 \* u.arcsec
metadata['data_value'] = u.Jy

        # List of image types

image_types = ["I2", "I3", "I4"] # Get a list of ratioed image file paths
r21_filepaths = glob.glob(pwd+'Images/Ratio/\*\_R2_1.fits') # Define the base output directory
base_output_dir = pwd+'Images'

for r21_filepath in r21_filepaths:
print(f'--------------------------')
print(f'{r21_filepath = }')
for image_type in image_types:

        if(True):
            # Construct the corresponding image file path
            dendro_dir = f'{base_output_dir}/Dendrograms'
            dendro_img = f'{dendro_dir}/Images'
            os.makedirs(dendro_dir, exist_ok=True)
            os.makedirs(dendro_img, exist_ok=True)
            base_name = os.path.basename(r21_filepath).rsplit('_R2_1.fits', 1)[0]
            i45_filepath = os.path.join(pwd+'Images', base_name + '_I2.fits')
            i58_filepath = os.path.join(pwd+'Images', base_name + '_I3.fits')
            i80_filepath = os.path.join(pwd+'Images', base_name + '_I4.fits')
            image_filepath = os.path.join(pwd+'Images', base_name + f'_{image_type}.fits')


# Save the wcs so that we can get the position from the dendrogram in astro coordinates

            wcs = WCS(fits.open(i45_filepath)[0].header)
            metadata['wcs'] = wcs
            process_image(i45_filepath, i58_filepath, i80_filepath, r21_filepath,dendro_dir,metadata)