import { PoseLandmarkMap, PoseLandmarkType } from "../models";
import { mirrorPoseWorldLandmarkMap } from "../transforms";
import {
  computeOverallPoseSimilarity,
  computePoseBodyPartSimilarities,
  DEFAULT_WEIGHTS,
} from "./algorithm";

export interface PoseSimilarityEvaluatorProps {
  poseLandmarkWeights?: Map<PoseLandmarkType, number>;
  mismatchThreshold?: number;
  mismatchPenaltyPerBodyPart?: number;
  allowMirroring?: boolean;
}

export interface PoseSimilarityEvaluationResult {
  /**
   * The pose world landmark map.
   *
   * If the mirror is allowed and the similarity associated with the mirrored pose is higher, the mirrored pose is returned.
   *
   * The body part similarities and the overall pose similarity are also returned in the sense of this pose world landmark map.
   */
  poseWorldLandmarkMap: PoseLandmarkMap;

  /**
   * The pose body part similarities.
   */
  poseBodyPartSimilarities: Map<PoseLandmarkType, number>;

  /**
   * The overall pose similarity.
   */
  poseSimilarity: number;
}

export class PoseSimilarityEvaluator {
  private poseLandmarkWeights: Map<PoseLandmarkType, number>;
  private mismatchThreshold: number;
  private mismatchPenaltyPerBodyPart: number;
  private allowMirroring: boolean;

  /**
   * Initializes a new instance of the class.
   *
   * @param {PoseSimilarityEvaluatorProps} param - An object containing optional parameters for the constructor.
   * @param {Array<number>} param.poseLandmarkWeights - An array of weights for pose landmarks.
   * @param {number} param.mismatchThreshold - The threshold for mismatches.
   * @param {number} param.mismatchPenaltyPerBodyPart - The penalty for mismatches per body part.
   * @param {boolean} param.allowMirroring - Whether or not to allow mirroring.
   */
  constructor({
    poseLandmarkWeights = DEFAULT_WEIGHTS,
    mismatchThreshold = 0.7,
    mismatchPenaltyPerBodyPart = 0.1,
    allowMirroring = true,
  }: PoseSimilarityEvaluatorProps = {}) {
    this.poseLandmarkWeights = poseLandmarkWeights;
    this.mismatchThreshold = mismatchThreshold;
    this.mismatchPenaltyPerBodyPart = mismatchPenaltyPerBodyPart;
    this.allowMirroring = allowMirroring;
  }

  /**
   * Compute the similarity evaluation result between two sets of pose landmarks.
   *
   * @param {PoseLandmarkMap} curPoseWorldLandmarkMap - The current pose landmark map.
   * @param {PoseLandmarkMap} tgtPoseWorldLandmarkMap - The target pose landmark map.
   * @return {PoseSimilarityEvaluationResult} The result of the pose similarity evaluation.
   */
  evaluate(
    curPoseWorldLandmarkMap: PoseLandmarkMap,
    tgtPoseWorldLandmarkMap: PoseLandmarkMap,
  ): PoseSimilarityEvaluationResult {
    // Compute unmirrored pose similarities
    const unmirroredPoseBodyPartSimilarities = computePoseBodyPartSimilarities(
      curPoseWorldLandmarkMap,
      tgtPoseWorldLandmarkMap,
    );

    // Compute overall pose similarity
    const unmirroredPoseSimilarity = computeOverallPoseSimilarity(
      unmirroredPoseBodyPartSimilarities,
      {
        weights: this.poseLandmarkWeights,
        mismatchThreshold: this.mismatchThreshold,
        mismatchPenaltyPerBodyPart: this.mismatchPenaltyPerBodyPart,
      },
    );

    // If mirroring is not allowed, return the result of the unmirrored pose
    if (!this.allowMirroring) {
      return {
        poseWorldLandmarkMap: curPoseWorldLandmarkMap,
        poseBodyPartSimilarities: unmirroredPoseBodyPartSimilarities,
        poseSimilarity: unmirroredPoseSimilarity,
      };
    }

    // Mirror the pose
    const mirroredPoseWorldLandmarkMap = mirrorPoseWorldLandmarkMap(
      curPoseWorldLandmarkMap,
    );

    // Compute the pose similarity of each body part of the mirrored pose landmark map
    const mirroredPoseBodyPartSimilarities = computePoseBodyPartSimilarities(
      mirroredPoseWorldLandmarkMap,
      tgtPoseWorldLandmarkMap,
    );

    // Compute the overall pose similarity of the mirrored pose landmark map
    const mirroredPoseSimilarity = computeOverallPoseSimilarity(
      mirroredPoseBodyPartSimilarities,
      {
        weights: this.poseLandmarkWeights,
        mismatchThreshold: this.mismatchThreshold,
        mismatchPenaltyPerBodyPart: this.mismatchPenaltyPerBodyPart,
      },
    );

    // Return the result of the pose with the higher similarity
    if (mirroredPoseSimilarity > unmirroredPoseSimilarity) {
      return {
        poseWorldLandmarkMap: mirroredPoseWorldLandmarkMap,
        poseBodyPartSimilarities: mirroredPoseBodyPartSimilarities,
        poseSimilarity: mirroredPoseSimilarity,
      };
    } else {
      return {
        poseWorldLandmarkMap: curPoseWorldLandmarkMap,
        poseBodyPartSimilarities: unmirroredPoseBodyPartSimilarities,
        poseSimilarity: unmirroredPoseSimilarity,
      };
    }
  }
}
