// SPDX-FileComment: This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Solvers/Optimization/Momentum.h>

namespace TNL::Solvers::Optimization {

template< typename Vector, typename SolverMonitor >
void
Momentum< Vector, SolverMonitor >::configSetup( Config::ConfigDescription& config, const std::string& prefix )
{
   IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix );
   config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the momentum method.", 1.0 );
   config.addEntry< double >( prefix + "momentum", "Momentum parameter for the momentum method.", 0.9 );
}

template< typename Vector, typename SolverMonitor >
bool
Momentum< Vector, SolverMonitor >::setup( const Config::ParameterContainer& parameters, const std::string& prefix )
{
   this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) );
   this->setMomentum( parameters.getParameter< double >( prefix + "momentum" ) );
   return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix );
}

template< typename Vector, typename SolverMonitor >
void
Momentum< Vector, SolverMonitor >::setRelaxation( const RealType& lambda )
{
   this->relaxation = lambda;
}

template< typename Vector, typename SolverMonitor >
auto
Momentum< Vector, SolverMonitor >::getRelaxation() const -> const RealType&
{
   return this->relaxation;
}

template< typename Vector, typename SolverMonitor >
void
Momentum< Vector, SolverMonitor >::setMomentum( const RealType& beta )
{
   this->momentum = beta;
}

template< typename Vector, typename SolverMonitor >
auto
Momentum< Vector, SolverMonitor >::getMomentum() const -> const RealType&
{
   return this->momentum;
}

template< typename Vector, typename SolverMonitor >
template< typename GradientGetter >
bool
Momentum< Vector, SolverMonitor >::solve( VectorView& w, GradientGetter&& getGradient )
{
   this->gradient.setLike( w );
   this->v.setLike( w );
   auto gradient_view = gradient.getView();
   auto w_view = w.getView();
   auto v_view = v.getView();
   this->gradient = 0.0;
   this->v = 0.0;

   /////
   // Set necessary parameters
   this->resetIterations();
   this->setResidue( this->getConvergenceResidue() + 1.0 );

   /////
   // Start the main loop
   while( true ) {
      /////
      // Compute the gradient
      getGradient( w_view, gradient_view );
      v_view = this->momentum * v_view - this->relaxation * gradient_view;

      RealType lastResidue = this->getResidue();
      this->setResidue( Algorithms::reduce< DeviceType >( (IndexType) 0,
                                                          w_view.getSize(),
                                                          [ = ] __cuda_callable__( IndexType i ) mutable
                                                          {
                                                             w_view[ i ] += v_view[ i ];
                                                             return abs( v_view[ i ] );
                                                          },
                                                          TNL::Plus() )
                        / ( this->relaxation * (RealType) w.getSize() ) );

      if( ! this->nextIteration() )
         return this->checkConvergence();

      /////
      // Check the stop condition
      if( this->getConvergenceResidue() != 0.0 && this->getResidue() < this->getConvergenceResidue() )
         return true;
   }
}

}  // namespace TNL::Solvers::Optimization
