import * as math from "mathjs";
import { PoseLandmarkMap, PoseLandmarkType } from "../models";
import { alignYawToTarget } from "./alignment";
import { convertPoseLandmarksToRelativeVectors } from "./relative-positioning";

/**
 * Compute the overall pose similarity based on the pose body part similarities and weights.
 * The score is ranged from 0 to 1.
 *
 * @param {Map<PoseLandmarkType, number>} poseBodyPartSimilarities - Map of pose body part similarities
 * @param {Object} options - Object containing weights, mismatch threshold, and mismatch penalty per body part
 * @param {Map<PoseLandmarkType, number>} options.weights - Map of weights for each pose landmark type
 * @param {number} options.mismatchThreshold - Threshold for considering a mismatch
 * @param {number} options.mismatchPenaltyPerBodyPart - Penalty value for each body part mismatch
 * @return {number} The computed overall pose similarity value
 */
export function computeOverallPoseSimilarity(
  poseBodyPartSimilarities: Map<PoseLandmarkType, number>,
  {
    weights = DEFAULT_WEIGHTS,
    mismatchThreshold = 0.7,
    mismatchPenaltyPerBodyPart = 0.1,
  }: {
    weights: Map<PoseLandmarkType, number>;
    mismatchThreshold: number;
    mismatchPenaltyPerBodyPart: number;
  },
): number {
  // Convert weights to probabilities
  const probabilities = convertWeightsToProbabilities(weights);

  // Weighted similarities
  const weightedSimilarities = new Map<PoseLandmarkType, number>();
  probabilities.forEach((probability, poseLandmarkType) => {
    const similarity = poseBodyPartSimilarities.get(poseLandmarkType)!;
    weightedSimilarities.set(poseLandmarkType, similarity * probability);
  });

  // Sum over all body parts
  let similarity = Array.from(weightedSimilarities.values()).reduce(
    (partialSum, similarity) => partialSum + similarity,
  );

  // Penalty for mismatch
  similarity -= computeMismatchPenalty(poseBodyPartSimilarities, {
    mismatchThreshold: mismatchThreshold,
    mismatchPenaltyPerBodyPart: mismatchPenaltyPerBodyPart,
  });

  // Clamp the similarity to [0, 1]
  // Theorically, the similarity has range [-1, 1]
  similarity = clamp(similarity, 0, 1);

  return similarity;
}

/**
 * Computes the similarities between the body parts of two poses.
 *
 * @param {PoseLandmarkMap} curPoseWorldLandmarkMap - The current pose's landmark map.
 * @param {PoseLandmarkMap} tgtPoseWorldLandmarkMap - The target pose's landmark map.
 * @return {Map<PoseLandmarkType, number>} - A map of pose landmark types and their corresponding similarities.
 */
export function computePoseBodyPartSimilarities(
  curPoseWorldLandmarkMap: PoseLandmarkMap,
  tgtPoseWorldLandmarkMap: PoseLandmarkMap,
): Map<PoseLandmarkType, number> {
  // Align the current pose with the target pose
  const alignedCurPoseWorldLandmarkMap = alignYawToTarget(
    curPoseWorldLandmarkMap,
    tgtPoseWorldLandmarkMap,
  );

  // Get the relative vectors of body parts
  const curRelativeVectors = convertPoseLandmarksToRelativeVectors(
    alignedCurPoseWorldLandmarkMap,
  );
  const tgtRelativeVectors = convertPoseLandmarksToRelativeVectors(
    tgtPoseWorldLandmarkMap,
  );

  // Compute the similarities
  const poseBodyPartSimilarities = new Map<PoseLandmarkType, number>();
  curRelativeVectors.forEach((curVector, poseLandmarkType) => {
    // Get the target vector
    const tgtVector = tgtRelativeVectors.get(poseLandmarkType)!;

    // Compute the cosine similarity

    // Inner product
    let similarity = math.dot(curVector, tgtVector);

    // Norms of the vectors
    const curVectorNorm = math.norm(curVector) as number;
    const tgtVectorNorm = math.norm(tgtVector) as number;

    // Handle the case of zero norms
    if (curVectorNorm > 0 && tgtVectorNorm > 0) {
      similarity = similarity / (curVectorNorm * tgtVectorNorm);
    } else {
      similarity = -1;
    }

    // Set the similarity
    poseBodyPartSimilarities.set(poseLandmarkType, similarity);
  });

  return poseBodyPartSimilarities;
}

