#include "blur.h"

#include <core/flow.h>
#include <core/geom.h>
#include <core/graph.h>
#include <core/grabcut.h>
#include <gl/platform.h>
#include <gl/util.h>

// opencv
#include <cv.h>
#include <highgui.h>

// standard c++
#include <iostream>
#include <string>

// standard c
#include <stdio.h>

#define USE_TEXTURE 0

using namespace std;
using cv::Mat; 

static MyViewer* view;

static const float kMaxHeatHue = 0.6;
static const float kMatteThreshhold = 0.03;
static const int kMaskWindowHalf = 1;

Vec3f
HeatToHSV(float value)
{
  return Vec3f(value * kMaxHeatHue, 1.f, 1.f);
}

// grabbed from
// http://www.alvyray.com/Papers/hsv2rgb.htm
Vec3f
HSVtoRGB(const Vec3f& hsv) 
{
  float h = hsv[0] * 6.f;
  float s = hsv[1];
  float v = hsv[2];

  int i = (int)h;
  float f = h - i;
  if (i % 2 == 0) {
    f = 1.f -f;
  }

  float m = v*(1-s);
  float n = v*(1- s*f);
  switch (i) {
    case 6:
    case 0: return Vec3f(v, n, m);
    case 1: return Vec3f(n, v, m);
    case 2: return Vec3f(m, v, n);
    case 3: return Vec3f(m, n, v);
    case 4: return Vec3f(n, m, v);
    case 5: return Vec3f(v, m, n);
    default:
      return Vec3f(0.f, 0.f, 0.f);
  }
}

MyViewer::MyViewer(
    int windowId, 
    int windowWidth, int windowHeight) :
  _windowId(windowId),
  _canPan(false),
  _windowPan(windowWidth/2.f, windowHeight/2.f),
  _zoom(1.f), _lastZoom(_zoom), _lastZoomR(1.f),
  _selectStart(0.f, 0.f),
  _action(None),
  _lastPos(0.f, 0.f),
  _currImg(-1),
  _matteThreshhold(kMatteThreshhold),
  _showXformed(false),
  _aligned(false),
  _showImage(true),
  _showFeatures(true),
  _showUnmatched(true),
  _viewMode(Single)
{
  InitByteToFloatLUT();

  _windowSize[0] = windowWidth;
  _windowSize[1] = windowHeight;

}

void
MyViewer::SetWindowSize(int windowWidth, int windowHeight)
{
  _windowSize[0] = windowWidth;
  _windowSize[1] = windowHeight;
}

