# -----------------------------------------------------------------------------
#  Copyright (C) 2018 University of Dundee. All rights reserved.
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# ------------------------------------------------------------------------------

# This Jython script uses TrackMate
# Use this script in the Scripting Dialog of Fiji (File > New > Script).
# Select Python as language in the Scripting Dialog.
# Error handling is omitted to ease the reading of the script but
# this should be added
# if used in production to make sure the services are closed
# Information can be found at
# https://docs.openmicroscopy.org/latest/omero/developers/Java.html

from java.awt import Color
from java.util import ArrayList
from java.util import Collections
from java.lang import Double
from java.lang import Long
from java.lang import Object


# OMERO Dependencies
from omero.gateway import Gateway
from omero.gateway import LoginCredentials
from omero.gateway import SecurityContext
from omero.gateway.facility import BrowseFacility
from omero.gateway.facility import DataManagerFacility
from omero.gateway.facility import ROIFacility
from omero.gateway.facility import TablesFacility
from omero.gateway.model import EllipseData
from omero.gateway.model import FileAnnotationData
from omero.gateway.model import ImageData
from omero.gateway.model import PolylineData
from omero.gateway.model import ProjectData
from omero.gateway.model import ROIData
from omero.gateway.model import TableData
from omero.gateway.model import TableDataColumn
from omero.log import SimpleLogger
from omero.model import PolylineI
from omero.model import ProjectI
from omero.rtypes import rstring
from omero.rtypes import rlong
from omero.model import ChecksumAlgorithmI
from omero.model import FileAnnotationI
from omero.model import OriginalFileI
from omero.model.enums import ChecksumAlgorithmSHA1160


from ij import IJ
import fiji.plugin.trackmate.Spot as Spot
import fiji.plugin.trackmate.Settings as Settings
import fiji.plugin.trackmate.Model as Model
import fiji.plugin.trackmate.SelectionModel as SelectionModel
import fiji.plugin.trackmate.TrackMate as TrackMate
import fiji.plugin.trackmate.detection.DetectorKeys as DetectorKeys
import fiji.plugin.trackmate.detection.DogDetectorFactory as DogDetectorFactory
import fiji.plugin.trackmate.tracking.sparselap.SparseLAPTrackerFactory as SparseLAPTrackerFactory
import fiji.plugin.trackmate.tracking.LAPUtils as LAPUtils
import fiji.plugin.trackmate.visualization.hyperstack.HyperStackDisplayer as HyperStackDisplayer
import fiji.plugin.trackmate.features.spot.SpotContrastAndSNRAnalyzerFactory as SpotContrastAndSNRAnalyzerFactory
import fiji.plugin.trackmate.features.spot.SpotIntensityAnalyzerFactory as SpotIntensityAnalyzerFactory
import fiji.plugin.trackmate.features.track.TrackSpeedStatisticsAnalyzer as TrackSpeedStatisticsAnalyzer

# Import required to save the results as CSV
import csv
import os
import tempfile

# Setup
# =====

# OMERO Server details
HOST = "outreach.openmicroscopy.org"
PORT = 4064
#  parameters to edit
project_id = 1102
USERNAME = "username"
PASSWORD = "password"


def connect_to_omero():
    "Connect to OMERO. Returns a reference to the gateway"

    credentials = LoginCredentials()
    credentials.getServer().setHostname(HOST)
    credentials.getServer().setPort(PORT)
    credentials.getUser().setUsername(USERNAME.strip())
    credentials.getUser().setPassword(PASSWORD.strip())
    simpleLogger = SimpleLogger()
    gateway = Gateway(simpleLogger)

    user = gateway.connect(credentials)
    print user.getGroupId()
    return gateway


def get_image_ids(gateway, ctx, project_id):
    "List all image's ids contained in a Project"

    browse = gateway.getFacility(BrowseFacility)

    ids = ArrayList(1)
    ids.add(Long(project_id))
    return browse.getImagesForProjects(ctx, ids)


