// SPDX-FileComment: This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
// SPDX-License-Identifier: MIT

#pragma once

#include <algorithm>

#include <TNL/Assert.h>
#include <TNL/Backend/Macros.h>
#include <TNL/Algorithms/staticFor.h>
#include <TNL/Containers/ndarray/Meta.h>

namespace TNL::Containers::detail {

template< typename SizesHolder, typename... IndexTypes >
void
setSizesHelper( SizesHolder& holder, IndexTypes&&... otherSizes )
{
   static_assert( SizesHolder::getDimension() == sizeof...( otherSizes ), "invalid number of sizes passed to setSizesHelper" );

   Algorithms::staticFor< std::size_t, 0, sizeof...( otherSizes ) >(
      [ &holder, &otherSizes... ]( auto i )
      {
         holder.template setSize< i >( detail::get_from_pack< i >( std::forward< IndexTypes >( otherSizes )... ) );
      } );
}

// A variadic bounds-checker for indices
template< typename SizesHolder, typename Overlaps >
__cuda_callable__
void
assertIndicesInBounds( const SizesHolder&, const Overlaps& overlaps )
{}

template< typename SizesHolder, typename Overlaps, typename Index, typename... IndexTypes >
__cuda_callable__
void
assertIndicesInBounds( const SizesHolder& sizes, const Overlaps& overlaps, Index&& i, IndexTypes&&... indices )
{
#ifndef NDEBUG
   // sizes.template getSize<...>() cannot be inside the assert macro, but the variables
   // shouldn't be declared when compiling without assertions
   constexpr std::size_t level = SizesHolder::getDimension() - sizeof...( indices ) - 1;
   const auto size = sizes.template getSize< level >();
   const auto overlap = overlaps.template getSize< level >();
   TNL_ASSERT_LE( -overlap, (decltype( size )) i, "Input error - some index is below the lower bound." );
   TNL_ASSERT_LT( (decltype( size )) i, size + overlap, "Input error - some index is above the upper bound." );
#endif
   assertIndicesInBounds( sizes, overlaps, std::forward< IndexTypes >( indices )... );
}

// A variadic bounds-checker for distributed indices with overlaps
template< typename SizesHolder1, typename SizesHolder2, typename Overlaps >
__cuda_callable__
void
assertIndicesInRange( const SizesHolder1&, const SizesHolder2&, const Overlaps& )
{}

template< typename SizesHolder1, typename SizesHolder2, typename Overlaps, typename Index, typename... IndexTypes >
__cuda_callable__
void
assertIndicesInRange( const SizesHolder1& begins,
                      const SizesHolder2& ends,
                      const Overlaps& overlaps,
                      Index&& i,
                      IndexTypes&&... indices )
{
   static_assert( SizesHolder1::getDimension() == SizesHolder2::getDimension(), "Inconsistent begins and ends." );
#ifndef NDEBUG
   // sizes.template getSize<...>() cannot be inside the assert macro, but the variables
   // shouldn't be declared when compiling without assertions
   constexpr std::size_t level = SizesHolder1::getDimension() - sizeof...( indices ) - 1;
   const auto begin = begins.template getSize< level >();
   const auto end = ends.template getSize< level >();
   const auto overlap = overlaps.template getSize< level >();
   TNL_ASSERT_LE( begin - overlap, i, "Input error - some index is below the lower bound." );
   TNL_ASSERT_LT( i, end + overlap, "Input error - some index is above the upper bound." );
#endif
   assertIndicesInRange( begins, ends, overlaps, std::forward< IndexTypes >( indices )... );
}

// helper for the assignment operator in NDArray
template< typename TargetHolder, typename SourceHolder, std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesCopyHelper
{
   static void
   copy( TargetHolder& target, const SourceHolder& source )
   {
      if( target.template getStaticSize< level >() == 0 ) {
         target.template setSize< level >( source.template getSize< level >() );
         SetSizesCopyHelper< TargetHolder, SourceHolder, level - 1 >::copy( target, source );
      }
      else if( source.template getSize< level >() < 0
               || target.template getStaticSize< level >() != (std::size_t) source.template getSize< level >() )
         throw std::logic_error( "Cannot copy sizes due to inconsistent underlying types (static sizes don't match)." );
   }
};

template< typename TargetHolder, typename SourceHolder >
struct SetSizesCopyHelper< TargetHolder, SourceHolder, 0 >
{
   static void
   copy( TargetHolder& target, const SourceHolder& source )
   {
      if( target.template getStaticSize< 0 >() == 0 )
         target.template setSize< 0 >( source.template getSize< 0 >() );
      else if( source.template getSize< 0 >() < 0
               || target.template getStaticSize< 0 >() != (std::size_t) source.template getSize< 0 >() )
         throw std::logic_error( "Cannot copy sizes due to inconsistent underlying types (static sizes don't match)." );
   }
};

// helper for the assignment operator in NDArrayView
template< typename SizesHolder1, typename SizesHolder2 >
[[nodiscard]] __cuda_callable__
bool
sizesWeakCompare( const SizesHolder1& sizes1, const SizesHolder2& sizes2 )
{
   static_assert( SizesHolder1::getDimension() == SizesHolder2::getDimension(),
                  "Cannot compare sizes of different dimensions." );
   bool result = true;
   Algorithms::staticFor< std::size_t, 0, SizesHolder1::getDimension() >(
      [ &result, &sizes1, &sizes2 ]( auto level )
      {
         result = result && sizes1.template getSize< level >() == sizes2.template getSize< level >();
      } );
   return result;
}

// helper for the forInterior and forBoundary methods (NDArray and DistributedNDArray)
template< std::size_t ConstValue,
          typename TargetHolder,
          typename SourceHolder,
          typename Overlaps,
          std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesSubtractHelper
{
   static void
   subtract( TargetHolder& target, const SourceHolder& source, const Overlaps& overlaps )
   {
      if constexpr( SourceHolder::template getStaticSize< level >() == 0 ) {
         const auto overlap = overlaps.template getSize< level >();
         if( overlap == 0 )
            target.template setSize< level >( source.template getSize< level >() - ConstValue );
         else
            target.template setSize< level >( source.template getSize< level >() );
      }
      SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, level - 1 >::subtract(
         target, source, overlaps );
   }
};

template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps >
struct SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, 0 >
{
   static void
   subtract( TargetHolder& target, const SourceHolder& source, const Overlaps& overlaps )
   {
      if constexpr( SourceHolder::template getStaticSize< 0 >() == 0 ) {
         const auto overlap = overlaps.template getSize< 0 >();
         if( overlap == 0 )
            target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue );
         else
            target.template setSize< 0 >( source.template getSize< 0 >() );
      }
   }
};