void
MyViewer::Draw() const
{
  glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

  glMatrixMode(GL_PROJECTION);
  glLoadIdentity();

  float halfView[2] = { 
    _windowSize[0] / (2.f * _zoom), 
    _windowSize[1] / (2.f * _zoom)
  };

  glOrtho(- halfView[0], halfView[0],
    - halfView[1], halfView[1],
    -1.f, 1.f);

  if (_viewMode == Original) {
    int width = _original.size().width;
    int height = _original.size().height;
    const unsigned char* data = _original.data;

    assert (_original.elemSize() == 3);

    glRasterPos2f(0.f, 0.f);
    glPixelZoom(_zoom, -_zoom);
    glBitmap(0, 0, 0.f, 0.f, -_windowPan[0]*_zoom, -_windowPan[1]*_zoom, NULL);
    glDrawPixels(width, height, GL_BGR, GL_UNSIGNED_BYTE, data);
    return;
  }

  if (_viewMode == Blurred) {
    int width = _blurredImage.size().width;
    int height = _blurredImage.size().height;
    const unsigned char* data = _blurredImage.data;

    assert (_blurredImage.elemSize() == 3);

    glRasterPos2f(0.f, 0.f);
    glPixelZoom(_zoom, -_zoom);
    glBitmap(0, 0, 0.f, 0.f, -_windowPan[0]*_zoom, -_windowPan[1]*_zoom, NULL);
    glDrawPixels(width, height, GL_BGR, GL_UNSIGNED_BYTE, data);
    return;
  }

  if (_viewMode == Final) {
    int width = _final.size().width;
    int height = _final.size().height;
    const unsigned char* data = _final.data;

    assert (_final.elemSize() == 3);

    glRasterPos2f(0.f, 0.f);
    glPixelZoom(_zoom, -_zoom);
    glBitmap(0, 0, 0.f, 0.f, -_windowPan[0]*_zoom, -_windowPan[1]*_zoom, NULL);
    glDrawPixels(width, height, GL_BGR, GL_UNSIGNED_BYTE, data);
    return;
  }

  glMatrixMode(GL_MODELVIEW);
  glLoadIdentity();

  glTranslatef(-_windowPan[0], -_windowPan[1], 0.f);

  if (_viewMode == Alpha) {
    DrawImagePixels(_matte, 1.f);
    return;
  }

  if (_currImg < 0 || _currImg >= (int)_images.size()) return;

  // Draw image
  bool overlay = _viewMode == Overlay && _currImg > 0;
  if (overlay) {
    // note the blendfunc
    if (_showImage) {
      DrawImage(0, false, 0.5f, false);
    }

    if (_showFeatures) {
      DrawFeatures(0, false);
    }

    glBlendFunc(GL_DST_ALPHA, GL_DST_ALPHA);
    glEnable(GL_BLEND);
  }

  bool showXformed = _showXformed && _images[_currImg].xform_tex;
  if (_showImage) {
    DrawImage(_currImg, showXformed);
  }
  if (_showFeatures) {
    DrawFeatures(_currImg, showXformed);
  }

  if (overlay) {
    glDisable(GL_BLEND);

    DrawMatches(showXformed);
  }

  bool sideBySide = _viewMode == SideBySide && _currImg > 0;
  if (sideBySide) {
    DrawSideImage(showXformed);
  }

  bool showHeatmap = sideBySide || overlay;
  if (showHeatmap) {
    glPushMatrix();
    {
      static const float kHeatmapHeight = 25.f;
      cv::Size s = _images[_currImg].orig_img.m.size();

      float xOffset = sideBySide ? -0.5f * s.width : 0.f;
      glTranslatef(xOffset, -(kHeatmapHeight + 40.f), 0.f);
      DrawHeatmap(s.width, kHeatmapHeight, 2*kHeatmapIntervals);
    }
    glPopMatrix();
  }

  if (_action == Select || _action == Matte) {
    DrawSelectBox();
  }

}

void
MyViewer::DrawMatches(bool otherXformed) const
{
  glPushMatrix();
  {
    // translate to top left corner of image, and flip y
    glTranslatef(.5f, _images[0].orig_img.m.size().height - 0.5f, -0.4f);
    glScalef(1.f, -1.f, 1.f);

    bool drawUnmatched = _showUnmatched;
    if (drawUnmatched) {
      glLineStipple (3, 0xAAAA);
      glEnable(GL_LINE_STIPPLE);

      glColor4f(1.f, .5f, .5f, .5f);
      DrawMatches(otherXformed, false);

      glDisable(GL_LINE_STIPPLE);
    }

    glColor4f(1.f, 0.f, 1.f, 1.f);
    DrawMatches(otherXformed, true);

  }
  glPopMatrix();
}

// Draws 0th features at origin, and other image to the right
// only draws the matches, or the non-matches
void
MyViewer::DrawMatches(bool otherXformed, bool drawMatches) const
{
  const std::vector<Sifter::Feature>& baseFeatures = _images[0].orig_img.features;
  const std::vector<Sifter::Feature>& otherFeatures = otherXformed ?
      _images[_currImg].xform_img.features : _images[_currImg].orig_img.features;
  const Matcher::Matches& matches = _images[_currImg].matches;

  float xOffset = _viewMode == SideBySide ? _images[0].orig_img.m.size().width : 0.f;

  glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
  glEnable(GL_BLEND);
  glEnable(GL_LINE_SMOOTH);
  glBegin(GL_LINES);
  for (int i = 0; i < (int)matches.size(); i++) {
    const Matcher::MatchTuple& match = matches[i];

    const Sifter::Feature& p = baseFeatures[match.first];

    if (drawMatches != match.matched) continue;
    if (match.second >= (int)otherFeatures.size()) continue;

    const Sifter::Feature& q = otherFeatures[match.second];

    Vec3f hsv = HeatToHSV(match.distance);
    Vec3f rgb = HSVtoRGB(hsv);
    glColor3fv(rgb.getValue());
    glVertex2f(p.x, p.y);
    glVertex2f(q.x + xOffset, q.y);
  }
  glEnd();
  glDisable(GL_LINE_SMOOTH);
  glDisable(GL_BLEND);
}

