// SPDX-FileComment: This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Containers/Array.h>
#include <TNL/Containers/DistributedArrayView.h>

namespace TNL::Containers {

/**
 * \brief Distributed array.
 *
 * \par Example
 * \include Containers/DistributedArrayExample.cpp
 * \par Output
 * \include DistributedArrayExample.out
 */
template< typename Value,
          typename Device = Devices::Host,
          typename Index = int,
          typename Allocator = typename Allocators::Default< Device >::template Allocator< Value > >
class DistributedArray
{
   using LocalArrayType = Containers::Array< Value, Device, Index, Allocator >;

public:
   using ValueType = Value;
   using DeviceType = Device;
   using IndexType = Index;
   using AllocatorType = Allocator;
   using LocalRangeType = Subrange< Index >;
   using LocalViewType = Containers::ArrayView< Value, Device, Index >;
   using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >;
   using ViewType = DistributedArrayView< Value, Device, Index >;
   using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index >;
   using SynchronizerType = typename ViewType::SynchronizerType;

   /**
    * \brief A template which allows to quickly obtain a
    * \ref TNL::Containers::DistributedArray "DistributedArray" type with
    * changed template parameters.
    */
   template< typename _Value,
             typename _Device = Device,
             typename _Index = Index,
             typename _Allocator = typename Allocators::Default< _Device >::template Allocator< _Value > >
   using Self = DistributedArray< _Value, _Device, _Index, _Allocator >;

   ~DistributedArray();

   /**
    * \brief Constructs an empty array with zero size.
    */
   DistributedArray() = default;

   /**
    * \brief Constructs an empty array and sets the provided allocator.
    *
    * \param allocator The allocator to be associated with this array.
    */
   explicit DistributedArray( const AllocatorType& allocator );

   /**
    * \brief Copy constructor (makes a deep copy).
    *
    * \param array The array to be copied.
    */
   explicit DistributedArray( const DistributedArray& array );

   // default move-constructor
   DistributedArray( DistributedArray&& ) noexcept = default;

   /**
    * \brief Copy constructor with a specific allocator (makes a deep copy).
    *
    * \param array The array to be copied.
    * \param allocator The allocator to be associated with this array.
    */
   explicit DistributedArray( const DistributedArray& array, const AllocatorType& allocator );

   DistributedArray( LocalRangeType localRange,
                     Index ghosts,
                     Index globalSize,
                     const MPI::Comm& communicator,
                     const AllocatorType& allocator = AllocatorType() );

   /**
    * \brief Set new global size and distribution of the array.
    *
    * \param localRange The range of elements in the global array that is owned by this rank.
    * \param ghosts Number of ghost elements allocated by this rank.
    * \param globalSize The size of the global array.
    * \param communicator Reference to the MPI communicator on which the array is distributed.
    */
   void
   setDistribution( LocalRangeType localRange, Index ghosts, Index globalSize, const MPI::Comm& communicator );

   /**
    * \brief Returns the local range of the distributed array.
    */
   [[nodiscard]] const LocalRangeType&
   getLocalRange() const;

   [[nodiscard]] IndexType
   getGhosts() const;

   /**
    * \brief Returns the MPI communicator associated to the array.
    */
   [[nodiscard]] const MPI::Comm&
   getCommunicator() const;

   /**
    * \brief Returns the allocator associated to the array.
    */
   [[nodiscard]] AllocatorType
   getAllocator() const;

   /**
    * \brief Returns a modifiable view of the local part of the array.
    */
   [[nodiscard]] LocalViewType
   getLocalView();

   /**
    * \brief Returns a non-modifiable view of the local part of the array.
    */
   [[nodiscard]] ConstLocalViewType
   getConstLocalView() const;

   /**
    * \brief Returns a modifiable view of the local part of the array,
    * including ghost values.
    */
   [[nodiscard]] LocalViewType
   getLocalViewWithGhosts();

   /**
    * \brief Returns a non-modifiable view of the local part of the array,
    * including ghost values.
    */
   [[nodiscard]] ConstLocalViewType
   getConstLocalViewWithGhosts() const;

