// SPDX-FileComment: This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Meshes/Mesh.h>
#include <TNL/Meshes/MeshEntity.h>
#include <TNL/Meshes/MeshBuilder.h>
#include <TNL/Meshes/Geometry/EntityRefiner.h>
#include <TNL/Algorithms/parallelFor.h>
#include <TNL/Algorithms/scan.h>

namespace TNL::Meshes {

// TODO: refactor to avoid duplicate points altogether - first split edges, then faces, then cells
template< EntityRefinerVersion RefinerVersion,
          typename MeshConfig,
          std::enable_if_t< std::is_same_v< typename MeshConfig::CellTopology, Topologies::Triangle >
                               || std::is_same_v< typename MeshConfig::CellTopology, Topologies::Quadrangle >
                               || std::is_same_v< typename MeshConfig::CellTopology, Topologies::Tetrahedron >
                               || std::is_same_v< typename MeshConfig::CellTopology, Topologies::Hexahedron >,
                            bool > = true >
auto  // returns MeshBuilder
refineMesh( const Mesh< MeshConfig, Devices::Host >& inMesh )
{
   using namespace TNL;
   using namespace TNL::Containers;
   using namespace TNL::Algorithms;

   using Mesh = Mesh< MeshConfig, Devices::Host >;
   using MeshBuilder = MeshBuilder< Mesh >;
   using GlobalIndexType = typename Mesh::GlobalIndexType;
   using PointType = typename Mesh::PointType;
   using EntityRefiner = EntityRefiner< MeshConfig, typename MeshConfig::CellTopology, RefinerVersion >;
   constexpr int CellDimension = Mesh::getMeshDimension();

   MeshBuilder meshBuilder;

   const GlobalIndexType inPointsCount = inMesh.template getEntitiesCount< 0 >();
   const GlobalIndexType inCellsCount = inMesh.template getEntitiesCount< CellDimension >();

   // Find the number of output points and cells as well as
   // starting indices at which every cell will start writing new refined points and cells
   using IndexPair = std::pair< GlobalIndexType, GlobalIndexType >;
   Array< IndexPair, Devices::Host > indices( inCellsCount + 1 );
   auto setCounts = [ & ]( GlobalIndexType i )
   {
      const auto cell = inMesh.template getEntity< CellDimension >( i );
      indices[ i ] = EntityRefiner::getExtraPointsAndEntitiesCount( cell );
   };
   parallelFor< Devices::Host >( 0, inCellsCount, setCounts );
   indices[ inCellsCount ] = { 0,
                               0 };  // extend exclusive prefix sum by one element to also get result of reduce at the same time
   auto reduction = []( const IndexPair& a, const IndexPair& b ) -> IndexPair
   {
      return { a.first + b.first, a.second + b.second };
   };
   inplaceExclusiveScan( indices, 0, indices.getSize(), reduction, std::make_pair( 0, 0 ) );
   const auto& reduceResult = indices[ inCellsCount ];
   const GlobalIndexType outPointsCount = inPointsCount + reduceResult.first;
   const GlobalIndexType outCellsCount = reduceResult.second;
   meshBuilder.setEntitiesCount( outPointsCount, outCellsCount );

   // Copy the points from inMesh to outMesh
   auto copyPoint = [ & ]( GlobalIndexType i ) mutable
   {
      meshBuilder.setPoint( i, inMesh.getPoint( i ) );
   };
   parallelFor< Devices::Host >( 0, inPointsCount, copyPoint );

   // Refine each cell
   auto refineCell = [ & ]( GlobalIndexType i ) mutable
   {
      const auto cell = inMesh.template getEntity< CellDimension >( i );
      const auto& indexPair = indices[ i ];

      // Lambda for adding new points
      GlobalIndexType setPointIndex = inPointsCount + indexPair.first;
      auto addPoint = [ & ]( const PointType& point )
      {
         const auto pointIdx = setPointIndex++;
         meshBuilder.setPoint( pointIdx, point );
         return pointIdx;
      };

      // Lambda for adding new cells
      GlobalIndexType setCellIndex = indexPair.second;
      auto addCell = [ & ]( auto... vertexIndices )
      {
         auto entitySeed = meshBuilder.getCellSeed( setCellIndex++ );
         entitySeed.setCornerIds( vertexIndices... );
      };

      EntityRefiner::decompose( cell, addPoint, addCell );
   };
   parallelFor< Devices::Host >( 0, inCellsCount, refineCell );

   return meshBuilder;
}

template< EntityRefinerVersion RefinerVersion,
          typename MeshConfig,
          std::enable_if_t< std::is_same_v< typename MeshConfig::CellTopology, Topologies::Triangle >
                               || std::is_same_v< typename MeshConfig::CellTopology, Topologies::Quadrangle >
                               || std::is_same_v< typename MeshConfig::CellTopology, Topologies::Tetrahedron >
                               || std::is_same_v< typename MeshConfig::CellTopology, Topologies::Hexahedron >,
                            bool > = true >
auto  // returns Mesh
getRefinedMesh( const Mesh< MeshConfig, Devices::Host >& inMesh )
{
   using Mesh = Mesh< MeshConfig, Devices::Host >;

   Mesh outMesh;
   auto meshBuilder = refineMesh< RefinerVersion >( inMesh );
   meshBuilder.deduplicatePoints();
   meshBuilder.build( outMesh );
   return outMesh;
}

}  // namespace TNL::Meshes