void
MyViewer::DrawSideImage(bool otherXformed) const
{
  glPushMatrix();
  {
    glTranslatef(-_images[0].orig_img.m.size().width, 0.f, 0.f);
    if (_showImage) {
      DrawImage(0, false);
    }

    bool showMatches = _showFeatures;
    if (showMatches) {
      DrawFeatures(0, false);
      DrawMatches(otherXformed);
    }
  }
  glPopMatrix();
}

void
MyViewer::DrawImage(int which, bool showXformed, float alpha, bool drawText) const
{

#if USE_TEXTURE
  glEnable(GL_TEXTURE_2D);

  glPushMatrix();
  {
    glTranslatef(0.f, 0.f, -0.2f);

    int width = _images[which].orig_img.m.size().width;
    int height = _images[which].orig_img.m.size().height;

    glColor4f(1.f, 1.f, 1.f, alpha);
    bool showXformed = _showXformed && _images[which].xform_tex;
    GLuint whichTex = showXformed ? _images[which].xform_tex : _images[which].orig_tex;
    glBindTexture(GL_TEXTURE_2D, whichTex);

    glBegin(GL_QUADS);
      glTexCoord2f(0.f, 1.f); glVertex2f(0.5f, 0.5f);
      glTexCoord2f(1.f, 1.f); glVertex2f(width - 0.5f, 0.5f);
      glTexCoord2f(1.f, 0.f); glVertex2f(width - 0.5f, height - 0.5f);
      glTexCoord2f(0.f, 0.f); glVertex2f(0.5f, height - 0.5f);
    glEnd();
  }
  glPopMatrix();
  glDisable(GL_TEXTURE_2D);
#else
  const Mat& imgData = showXformed ? _images[which].xform_img.m : _images[which].orig_img.m; 

  DrawImagePixels(imgData, alpha);
#endif

  if (drawText) {
    glColor4f(1.f, 1.f, 1.f, 1.f);
    char buf[255];
    const char* type;
    if (showXformed) {
      type = which == 0 ? "blended" : "(xformed)";
    }
    else {
      type = "";
    }
    snprintf(buf, 255, "%d %s", which, type);

    const float kFontHeight = 10.f;
    const float kLineHeight = 1.5f*kFontHeight;
    glDrawText(0.f, -kLineHeight, buf, kFontHeight);
  }
}

void
MyViewer::DrawImagePixels(const cv::Mat& imgData, float alpha) const
{
  // TODO(edluong): use glBitmap + glDrawPixels
  int width = imgData.size().width;
  int height = imgData.size().height;

  glPushMatrix();
  {
    glTranslatef(0.f, imgData.size().height, -0.2f);
    glScalef(1.f, -1.f, 1.f);

    glColor4f(1.f, 1.f, 1.f, 1.f);

    if (imgData.type() == CV_8UC3) {
      glBegin(GL_QUADS);
      for (int y = 0; y < height; y++) {

        const unsigned char* data = imgData.ptr<unsigned char>(y);

        for (int x = 0; x < width; x++) {
          float b = ByteToFloat(data[0]);
          float g = ByteToFloat(data[1]);
          float r = ByteToFloat(data[2]);
          glColor4f(r, g, b, alpha);

          glVertex2f(x, y);
          glVertex2f(x+1.f, y);
          glVertex2f(x+1.f, y+1.f);
          glVertex2f(x, y+1.f);

          data += 3;
        }
      }
      glEnd();
    }
    else if (imgData.type() == CV_8U) {
      //glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
      //glEnable(GL_BLEND);
      glBegin(GL_QUADS);
      for (int y = 0; y < height; y++) {

        const unsigned char* data = imgData.ptr<unsigned char>(y);

        for (int x = 0; x < width; x++) {
          // these are actually encoded as binary for me
          float a = ByteToFloat(data[0]);
          //glColor4f(1.f, 0.f, 0.f, a);
          glColor4f(a, 0.f, 0.f, 1.f);

          glVertex2f(x, y);
          glVertex2f(x+1.f, y);
          glVertex2f(x+1.f, y+1.f);
          glVertex2f(x, y+1.f);

          data++;
        }
      }
      glEnd();
      //glDisable(GL_BLEND);
    }
  }
  glPopMatrix();
}