// helper for the forInterior and forBoundary methods (DistributedNDArray)
template< std::size_t ConstValue,
          typename TargetHolder,
          typename SourceHolder,
          typename Overlaps,
          std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesAddHelper
{
   static void
   add( TargetHolder& target, const SourceHolder& source, const Overlaps& overlaps )
   {
      if constexpr( SourceHolder::template getStaticSize< level >() == 0 ) {
         const auto overlap = overlaps.template getSize< level >();
         if( overlap == 0 )
            target.template setSize< level >( source.template getSize< level >() + ConstValue );
         else
            target.template setSize< level >( source.template getSize< level >() );
      }
      SetSizesAddHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, level - 1 >::add( target, source, overlaps );
   }
};

template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps >
struct SetSizesAddHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, 0 >
{
   static void
   add( TargetHolder& target, const SourceHolder& source, const Overlaps& overlaps )
   {
      if constexpr( SourceHolder::template getStaticSize< 0 >() == 0 ) {
         const auto overlap = overlaps.template getSize< 0 >();
         if( overlap == 0 )
            target.template setSize< 0 >( source.template getSize< 0 >() + ConstValue );
         else
            target.template setSize< 0 >( source.template getSize< 0 >() );
      }
   }
};

// helper for the forLocalInterior, forLocalBoundary and forGhosts methods (DistributedNDArray)
template< typename TargetHolder, typename SourceHolder, typename Overlaps, std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesSubtractOverlapsHelper
{
   static void
   subtract( TargetHolder& target, const SourceHolder& source, const Overlaps& overlaps )
   {
      if constexpr( SourceHolder::template getStaticSize< level >() == 0 ) {
         const auto overlap = overlaps.template getSize< level >();
         target.template setSize< level >( source.template getSize< level >() - overlap );
      }
      SetSizesSubtractOverlapsHelper< TargetHolder, SourceHolder, Overlaps, level - 1 >::subtract( target, source, overlaps );
   }
};

