#pragma once

#include <vector>
#include <complex>
#include <opencv2/opencv.hpp>
#include <iterator>
#include <cmath>

namespace math {

  using complex = std::complex<double>;
  using signal = std::vector<double>;
  using csignal = std::vector<complex>;
  using contour = std::vector<cv::Point>;
  constexpr double pi() {return std::atan(1)*4;}

  void display_abs(const csignal& s) {
    int count=0;
    for (auto d: s) {
      std::cout << count++ << ' ' << std::abs(d) << std::endl;
    }
  }

 void display(const contour& s) {
    int count=0;
    for (auto d: s) {
      std::cout << count++ << ' ' << d.x << ' ' << d.y << std::endl;
    }
  }

  void display(const csignal& s) {
    int count=0;
    for (auto d: s) {
      std::cout << count++ << ' ' << d.real() << ' ' << d.imag() << std::endl;
    }
  }

  void to_binary(const cv::Mat& img, cv::Mat& output) {
  	for (int index=0,indexNB=0;index<3*img.rows*img.cols;index+=3,indexNB++) {
  		unsigned char B = img.data[index  ];
  		unsigned char G = img.data[index+1];
  		unsigned char R = img.data[index+2];
  
      if (float(R + B + G)/3 > 127) {
  			output.data[indexNB]=0;
      } else {
  			output.data[indexNB]=255;
      }
  	}
  }

  void filter(const cv::Mat& img, cv::Mat& output, int seuil) {
  	bool detect = false;
  	uchar R, G, B;
  	int rows = img.rows;
  	int cols = img.cols;
  	int dim = img.channels();
  	int indexNB;
  
  	for (int index=0,indexNB=0;index<dim*rows*cols;index+=dim,indexNB++) {
  		detect = false;
  		B = img.data[index  ];
  		G = img.data[index+1];
  		R = img.data[index+2];
  
      if ((R>G) && (R>B)) {
        if (((R-B)>=seuil) || ((R-G)>=seuil)) {
          output.data[indexNB]=255;
        } else {
          output.data[indexNB]=0;
        }
      }
  	}
  }

  csignal cont2sig(const contour& cont) {
    csignal sig;
    for (auto p: cont) {
      sig.push_back(complex(p.x, p.y));
    }
    return sig;
  };

  complex mean(const csignal& sig) {
    complex res = 0;
    for (auto x: sig) {
      res += x;
    }
    return complex(res.real()/sig.size(), res.imag()/sig.size());
  };

  csignal diff(const csignal& input, complex mean) {
    csignal res;
    for (auto x: input) {
      res.push_back(x-mean);
    }
    return res;
  }

  csignal& dft(const csignal& input) {
    csignal* res = new csignal();
    int size = input.size();

    for (int k=0; k<size; ++k) {
      complex t=0;
      for (int n=0; n<size; ++n) {
        t += (input[n] * std::exp(complex(0, -2*pi()*n*k/size)));
      }
      res->push_back(t);
    }
    return *res;
  }

  csignal fft_rec(const csignal& input) { //TODO: implémenter la fft !!!
    int size = input.size();

    if (size <= 1) {
      return input;
    } else if (size == 2) {
      csignal res;
      res.push_back(input[0]+input[1]);
      res.push_back(input[0]-input[1]);
      return res;
    } else if (size == 3) {
      csignal res;
      complex e2 = std::exp(complex(0, -2*pi()/3));
      complex e4 = std::exp(complex(0, -4*pi()/3));
      complex e8 = std::exp(complex(0, -8*pi()/3));
      res.push_back(input[0]+input[1]+input[2]);
      res.push_back(input[0]+input[1]*e2+input[2]*e4);
      res.push_back(input[0]+input[1]*e4+input[2]*e8);
      return res;
    } else {
      csignal odd;
      csignal even;
      auto odd_back_it = std::back_inserter(odd);
      auto even_back_it = std::back_inserter(even);
      bool insert_in_even = true;

      for (auto it = input.begin(); it != input.end(); ++it) {
        if (insert_in_even) {
          *(even_back_it++) = *it;
          insert_in_even = false;
        } else {
          *(odd_back_it++) = *it;
          insert_in_even = true;
        }
      }

      csignal res(size, complex());
      csignal odd_fft = fft_rec(odd);
      csignal even_fft = fft_rec(even);

      for (int k=0; k<size/2; ++k) {
        complex t = std::exp(complex(0, -2*pi()*k/size)) * odd_fft[k];
        res[k] = even_fft[k] + t;
        res[size/2+k] = even_fft[k] - t;
      }
      return res;
    }
  }

  csignal fft(const csignal& input, int N=0) {
    int opt_size;
    if (N < input.size()) {
      opt_size = 1 << (int)std::ceil(std::log(input.size())/std::log(2));
    } else if (N==0){
      opt_size = input.size();
    } else {
      opt_size = 1 << (int)std::ceil(std::log(N)/std::log(2));
    }
    opt_size = input.size();
    csignal sig(input);
    for (int i=0; i<opt_size-input.size(); ++i) {
      sig.push_back(complex(0, 0));
    }
    return fft_rec(sig);
  };

