#include <iostream>
#include <vector>
#include <chrono>
#include <fstream>
#include <string>

#include <omp.h>

// Usage: ./heat1DOpenMP <numberOfPoints> <finalTime> <numberOfThreads>
int main(int argc, char* argv[]) {
    // input parameters
    const int numberOfPoints = static_cast<int>(atoi(argv[1]));
    const double leftBoundary = 0.0;
    const double rightBoundary = 1.0;
    const double CFL = 0.25;
    const double heatDiffusionCoefficient = 0.01;
    const double finalTime = static_cast<double>(atof(argv[2]));
    omp_set_num_threads(atoi(argv[3]));
    
    // computed input paramters
    const double dx = 1.0 / (numberOfPoints - 1);
    const double dt = CFL * dx * dx / heatDiffusionCoefficient;
    const int numberOfTimeSteps = static_cast<int>(finalTime / dt);

    // allocate memory for field arrays
    std::vector<double> T0(numberOfPoints); // T at time n
    std::vector<double> T1(numberOfPoints); // T at time n+1
    std::vector<double> x(numberOfPoints);

    // create mesh
    for (int i = 0; i < numberOfPoints; ++i) {
        x[i] = dx * i;
    }

    // initialise field arrays
    for (int i = 0; i < numberOfPoints; ++i) {
        T0[i] = 0.0;
        T1[i] = 0.0;
    }

    // set boundary conditions (important, set it for T1, not T0)
    T1[0] = leftBoundary;
    T1[numberOfPoints - 1] = rightBoundary;

    auto startTime = std::chrono::high_resolution_clock::now();
    // loop over all timesteps
    #pragma omp parallel shared(T0, T1, heatDiffusionCoefficient, dt, dx, numberOfPoints, numberOfTimeSteps)
    {
        #pragma omp single
        {
            std::cout << "Running with " << omp_get_num_threads() << " threads\n";
        }
        for (int t = 0; t < numberOfTimeSteps; ++t) {
            #pragma omp for
            for (int i = 0; i < numberOfPoints; ++i) {
                T0[i] = T1[i];
            }

            // compute solution at next time level T^n+1 using an explicit update in time and e central difference in space
            #pragma omp for
            for (int i = 1; i < numberOfPoints - 1; ++i) {
                T1[i] = T0[i] + heatDiffusionCoefficient * dt / (dx * dx) * (T0[i + 1] - 2 * T0[i] + T0[i - 1]);
            }
        }
    }

    auto endTime = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
    std::cout << "Executation time: " << static_cast<double>(duration.count()) / 1000 << " seconds" << std::endl;
    
    // compute error
    double error = 0.0;
    for (int i = 0; i < numberOfPoints; ++i) {
        error += std::abs(T1[i] - x[i]);
    }
    error /= numberOfPoints;
    std::cout << "Error: " << error << std::endl;

    // // output results
    // auto finalTimeString = std::to_string(finalTime);
    // finalTimeString = finalTimeString.substr(0, finalTimeString.find("."));
    // std::string fileName = "results_OpenMP_" + finalTimeString + ".csv";
    // std::ofstream file(fileName);
    // file << "x,T" << std::endl;
    // for (int i = 0; i < numberOfPoints; ++i) {
    //     file << x[i] << ", " << T1[i] << std::endl;
    // }
    // file.close();

    return 0;
}