void
MyViewer::DrawFeatures(int which, bool showXformed) const
{
  glEnable(GL_POINT_SMOOTH);

  glPushMatrix();
  {
    // translate to top left corner of image, and flip y
    glTranslatef(.5f, _images[which].orig_img.m.size().height - 0.5f, -0.4f);
    glScalef(1.f, -1.f, 1.f);

    static const float kPointSize = 1.f;
    glPointSize(kPointSize * _zoom);
    glBegin(GL_POINTS);

    static const Vec3f kSelectedFeatureColor(0.f, .4f, 1.f);
    static const Vec3f kUnSelectedFeatureColor(1.f, 0.f, 1.f);

    const Sifter::Image& img = showXformed ? _images[which].xform_img : _images[which].orig_img;
    const vector<Sifter::Feature>& features = img.features;
    for (int i = 0; i < (int)features.size(); i++) {
      const Sifter::Feature& f = features[i];

      bool selected = which == 0 && _selFeatureIds.find(f.id) != _selFeatureIds.end();

      glColor4fv((selected ? kSelectedFeatureColor : kUnSelectedFeatureColor).getValue());

      glVertex2f(f.x, f.y);
    }
    glEnd();
  }
  glPopMatrix();

  glDisable(GL_POINT_SMOOTH);
}

BBox2f
MyViewer::GetSelectBox() const
{
  Vec2f start = screenToWorld(_selectStart);
  Vec2f end = screenToWorld(_lastPos);

  BBox2f selBox;
  selBox.AddPoint(start);
  selBox.AddPoint(end);

  return selBox;
}

void
MyViewer::DrawSelectBox() const
{
  // Could easily draw just directly on screen but whatever.

  BBox2f selBox = GetSelectBox();

  glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
  glEnable(GL_BLEND);

  glColor4f(1.f, .4f, 0.f, .3f);
  glBoxf(selBox, true);

  glColor4f(1.f, .4f, 0.f, 1.f);
  glBoxf(selBox, false);

  glDisable(GL_BLEND);
}

void
MyViewer::DrawHeatmap(float width, float height, int numIntervals) const
{
  float boxWidth = width / numIntervals;
  float boxHeight = height;

  float valueDelta = 1.f / numIntervals;
  float value = valueDelta / 2.f;

  float x = 0.f;
  for (int i = 0; i < numIntervals; i++) {
    BBox2f intervalBox;
    intervalBox.AddPoint(Vec2f(x, 0.f));
    intervalBox.AddPoint(Vec2f(x+boxWidth, boxHeight));

    Vec3f hsv = HeatToHSV(value);
    Vec3f rgb = HSVtoRGB(hsv);
    glColor3fv(rgb.getValue());
    glBoxf(intervalBox, true);

    glColor3f(0.f, 0.f, 0.f);
    char valueStr[255];
    snprintf(valueStr, 255, "%.2f", value);

    float fontHeight = height/4.f;
    float textWidth = ComputeTextStringWidth(valueStr, fontHeight);
    float textX = x + 0.5f*(boxWidth - textWidth);
    float textY = 0.5f*(boxHeight - fontHeight);

    glDrawText(textX, textY, valueStr, fontHeight);

    x += boxWidth;
    value += valueDelta;
  }
}

