// SPDX-FileComment: This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
// SPDX-License-Identifier: MIT

#pragma once

#include <tuple>
#include <map>
#include <fstream>
#include <filesystem>

#include <TNL/Timer.h>
#include <TNL/PerformanceCounters.h>
#include <TNL/Devices/Cuda.h>
#include <TNL/Containers/Vector.h>
#include <TNL/Solvers/IterativeSolverMonitor.h>

#include <TNL/Devices/Host.h>
#include <TNL/SystemInfo.h>
#include <TNL/Backend.h>
#include <TNL/Config/ConfigDescription.h>
#include <TNL/MPI/Comm.h>
#include <TNL/MPI/Wrappers.h>

namespace TNL::Benchmarks {

// returns a tuple of (loops, mean, stddev) where loops is the number of
// performed loops (i.e. timing samples), mean is the arithmetic mean of the
// computation times and stddev is the sample standard deviation
template< typename Device,
          typename ComputeFunction,
          typename ResetFunction,
          typename Monitor = TNL::Solvers::IterativeSolverMonitor< double > >
std::tuple< std::size_t, double, double, double, double >
timeFunction( ComputeFunction compute,
              ResetFunction reset,
              std::size_t maxLoops,
              const double& minTime,
              Monitor&& monitor = Monitor() )
{
   // the timer is constructed zero-initialized and stopped
   Timer timer;

   // set timer to the monitor
   monitor.setTimer( timer );

   PerformanceCounters performanceCounters;

   // warm up
   reset();
   compute();

   Containers::Vector< double > results_time( maxLoops, 0.0 );
   Containers::Vector< long long int > results_cpu_cycles( maxLoops, 0 );

   std::size_t loops;
   for( loops = 0; loops < maxLoops || sum( results_time ) < minTime; loops++ ) {
      // abuse the monitor's "time" for loops
      monitor.setTime( loops + 1 );
      reset();

      // Explicit synchronization of the CUDA device
      if constexpr( std::is_same_v< Device, Devices::Cuda > )
         Backend::deviceSynchronize();

      // reset timer and performance counters before each computation
      timer.reset();
      performanceCounters.reset();
      timer.start();
      performanceCounters.start();
      compute();
      if constexpr( std::is_same_v< Device, Devices::Cuda > )
         Backend::deviceSynchronize();
      timer.stop();
      performanceCounters.stop();

      results_time[ loops ] = timer.getRealTime();
      if constexpr( std::is_same_v< Device, Devices::Sequential > || std::is_same_v< Device, Devices::Host > )
         results_cpu_cycles[ loops ] = performanceCounters.getCPUCycles();
   }

   const double mean_time = sum( results_time ) / (double) loops;
   const double mean_cpu_cycles = sum( results_cpu_cycles ) / (double) loops;
   double stddev_time;
   double stddev_cpu_cycles;
   if( loops > 1 ) {
      stddev_time = 1.0 / std::sqrt( loops - 1 ) * l2Norm( results_time - mean_time );
      stddev_cpu_cycles = 1.0 / std::sqrt( loops - 1 ) * l2Norm( results_cpu_cycles - mean_cpu_cycles );
   }
   else {
      stddev_time = std::numeric_limits< double >::quiet_NaN();
      stddev_cpu_cycles = std::numeric_limits< double >::quiet_NaN();
   }
   return std::make_tuple( loops, mean_time, stddev_time, mean_cpu_cycles, stddev_cpu_cycles );
}

inline std::map< std::string, std::string >
getHardwareMetadata()
{
   const CPUCacheSizes cacheSizes = getCPUCacheSizes();
   const std::string cacheInfo = std::to_string( cacheSizes.L1data ) + ", " + std::to_string( cacheSizes.L1instruction ) + ", "
                               + std::to_string( cacheSizes.L2 ) + ", " + std::to_string( cacheSizes.L3 );
#if defined( __CUDACC__ ) || defined( __HIP__ )
   const int activeGPU = Backend::getDevice();
   const std::string deviceArch = std::to_string( Backend::getArchitectureMajor( activeGPU ) ) + "."
                                + std::to_string( Backend::getArchitectureMinor( activeGPU ) );
#endif

#ifdef HAVE_MPI
   int nproc = 1;
   // check if MPI was initialized (some benchmarks do not initialize MPI even when
   // they are built with HAVE_MPI and thus MPI::GetSize() cannot be used blindly)
   if( TNL::MPI::Initialized() )
      nproc = TNL::MPI::GetSize();
#endif

   std::map< std::string, std::string > metadata{
      { "host name", getHostname() },
      { "architecture", getSystemArchitecture() },
      { "system", getSystemName() },
      { "system release", getSystemRelease() },
      { "compiler", getCompilerName() },
      { "start time", getCurrentTime() },
#ifdef HAVE_MPI
      { "number of MPI processes", std::to_string( nproc ) },
#endif
      { "OpenMP enabled", Devices::Host::isOMPEnabled() ? "yes" : "no" },
      { "OpenMP threads", std::to_string( Devices::Host::getMaxThreadsCount() ) },
      { "CPU model name", getCPUInfo().modelName },
      { "CPU cores", std::to_string( getCPUInfo().cores ) },
      { "CPU threads per core", std::to_string( getCPUInfo().threads / getCPUInfo().cores ) },
      { "CPU max frequency (MHz)", std::to_string( getCPUMaxFrequency() / 1e3 ) },
      { "CPU cache sizes (L1d, L1i, L2, L3) (kiB)", cacheInfo },
#if defined( __CUDACC__ ) || defined( __HIP__ )
      { "GPU name", Backend::getDeviceName( activeGPU ) },
      { "GPU architecture", deviceArch },
      { "GPU CUDA cores", std::to_string( Backend::getDeviceCores( activeGPU ) ) },
      { "GPU global memory (GB)", std::to_string( (double) Backend::getGlobalMemorySize( activeGPU ) / 1e9 ) },
      { "GPU memory ECC enabled", TNL::convertToString( Backend::getECCEnabled( activeGPU ) ) },
#endif
   };

   return metadata;
}

inline void
writeMapAsJson( const std::map< std::string, std::string >& data, std::ostream& out )
{
   out << "{\n";
   for( auto it = data.begin(); it != data.end(); ) {
      out << "\t\"" << it->first << "\": \"" << it->second << "\"";
      // increment the iterator now to peek at the next element
      it++;
      // write a comma if there are still elements remaining
      if( it != data.end() )
         out << ",";
      out << "\n";
   }
   out << "}\n" << std::flush;
}

inline void
writeMapAsJson( const std::map< std::string, std::string >& data, std::string filename, const std::string& newExtension = "" )
{
   namespace fs = std::filesystem;

   if( ! newExtension.empty() ) {
      const fs::path oldPath = filename;
      const fs::path newPath = oldPath.parent_path() / ( oldPath.stem().string() + newExtension );
      filename = newPath.string();
   }

   std::ofstream file( filename );
   // enable exceptions
   file.exceptions( std::ostream::failbit | std::ostream::badbit | std::ostream::eofbit );
   writeMapAsJson( data, file );
}

}  // namespace TNL::Benchmarks
