#include "grabcut.h"

#include <third_party/maxflow/graph.h>
#include <ml.h>

typedef enum {
  BG = 0,
  Unknown = 15,
  FG = 255
} MatteLabel;

static const int kNumClusters = 5;

void
CropBox(BBox2i& box, const cv::Size& s)
{
  box.bmin[0] = std::max(0, box.bmin[0]);
  box.bmin[1] = std::max(0, box.bmin[1]);

  box.bmax[0] = std::min(s.width - 1, box.bmax[0]);
  box.bmax[1] = std::min(s.height - 1, box.bmax[1]);
}

void
ComputeGMM(CvEM& em_model, const cv::Mat& image, const cv::Mat& alpha, MatteLabel which)
{
  cv::Size s = image.size();

  int nsamples = 0;
  // TODO(edluong): this is highly inefficient.  can easiliy infer this from the boxes
  for (int y = 0; y < s.height; y++) {
    const unsigned char* data = alpha.ptr<unsigned char>(y);
    for (int x = 0; x < s.width; x++) {
      if (*data == which) 
        nsamples++;

      data++;
    }
  }

  static const int kDim = 3; // 3 color channels
  cv::Mat samples(nsamples, kDim, CV_32FC1);
  // add the samples
  int sampleIndex = 0;
  for (int y = 0; y < s.height; y++) {
    const unsigned char* data = image.ptr<unsigned char>(y);
    for (int x = 0; x < s.width; x++) {
      if (alpha.at<unsigned char>(y, x) == which) {
        cv::Vec3f f = image.at<cv::Vec3f>(y, x);

        for (int d = 0; d < kDim; d++) {
          samples.at<cv::Vec3f>(sampleIndex, d) = f[d];
        }
        sampleIndex++;
      }

      data++;
    }
  }
  assert (sampleIndex == nsamples);

  CvEMParams params;

  // initialize model's parameters
  params.covs      = NULL;
  params.means     = NULL;
  params.weights   = NULL;
  params.probs     = NULL;
  params.nclusters = kNumClusters;
  params.cov_mat_type       = CvEM::COV_MAT_SPHERICAL;
  params.start_step         = CvEM::START_AUTO_STEP;
  params.term_crit.max_iter = 10;
  params.term_crit.epsilon  = 0.1;
  params.term_crit.type     = CV_TERMCRIT_ITER|CV_TERMCRIT_EPS;

  // cluster the data
  em_model.train(samples); //, cv::Mat(), params);
}

void
ComputeDataTerms(cv::Mat& dataTerms, const CvEM& gmm, const cv::Mat& image)
{
  cv::Size s = image.size();

  cv::Mat sample(1, 3, cv::DataType<float>::type);
  cv::Mat probs(1, kNumClusters, cv::DataType<float>::type);
  for (int y = 0; y < s.height; y++) {
    for (int x = 0; x < s.width; x++) {
      const cv::Vec3f& z = image.at<cv::Vec3f>(y, x);
      sample.at<float>(0, 0) = z[0];
      sample.at<float>(0, 1) = z[1];
      sample.at<float>(0, 2) = z[2];

      int k = cvRound(gmm.predict(sample, &probs));

      float prob = probs.at<float>(0, k);
      float pi = ((cv::Mat)gmm.get_weights()).at<float>(0, k);
      dataTerms.at<float>(y, x) = -log(prob) - log(pi);
    }
  }
}

