<!DOCTYPE html>
<html>
<head>
  <link rel="stylesheet" type="text/css" href="/static/components.css"></link>
  <script src="/static/annotation.js"></script>
  <script src="/static/components.js"></script>
  <script src="/static/util.js"></script>
  <style>
.box-border {
  border: 1px solid #262e3d;
}
  </style>
</head>

<body class="box-border" style="padding: 0px;">

<script>

class TritonApplet extends tatorUi.annotation.MenuAppletElement {
  constructor() {
    super();

    // Create the main div content area
    this._contentDiv = document.createElement("div");
    this._contentDiv.setAttribute("class", "d-flex flex-grow flex-column text-gray f2 px-4 py-2");
    this._contentDiv.style.height = "212px";
    this._shadow.appendChild(this._contentDiv);

    var label = document.createElement("div");
    label.setAttribute("class", "h3 text-white text-semibold px-3 py-2 my-3 text-center");
    label.textContent = "Run object detector and create corresponding localizations in current frame and version?";
    this._contentDiv.appendChild(label);

    // The following fields will be displayed and are read-only
    this._mediaText = document.createElement("text-input");
    this._mediaText.setAttribute("name", "Media");
    this._mediaText.setAttribute("type", "string");
    this._mediaText.permission = false;
    this._contentDiv.appendChild(this._mediaText);

    this._frameText = document.createElement("text-input");
    this._frameText.setAttribute("name", "Frame Number");
    this._frameText.setAttribute("type", "int");
    this._frameText.permission = false;
    this._contentDiv.appendChild(this._frameText);

    this._versionText = document.createElement("text-input");
    this._versionText.setAttribute("name", "Version");
    this._versionText.setAttribute("type", "string");
    this._versionText.permission = false;
    this._contentDiv.appendChild(this._versionText);
  }

  /**
   * Override parent method
   * @returns string
   */
  getModalTitle() {
    return "Faster R-CNN Inception v2";
  }

  /**
   * Override parent method
   * @returns string
   */
  getModalHeight() {
    return "230px";
  }

  /**
   * Override parent method
   * @returns string
   */
  getAcceptButtonText() {
    return "Detect";
  }

  /**
   * Override parent method
   */
  accept() {

    this.dispatchEvent(new Event("displayLoadingScreen"));

    this.getDetectionType().then(() => {
      this.getImage().then((imageBlob) => {
        this.sendImage(imageBlob).then((results) => {
          this.processResults(results).then((objectCount) => {

            console.log("Results from Triton Inference Server:")
            console.log(results);
            var msg = `${objectCount} objects detected`;

            this.dispatchEvent(new CustomEvent(
              "refreshDataType", {
                detail: {
                  dataType: this._detectionType
                }
              }
            ));
            this.dispatchEvent(new Event("hideLoadingScreen"));
            this.dispatchEvent(new CustomEvent(
              "displaySuccessMessage", {
                detail: {
                  message: msg
                }
            }));
          })
          .catch((error) => {
            this.dispatchEvent(new Event("hideLoadingScreen"));
            this.dispatchEvent(new CustomEvent(
              "displayErrorMessage", {
                detail: {
                  message: `ERROR: ${error}`
                }
            }));
          });
        })
        .catch((error) => {
          this.dispatchEvent(new Event("hideLoadingScreen"));
          this.dispatchEvent(new CustomEvent(
            "displayErrorMessage", {
              detail: {
                message: `ERROR: ${error}`
              }
          }));
        });
      })
      .catch((error) => {
        this.dispatchEvent(new Event("hideLoadingScreen"));
        this.dispatchEvent(new CustomEvent(
          "displayErrorMessage", {
            detail: {
              message: `ERROR: ${error}`
            }
        }));
      });
    })
    .catch((error) => {
      this.dispatchEvent(new Event("hideLoadingScreen"));
      this.dispatchEvent(new CustomEvent(
        "displayErrorMessage", {
          detail: {
            message: `ERROR: ${error}`
          }
      }));
    });
  }

  /**
   * Override parent method
   */
   updateUI() {

    // Set the input fields
    var mediaText = `${this._data.media.name} (ID: ${this._data.media.id})`;
    this._mediaText.setValue(mediaText);
    this._frameText.setValue(this._data.frame);
    var versionText = `${this._data.version.name} (ID: ${this._data.version.id})`;
    this._versionText.setValue(versionText);
  }

  /**
   * @returns Object of the project's detection / box info
   */
  getDetectionInfo() {
    return {
      name: "Boxes",
      label: "Label",
      score: "Score"
    };
  }