void
MyViewer::PickFeatures(int which, const BBox2f& box)
{
  assert (which >= 0 && which < (int)_images.size());

  const std::vector<Sifter::Feature>& features = _images[which].orig_img.features;
  for (int i = 0; i < (int)features.size(); i++) {
    const Sifter::Feature& f = features[i];

    if (box.Inside(Vec2f(f.x, f.y))) {
      _selFeatureIds.insert(f.id);
    }
  }
}

void
MyViewer::KeyPress(unsigned char key, bool down, int x, int y)
{
  switch (key) {
    case ' ':
      _canPan = down; 
      return;

    default: break;
  }

  if (down) {
    switch (key) {
      case 9: // tab key
          _showXformed = !_showXformed;
        break;
      case '`':
        SetCurrent(0);
        break;
      case '1': case '2': case '3': case '4': case '5':
      case '6': case '7': case '8': case '9':
        SetCurrent(key - '1' + 1);
        break;
      case 'a':
        _viewMode = _viewMode == Alpha ? Single : Alpha;
        break;
      case 'b':
        if (_aligned) {
          BlendImages();
        }
        break;
      case 'c':
        ClearBlur(); // undoes
        break;
      case 'e':
        _viewMode = _viewMode == Original ? Single : Original;
        break;
      case 'f':
        _showFeatures = !_showFeatures;
        break;
      case 'h':
        _showUnmatched = !_showUnmatched;
        break;
      case 'i':
        _showImage = !_showImage;
        break;
      case 'm': case 'M':
        AlignImages(key == 'm');
        break;
      case 'o':
        _viewMode = _viewMode == Overlay ? Single : Overlay;
        break;
      case 'q':
        glutDestroyWindow(_windowId); 
        exit(0);
        break;
      case 'r':
        _viewMode = _viewMode == Blurred ? Single : Blurred;
        break;
      case 's':
        _viewMode = _viewMode == SideBySide ? Single : SideBySide;
        break;
      case 't':
        _viewMode = _viewMode == Final ? Single : Final;
        break;

      default: break;
    }
  }
}

void
MyViewer::SpecialKeyPress(int key, int x, int y)
{
  switch (key) {
    case GLUT_KEY_LEFT:
    case GLUT_KEY_RIGHT:
      SetCurrent(_currImg + (key == GLUT_KEY_LEFT ? -1 : 1));
      break;
      /*
    case GLUT_KEY_UP:
    case GLUT_KEY_DOWN:
      if (_viewMode == Alpha) {
        float delta = 0.2;
        if (key == GLUT_KEY_DOWN) delta *= -1.f;

        _matteThreshhold += delta;
        _matte = ComputeChangeMask(_original, _blurredImage, _matteThreshhold);
      }
      break;
      */
    default: break;
  }
}

void
MyViewer::MouseButton(int button, int state, int x, int y)
{
  int mod = glutGetModifiers();

  if (state == GLUT_UP) {
    switch (_action) {
      case Zoom:
        _lastZoom = _zoom;
        break;
      case Select:
        if (_currImg >= 0) {
          BBox2f selBox = GetSelectBox();

          BBox2f imgBox; // convert to image space
          imgBox.AddPoint(worldToImage(selBox.bmin, _currImg));
          imgBox.AddPoint(worldToImage(selBox.bmax, _currImg));

          // TODO(edluong): make a "selected features" set for each image
          PickFeatures(_currImg, imgBox);
        }
        break;
        /*
      case Matte:
        if (_currImg >= 0) {
          BBox2f selBox = GetSelectBox();

          BBox2f imgBox; // convert to image space
          imgBox.AddPoint(worldToImage(selBox.bmin, _currImg));
          imgBox.AddPoint(worldToImage(selBox.bmax, _currImg));

          const cv::Mat& img = _images[0].orig_img.m;
          _matte = GrabCut(img, imgBox);
        }
        break;
        */
      default: break;
    }

    _action = None;
    return;
  }

  if (button == GLUT_LEFT_BUTTON) {
    if (_canPan) {
      _action = Pan;
    }
    else {
      if (!(mod & GLUT_ACTIVE_SHIFT)) {
        _selFeatureIds.clear();
      }

      _selectStart.p[0] = x;
      _selectStart.p[1] = y;
      _action = Select;
    }
  }
  else if (button == GLUT_RIGHT_BUTTON) {
    if (mod & GLUT_ACTIVE_ALT) {
      _lastZoomR = CenterDistance(x, y);
      _lastZoom = _zoom;
      _action = Zoom;
    }
  }
}

