#include "FEDataStructures.h"

#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

void InitializeGrid(Grid* grid, const unsigned int numPoints[3]) {
  unsigned int dim;
  grid->NumberOfPoints = 0;
  grid->Points = 0;
  grid->NumberOfCells = 0;
  grid->Cells = 0;
  for (dim = 0; dim < 3; dim++)
    grid->numPoints[dim] = numPoints[dim];
  if (numPoints[0] == 0 || numPoints[1] == 0 || numPoints[2] == 0)
    printf("Must have a non-zero amount of points in each direction.\n");
  // in parallel, we do a simple partitioning in the x-direction.
  int mpiSize = 1;
  int mpiRank = 0;
  MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank);
  MPI_Comm_size(MPI_COMM_WORLD, &mpiSize);

  unsigned int startXPoint = mpiRank * numPoints[0] / mpiSize;
  unsigned int endXPoint = (mpiRank + 1) * numPoints[0] / mpiSize;
  if (mpiSize != mpiRank + 1)
    endXPoint++;

  // create the points -- slowest in the x and fastest in the z directions
  if (grid->Points != 0)
    free(grid->Points);
  unsigned int numXPoints = endXPoint - startXPoint;
  grid->Points = (double*)malloc(3 * sizeof(double) * numPoints[1] * numPoints[2] * numXPoints);
  unsigned int counter = 0;
  unsigned int i, j, k;
  for (i = startXPoint; i < endXPoint; i++) {
    for (j = 0; j < numPoints[1]; j++) {
      for (k = 0; k < numPoints[2]; k++) {
        grid->Points[counter]     = (double)i/((double)numPoints[0]-1);
        grid->Points[counter + 1] = (double)j/((double)numPoints[1]-1);
	grid->Points[counter + 2] = (double)k/((double)numPoints[2]-1);
        counter += 3;
      }
    }
  }
  grid->NumberOfPoints = numPoints[1] * numPoints[2] * numXPoints;
  // create the hex cells
  if (grid->Cells != 0)
    free(grid->Cells);
  grid->Cells = (int64_t*)malloc(8 * sizeof(int64_t) * (numPoints[1]-1) * (numPoints[2]-1) * (numXPoints-1));
  counter = 0;
  for (i = 0; i < numXPoints - 1; i++) {
    for (j = 0; j < numPoints[1] - 1; j++) {
      for (k = 0; k < numPoints[2] - 1; k++) {
        grid->Cells[counter]     = i * numPoints[1] * numPoints[2] + j * numPoints[2] + k;
        grid->Cells[counter + 1] = (i+1) * numPoints[1] * numPoints[2] + j * numPoints[2] + k;
        grid->Cells[counter + 2] = (i+1) * numPoints[1] * numPoints[2] + (j+1) * numPoints[2] + k;
        grid->Cells[counter + 3] = i * numPoints[1] * numPoints[2] + (j+1) * numPoints[2] + k;
        grid->Cells[counter + 4] = i * numPoints[1] * numPoints[2] + j * numPoints[2] + k + 1;
        grid->Cells[counter + 5] = (i+1) * numPoints[1] * numPoints[2] + j * numPoints[2] + k + 1;
        grid->Cells[counter + 6] = (i+1) * numPoints[1] * numPoints[2] + (j+1) * numPoints[2] + k + 1;
        grid->Cells[counter + 7] = i * numPoints[1] * numPoints[2] + (j+1) * numPoints[2] + k + 1;
        counter += 8;
      }
    }
  }
  grid->NumberOfCells = (numPoints[1] - 1) * (numPoints[2] - 1) * (numXPoints - 1);
}

void FinalizeGrid(Grid* grid) {
  if (grid->Points) {
    free(grid->Points);
    grid->Points = 0;
  }
  if (grid->Cells) {
    free(grid->Cells);
    grid->Cells = 0;
  }
  grid->NumberOfPoints = 0;
  grid->NumberOfCells = 0;
}

void InitializeAttributes(Attributes* attributes, Grid* grid) {
  attributes->GridPtr = grid;
  attributes->Velocity = 0;
  attributes->Density = 0;
}

void UpdateFields(Attributes* attributes, double time) {
  unsigned int numPoints = attributes->GridPtr->NumberOfPoints;
  if (attributes->Velocity != 0)
    free(attributes->Velocity);
  attributes->Velocity = (double*)malloc(sizeof(double) * numPoints * 3);
  unsigned int i, point, vertex, dim;
  for (i = 0; i < numPoints; i++) {
    attributes->Velocity[i] = 0;
    attributes->Velocity[i + numPoints] = attributes->GridPtr->Points[i * 3 + 1] * time;
    attributes->Velocity[i + 2 * numPoints] = 0;
  }
  unsigned int numCells = attributes->GridPtr->NumberOfCells;
  if (attributes->Density != 0)
    free(attributes->Density);
  attributes->Density = (float*)malloc(sizeof(float) * numCells);
  for (i = 0; i < numCells; i++) {
    double x[3] = {0., 0., 0.};
    for (vertex = 0; vertex < 8; vertex++) {
      point = attributes->GridPtr->Cells[i*8+vertex];
      for (dim = 0; dim < 3; dim++) // compute x,y,z coordinates of the cell centre
	x[dim] += attributes->GridPtr->Points[3*point+dim];
    }
    for (dim = 0; dim < 3; dim++)
      x[dim] /= 8;
    x[0] -= (float)time/10.;  // animation along the x-axis
    if (x[0] < 0.0)
      x[0] += 1.0;
    for (dim = 0; dim < 3; dim++)
      x[dim] = 15*(x[dim]-0.5);
    double t1 = x[1]*x[1] + x[0]*x[0];
    double t2 = x[2]*x[2] + x[1]*x[1];
    attributes->Density[i] = (pow(sin(sqrt(t1)),2)-0.5)/pow((0.001*(t1)+1.),2) +
      (pow(sin(sqrt(t2)),2)-0.5)/pow((0.001*(t2)+1.),2) + 1;
  }
}
