#include <iostream>
#include <vector>
#include <cuda_runtime.h>

__global__ void updateTemperature(double *d_T0, double *d_T1, int numberOfPoints) {
    // give each visiting thread a unique ID
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    
    // update the temperature array as long as the thread ID is within the size of the array itself 
    if (i < numberOfPoints) {
        d_T0[i] = d_T1[i];
    }
}

__global__ void computeTemperatureAtNextTimeLevel(double *d_T0, double *d_T1, double CFL, int numberOfPoints) {
    
    // give each visiting thread a unique ID
    int i = blockIdx.x * blockDim.x + threadIdx.x;
      
    // compute solution at next time level T^n+1 using an explicit update in time and e central difference in space
    if (i > 0 && i < numberOfPoints - 1) {
        d_T1[i] = d_T0[i] + CFL * (d_T0[i + 1] - 2 * d_T0[i] + d_T0[i - 1]);
    }
}

void computeTemperature(std::vector<double> &T0, std::vector<double> &T1, double heatDiffusionCoefficient, double dt,
    double dx, int numberOfPoints, int numberOfTimeSteps) {
    
    // calculate the number of threads and block required to launch a CUDA kernel
    int threadsPerBlock = 256;
    int numberOfBlocks = static_cast<int>(numberOfPoints / threadsPerBlock) + 1;
    
    // compute the CFL number, to pass less data to the kernel
    double CFL = heatDiffusionCoefficient * dt / (dx * dx);

    // create pointers that will be allocated on the GPU and mirror the data on the CPU
    double *d_T0 = nullptr;
    double *d_T1 = nullptr;

    // allocate memory on the GPU for the previously generated pointers
    cudaMalloc(&d_T0, numberOfPoints * sizeof(double));
    cudaMalloc(&d_T1, numberOfPoints * sizeof(double));

    // copy data from the CPU to the GPU
    cudaMemcpy(d_T0, T0.data(), numberOfPoints * sizeof(double), cudaMemcpyHostToDevice);
    cudaMemcpy(d_T1, T1.data(), numberOfPoints * sizeof(double), cudaMemcpyHostToDevice);

    // loop over time
    for (int t = 0; t < numberOfTimeSteps; ++t) {
        // update temperature array
        updateTemperature<<<numberOfBlocks, threadsPerBlock>>>(d_T0, d_T1, numberOfPoints);

        // compute the solution at T^n+1
        computeTemperatureAtNextTimeLevel<<<numberOfBlocks, threadsPerBlock>>>
            (d_T0, d_T1, CFL, numberOfPoints);
    }

    // copy temperature data back into temperature array on the host (CPU memory)
    cudaMemcpy(T0.data(), d_T0, numberOfPoints * sizeof(double), cudaMemcpyDeviceToHost);
    cudaMemcpy(T1.data(), d_T1, numberOfPoints * sizeof(double), cudaMemcpyDeviceToHost);

    // be a good citizen and free up memory after you are done
    cudaFree(d_T0);
    cudaFree(d_T1);
}