#include <iostream>
#include <vector>
#include <chrono>
#include <fstream>
#include <string>

#include <mpi.h>

int main(int argc, char* argv[]) {
    // Initialise the MPI environment
    MPI_Init(&argc, &argv);
    
    // Get the number of processes
    int worldSize;
    MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
    
    // Get the current processor ID (rank) of the process
    int rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);

    // input parameters
    const int numberOfPoints = static_cast<int>(atoi(argv[1]));
    const double leftBoundary = 0.0;
    const double rightBoundary = 1.0;
    const double CFL = 0.25;
    const double heatDiffusionCoefficient = 0.01;
    const double finalTime = static_cast<double>(atof(argv[2]));
    
    // computed input paramters
    const double dx = 1.0 / (numberOfPoints - 1);
    const double dt = CFL * dx * dx / heatDiffusionCoefficient;
    const int numberOfTimeSteps = static_cast<int>(finalTime / dt);

    // Domain decomposition
    int pointsPerProcess = numberOfPoints / worldSize;

    if (rank == 0) {
        std::cout << "Number of processes: " << worldSize << std::endl;
        std::cout << "Number of points per process: " << pointsPerProcess << std::endl;
    }
    MPI_Barrier(MPI_COMM_WORLD);

    // check that we have a valid decomposition. If we don't, then add additional points to the last processor
    int pointsNotUsed = 0;
    if (rank == worldSize - 1) {
        const int totalPoints = worldSize * pointsPerProcess;
        pointsNotUsed = numberOfPoints - totalPoints;
        pointsPerProcess += pointsNotUsed;
        if (pointsNotUsed > 0) {
            std::cout << "Adding " << pointsNotUsed << " point(s) to last process for load balancing" << std::endl;
        }
    }

    // allocate memory for field arrays
    std::vector<double> T0(pointsPerProcess); // T at time n
    std::vector<double> T1(pointsPerProcess); // T at time n+1
    std::vector<double> x(pointsPerProcess);

    // create mesh
    for (int i = 0; i < pointsPerProcess; ++i) {
        x[i] = dx * (i + rank * (pointsPerProcess - pointsNotUsed));
    }

    // initialise field arrays
    for (int i = 0; i < pointsPerProcess; ++i) {
        T0[i] = 0.0;
        T1[i] = 0.0;
    }

    // set boundary conditions (important, set it for T1, not T0)
    // only set it for left-most processor (rank == 0) and right-most processor (rank == worldSize - 1)
    if (rank == 0) {
        T1[0] = leftBoundary;
    }
    if (rank == worldSize - 1) {
        T1[pointsPerProcess - 1] = rightBoundary;
    }

    // helper variables for MPI
    std::vector<MPI_Request> requests;
    double leftBoundaryValue = 0.0;
    double rightBoundaryValue = 0.0;
    
    double startTime = MPI_Wtime();
    // loop over all timesteps
    for (int t = 0; t < numberOfTimeSteps; ++t) {
        T0 = T1;

        // send data to the right processor (except for the last processor, as there is no processor to the right)
        if (rank < worldSize - 1) {
            requests.push_back(MPI_Request{});
            MPI_Isend(&T0[pointsPerProcess - 1], 1, MPI_DOUBLE, rank + 1, 100 + rank, MPI_COMM_WORLD, &requests.back());
        }
        
        // send data to the left processor (except for the first processor, as there is no processor to the left)
        if (rank > 0) {
            requests.push_back(MPI_Request{});
            MPI_Isend(&T0[0], 1, MPI_DOUBLE, rank - 1, 200 + rank, MPI_COMM_WORLD, &requests.back());
        }
        
        // compute solution at next time level T^n+1 using an explicit update in time and e central difference in space
        for (int i = 1; i < pointsPerProcess - 1; ++i) {
            T1[i] = T0[i] + heatDiffusionCoefficient * dt / (dx * dx) * (T0[i + 1] - 2 * T0[i] + T0[i - 1]);
        }
        
        // before we try to receive data, we should make sure that the send requests have completed
        MPI_Waitall(static_cast<int>(requests.size()), requests.data(), MPI_STATUSES_IGNORE);
        
        // receive data from the left processor (except for the first processor, as there is no processor to the left)
        if (rank > 0) {
            MPI_Recv(&leftBoundaryValue, 1, MPI_DOUBLE, rank - 1, 100 + rank - 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        }
        
        // receive data from the right processor (except for the last processor, as there is no processor to the right)
        if (rank < worldSize - 1) {
            MPI_Recv(&rightBoundaryValue, 1, MPI_DOUBLE, rank + 1, 200 + rank + 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        }
        
        // now use the received values to compute the values at the inter-processor boundaries
        if (rank > 0) {
            int i = 0;
            T1[i] = T0[i] + heatDiffusionCoefficient * dt / (dx * dx) * (T0[i + 1] - 2 * T0[i] + leftBoundaryValue);
        }
        
        if (rank < worldSize - 1) {
            int i = pointsPerProcess - 1;
            T1[i] = T0[i] + heatDiffusionCoefficient * dt / (dx * dx) * (rightBoundaryValue - 2 * T0[i] + T0[i - 1]);
        }
    }
    double endTime = MPI_Wtime();
    double duration = endTime - startTime;
    if (rank == 0) {
        std::cout << "Executation time: " << duration << " seconds" << std::endl;
    }

    // compute error
    double error = 0.0;
    for (int i = 0; i < pointsPerProcess; ++i) {
        error += std::abs(T1[i] - x[i]);
    }
    error /= numberOfPoints;

    // reduce error over all processes into processor 0 (the root processor)
    int root = 0;
    double globalError = 0.0;
    MPI_Reduce(&error, &globalError, 1, MPI_DOUBLE, MPI_SUM, root, MPI_COMM_WORLD);

    // print error on processor 0 (the only processor that knows the global error)
    if (rank == root) {
        std::cout << "Error: " << globalError << std::endl;
    }

    // // reconstruct the global temperature and coordinate array so that it can be written to disk by just one processor
    // std::vector<double> globalTemperature(numberOfPoints);
    // std::vector<double> globalX(numberOfPoints);

    // // store the number of points that are send by each processor in an array
    // std::vector<int> pointsPerProcessArray(worldSize);
    // MPI_Gather(&pointsPerProcess, 1, MPI_INT, &pointsPerProcessArray[rank], 1, MPI_INT, root, MPI_COMM_WORLD);

    // // store the offset/displacement, which states the starting point into the global array for each processor
    // // for example, with 3 processors, where processor 0 send 3 elements, processor 1 sends 5 elements, and processor 2
    // // sends 1 element, the offset/displacement array is [0, 3, 8]. Processor 0 will write at location 0 in the global
    // // array, processor 1 will write at location 3 in the global array, and processor 2 will write at location 8 in the
    // // global array
    // std::vector<int> offsets(worldSize);
    // if (rank == root) {
    //     offsets[0] = 0;
    //     for (int i = 1; i < worldSize; ++i) {
    //         offsets[i] = offsets[i - 1] + pointsPerProcessArray[i - 1];
    //     }
    // }

    // // start gathering all data into the root processor
    // MPI_Gatherv(&T1[0], pointsPerProcess, MPI_DOUBLE, &globalTemperature[0], &pointsPerProcessArray[0], &offsets[0], MPI_DOUBLE, root, MPI_COMM_WORLD);
    // MPI_Gatherv(&x[0], pointsPerProcess, MPI_DOUBLE, &globalX[0], &pointsPerProcessArray[0], &offsets[0], MPI_DOUBLE, root, MPI_COMM_WORLD);

    // // output results on the master processor
    // if (rank == root) {
    //     auto finalTimeString = std::to_string(finalTime);
    //     finalTimeString = finalTimeString.substr(0, finalTimeString.find("."));
    //     std::string fileName = "results_MPI_" + finalTimeString + ".csv";
    //     std::ofstream file(fileName);
    //     file << "x,T" << std::endl;
    //     for (int i = 0; i < numberOfPoints; ++i) {
    //         file << globalX[i] << ", " << globalTemperature[i] << std::endl;
    //     }
    //     file.close();
    // }

    // Finalise the MPI environment.
    MPI_Finalize();

    return 0;
}