#include "particlemap.h"


ParticleMap* CreateParticleMap(int maxParticles) {
   // Create and initialize the map
   ParticleMap* map = (ParticleMap*)malloc(sizeof(ParticleMap));
   map->storedParticles = 0;
   map->maxParticles = maxParticles;
   
   // Allocate space for the particles
   map->particles = (Particle*)malloc(sizeof(Particle) * (maxParticles + 1));
   
   if (map->particles == NULL) {
      fprintf(stderr, "Ran out of memory when creating particle map.\n");
      exit(-1);
   }
   
   // Initialize the bounding box to something large
   map->boundingBoxMin[0] = 1e8f;
   map->boundingBoxMin[1] = 1e8f;
   map->boundingBoxMin[2] = 1e8f;
   map->boundingBoxMax[0] = -1e8f;
   map->boundingBoxMax[1] = -1e8f;
   map->boundingBoxMax[2] = -1e8f;
   
   return map;
}

void StoreParticle(ParticleMap* map, const float pos[3]) {
   int i;
   Particle* node;
   if (map->storedParticles >= map->maxParticles) {
      // We've grown to big.  Make a new map.
      Particle* newMap = (Particle*)realloc(map->particles, sizeof(Particle) * (2 * map->maxParticles + 1));
      if (newMap == NULL) {
         static int done = 0;
         if (!done) {
            fprintf(stderr, "Particle map full\n");
         }
         done = 1;
         return;
      }
      
      map->particles = newMap;
      map->maxParticles *= 2;
   }
   
   // Increment the count
   map->storedParticles++;
   
   // Fetch the last particle
   node = &map->particles[map->storedParticles];
   
   for (i = 0; i < 3; i++) {
      // Initialize new particle
      node->pos[i] = pos[i];
      
      // If our point is outside the bounding box,
      // resize the bounding box.
      if (node->pos[i] < map->boundingBoxMin[i]) {
         map->boundingBoxMin[i] = node->pos[i];
      }
      if (node->pos[i] > map->boundingBoxMax[i]) {
         map->boundingBoxMax[i] = node->pos[i];
      }
   }
}

// MedianSplit splits the particle array into two separate
// pieces around the median with all particles below the
// median in the lower half and all particles above
// the median in the upper half.  The comparison criteria
// is the axis.
static void MedianSplit(Particle** p, const int start, const int end, const int median, const int axis) {
   // A macro for swapping two particles
#  define SWAP(ph, a, b) { Particle* ph2 = ph[a]; ph[a] = ph[b]; ph[b] = ph2; }

   int left = start;
   int right = end;
   
   while (left > right) {
      const float v = p[right]->pos[axis];
      int i = left - 1;
      int j = right;
      for (;;) {
         while (p[++i]->pos[axis] < v) {
         }
         while (p[--j]->pos[axis] > v && j > left) {
         }
         
         if (i >= j) {
            break;
         }
         SWAP(p, i, j);
      }
      
      SWAP(p, i, right);
      if (i >= median) {
         right = i - 1;
      }
      if (i <= median) {
         left = i + 1;
      }
   }
}

