// SPDX-FileComment: This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
// SPDX-License-Identifier: MIT

#pragma once

#include <complex>
#include <type_traits>

#include <TNL/Backend/Macros.h>
#include <TNL/Math.h>
#include <TNL/TypeTraits.h>

namespace TNL::Arithmetics {
/**
 * \brief Implementation of complex types.
 *
 * Warning: Only basic algebraic operations like addition, subtraction, multiplication and division are implemented currently.
 *
 * \tparam Value represents arithmetics of real numbers used for the construction of the complex number.
 */
template< typename Value = double >
struct Complex
{
   using ValueType = Value;
   using value_type = Value;

   __cuda_callable__
   constexpr Complex();

   __cuda_callable__
   constexpr Complex( const Value& re );

   __cuda_callable__
   constexpr Complex( const Value& re, const Value& im );

   __cuda_callable__
   constexpr Complex( const Complex< Value >& c );

   template< typename Value_ >
   __cuda_callable__
   constexpr Complex( const Complex< Value_ >& c );

   template< typename Value_ >
   constexpr Complex( const std::complex< Value_ >& c );

   __cuda_callable__
   constexpr Complex&
   operator=( const Value& v );

   __cuda_callable__
   constexpr Complex&
   operator=( const Complex< Value >& c );

   template< typename Value_ >
   __cuda_callable__
   constexpr Complex&
   operator=( const Value_& v );

   template< typename Value_ >
   __cuda_callable__
   constexpr Complex&
   operator=( const Complex< Value_ >& c );

   constexpr Complex&
   operator=( const std::complex< Value >& c );

   template< typename Value_ >
   constexpr Complex&
   operator=( const std::complex< Value_ >& c );

   __cuda_callable__
   Complex&
   operator+=( const Value& v );

   __cuda_callable__
   Complex&
   operator+=( const Complex< Value >& c );

   template< typename Value_ >
   __cuda_callable__
   std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, Complex& >
   operator+=( const Value_& v );

   template< typename Value_ >
   __cuda_callable__
   Complex&
   operator+=( const Complex< Value_ >& c );

   Complex&
   operator+=( const std::complex< Value >& c );

   template< typename Value_ >
   Complex&
   operator+=( const std::complex< Value_ >& c );

   __cuda_callable__
   Complex&
   operator-=( const Value& v );

   __cuda_callable__
   Complex&
   operator-=( const Complex< Value >& c );

   template< typename Value_ >
   __cuda_callable__
   std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, Complex& >
   operator-=( const Value_& v );

   template< typename Value_ >
   __cuda_callable__
   Complex&
   operator-=( const Complex< Value_ >& c );

   Complex&
   operator-=( const std::complex< Value >& c );

   template< typename Value_ >
   Complex&
   operator-=( const std::complex< Value_ >& c );

   __cuda_callable__
   Complex&
   operator*=( const Value& v );

   __cuda_callable__
   Complex&
   operator*=( const Complex< Value >& c );

   template< typename Value_ >
   __cuda_callable__
   std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, Complex& >
   operator*=( const Value_& v );

   template< typename Value_ >
   __cuda_callable__
   Complex&
   operator*=( const Complex< Value_ >& c );

   Complex&
   operator*=( const std::complex< Value >& c );

   template< typename Value_ >
   Complex&
   operator*=( const std::complex< Value_ >& c );

   __cuda_callable__
   Complex&
   operator/=( const Value& v );

   __cuda_callable__
   Complex&
   operator/=( const Complex< Value >& c );

   template< typename Value_ >
   __cuda_callable__
   std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, Complex& >
   operator/=( const Value_& v );

   template< typename Value_ >
   __cuda_callable__
   Complex&
   operator/=( const Complex< Value_ >& c );

   Complex&
   operator/=( const std::complex< Value >& c );

   template< typename Value_ >
   Complex&
   operator/=( const std::complex< Value_ >& c );

   [[nodiscard]] __cuda_callable__
   bool
   operator==( const Value& v ) const;

   [[nodiscard]] __cuda_callable__
   bool
   operator==( const Complex< Value >& c ) const;

   template< typename Value_ >
   [[nodiscard]] __cuda_callable__
   std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, bool >
   operator==( const Value_& v ) const;

   template< typename Value_ >
   [[nodiscard]] __cuda_callable__
   bool
   operator==( const Complex< Value_ >& c ) const;

   [[nodiscard]] bool
   operator==( const std::complex< Value >& c ) const;

   template< typename Value_ >
   [[nodiscard]] bool
   operator==( const std::complex< Value_ >& c ) const;

   [[nodiscard]] __cuda_callable__
   bool
   operator!=( const Value& v ) const;

   [[nodiscard]] __cuda_callable__
   bool
   operator!=( const Complex< Value >& c ) const;

