#include "FEDataStructures.h"

#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

void InitializeGrid(Grid* grid, const unsigned int numPoints[3], const double spacing[3]) {
  unsigned int dim;
  grid->NumberOfPoints = 0;
  grid->Points = 0;
  grid->NumberOfCells = 0;
  for (dim = 0; dim < 3; dim++) {
    grid->numPoints[dim] = numPoints[dim];
    grid->Spacing[dim] = spacing[dim];
  }
  if (numPoints[0] == 0 || numPoints[1] == 0 || numPoints[2] == 0)
    printf("Must have a non-zero amount of points in each direction.\n");
  int mpiSize = 1;
  int mpiRank = 0;
  MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank);
  MPI_Comm_size(MPI_COMM_WORLD, &mpiSize);

  unsigned int startXPoint = mpiRank * numPoints[0] / mpiSize;
  unsigned int endXPoint = (mpiRank + 1) * numPoints[0] / mpiSize;
  if (mpiSize != mpiRank + 1)
    endXPoint++;

  grid->Extent[0] = startXPoint; // point indices in x
  grid->Extent[1] = endXPoint;
  grid->Extent[2] = 0;           // point indices in y
  grid->Extent[3] = numPoints[1];
  grid->Extent[4] = 0;           // point indices in z
  grid->Extent[5] = numPoints[2];

  // create the points -- slowest in the x and fastest in the z directions
  if (grid->Points != 0)
    free(grid->Points);
  unsigned int numXPoints = endXPoint - startXPoint;
  printf("%d %d %d %d\n", mpiRank, numXPoints - 1, numPoints[1] - 1, numPoints[2] - 1);
  grid->NumberOfPoints = numXPoints * numPoints[1] * numPoints[2];
  grid->NumberOfCells = (numXPoints - 1) * (numPoints[1] - 1) * (numPoints[2] - 1);
  grid->Points = (double*)malloc(3 * sizeof(double) * grid->NumberOfPoints);
  unsigned int counter = 0;
  unsigned int i, j, k;
  for (i = startXPoint; i < endXPoint; i++) {
    for (j = 0; j < numPoints[1]; j++) {
      for (k = 0; k < numPoints[2]; k++) {
        grid->Points[counter]     = (double)i/((double)numPoints[0]-1);
        grid->Points[counter + 1] = (double)j/((double)numPoints[1]-1);
	grid->Points[counter + 2] = (double)k/((double)numPoints[2]-1);
        counter += 3;
      }
    }
  }
}

void FinalizeGrid(Grid* grid) {
  if (grid->Points) {
    free(grid->Points);
    grid->Points = 0;
  }
  grid->NumberOfPoints = 0;
  grid->NumberOfCells = 0;
}

void InitializeAttributes(Attributes* attributes, Grid* grid) {
  attributes->GridPtr = grid;
  attributes->Density = 0;
}

void UpdateFields(Attributes* attributes, double time) {
  unsigned int dim;
  if (attributes->Density != 0)
    free(attributes->Density);
  attributes->Density = (float*)malloc(sizeof(float) * attributes->GridPtr->NumberOfCells);
  unsigned int i, j, k;
  double x[3];
  unsigned int counter = 0;
  for (k = attributes->GridPtr->Extent[4]; k < attributes->GridPtr->Extent[5]-1; k++) {
    for (j = attributes->GridPtr->Extent[2]; j < attributes->GridPtr->Extent[3]-1; j++) {
      for (i = attributes->GridPtr->Extent[0]; i < attributes->GridPtr->Extent[1]-1; i++) {
	x[0] = ((double)i+0.5)/((double)attributes->GridPtr->numPoints[0]-1);
	x[1] = ((double)j+0.5)/((double)attributes->GridPtr->numPoints[1]-1);
	x[2] = ((double)k+0.5)/((double)attributes->GridPtr->numPoints[2]-1);
	x[0] -= (float)time/10.;  // animation along the x-axis
	if (x[0] < 0.0)
	  x[0] += 1.0;   // periodic BCs
	for (dim = 0; dim < 3; dim++)
	  x[dim] = 15*(x[dim]-0.5);
	double t1 = x[1]*x[1] + x[0]*x[0];
	double t2 = x[2]*x[2] + x[1]*x[1];
	attributes->Density[counter] = (pow(sin(sqrt(t1)),2)-0.5)/pow((0.001*(t1)+1.),2) +
	  (pow(sin(sqrt(t2)),2)-0.5)/pow((0.001*(t2)+1.),2) + 1;
	counter++;
      }
    }
  }
  //printf(">>> %d\n", numPoints[0]);
}