bool
MyViewer::MouseMotion(int x, int y)
{
  Vec2f pos(x, y);
  Vec2f diff = pos - _lastPos;
  diff[1] *= -1.f;

  switch (_action) {
    case Pan:
      _windowPan[0] -= diff[0] / _zoom;
      _windowPan[1] -= diff[1] / _zoom;
      break;
    case Zoom:
      {
        float newR = CenterDistance(x, y);
        _zoom = (newR / _lastZoomR) * _lastZoom;
      }
      break;

    case None: default:
      break;
  }

  _lastPos[0] = x;
  _lastPos[1] = y;

  return _action != None;
}

GLuint
glLoadCvMat(const Mat& m)
{
  GLuint tex;
  glGenTextures(1, &tex);

  // initialize gl texture
  glBindTexture(GL_TEXTURE_2D, tex);
  glTexEnvf(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_MODULATE);
  glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR_MIPMAP_LINEAR);
  glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
  glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP);
  glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP);

  int imgWidth = m.size().width;
  int imgHeight = m.size().height;
  GLenum imgFmt = GL_BGR; // not sure why its in this format..
  const unsigned char* imgData = m.data;
  //const unsigned char* imgData = m.data;
  assert (m.elemSize() == 3); // 3 bytes

  // TODO(edluong): this segfaults.  wtf?
  glTexParameteri(GL_TEXTURE_2D, GL_GENERATE_MIPMAP_SGIS, GL_TRUE);
  glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, imgWidth, imgHeight, 0, imgFmt, GL_UNSIGNED_BYTE, imgData);

  return tex;
}


void
MyViewer::AddImage(const Sifter::Image& img)
{
  ImageRecord imgRec;
  imgRec.orig_img = img;

#if USE_TEXTURES
  imgRec.orig_tex = glLoadCvMat(img.m);
#else
  imgRec.orig_tex = 1;
#endif

  _images.push_back(imgRec);
  SetCurrent(0);
}

void
MyViewer::AlignImages(bool useRansac)
{
  if (_aligned) return;

  Matcher matcher;
  if (_selFeatureIds.empty()) {
    matcher.SetPrimary(_images[0].orig_img);
  }
  else {
    matcher.SetPrimary(_images[0].orig_img, _selFeatureIds);
  }

  cout << "Aligning images (" << (useRansac ? "ransac" : "lmeds") << ")";
  cout.flush();

  for (size_t i = 1; i < _images.size(); i++) {

    ImageRecord& imgRec = _images[i];
    imgRec.matches = matcher.GetCorrespondence(imgRec.orig_img);

    // NOTE(edluong): imperically this looks OK.  A better approach is to probably
    // compute the variance and threshhold based on that.  That will be
    // more of a "take some percentage" of the small matches
    float threshhold = 0.4f; 
    imgRec.xform = matcher.ComputeXform(imgRec.orig_img, imgRec.matches, threshhold, useRansac);

    matcher.Warp(imgRec.xform_img, imgRec.orig_img, imgRec.xform);

#if USE_TEXTURES
    imgRec.xform_tex = glLoadCvMat(imgRec.xform_img.m);
#else
    imgRec.xform_tex = 1;
#endif
    cout << ".";
    cout.flush();
  }
  cout << endl;

  _aligned = true;
  _showXformed = true;

  BlendImages();
}