def open_image_plus(HOST, USERNAME, PASSWORD, PORT, group_id, image_id):
    "Open the image using the Bio-Formats Importer"

    options = """location=[OMERO]
open=[omero:server=%s
user=%s
port=%s
pass=%s
groupID=%s
iid=%s]
 windowless=true view=Hyperstack """ % (HOST, USERNAME, PORT, PASSWORD, group_id, image_id)
    IJ.runPlugIn("loci.plugins.LociImporter", options)


def create_omero_results_table_columns(units):
    columns = [TableDataColumn] * 7
    columns[0] = TableDataColumn("Image", 0, ImageData)
    columns[1] = TableDataColumn("ROI ID", 1, Double)
    columns[2] = TableDataColumn("Shape ID", 2, Double)
    columns[3] = TableDataColumn("Track Mean Velocity (" + units + ")", 3, Double)
    columns[4] = TableDataColumn("Spot Quality", 4, Double)
    columns[5] = TableDataColumn("Spot SNR", 5, Double)
    columns[6] = TableDataColumn("Spot Mean Intensity", 6, Double)
    return columns


def create_omero_summary_table_columns(units):
    columns = [TableDataColumn] * 5
    columns[0] = TableDataColumn("Image", 0, ImageData)
    columns[1] = TableDataColumn("Average Track Mean Velocity (" + units + ")", 3, Double)
    columns[2] = TableDataColumn("Average Spot Quality", 4, Double)
    columns[3] = TableDataColumn("Average Spot SNR", 5, Double)
    columns[4] = TableDataColumn("Average Spot Mean Intensity", 6, Double)
    return columns


def get_spot(track_id, spot_id):
    track = model.getTrackModel().trackSpots(track_id)
    for spot in track:
        if spot.ID() == spot_id:
            return spot

def upload_csv_to_omero(ctx, file, project_id):
    "Upload the CSV file and attach it to the specified object"
    svc = gateway.getFacility(DataManagerFacility)
    file_size = os.path.getsize(file.name)
    original_file = OriginalFileI()
    name = os.path.split(file.name)[1]
    original_file.setName(rstring("trackmate_summary_results.csv"))
    original_file.setPath(rstring(file.name))
    original_file.setSize(rlong(file_size))

    checksum_algorithm = ChecksumAlgorithmI()
    checksum_algorithm.setValue(rstring(ChecksumAlgorithmSHA1160.value))
    original_file.setHasher(checksum_algorithm)
    original_file.setMimetype(rstring("text/csv"))
    original_file = svc.saveAndReturnObject(ctx, original_file)
    store = gateway.getRawFileService(ctx)

    # Open file and read stream
    store.setFileId(original_file.getId().getValue())
    print original_file.getId().getValue()
    try:
        store.setFileId(original_file.getId().getValue())
        with open(file.name, 'rb') as stream:
            buf = 10000
            for pos in range(0, long(file_size), buf):
                block = None
                if file_size-pos < buf:
                    block_size = file_size-pos
                else:
                    block_size = buf
                stream.seek(pos)
                block = stream.read(block_size)
                store.write(block, pos, block_size)

        original_file = store.save()
    finally:
        store.close()

    # create the file annotation
    namespace = "training.trackmate.demo"
    fa = FileAnnotationI()
    fa.setFile(original_file)
    fa.setNs(rstring(namespace))

    data_object = ProjectData(ProjectI(project_id, False))
    svc.attachAnnotation(ctx, FileAnnotationData(fa), data_object)


def save_summary_as_csv(file, rows, columns):
    "Save the summary locally as a CSV"
    with open(file.name, 'wb') as master:
        master_csv = csv.writer(master)
        headers = []
        for i in range(0, len(columns)):
            headers.append(columns[i].getName())
        master_csv.writerow(headers)

        for i in range(0, len(rows)):
            row = rows[i]
            line = []
            size = len(row)
            for j in range(0, size):
                value = row[j]
                if j == 0:
                    line.append(value.getId())
                else:
                    line.append(value)
            master_csv.writerow(line)


gateway = connect_to_omero()
exp = gateway.getLoggedInUser()
group_id = exp.getGroupId()
ctx = SecurityContext(group_id)

exp_id = exp.getId()

# get all images_ids in an omero project
images = get_image_ids(gateway, ctx, project_id)
print images