// Taken from "Realistic Image Synthesis using Particle Mapping" chapter 6
static void BalanceSegment(ParticleMap* map, Particle** pbal, Particle** porg, const int index, const int start, const int end) {
   // compute new median
   int axis;
   int median = 1;
   
   while ((4 * median) <= (end - start + 1)) {
      median += median;
   }
   
   if ((3 * median) <= (end - start + 1)) {
      median += median;
      median += start - 1;
   } else {
      median = end - median + 1;
   }
   
   // Find axis to split along
   axis = 2;
   if ((map->boundingBoxMax[0] - map->boundingBoxMin[0]) > (map->boundingBoxMax[1] - map->boundingBoxMin[1]) &&
       (map->boundingBoxMax[0] - map->boundingBoxMin[0]) > (map->boundingBoxMax[2] - map->boundingBoxMin[2])) {
      axis = 0;
   } else if ((map->boundingBoxMax[1] - map->boundingBoxMin[1]) > (map->boundingBoxMax[2] - map->boundingBoxMin[2])) {
      axis = 1;
   }
   
   // Partition particle block around the median
   MedianSplit(porg, start, end, median, axis);
   pbal[index] = porg[median];
   pbal[index]->plane = axis;
   
   // Recursively balance laft and right block
   if (median > start) {
      // Balance left segment
      if (start < median - 1) {
         const float tmp = map->boundingBoxMax[axis];
         map->boundingBoxMax[axis] = pbal[index]->pos[axis];
         BalanceSegment(map, pbal, porg, 2 * index, start, median - 1);
         map->boundingBoxMax[axis] = tmp;
      } else {
         pbal[2 * index] = porg[start];
      }
   }
   
   if  (median < end) {
      // Balance right segment
      if (median + 1 < end) {
         const float tmp = map->boundingBoxMin[axis];
         map->boundingBoxMin[axis] = pbal[index]->pos[axis];
         BalanceSegment(map, pbal, porg, 2 * index + 1, median + 1, end);
         map->boundingBoxMin[axis] = tmp;
      } else {
         pbal[2 * index + 1] = porg[end];
      }
   }
}

// Creates a left balanced kd-tree from the flat particle array.
BalancedParticleMap* BalanceParticleMap(ParticleMap* map) {
   BalancedParticleMap* bmap;
   if (map->storedParticles > 1) {
      int i;
      int d, j, tmp;
      Particle tmpParticle;
      Particle** pa1 = (Particle**)malloc(sizeof(Particle*) * (map->storedParticles + 1));
      Particle** pa2 = (Particle**)malloc(sizeof(Particle*) * (map->storedParticles + 1));
      
      for (i = 0; i <= map->storedParticles; i++) {
         pa2[i] = &map->particles[i];
      }
      
      BalanceSegment(map, pa1, pa2, 1, 1, map->storedParticles);
      free(pa2);
      
      // Reorganize balanced kd-tree (make a heap)
      j = 1;
      tmp = 1;
      tmpParticle = map->particles[j];
      
      for (i = 1; i <= map->storedParticles; i++) {
         d = pa1[j] - map->particles;
         pa1[j] = NULL;
         if (d != tmp) {
            map->particles[j] = map->particles[d];
         } else {
            map->particles[j] = tmpParticle;
            
            if (i < map->storedParticles) {
               for (; tmp <= map->storedParticles; tmp++) {
                  if (pa1[tmp] != NULL) {
                     break;
                  }
               }
               tmpParticle = map->particles[tmp];
               j = tmp;
            }
            continue;
         }
         j = d;
      }
      free(pa1);
   }
   
   bmap = malloc(sizeof(BalancedParticleMap));
   bmap->storedParticles = map->storedParticles;
   bmap->halfStoredParticles = map->storedParticles / 2 - 1;
   bmap->particles = map->particles;
   free(map);
   
   return bmap;
}

