#include "matcher.h"

using cv::Mat;
using cv::DataType;
using namespace cv::flann;

Mat
CreateFeatureMatrix(const std::vector<Sifter::Feature>& features)
{
  Mat dest(features.size(), FEATURE_MAX_D, DataType<float>::type);
  for (int i = 0; i < (int)features.size(); i++) {
    std::copy(features[i].descr, 
        features[i].descr + FEATURE_MAX_D,
        dest.ptr<float>(i));
  }
  return dest;
}

void
Matcher::SetPrimary(const Sifter::Image& img)
{
  _features = img.features;
  _queryMatrix = CreateFeatureMatrix(_features);
}

void
Matcher::SetPrimary(const Sifter::Image& img, const std::set<int>& featureIds)
{
  _features.clear();

  for (int i = 0; i < (int)img.features.size(); i++) {
    const Sifter::Feature& f = img.features[i];
    if (featureIds.find(f.id) != featureIds.end()) {
      _features.push_back(f);
    }
  }

  _queryMatrix = CreateFeatureMatrix(_features);
}

Matcher::Matches
Matcher::GetCorrespondence(const Sifter::Image& other) const
{
  // build index
  Mat featureMatrix = CreateFeatureMatrix(other.features);

  Index nn(featureMatrix, KDTreeIndexParams());

  // only looking for 1 neighbor
  const int numFeatures = _queryMatrix.size().height;
  const int numNeighbors = 1;
  Mat indices(numFeatures, numNeighbors, DataType<int>::type);
  Mat dists(numFeatures, numNeighbors, DataType<float>::type);

  //int numChecks = 64;
  //nn.knnSearch(_queryMatrix, indices, dists, numNeighbors, SearchParams(numChecks));
  nn.knnSearch(_queryMatrix, indices, dists, numNeighbors, SearchParams());

  Matches ret;

  // compute the mean
  float mean = 0.f;
  float maxDist = 0.f;
  float oneOverN = 1.f / numFeatures;
  for (int i = 0; i < numFeatures; i++) {
    float dist = dists.at<float>(i, 0);

    maxDist = std::max(maxDist, dist);
    mean += oneOverN * dist;
  }

  // TODO(edluong): really want to through out outliers in distance.

  // Find homography from matched to primary
  bool hadNonZeroDist = maxDist > 0.01;
  float oneOverMaxDist = hadNonZeroDist ? 1.f / maxDist : 1.f;
  for (int i = 0; i < numFeatures; i++) {
    MatchTuple match;
    match.first = _features[i].id;
    match.second = other.features[indices.at<int>(i, 0)].id;
    assert (match.second == indices.at<int>(i, 0));

    float normDist = (dists.at<float>(i, 0) - mean) * oneOverMaxDist;
    if (hadNonZeroDist) {
      normDist += 0.5f;
    }
    match.distance = normDist;
    ret.push_back(match);
  }

  return ret;
}

cv::Mat
Matcher::ComputeXform(const Sifter::Image& other, Matches& matches, float threshhold, bool useRansac, float ransacThreshhold) const
{
  // count number of features that match threshhold.
  int numFeatures = 0;
  for (int i = 0; i < (int)matches.size(); i++) {
    MatchTuple& match = matches[i];
    // TODO(edluong): only consider those below threshhold
    match.matched = match.distance < threshhold;

    if (match.matched) {
      numFeatures++;
    }
  }

  Mat srcFeatures(numFeatures, 2, DataType<float>::type);
  Mat destFeatures(numFeatures, 2, DataType<float>::type);

  int featureRow = 0;
  for (int i = 0; i < (int)matches.size(); i++) {
    const MatchTuple& match = matches[i];
    if (!match.matched) continue;

    const Sifter::Feature& srcFeature = other.features[match.second];
    float* srcRow = srcFeatures.ptr<float>(featureRow);
    srcRow[0] = srcFeature.x;
    srcRow[1] = srcFeature.y;

    // don't use match.first.  that points to the feature's id, which
    // is the index into the original matrix.  Recall that _features
    // is a subset of the original

    assert (match.first == _features[i].id);
    const Sifter::Feature& destFeature = _features[i];
    float* destRow = destFeatures.ptr<float>(featureRow);
    destRow[0] = destFeature.x;
    destRow[1] = destFeature.y;

    featureRow++;
  }

  assert (numFeatures == featureRow);

  if (useRansac) {
    return findHomography(srcFeatures, destFeatures, CV_RANSAC, ransacThreshhold);
  }
  else {
    return findHomography(srcFeatures, destFeatures, CV_LMEDS);
  }
}

void
Matcher::Warp(Sifter::Image& dest, const Sifter::Image& src, const cv::Mat& xform) const
{
  // warp the image
  cv::warpPerspective(src.m, dest.m, xform, src.m.size());

  // warp the feature points
  cv::Mat xformed_featurePts;

  // make featurePts matrix
  cv::Mat featurePts(src.features.size(), 1, DataType<cv::Vec2f>::type);
  for (size_t i = 0; i < src.features.size(); i++) {
    const Sifter::Feature& f = src.features[i];
    featurePts.at<cv::Vec2f>(i, 0) = cv::Vec2f(f.x, f.y);
  }
  cv::perspectiveTransform(featurePts, xformed_featurePts, xform);

  // copy back to feature format
  int numRows = xformed_featurePts.size().height;
  dest.features.clear();
  dest.features.resize(numRows);
  for (int i = 0; i < numRows; i++) {
    const cv::Vec2f& data = xformed_featurePts.at<cv::Vec2f>(i, 0);
    dest.features[i].x = data[0];
    dest.features[i].y = data[1];
  }
}