table_rows = []
summary_table_rows = []
space_units = ''
time_units = ''
img = images.iterator()
while img.hasNext():
    # Open the image
    image = img.next()
    id = image.getId()
    size_t = image.getDefaultPixels().getSizeT()
    print id
    open_image_plus(HOST, USERNAME, PASSWORD, PORT, group_id, id)
    imp = IJ.getImage()
    dx = imp.getCalibration().pixelWidth
    dy = imp.getCalibration().pixelHeight

    # Instantiate model object
    model = Model()

    # Prepare settings object
    settings = Settings()
    settings.setFrom(imp)
    # Configure detector
    settings.detectorFactory = DogDetectorFactory()
    settings.detectorSettings = {
        DetectorKeys.KEY_DO_SUBPIXEL_LOCALIZATION: True,
        DetectorKeys.KEY_RADIUS: 2.5,
        DetectorKeys.KEY_TARGET_CHANNEL: 1,
        DetectorKeys.KEY_THRESHOLD: 5.,
        DetectorKeys.KEY_DO_MEDIAN_FILTERING: False,
    }

    # Configure tracker
    settings.trackerFactory = SparseLAPTrackerFactory()
    settings.trackerSettings = LAPUtils.getDefaultLAPSettingsMap()
    settings.trackerSettings['LINKING_MAX_DISTANCE'] = 10.0
    settings.trackerSettings['GAP_CLOSING_MAX_DISTANCE'] = 10.0
    settings.trackerSettings['MAX_FRAME_GAP'] = 3

    # Add the analyzers for some spot features
    settings.addSpotAnalyzerFactory(SpotIntensityAnalyzerFactory())
    settings.addSpotAnalyzerFactory(SpotContrastAndSNRAnalyzerFactory())

    # Add an analyzer for some track features, such as the track mean speed.
    settings.addTrackAnalyzer(TrackSpeedStatisticsAnalyzer())
    settings.initialSpotFilterValue = 1
    print(str(settings))

    # Instantiate trackmate
    trackmate = TrackMate(model, settings)
    ok = trackmate.checkInput()
    if not ok:
        print(str(trackmate.getErrorMessage()))
        continue

    ok = trackmate.process()
    if not ok:
        print(str(trackmate.getErrorMessage()))
        continue

    # Display Results
    selectionModel = SelectionModel(model)
    displayer = HyperStackDisplayer(model, selectionModel, imp)
    displayer.render()
    displayer.refresh()
    # The feature model, that stores edge and track features.
    fm = model.getFeatureModel()
    space_units = model.getSpaceUnits()
    time_units = model.getTimeUnits()
    # Save the roi spot to omero
    rois = []
    total_track_mean = 0
    total_spot_quality = 0
    total_spot_snr = 0
    total_spot_mean = 0
    track_count = 0
    spot_count = 0
    for track_id in model.getTrackModel().trackIDs(True):
        track_count += 1

        # Fetch the track feature from the feature model.
        v = fm.getTrackFeature(track_id, 'TRACK_MEAN_SPEED')
        total_track_mean += v
        track = model.getTrackModel().trackSpots(track_id)
        sorted_track = list(track)
        Collections.sort(sorted_track, Spot.frameComparator)
        roi = ROIData()
        rois.append(roi)
        points = ""
        min_value = size_t
        max_value = 0
        for spot in sorted_track:
            spot_count += 1
            sid = spot.ID()
            # Fetch spot features directly from spot.
            x = spot.getFeature('POSITION_X')/dx
            y = spot.getFeature('POSITION_Y')/dy
            r = spot.getFeature('RADIUS')
            z = spot.getFeature('POSITION_Z')
            t = spot.getFeature('POSITION_T')
            # Save spot as Point in OMERO
            ellipse = EllipseData(x, y, r, r)
            ellipse.setZ(int(z))
            ellipse.setT(int(t))
            # set trackmate track ID and spot ID for later
            ellipse.setText(str(track_id)+':'+str(sid))
            # set a default color
            settings = ellipse.getShapeSettings()
            settings.setStroke(Color.RED)
            roi.addShapeData(ellipse)
            points = points + str(x) + ',' + str(y) + ' '
            total_spot_quality += spot.getFeature('QUALITY')
            total_spot_snr += spot.getFeature('SNR')
            total_spot_mean += spot.getFeature('MEAN_INTENSITY')
            max_value = max(max_value, int(t))
            min_value = min(min_value, int(t))

        # Save the track
        points = points.strip()
        # for i in range(int(min_value), int(max_value)):
        polyline = PolylineI()
        polyline.setPoints(rstring(points))
        pl = PolylineData(polyline)
        # pl.setT(i)
        # set a default color
        settings = pl.getShapeSettings()
        settings.setStroke(Color.YELLOW)
        roi.addShapeData(pl)

    # Save the tracks back to OMERO
    roi_facility = gateway.getFacility(ROIFacility)
    results = roi_facility.saveROIs(ctx, id, exp_id, rois)
    # Create a table row for each image
    j = results.iterator()
    while j.hasNext():
        roi = j.next()
        shapes = roi.getIterator()
        while shapes.hasNext():
            values = shapes.next()
            k = values.iterator()
            while k.hasNext():
                shape = k.next()
                if shape.getClass().getName() == "omero.gateway.model.EllipseData":
                    shape_id = shape.getId()
                    ids = shape.getText().split(":")
                    track_id = int(ids[0])
                    spot_id = int(ids[1])
                    v = fm.getTrackFeature(track_id, 'TRACK_MEAN_SPEED')
                    spot = get_spot(track_id, spot_id)
                    # create a row
                    row = []
                    row.append(image)
                    row.append(Double(roi.getId()))
                    row.append(Double(shape_id))
                    row.append(Double(v))
                    row.append(Double(spot.getFeature('QUALITY')))
                    row.append(Double(spot.getFeature('SNR')))
                    row.append(Double(spot.getFeature('MEAN_INTENSITY')))
                    table_rows.append(row)
    # Close
    imp.changes = False
    imp.close()
    
    # Create the summary table
    row = []
    row.append(image)
    row.append(Double(total_track_mean/track_count))
    row.append(Double(total_spot_quality/spot_count))
    row.append(Double(total_spot_snr/spot_count))
    row.append(Double(total_spot_mean/spot_count))
    summary_table_rows.append(row)