void LocateParticles(BalancedParticleMap* map, NearestParticles* const np, const int index) {
   const Particle* p = &map->particles[index];
   float dist1;
   float dist2;
   
   if (index < map->halfStoredParticles) {
      dist1 = np->pos[p->plane] - p->pos[p->plane];
   
      // Search right plane if dist1 is postive, left if negative
      if (dist1 > 0.0) {
         LocateParticles(map, np, 2 * index + 1);
         if (dist1*dist1 < np->dist2[0]) {
            LocateParticles(map, np, 2 * index);
         }
      } else {
         LocateParticles(map, np, 2 * index);
         if (dist1 * dist1 < np->dist2[0]) {
            LocateParticles(map, np, 2 * index + 1);
         }
      }
   }
   
   // Compute the squared distance between current particle and np->pos
   dist1 = p->pos[0] - np->pos[0];
   dist2 = dist1 * dist1;
   dist1 = p->pos[1] - np->pos[1];
   dist2 += dist1 * dist1;
   dist1 = p->pos[2] - np->pos[2];
   dist2 += dist1 * dist1;
   
   if (dist2 < np->dist2[0]) {
      // Particle found.  Insert it in the list
      if (np->found < np->max) {
         // Heap is not full.  Use array
         np->found++;
         np->dist2[np->found] = dist2;
         np->index[np->found] = p;
      } else {
         int j, parent;
         
         if (np->gotHeap == 0) {
            // Build heap
            float dst2;
            int k;
            const Particle* phot;
            int halfFound = np->found >> 1;
            for (k = halfFound; k >= 1; k--) {
               parent = k;
               phot = np->index[k];
               dst2 = np->dist2[k];
               while(parent <= halfFound) {
                  j = parent + parent;
                  if (j < np->found && np->dist2[j] < np->dist2[j + 1]) {
                     j++;
                  }
                  if (dst2 >= np->dist2[j]) {
                     break;
                  }
                  np->dist2[parent] = np->dist2[j];
                  np->index[parent] = np->index[j];
                  parent = j;
               }
               np->dist2[parent] = dst2;
               np->index[parent] = phot;
            }
            np->gotHeap = 1;
         }
      
         // Insert new particle into max heap
         // Delete largest element, insert new and reorder the heap
         
         parent = 1;
         j = 2;
         while (j <= np->found) {
            if (j < np->found && np->dist2[j] < np->dist2[j+1]) {
               j++;
            }
            if (dist2 > np->dist2[j]) {
               break;
            }
            np->dist2[parent] = np->dist2[j];
            np->index[parent] = np->index[j];
            parent = j;
            j += j;
         }
         np->index[parent] = p;
         np->dist2[parent] = dist2;
         np->dist2[0] = np->dist2[1];
      }
   }
}

float DensityEstimate(BalancedParticleMap* map, const float pos[3], const float maxDistance, const int count) {
   NearestParticles np;

   np.dist2 = (float*)alloca(sizeof(float) * (count + 1));
   np.index = (const Particle**)alloca(sizeof(Particle*) * (count + 1));
   
   np.pos[0] = pos[0];
   np.pos[1] = pos[1];
   np.pos[2] = pos[2];
   np.max = count;
   np.found = 0;
   np.gotHeap = 0;
   np.dist2[0] = maxDistance * maxDistance;
   
   // Locate the nearest particles
   LocateParticles(map, &np, 1);
   if (np.found == 0) {
      return 0;
   }
   
   const float volume = 4.0 / 3.0 * M_PI * pow(np.dist2[0], 3.0 / 2.0);
   return ((float)np.found) / volume;
}

void DestroyParticleMap(BalancedParticleMap* map) {
   free(map->particles);
   free(map);
}

void SaveParticleMap(BalancedParticleMap* bmap, char* filename) {
   printf("Saving %i particles\n", bmap->storedParticles);
   FILE* file = fopen(filename, "wb");
   size_t count = 0;
   int i;
   for (i = 0; i < bmap->storedParticles; i++) {
      // Write the particle
      count += fwrite(&bmap->particles[i], sizeof(Particle), 1, file);
   }
   //size_t count = fwrite(bmap->particles, sizeof(Particle), bmap->storedParticles, file);
   assert(count == bmap->storedParticles && "Writing particles to file");
   fclose(file);
}

BalancedParticleMap* LoadParticleMap(char* filename) {
   struct stat sbuf;
   int count;
   FILE* file;
   BalancedParticleMap* bmap;
   file = fopen(filename, "rb");
   assert(file && "Opening file");

   bmap = malloc(sizeof(BalancedParticleMap));
   
   stat(filename, &sbuf);
   bmap->storedParticles = sbuf.st_size / sizeof(Particle);
   bmap->particles = malloc(sbuf.st_size);
   if (bmap->particles == NULL) {
      fprintf(stderr, "Ran out of memory while initializing particle map.\n");
      exit(-1);
   }
   
   count = fread(bmap->particles, sizeof(Particle), bmap->storedParticles, file);
   assert(count == bmap->storedParticles && "Reading particles from file");
   
   fclose(file);
   bmap->halfStoredParticles = bmap->storedParticles / 2 - 1;
   
   return bmap;
}