  void operator*=(csignal& sig, complex& m) {
    for(auto x: sig) {
      x *= m;
    }
  }

  void operator*=(csignal& sig, complex&& m) {
    for(auto x: sig) {
      x *= m;
    }
  }

  void operator/=(csignal& sig, complex& m) {
    for(auto x: sig) {
      x /= m;
    }
  }

  void operator/=(csignal& sig, complex&& m) {
    for(auto x: sig) {
      x /= m;
    }
  }

  csignal extract(const csignal& tfd, int cmin, int cmax) {
    csignal res;
    int kmin = tfd.size()/2 + cmin;
    int kmax = tfd.size()/2 + cmax;
    
    auto tfd_it = tfd.end() + cmin;
    for (int k=0; k<-cmin; ++k) {
      res.push_back(*(tfd_it++));
    }
    tfd_it = tfd.begin();
    for (int k=0; k<cmax+1; ++k) {
      res.push_back(*(tfd_it++));
    }
    return res;
  }

  contour sig2cont(const csignal& sig) {
    contour res;
    for (auto x: sig) {
      res.push_back(cv::Point(x.real(), x.imag()));
    }
    res.push_back(res[0]);
    return res;
  }

  csignal desc2sig(const csignal& desc, complex mean, int N, int cmin, int cmax) { //TODO: retirer cmax des arguments
    csignal cont;
    auto desc_it = desc.begin();

    for (int m=0; m<N; ++m) {
      complex sum = 0;
      auto d_it = desc.begin();
      for (int k=0; k<desc.size(); ++k) {
        sum += *(d_it++)*std::exp(complex(0, 2*pi()*(k+cmin)*m/N));
      }
      cont.push_back(mean + sum);
    }
    return cont;
  };

  std::array<int, 4> bounds(const contour& cont) {
    std::array<int, 4> res = {cont[0].x, cont[0].y, cont[0].x, cont[0].y};

    for (auto p: cont) {
      if (res[0] > p.x) {
        res[0] = p.x;
      }
      if (res[1] > p.y) {
        res[1] = p.y;
      }
      if (res[2] < p.x) {
        res[2] = p.x;
      }
      if (res[3] < p.y) {
        res[3] = p.y;
      }
    }

    return res;
  }

  int x_to_cv(double x, int xmin, int xmax, int width) {
    double a = 0.8 * float(width) / (xmax - xmin);
    double b = 0.1 * float(width) - a * xmin;
    return (a * x + b);
  }

  int y_to_cv(double x, int ymin, int ymax, int width) {
    double a = 0.8 * float(width) / (ymin - ymax);
    double b = 0.1 * float(width) - a * ymax;
    return (a * x + b);
  }

  contour transform(contour& cont, std::array<int, 4>& bounds, int size) {
    contour res;
    for (auto p: cont) {
      int px = x_to_cv(p.x, bounds[0], bounds[2], size);
      int py = x_to_cv(p.y, bounds[1], bounds[3], size);
      res.push_back(cv::Point(px, py));
    }
    return res;
  }

  csignal descriptors(const contour& cont, int cmax) {
    csignal z = cont2sig(cont);
    complex zm = mean(z);
    csignal tfd = dft(diff(z, zm));
    tfd /= z.size();
    int cmin = -cmax;
    csignal desc = extract(tfd, cmin, cmax);

    if (std::abs(desc[desc.size()/2-1]) > std::abs(desc[desc.size()/2+1])) {
      std::reverse(desc.begin(), desc.end());
    }

    double phy = std::arg(desc[desc.size()/2-1]*desc[desc.size()/2+1])/2;
    desc *= std::exp(complex(0, -phy));
    double theta = std::arg(desc[desc.size()/2+1]);

    for (int k=0; k<desc.size(); ++k) {
      desc[k] *= std::exp(complex(0, -theta*(k-cmin)));
    }
    desc /= std::abs(desc[desc.size()/2+1]);

    return desc;
  }

  contour simplify_contour(const contour& cont, int cmax) {
    csignal z = cont2sig(cont);
    complex zm = mean(z);
    csignal tfd = dft(diff(z, zm));
    tfd /= z.size();
    int cmin = -cmax;
    csignal desc = extract(tfd, cmin, cmax);

    if (std::abs(desc[desc.size()/2-1]) > std::abs(desc[desc.size()/2+1])) {
      std::reverse(desc.begin(), desc.end());
    }

    double phy = std::arg(desc[desc.size()/2-1]*desc[desc.size()/2+1])/2;
    desc *= std::exp(complex(0, -phy));
    double theta = std::arg(desc[desc.size()/2+1]);

    for (int k=0; k<desc.size(); ++k) {
      desc[k] *= std::exp(complex(0, -theta*(k-cmin)));
    }
    desc /= std::abs(desc[desc.size()/2+1]);
    /*
    */

    csignal sig = desc2sig(desc, zm, z.size(), cmin, cmax);
    return sig2cont(sig);
  };

  int max_cont(const std::vector<contour>& contours) {
    int max = 0;
    int id = 0;
    for (int i=0; i<contours.size(); ++i) {
      if (contours[i].size() > max) {
        max = contours[i].size();
        id = i;
      }
    }
    return id;
  };
}