void
MyViewer::BlendImages()
{
  if (!_aligned) return;

  float factor = 1.f/(_images.size() - 1);

  cv::Size s = _images[0].orig_img.m.size();

  int matType = cv::DataType<cv::Vec3d>::type;
  cv::Mat blended = cv::Mat::zeros(s, matType);
  for (size_t i = 1; i < _images.size(); i++) {
    ImageRecord& imgRec = _images[i];
    cv::Mat m(imgRec.xform_img.m.size(), matType);
    imgRec.xform_img.m.convertTo(m, matType);

    blended += m;
  }
  blended *= factor;

  int blendedType = cv::DataType<cv::Vec3b>::type;
  _images[0].xform_img.m = cv::Scalar_<unsigned char>(0, 0, 0);
  blended.convertTo(_images[0].xform_img.m, blendedType);
  _images[0].xform_tex = 1;

  // copy to final.  just resize instead of upsampling
  cv::Size origSize = _original.size();
  cv::resize(_images[0].xform_img.m, _blurredImage, origSize, 0, 0, cv::INTER_CUBIC);

  _matte = ComputeChangeMask(_original, _blurredImage, _matteThreshhold);

  MakeFinal(_final, _original, _blurredImage, _matte);
}

bool
IsCloseColor(const cv::Vec3b& a, const cv::Vec3b& b, float threshhold)
{
  // careful of sign
  for (int i = 0; i < 3; i++) {
    int v_a = a[i];
    int v_b = b[i];

    if (abs(v_a - v_b) >= threshhold) return false;
  }
  return true;
}

cv::Mat
MyViewer::ComputeChangeMask(
    const cv::Mat& ref,
    const cv::Mat& changed,
    float threshhold) 
{
  threshhold *= 255;

  fprintf(stderr, "computing mask.\n");
  cv::Size refSize = ref.size();
  cv::Mat ret = cv::Mat::zeros(refSize, cv::DataType<unsigned char>::type);

  const cv::Mat& b = changed;
  /*
  if (b.size() != refSize) {
    cv::resize(changed, b, refSize);
  }
  */

  for (int y = 0; y < refSize.height; y++) {
    for (int x = 0; x < refSize.width; x++) {
      bool maskSet = false;

      const cv::Vec3b& ref_rgb = ref.at<cv::Vec3b>(y, x);

      for (int dy = -kMaskWindowHalf; dy <= kMaskWindowHalf; dy++) {
        int wy = y + dy;
        if (wy < 0 || wy >= refSize.height) continue;

        for (int dx = -kMaskWindowHalf; dx <= kMaskWindowHalf; dx++) {
          int wx = x + dx;
          if (wx < 0 || wx >= refSize.width) continue;

          const cv::Vec3b& w_rgb = b.at<cv::Vec3b>(wy, wx);

          if (IsCloseColor(ref_rgb, w_rgb, threshhold)) {

            ret.at<unsigned char>(y, x) = 255;
            maskSet = true;
            break;
          }
        }
        if (maskSet) break;
      }
    }
  }

  cv::GaussianBlur(ret, ret, cv::Size2i(15, 15), 0, 0);

  fprintf(stderr, "done mask.\n");
  return ret;
}

cv::Vec3b 
lerp(const cv::Vec3b& v1, const cv::Vec3b& v2, float a)
{
  float oneMinusA = 1.f - a;
  cv::Vec3b ret;
  ret[0] = a*v1[0] + oneMinusA*v2[0];
  ret[1] = a*v1[1] + oneMinusA*v2[1];
  ret[2] = a*v1[2] + oneMinusA*v2[2];
  return  ret;
}

void
MyViewer::MakeFinal(cv::Mat& final, const cv::Mat& orig, const cv::Mat& blurred, const cv::Mat& matte)
{
  cv::Size s = orig.size();
  for (int y = 0; y < s.height; y++) {
    for (int x = 0; x < s.width; x++) {
      float a = ByteToFloat(matte.at<unsigned char>(y, x));

      const cv::Vec3b& orig_rgb = orig.at<cv::Vec3b>(y, x);
      const cv::Vec3b& blurred_rgb = blurred.at<cv::Vec3b>(y, x);

      final.at<cv::Vec3b>(y, x) = lerp(orig_rgb, blurred_rgb, a);
    }
  }

  cv::imwrite("blurred.png", blurred);
  cv::imwrite("final.png", final);
}