template< typename TargetHolder, typename SourceHolder, typename Overlaps >
struct SetSizesSubtractOverlapsHelper< TargetHolder, SourceHolder, Overlaps, 0 >
{
   static void
   subtract( TargetHolder& target, const SourceHolder& source, const Overlaps& overlaps )
   {
      if constexpr( SourceHolder::template getStaticSize< 0 >() == 0 ) {
         const auto overlap = overlaps.template getSize< 0 >();
         target.template setSize< 0 >( source.template getSize< 0 >() - overlap );
      }
   }
};

// helper for the forLocalInterior, forLocalBoundary and forGhosts methods (DistributedNDArray)
template< typename TargetHolder, typename SourceHolder, typename Overlaps, std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesAddOverlapsHelper
{
   static void
   add( TargetHolder& target, const SourceHolder& source, const Overlaps& overlaps )
   {
      if constexpr( SourceHolder::template getStaticSize< level >() == 0 ) {
         const auto overlap = overlaps.template getSize< level >();
         target.template setSize< level >( source.template getSize< level >() + overlap );
      }
      SetSizesAddOverlapsHelper< TargetHolder, SourceHolder, Overlaps, level - 1 >::add( target, source, overlaps );
   }
};

template< typename TargetHolder, typename SourceHolder, typename Overlaps >
struct SetSizesAddOverlapsHelper< TargetHolder, SourceHolder, Overlaps, 0 >
{
   static void
   add( TargetHolder& target, const SourceHolder& source, const Overlaps& overlaps )
   {
      if constexpr( SourceHolder::template getStaticSize< 0 >() == 0 ) {
         const auto overlap = overlaps.template getSize< 0 >();
         target.template setSize< 0 >( source.template getSize< 0 >() + overlap );
      }
   }
};

// helper for the forInterior method (DistributedNDArray)
template< typename TargetHolder, typename SourceHolder, std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesMaxHelper
{
   static void
   max( TargetHolder& target, const SourceHolder& source )
   {
      if constexpr( SourceHolder::template getStaticSize< level >() == 0 )
         target.template setSize< level >( std::max( target.template getSize< level >(), source.template getSize< level >() ) );
      SetSizesMaxHelper< TargetHolder, SourceHolder, level - 1 >::max( target, source );
   }
};

template< typename TargetHolder, typename SourceHolder >
struct SetSizesMaxHelper< TargetHolder, SourceHolder, 0 >
{
   static void
   max( TargetHolder& target, const SourceHolder& source )
   {
      if constexpr( SourceHolder::template getStaticSize< 0 >() == 0 )
         target.template setSize< 0 >( std::max( target.template getSize< 0 >(), source.template getSize< 0 >() ) );
   }
};

// helper for the forInterior method (DistributedNDArray)
template< typename TargetHolder, typename SourceHolder, std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesMinHelper
{
   static void
   min( TargetHolder& target, const SourceHolder& source )
   {
      if constexpr( SourceHolder::template getStaticSize< level >() == 0 )
         target.template setSize< level >( std::min( target.template getSize< level >(), source.template getSize< level >() ) );
      SetSizesMinHelper< TargetHolder, SourceHolder, level - 1 >::min( target, source );
   }
};

template< typename TargetHolder, typename SourceHolder >
struct SetSizesMinHelper< TargetHolder, SourceHolder, 0 >
{
   static void
   min( TargetHolder& target, const SourceHolder& source )
   {
      if constexpr( SourceHolder::template getStaticSize< 0 >() == 0 )
         target.template setSize< 0 >( std::min( target.template getSize< 0 >(), source.template getSize< 0 >() ) );
   }
};

}  // namespace TNL::Containers::detail