  /**
   * @returns Promise that is resolved with the localization type
   */
  getDetectionType() {
    var donePromise = new Promise((resolve, reject) => {

      const localizationRestUrl = `/rest/LocalizationTypes/${this._data.projectId}`;
      fetch(localizationRestUrl, {
        method: "GET",
        mode: "cors",
        credentials: "include",
        headers: {
          "X-CSRFToken": tatorUi.util.getCookie("csrftoken"),
          "Accept": "application/json",
          "Content-Type": "application/json"
        },
      })
      .then(response => response.json())
      .then(localizationTypes => {
        this._detectionType = null;
        var detInfo = this.getDetectionInfo();
        for (const locType of localizationTypes) {
          if (locType.name == detInfo.name) {
            this._detectionType = locType;
            this._detectionLabelAttribute = detInfo.label;
            this._detectionScoreAttribute = detInfo.score;
            break;
          }
        }
        resolve();
      })
      .catch(error => {
        reject(error);
      });
    });
    return donePromise;
  }

  /**
   * @returns Promise when resolved is a ImageData of current media/frame
   *          this._imageWidth and this._imageHeight are saved
   */
  getImage() {
    var blobPromise = new Promise((resolve, reject) => {
      fetch(`/rest/GetFrame/${this._data.media.id}?frames=${this._data.frame}&force_scale=1024x600`, {
          method: "GET",
          mode: "cors",
          credentials: "include",
          headers: {
            "X-CSRFToken": tatorUi.util.getCookie("csrftoken"),
            "Accept": "image/*",
            "Content-Type": "image/*"
          }
      })
      .then(response => response.blob())
      .then(imageBlob => {

        // Need to save the image data
        const canvas = document.createElement("canvas");
        const context = canvas.getContext("2d");
        const image = new Image();
        image.onload = () => {
            this._imageWidth = image.width;
            this._imageHeight = image.height;
        }
        image.src = URL.createObjectURL(imageBlob);
        resolve(imageBlob);
      })
      .catch(error => {
        reject(error);
      });
    });
    return blobPromise;
  }

  /**
   * @param {blob} image - Image to send to the Triton Inference Server
   * @returns Promise that is resolved with the result
   */
  sendImage(imageBlob) {
    var resultPromise = new Promise((resolve, reject) => {
      imageBlob.arrayBuffer().then(buffer => {
        const typedArray = new Uint8Array(buffer);
        const inputs = {
          inputs: [{
            name: "INPUT",
            shape: [1, typedArray.length],
            datatype: "UINT8",
            data: [Array.from(typedArray)],
          }]
        };
        const requestOptions = {
          method: "POST",
          headers: {"Accept": "application/json"},
          body: JSON.stringify(inputs),
        };
        const url = "https://YOUR_API_DOMAIN/v2/models/faster_rcnn_inception_v2_with_preproc/infer";
        fetch(url, requestOptions)
        .then(response => response.json())
        .then(data => {
          resolve(data);
        })
        .catch(error => {
          reject(error);
        });
      });
    });
    return resultPromise;
  }

  /**
   * @param {Object} results - Results from adamant
   * @returns Promise that is resolved with the number of objects detected
   */
  processResults(results){
    var resultPromise = new Promise((resolve, reject) => {
      if (results.outputs) {
        var newLocSpecs = [];
        let outputs = {}
        for (const output of results.outputs) {
          outputs[output.name] = output;
        }
        for (let i = 0; i < outputs.num_detections.data[0]; i++) {
          const y0 = outputs.detection_boxes.data[i*4 + 0]
          const x0 = outputs.detection_boxes.data[i*4 + 1]
          const y1 = outputs.detection_boxes.data[i*4 + 2]
          const x1 = outputs.detection_boxes.data[i*4 + 3]
          var newLocSpec = {
            media_id: this._data.media.id,
            type: this._detectionType.id,
            frame: this._data.frame,
            version: this._data.version.id,
            x: x0,
            y: y0,
            width: Math.abs(x1 - x0),
            height: Math.abs(y1 - y0),
          };
          newLocSpec[this._detectionLabelAttribute] = outputs.detection_classes.data[i];
          newLocSpec[this._detectionScoreAttribute] = outputs.detection_scores.data[i];
          newLocSpecs.push(newLocSpec);
        }

        if (newLocSpecs.length == 0) {
          resolve(0);
        }
        else if (newLocSpecs.length > 500) {
          reject("More than 500 detections would be created.");
        }
        else {
          fetch(`/rest/Localizations/${this._data.projectId}`, {
            method: "POST",
            credentials: "same-origin",
            headers: {
              "X-CSRFToken": tatorUi.util.getCookie("csrftoken"),
              "Accept": "application/json",
              "Content-Type": "application/json",
            },
            body: JSON.stringify(newLocSpecs)
          })
          .then(response => response.json())
          .then((result) => {
            resolve(newLocSpecs.length);
          })
          .catch(error => {
            reject(error);
          });
        }
      }
      else {
        reject("Server results.success == false");
      }
    });
    return resultPromise;
  }
}

customElements.define("triton-applet", TritonApplet);

</script>

<div>
  <triton-applet id="mainApplet"></triton-applet>
</div>

</body>
</html>