cv::Mat
GrabCut(const cv::Mat& image_b, const BBox2f& box)
{
  cv::Size s = image_b.size();
  cv::Mat image(s, cv::DataType<cv::Vec3f>::type); 
  for (int y = 0; y < s.height; y++) {
    for (int x = 0; x < s.width; x++) {
      cv::Vec3b b = image_b.at<cv::Vec3b>(y, x);
      cv::Vec3f v(b[0] / 255.f, b[1] / 255.f, b[2] / 255.f);
      image.at<cv::Vec3f>(y, x) = v;
    }
  }
  
  BBox2i selBox;
  selBox.bmin[0] = static_cast<int>(box.bmin[0]);
  selBox.bmin[1] = static_cast<int>(box.bmin[1]);
  selBox.bmax[0] = static_cast<int>(box.bmax[0]);
  selBox.bmax[1] = static_cast<int>(box.bmax[1]);

  BBox2i biggerBox;
  biggerBox.bmin[0] = selBox.bmin[0] - 10;
  biggerBox.bmin[1] = selBox.bmin[1] - 10;
  biggerBox.bmax[0] = selBox.bmax[0] + 10;
  biggerBox.bmax[1] = selBox.bmax[1] + 10;

  CropBox(selBox, s);
  CropBox(biggerBox, s);

  cv::Mat alpha(s, cv::DataType<unsigned char>::type);
  alpha = FG;

  cv::Mat matte(s, cv::DataType<unsigned char>::type);
  matte = Unknown;

  for (int y = biggerBox.bmin[1]; y < biggerBox.bmax[1]; y++) {
    if (y < selBox.bmin[1] || y >= selBox.bmax[1]) {
      for (int x = biggerBox.bmin[0]; x < biggerBox.bmax[0]; x++) {
        alpha.at<unsigned char>(y, x) = BG;
        matte.at<unsigned char>(y, x) = BG;
      }
    }
    else {
      for (int x = biggerBox.bmin[0]; x < selBox.bmin[0]; x++) {
        alpha.at<unsigned char>(y, x) = BG;
        matte.at<unsigned char>(y, x) = BG;
      }
      for (int x = selBox.bmax[0]; x < biggerBox.bmax[0]; x++) {
        alpha.at<unsigned char>(y, x) = BG;
        matte.at<unsigned char>(y, x) = BG;
      }
    }
  }

  CvEM gmm[2];
  // generate GMM for bg and fg.  use alpha with labels to get the correct gmm
  ComputeGMM(gmm[0], image, alpha, BG);
  ComputeGMM(gmm[1], image, alpha, FG);

  // apply min-cut here

  int numNodes = s.height * s.width;
  int numEdges = 2 * numNodes;

  typedef Graph<int, int, int> GraphType;
  GraphType g(numNodes, numEdges);

  g.add_node(numNodes);

  cv::Mat dataTermsBG(s, cv::DataType<float>::type);
  cv::Mat dataTermsFG(s, cv::DataType<float>::type);

  ComputeDataTerms(dataTermsBG, gmm[0], image);
  ComputeDataTerms(dataTermsFG, gmm[1], image);

  for (int y = 0; y < s.height; y++) {
    for (int x = 0; x < s.width; x++) {
      int whichNode = y*s.width + x;
      unsigned char myAlpha = alpha.at<unsigned char>(y, x);

      int data_term_BG = 0;
      int data_term_FG = 0;

      if (matte.at<unsigned char>(y, x) == Unknown) {
        // attach data term
        data_term_BG = (int)dataTermsBG.at<float>(y, x);
        data_term_FG = (int)dataTermsFG.at<float>(y, x);
      }
      else {
        data_term_BG = myAlpha == BG ? numNodes : 0;
        data_term_FG = myAlpha == FG ? numNodes : 0;
      }

      g.add_tweights(whichNode, -data_term_BG, -data_term_FG);

      // attach smoothness term
      if (x < s.width - 1) {
        int otherNode = y*s.width + (x+1);
        unsigned char otherAlpha = alpha.at<unsigned char>(y, x+1);
        int delta = 50*(myAlpha == otherAlpha ? 1 : 0);
        int weight = delta;
        g.add_edge(whichNode, otherNode, weight, weight);
      }
      if (y < s.height - 1) {
        int otherNode = (y+1)*s.width + x;
        unsigned char otherAlpha = alpha.at<unsigned char>(y+1, x);
        int delta = 50*(myAlpha == otherAlpha ? 1 : 0);
        int weight = delta;
        g.add_edge(whichNode, otherNode, weight, weight);
      }
    }
  }

  g.maxflow();
  //int flow = g.maxflow();

  for (int y = 0; y < s.height; y++) {
    for (int x = 0; x < s.width; x++) {
      int whichNode = y*s.width + x;

      if (matte.at<unsigned char>(y, x) == Unknown) {
        int whichSeg = g.what_segment(whichNode);
        alpha.at<unsigned char>(y, x) = whichSeg == GraphType::SOURCE ? 0 : 255;
      }
    }
  }

  return alpha;
}

