import { Matrix, SingularValueDecomposition, determinant } from 'ml-matrix';
import {
  distance_between_points,
  order_points_along_line,
  project_point_to_line,
} from './Geometry';
import { ransac } from './RANSAC';

// https://stats.stackexchange.com/a/332668
class LinearLeastSquares2D {
  /*
  2D linear least squares using the hesse normal form:
      d = x*sin(theta) + y*cos(theta)
  which allows you to have vertical lines.
  */

  fit(data, weights) {
    // Fit model to input data
    var W = Matrix.diag(weights);

    const weight_sum = weights.reduce((a, b) => a + b, 0.0);

    const m = data[0].length;
    var A = new Matrix(data);
    var A_weighted = Matrix.mul(
      A,
      Matrix.ones(A.rows, m).mulColumnVector(Matrix.rowVector(weights))
    );
    var mean = Matrix.columnVector(
      A_weighted.sum('column').map((v) => v / weight_sum)
    );

    A = Matrix.sub(A, Matrix.ones(A.rows, m).mulRowVector(mean));

    var AT_W_A = Matrix.div(A.transpose().mmul(W).mmul(A), weight_sum);
    var uval = new SingularValueDecomposition(AT_W_A, {
      autoTranspose: true,
    });

    var vec = Matrix.mul(Matrix.columnVector(uval.V.getColumn(0)), -1);

    return { vec: vec, mean: mean };
  }

  residuals(model, data) {
    // Calculate residual error between data and model
    var vec = new Matrix(model.vec.to2DArray());
    var tmp = vec.get(0, 0);
    vec.set(0, 0, vec.get(1, 0));
    vec.set(1, 0, -tmp);

    var data_ = new Matrix(data);
    var mean_ = Matrix.ones(3, data.length);
    mean_ = mean_.mulColumnVector(model.mean).transpose();
    return Matrix.sub(data_, mean_).mmul(vec).getColumn(0);
  }
}

export function line_fit(points, weights) {
  if (points === undefined || points.length < 2) {
    return null;
  }

  if (weights === undefined || weights.length === 0) {
    weights = Array(points.length).fill(1.0);
  }

  if (points[0].length === 2) {
    points.forEach((p, i) => points[i].push(0.0));
  }

  var mean = new Matrix(points).mean('column');

  var ls = new LinearLeastSquares2D();
  const min_samples = 10;
  var fit;
  if (points.length < min_samples) {
    var model = ls.fit(points, weights);
    fit = { model: model };
  } else {
    fit = ransac(points, weights, ls, 5);
  }

  if (fit.model !== null) {
    var mean = fit.model.mean;
    var vec = fit.model.vec;

    const line = [
      Matrix.add(mean, Matrix.mul(vec, -1)).getColumn(0),
      Matrix.add(mean, vec).getColumn(0),
    ];

    const result = order_points_along_line(points, line);
    const ordered_points = result.points;
    const p_start = ordered_points[0];
    const p_end = ordered_points[ordered_points.length - 1];
    const [p1, t1] = project_point_to_line(p_start, line);
    const [p2, t2] = project_point_to_line(p_end, line);

    var point1 = p1;
    var point2 = p2;

    // Swap the points and vector if projected
    // point p1 is not near the first point
    if (
      distance_between_points(p1, points[0]) >
      distance_between_points(p2, points[0])
    ) {
      var tmp = point1;
      point1 = point2;
      point2 = tmp;
      vec[0] -= vec[0];
      vec[1] -= vec[1];
      vec[2] -= vec[2];
    }
    return { p1: point1, p2: point2, model: { mean: mean, vec: vec } };
  }
  return null;
}