/**
 * The default weights for each pose landmark type when computing the overall pose similarity.
 */
export const DEFAULT_WEIGHTS = new Map<PoseLandmarkType, number>([
  [PoseLandmarkType.LeftShoulder, 5.0],
  [PoseLandmarkType.RightShoulder, 5.0],
  [PoseLandmarkType.LeftElbow, 5.0],
  [PoseLandmarkType.RightElbow, 5.0],
  [PoseLandmarkType.LeftWrist, 3.0],
  [PoseLandmarkType.RightWrist, 3.0],
  [PoseLandmarkType.LeftKnee, 5.0],
  [PoseLandmarkType.RightKnee, 5.0],
  [PoseLandmarkType.LeftAnkle, 2.0],
  [PoseLandmarkType.RightAnkle, 2.0],
]);

/**
 * Clamp a value between a minimum and maximum value.
 *
 * @param {number} value - the value to clamp.
 * @param {number} min - the minimum value.
 * @param {number} max - the maximum value.
 * @return {number} the clamped value.
 */
function clamp(value: number, min: number, max: number): number {
  return Math.min(Math.max(value, min), max);
}

/**
 * Converts weights to probabilities based on the input weights map.
 * Each weight must be non-negative and the sum of weights must be positive.
 *
 * @param {Map<PoseLandmarkType, number>} weights - The map of weights for each pose landmark type.
 * @return {Map<PoseLandmarkType, number>} The map of probabilities for each pose landmark type.
 */
function convertWeightsToProbabilities(
  weights: Map<PoseLandmarkType, number>,
): Map<PoseLandmarkType, number> {
  // Calculate the sum of weights
  const sumOfWeights = Array.from(weights.values()).reduce(
    (partialSum, weight) => partialSum + weight,
  );

  // Check if the sum is positive
  if (sumOfWeights <= 0) {
    throw new Error("Sum of weights must be positive");
  }

  // Calculate the probabilities
  const probabilities = new Map<PoseLandmarkType, number>();
  weights.forEach((weight, poseLandmarkType) => {
    // Check if the weight is non-negative
    if (weight < 0) {
      throw new Error("Weight must be nonnegative");
    }

    // Set the probability
    probabilities.set(poseLandmarkType, weight / sumOfWeights);
  });

  return probabilities;
}

/**
 * Computes the mismatch penalty based on the similarity of pose body parts.
 *
 * @param {Map<PoseLandmarkType, number>} poseBodyPartSimilarities - A map of pose landmark types to their corresponding similarity values.
 * @param {Object} options - The options for computing the mismatch penalty.
 * @param {number} options.mismatchThreshold - The threshold for considering a body part as mismatched.
 * @param {number} options.mismatchPenaltyPerBodyPart - The penalty to be applied for each mismatched body part.
 * @return {number} The computed mismatch penalty.
 */
function computeMismatchPenalty(
  poseBodyPartSimilarities: Map<PoseLandmarkType, number>,
  {
    mismatchThreshold,
    mismatchPenaltyPerBodyPart,
  }: {
    mismatchThreshold: number;
    mismatchPenaltyPerBodyPart: number;
  },
): number {
  // Find the number of mismatched body parts
  const numMismatchedBodyParts = Array.from(
    poseBodyPartSimilarities.values(),
  ).filter((similarity) => similarity < mismatchThreshold).length;

  // Compute the mismatch penalty
  const mismatchPenalty = numMismatchedBodyParts * mismatchPenaltyPerBodyPart;

  return mismatchPenalty;
}
