// SPDX-FileCopyrightText: Copyright (c) 2025 SpacemiT. All rights reserved.
// SPDX-License-Identifier: MIT

#include <assert.h>

#include <filesystem>
#include <iostream>

#include "spacemit_ort_env.h"
#include "utils.h"

using ImageNetLabel = std::pair<std::string, int>;

std::vector<ImageNetLabel> ReadImageNetLabels(const std::string& file_path) {
  std::ifstream ifs(file_path);
  if (!ifs) {
    throw std::runtime_error("open file failed");
  }
  std::string line;
  std::vector<ImageNetLabel> labels;
  std::filesystem::path img_list_path(file_path);
  std::string img_directory = img_list_path.parent_path().string();
  while (std::getline(ifs, line)) {
    if (!line.empty()) {
      auto label_line = StringSplit(line, ',');
      if (label_line.size() == 2) {
        labels.push_back(std::make_pair(img_directory + "/" + label_line[0], std::stoi(label_line[1])));
      }
    }
  }
  return labels;
}

int main(int argc, char** argv) {
  if (argc < 5) {
    fprintf(stderr,
            "Usage: %s [net_name=str] [net_param_path=str] [img_file_list_path=str] [num_threads=int] [mean_value=str] "
            "[std_value=str]\n",
            argv[0]);
    return -1;
  }
  Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ort_test");
  const char* net_name = argv[1];
  const char* net_param_path = argv[2];
  const char* img_file_list_path = argv[3];
  char* mean_str = nullptr;
  char* std_str = nullptr;
  const int num_threads = atoi(argv[4]);

  if (argc >= 6) {
    mean_str = argv[5];
  }

  if (argc >= 7) {
    std_str = argv[6];
  }

  std::cout << "imagenet_test [" << net_name << "], num_threads=" << num_threads << std::endl;

  auto labels = ReadImageNetLabels(img_file_list_path);

  std::cout << "Load ImageNet test lables " << labels.size() << std::endl;

  Ort::AllocatorWithDefaultOptions allocator;
  Ort::SessionOptions session_options;
  OrtStatus* status = Ort::SessionOptionsSpaceMITEnvInit(session_options);
  session_options.SetIntraOpNumThreads(num_threads);
  session_options.SetInterOpNumThreads(num_threads);

  NetSession session(env, net_param_path, session_options);

  std::vector<float> mean_value = {123.675f, 116.28f, 103.53f};
  std::vector<float> scale_value = {58.395f, 57.12f, 57.375f};

  if (mean_str != nullptr && std_str != nullptr) {
    std::cout << "parse mean and std " << mean_str << ", " << std_str << std::endl;
    ParseMeanStd(mean_value, scale_value, mean_str, std_str);
  }

  auto input_count = session.GetInputCount();
  auto output_count = session.GetOutputCount();

  if (input_count != 1 || (output_count != 1 && output_count != 2)) {
    throw std::runtime_error("input or output count mismatch");
  }

  std::vector<Ort::Value> input_values;
  input_values.reserve(input_count);
  for (size_t i = 0; i < input_count; i++) {
    auto input_shape = session.GetInputShape(i);
    // set dyn shape
    for (size_t si = 0; si < input_shape.size(); si++) {
      if (input_shape[si] <= 0) {
        input_shape[si] = 1;
      }
    }
    session.SetInputShape(i, input_shape);
    input_values.push_back(session.CreatorInputValue(i));
  }
  auto input_shape = session.GetInputShape(0);

  float top1_acc = 0.0;
  float top5_acc = 0.0;
  size_t test_idx = 0;

  for (auto label : labels) {
    auto file_path = label.first;
    auto label_idx = label.second;

    float* input_img_data = input_values[0].GetTensorMutableData<float>();
    ReadImageRawFile<float>(file_path, input_img_data, input_shape[0], input_shape[1], input_shape[2] * input_shape[3],
                            mean_value, scale_value);

    auto output_tensors = session.Run(input_values.data());
    Ort::Value output_value = std::move(output_tensors[output_count - 1]);
    auto output_shape = output_value.GetTensorTypeAndShapeInfo().GetShape();
    size_t numel = 1;
    for (size_t i = 0; i < output_shape.size(); i++) {
      numel *= output_shape[i];
    }
    const float* output_data_ptr = output_value.GetTensorData<float>();
    std::vector<std::pair<size_t, float>> output_data_vec;
    for (size_t i = 0; i < numel; i++) {
      output_data_vec.push_back(std::make_pair(i, output_data_ptr[i]));
    }
    std::sort(output_data_vec.begin(), output_data_vec.end(),
              [](std::pair<size_t, float> lhs, std::pair<size_t, float> rhs) { return lhs.second > rhs.second; });
    if (output_data_vec.size() == 1001) {
      if (output_data_vec[0].first == label_idx + 1) {
        top1_acc += 100;
      }

      for (size_t i = 0; i < 5; i++) {
        if (output_data_vec[i].first == label_idx + 1) {
          top5_acc += 100;
          break;
        }
      }
    } else {
      if (output_data_vec[0].first == label_idx) {
        top1_acc += 100;
      }

      for (size_t i = 0; i < 5; i++) {
        if (output_data_vec[i].first == label_idx) {
          top5_acc += 100;
          break;
        }
      }
    }

    test_idx += 1;
    if (test_idx % 100 == 0) {
      std::cout << "test_idx: " << test_idx << ", top1_accuracy: " << top1_acc / (float)test_idx
                << ", top5_accuracy: " << top5_acc / (float)test_idx << std::endl;
    }
  }

  std::cout << net_name << ", top1_accuracy: " << top1_acc / (float)test_idx
            << ", top5_accuracy: " << top5_acc / (float)test_idx << std::endl;

  return 0;
}