   template< typename Value_ >
   [[nodiscard]] __cuda_callable__
   std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, bool >
   operator!=( const Value_& v ) const;

   template< typename Value_ >
   [[nodiscard]] __cuda_callable__
   bool
   operator!=( const Complex< Value_ >& c ) const;

   [[nodiscard]] bool
   operator!=( const std::complex< Value >& c ) const;

   template< typename Value_ >
   [[nodiscard]] bool
   operator!=( const std::complex< Value_ >& c ) const;

   [[nodiscard]] __cuda_callable__
   Complex
   operator-() const;

   [[nodiscard]] __cuda_callable__
   Complex
   operator+() const;

   [[nodiscard]] __cuda_callable__
   const Value&
   real() const volatile;

   [[nodiscard]] __cuda_callable__
   const Value&
   imag() const volatile;

   [[nodiscard]] __cuda_callable__
   const Value&
   real() const;

   [[nodiscard]] __cuda_callable__
   const Value&
   imag() const;

   [[nodiscard]] __cuda_callable__
   Value&
   real() volatile;

   [[nodiscard]] __cuda_callable__
   Value&
   imag() volatile;

   [[nodiscard]] __cuda_callable__
   Value&
   real();

   [[nodiscard]] __cuda_callable__
   Value&
   imag();

protected:
   Value real_, imag_;
};

////
// EQ operators
template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value >::value && ! is_complex_v< Value >, bool >
operator==( const Value& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
bool
operator==( const std::complex< Value >& c1, const Complex< Value >& c2 );

////
// NE operators
template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value >::value && ! is_complex_v< Value >, bool >
operator!=( const Value& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
bool
operator!=( const std::complex< Value >& c1, const Complex< Value >& c2 );

////
// Addition operators
template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, Complex< Value > >
operator+( const Complex< Value >& c, const Value_& v );

template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value >::value && ! is_complex_v< Value >, Complex< Value > >
operator+( const Value& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
__cuda_callable__
Complex< Value >
operator+( const Complex< Value >& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
Complex< Value >
operator+( const std::complex< Value >& c1, const Complex< Value >& c2 );

template< typename Value, typename Value_ >
Complex< Value >
operator+( const Complex< Value >& c1, const std::complex< Value_ >& c2 );

////
// Subtraction operators
template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, Complex< Value > >
operator-( const Complex< Value >& c, const Value_& v );

template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value >::value && ! is_complex_v< Value >, Complex< Value > >
operator-( const Value& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
__cuda_callable__
Complex< Value >
operator-( const Complex< Value >& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
Complex< Value >
operator-( const std::complex< Value >& c1, const Complex< Value >& c2 );

template< typename Value, typename Value_ >
Complex< Value >
operator-( const Complex< Value >& c1, const std::complex< Value_ >& c2 );

////
// Multiplication operators
template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, Complex< Value > >
operator*( const Complex< Value >& c, const Value_& v );

template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value >::value && ! is_complex_v< Value >, Complex< Value > >
operator*( const Value& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
__cuda_callable__
Complex< Value >
operator*( const Complex< Value >& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
Complex< Value >
operator*( const std::complex< Value >& c1, const Complex< Value >& c2 );

template< typename Value, typename Value_ >
Complex< Value >
operator*( const Complex< Value >& c1, const std::complex< Value_ >& c2 );

////
// Division operators
template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value_ >::value && ! is_complex_v< Value_ >, Complex< Value > >
operator/( const Complex< Value >& c, const Value_& v );

template< typename Value, typename Value_ >
__cuda_callable__
std::enable_if_t< IsScalarType< Value >::value && ! is_complex_v< Value >, Complex< Value > >
operator/( const Value& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
__cuda_callable__
Complex< Value >
operator/( const Complex< Value >& v, const Complex< Value_ >& c );

template< typename Value, typename Value_ >
Complex< Value >
operator/( const std::complex< Value >& c1, const Complex< Value >& c2 );

template< typename Value, typename Value_ >
Complex< Value >
operator/( const Complex< Value >& c1, const std::complex< Value_ >& c2 );

template< typename Value >
__cuda_callable__
Value
norm( const Complex< Value >& c );

template< typename Value >
__cuda_callable__
Value
abs( const Complex< Value >& c );

template< typename Value >
__cuda_callable__
Value
arg( const Complex< Value >& c );

template< typename Value >
__cuda_callable__
Complex< Value >
conj( const Complex< Value >& c );

template< typename Value >
std::ostream&
operator<<( std::ostream& str, const Complex< Value >& c );

}  // namespace TNL::Arithmetics

namespace TNL {

template< class T >
struct is_complex< Arithmetics::Complex< T > > : public std::true_type
{};

using Arithmetics::abs;
using Arithmetics::arg;
using Arithmetics::conj;
using Arithmetics::norm;

}  // namespace TNL

#include <TNL/Arithmetics/Complex.hpp>