# Adjust the name of the Velocity now that the unit is set
# columns for the omero table
columns = create_omero_results_table_columns(space_units + ' per ' + time_units)
average_columns = create_omero_summary_table_columns(space_units + ' per ' + time_units)

# Now create the results table
data = [[Object() for tr in range(len(table_rows))] for c in range(len(columns))]
for r in range(0, len(table_rows)):
    row = table_rows[r]
    for i in range(0, len(row)):
        data[i][r] = row[i]

table_facility = gateway.getFacility(TablesFacility)

table_data = TableData(columns, data)
data_object = ProjectData(ProjectI(project_id, False))
data = table_facility.addTable(ctx, data_object, "Results_from_TrackMate", table_data)

# Now create the summary table
summary_data = [[Object() for tr in range(len(summary_table_rows))] for c in range(len(average_columns))]
for r in range(0, len(summary_table_rows)):
    row = summary_table_rows[r]
    for i in range(0, len(row)):
        summary_data[i][r] = row[i]

summary_table = TableData(average_columns, summary_data)
summary_data = table_facility.addTable(ctx, data_object, "Summary_from_TrackMate", summary_table)

oid = summary_data.getOriginalFileId()
# Retrieve the annotation and set the namespace (due to a limitation of JavaGateway)
annotations = table_facility.getAvailableTables(ctx, data_object)
it = annotations.iterator()
while it.hasNext():
    ann = it.next()
    if ann.getFileID() == oid:
        ann.setNameSpace(FileAnnotationData.BULK_ANNOTATIONS_NS)
        gateway.getUpdateService(ctx).saveAndReturnObject(ann.asIObject())
        break

# Create the Summary CSV
tmp_dir = tempfile.gettempdir()
file = tempfile.TemporaryFile(mode='wb', prefix='trackmate_summary_results_', suffix='.csv', dir=tmp_dir)

save_summary_as_csv(file, summary_table_rows, average_columns)
upload_csv_to_omero(ctx, file, project_id)
# delete local copy of the file
os.remove(file.name)

print "done"