   void
   copyFromGlobal( ConstLocalViewType globalArray );

   // synchronizer stuff
   void
   setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement = 1 );

   [[nodiscard]] std::shared_ptr< SynchronizerType >
   getSynchronizer() const;

   [[nodiscard]] int
   getValuesPerElement() const;

   void
   startSynchronization();

   void
   waitForSynchronization() const;

   // Usual Array methods follow below.

   /**
    * \brief Returns a modifiable view of the array.
    */
   [[nodiscard]] ViewType
   getView();

   /**
    * \brief Returns a non-modifiable view of the array.
    */
   [[nodiscard]] ConstViewType
   getConstView() const;

   /**
    * \brief Conversion operator to a modifiable view of the array.
    */
   operator ViewType();

   /**
    * \brief Conversion operator to a non-modifiable view of the array.
    */
   operator ConstViewType() const;

   template< typename Array >
   void
   setLike( const Array& array );

   // Resets the array to the empty state.
   void
   reset();

   // Returns true if the current array size is zero.
   [[nodiscard]] bool
   empty() const;

   // TODO: swap

   // Returns the *global* size
   [[nodiscard]] IndexType
   getSize() const;

   // Sets all elements of the array to the given value
   void
   setValue( ValueType value );

   // Safe device-independent element setter
   void
   setElement( IndexType i, ValueType value );

   // Safe device-independent element getter
   [[nodiscard]] ValueType
   getElement( IndexType i ) const;

   // Unsafe element accessor usable only from the Device
   [[nodiscard]] __cuda_callable__
   ValueType&
   operator[]( IndexType i );

   // Unsafe element accessor usable only from the Device
   [[nodiscard]] __cuda_callable__
   const ValueType&
   operator[]( IndexType i ) const;

   // Copy-assignment operator
   DistributedArray&
   operator=( const DistributedArray& array );

   // Move-assignment operator
   DistributedArray&
   operator=( DistributedArray&& ) noexcept( false ) = default;

   template< typename Array, typename..., typename = std::enable_if_t< HasSubscriptOperator< Array >::value > >
   DistributedArray&
   operator=( const Array& array );

   // Comparison operators
   template< typename Array >
   [[nodiscard]] bool
   operator==( const Array& array ) const;

   template< typename Array >
   [[nodiscard]] bool
   operator!=( const Array& array ) const;

   /**
    * \brief Process the lambda function \e f for each array element in interval [ \e begin, \e end).
    *
    * The lambda function is supposed to be declared as
    *
    * ```
    * f( IndexType elementIdx, ValueType& elementValue )
    * ```
    *
    * where
    *
    * - \e elementIdx is an index of the array element being currently processed
    * - \e elementValue is a value of the array element being currently processed
    *
    * This is performed at the same place where the array is allocated,
    * i.e. it is efficient even on GPU.
    *
    * \param begin The beginning of the array elements interval.
    * \param end The end of the array elements interval.
    * \param f The lambda function to be processed.
    */
   template< typename Function >
   void
   forElements( IndexType begin, IndexType end, Function&& f );

   /**
    * \brief Process the lambda function \e f for each array element in interval [ \e begin, \e end) for constant instances of
    * the array.
    *
    * The lambda function is supposed to be declared as
    *
    * ```
    * f( IndexType elementIdx, ValueType& elementValue )
    * ```
    *
    * where
    *
    * - \e elementIdx is an index of the array element being currently processed
    * - \e elementValue is a value of the array element being currently processed
    *
    * This is performed at the same place where the array is allocated,
    * i.e. it is efficient even on GPU.
    *
    * \param begin The beginning of the array elements interval.
    * \param end The end of the array elements interval.
    * \param f The lambda function to be processed.
    */
   template< typename Function >
   void
   forElements( IndexType begin, IndexType end, Function&& f ) const;

   void
   loadFromGlobalFile( const String& fileName, bool allowCasting = false );

   void
   loadFromGlobalFile( File& file, bool allowCasting = false );

protected:
   ViewType view;
   LocalArrayType localData;
};

}  // namespace TNL::Containers

#include "DistributedArray.hpp"