void
MyViewer::ClearBlur()
{
  cout << "Clearing" << endl;
  _aligned = false;
  for (size_t i = 0; i < _images.size(); i++) {
    ImageRecord& imgRec = _images[i];
    // TODO(edluong): make this a fn on the image record?
    imgRec.xform_img.m = cv::Scalar_<unsigned char>(0, 0, 0);
    imgRec.xform_img.features.clear();
    imgRec.xform_tex = 0;
    imgRec.matches.clear();
  }
}

void Display() { view->Draw(); glutSwapBuffers(); }
void KeyDown(unsigned char key, int x, int y) { view->KeyPress(key, true, x, y); glutPostRedisplay(); }
void KeyUp(unsigned char key, int x, int y) { view->KeyPress(key, false, x, y); glutPostRedisplay(); }
void SpecialKeyPress(int key, int x, int y) { view->SpecialKeyPress(key, x, y); glutPostRedisplay(); }
void Reshape(int w, int h) { glViewport(0, 0, w, h); view->SetWindowSize(w, h); glutPostRedisplay(); }
void MouseButton(int button, int state, int x, int y) { view->MouseButton(button, state, x, y); glutPostRedisplay(); }
void MouseMotion(int x, int y) { if(view->MouseMotion(x, y)) glutPostRedisplay(); }

void
TestMinCut()
{
  // a b c d e x y
  // 0 1 2 3 4 5 6
  
  std::vector<Vertex> v(7); 
  std::vector<Edge> e;
  e.push_back(Edge(0, 2, 3)); // a - c
  e.push_back(Edge(1, 2, 5)); // b - c
  e.push_back(Edge(1, 3, 4)); // b - d
  e.push_back(Edge(2, 6, 2)); // c - y
  e.push_back(Edge(3, 4, 2)); // d - e
  e.push_back(Edge(4, 6, 3)); // e - y
  e.push_back(Edge(5, 0, 3)); // x - a
  e.push_back(Edge(5, 1, 1)); // x - b

  Graph g(v, e);

  FlowSolver fs(g, 5, 6);
  std::vector<int> mc = fs.MinCut();

  for (int i = 0; i < (int)mc.size(); i++) {
    fprintf(stderr, "%d ", mc[i]);
  }
  fprintf(stderr, "\n");
}


int main(int argc, char* argv[])
{
  //TestMinCut(); return 0;

  glutInit(&argc, argv);

  if (argc < 3) {
    fprintf(stderr, "Must provide at least 2 images.\n");
    exit(1);
  }

  cv::Size imageSize = cv::imread(argv[2]).size();

  static const int kWindowBorder = 0;
  int windowSize[2] = {imageSize.width + 2*kWindowBorder, imageSize.height + 2*kWindowBorder};

  glutInitWindowSize(windowSize[0], windowSize[1]);
  glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE);
  int windowId = glutCreateWindow("blur");

  view = new MyViewer(windowId, windowSize[0], windowSize[1]);

  cout << "Extracting features";
  cout.flush();

  for (int i = 1; i < argc; i++) {
    Sifter::Image img;
    img.m = cv::imread(argv[i]);

    if (i == 1) {
      view->SetOriginal(img.m);
    }

    if (img.m.size().width != imageSize.width || img.m.size().height != imageSize.height) {
      cv::Mat dest;
      cv::resize(img.m, dest, imageSize);
      img.m = dest;
    }

    Sifter s;
    s.ExtractFeatures(img);
    view->AddImage(img);
    cout << ".";
    cout.flush();
  }
  cout << endl;

  glutDisplayFunc(Display);
  glutIgnoreKeyRepeat(1); // need a non-zero value here
  glutKeyboardFunc(KeyDown);
  glutKeyboardUpFunc(KeyUp);
  glutSpecialFunc(SpecialKeyPress);
  glutReshapeFunc(Reshape);
  glutMouseFunc(MouseButton);
  glutPassiveMotionFunc(MouseMotion);
  glutMotionFunc(MouseMotion);
  glutMainLoop();

  return 0;
}

