// The following functions compute square roots modulo prime numbers.
// please note:
//  - primality is not checked! 
//    (if necessary, you have to do this beforehand, eg. using is_probab_prime)
//  - no check is done, whether the square root exists!
//    (if necessary, check it using legendre(Radikant,Primenumber)==1) 

// written by Thorsten Reinecke 1998,1999
// last change: 2005-02-28

/*! @file
 * @brief
 * compute square roots modulo prime numbers (mpz numbers)
 */


// remark: Lucasfunctions can be used without dynamic programming by evaluating
// binary digits (less memory consumption). This isn't implemented yet.

class Clucas_capsule_mpz
{
 private:
  static const int lucas_cache_mpz_size = 500;
   // maximum size of cache
   
  mpz_t lucas_cache_mpz [lucas_cache_mpz_size][2];
  
  int lucas_cache_mpz_index;
   // position for the next value to place in cache

  int lucas_cache_mpz_init_index;
   // needed for allocating memory

  mpz_t lucas_p_mpz, lucas_q_mpz, lucas_p_inv_mpz;

  void lucasv(mpz_t res, const mpz_t Primenumber, mpz_t m);
  
 public:
  Clucas_capsule_mpz() : lucas_cache_mpz_index(0), lucas_cache_mpz_init_index(0)
  {
    mpz_init(lucas_p_mpz); mpz_init(lucas_q_mpz); mpz_init(lucas_p_inv_mpz);
  }
  ~Clucas_capsule_mpz()
  {
    mpz_clear(lucas_p_mpz); mpz_clear(lucas_q_mpz); mpz_clear(lucas_p_inv_mpz);
    for (int i=0; i<lucas_cache_mpz_init_index; ++i)
     {
       //cerr << i << ": " << lucas_cache_mpz[i][0] << ": " << lucas_cache_mpz[i][1] << endl;
       mpz_clear(lucas_cache_mpz[i][0]);
       mpz_clear(lucas_cache_mpz[i][1]);
     }
  }
  void lucas(mpz_t v, const mpz_t Radikant, const mpz_t Primenumber);
};
 

void Clucas_capsule_mpz::lucasv(mpz_t res, const mpz_t Primenumber, mpz_t m)
{
  if (mpz_cmp_ui(m,0)==0) { mpz_set_ui(res,2); return; }
  if (mpz_cmp_ui(m,1)==0) { mpz_set(res,lucas_p_mpz); return; }
  
  // value already known?
  for (int i=lucas_cache_mpz_index-1; i>=0; --i)
    {
      if (mpz_cmp(lucas_cache_mpz[i][0],m)==0)
       { 
         //cout << "Cache Hit!" << i << "/" << lucas_cache_mpz_index << ": " << m << endl;
         mpz_set(res,lucas_cache_mpz[i][1]);
         --lucas_cache_mpz_index; // this one does not hit anymore
         mpz_swap(lucas_cache_mpz[i][0],lucas_cache_mpz[lucas_cache_mpz_index][0]);
         mpz_swap(lucas_cache_mpz[i][1],lucas_cache_mpz[lucas_cache_mpz_index][1]);
         return;
       }
    }
  
  if (mpz_odd_p(m)) // m odd
    {
      mpz_t h1;
      mpz_init(h1); mpz_sub_ui(m,m,1); lucasv(h1,Primenumber,m);
      mpz_mul(res,h1,lucas_q_mpz);
      mpz_add_ui(m,m,2); lucasv(h1,Primenumber,m); mpz_sub_ui(m,m,1);
      mpz_add(res,res,h1); mpz_mod(res,res,Primenumber);
      mpz_mul(res,res,lucas_p_inv_mpz); mpz_mod(res,res,Primenumber);
      mpz_clear(h1);
    }
  else // m even
    {
      mpz_t h1;
      mpz_init(h1); mpz_div_ui(m,m,2);
      lucasv(h1,Primenumber,m);

      mpz_powm(res,lucas_q_mpz,m,Primenumber); mpz_mul_ui(res,res,2);
      mpz_mod(res,res,Primenumber);
      mpz_mul_ui(m,m,2);
      
      mpz_mul(h1,h1,h1); mpz_mod(h1,h1,Primenumber);
      mpz_add(h1,h1,Primenumber); mpz_sub(res,h1,res);
      mpz_mod(res,res,Primenumber);
      mpz_clear(h1);
  
      // place value into cache...
      // (only necessary for oven indices, because recursive call is done for even ones.)
      if (lucas_cache_mpz_index<lucas_cache_mpz_size)
        {
          if (lucas_cache_mpz_init_index==lucas_cache_mpz_index)
           {
             mpz_init(lucas_cache_mpz[lucas_cache_mpz_init_index][0]);
             mpz_init(lucas_cache_mpz[lucas_cache_mpz_init_index++][1]);
           }
          mpz_set(lucas_cache_mpz[lucas_cache_mpz_index][0],m);
          mpz_set(lucas_cache_mpz[lucas_cache_mpz_index++][1],res);
        } else cerr << "Lucas-Cache needs to be increased!" << endl;
    }
}


void Clucas_capsule_mpz::lucas(mpz_t v, const mpz_t Radikant, const mpz_t Primenumber)
{
  if (mpz_remainder_ui(Primenumber,4)!=1)
    { cerr << "Fehler in Lucassequenz Primenumber%4<>1!: " << Primenumber << endl; exit(1); }
  
  /* lucas_p_mpz ermitteln */
  mpz_set(lucas_q_mpz,Radikant);
  mpz_set_ui(lucas_p_mpz,1);
 
  {
    mpz_t h1,h2;
    mpz_init(h1); mpz_init(h2);
    do
     {
       mpz_add_ui(lucas_p_mpz,lucas_p_mpz,1);
       mpz_powm_ui(h1,lucas_p_mpz,2,Primenumber); //h1=squaremod(lucas_p,Primenumber);
       mpz_mul_ui(h2,lucas_q_mpz,4); mpz_mod(h2,h2,Primenumber); //h2=mulmod(4,lucas_q_mpz,Primenumber);
       
       //if (h1>=h2) h1-=h2; else h1=Primenumber-(h2-h1);
       if (mpz_cmp(h1,h2)>=0) mpz_sub(h1,h1,h2);
       else { mpz_sub(h1,h2,h1); mpz_sub(h1,Primenumber,h1); }
     } while (mpz_legendre(h1,Primenumber)!=-1);
    mpz_clear(h1); mpz_clear(h2);
  }

  // compute inverse of lucas_p_mpz modulo Primenumber
  mpz_invert(lucas_p_inv_mpz,lucas_p_mpz,Primenumber); mpz_mod(lucas_p_inv_mpz,lucas_p_inv_mpz,Primenumber);
  
  //cout << "Lucas-Cache was " << lucas_cache_mpz_index << endl;
  lucas_cache_mpz_index=0; // clear cache!
  
  // now compute the square root
  mpz_t m;
  mpz_init(m); mpz_add_ui(m,Primenumber,1); mpz_div_ui(m,m,2);
  lucasv(v,Primenumber,m);
  mpz_clear(m);

  // and divide by 2 (but modulo Primenumber!!)
  if (mpz_odd_p(v)) mpz_add(v,v,Primenumber);
  mpz_div_ui(v,v,2);
}


void mpz_sqrtmod(mpz_t res, const mpz_t Radikant_bel, const mpz_t Primenumber)
{
  // we could perform a primality check here, but we omit it for the sake of efficiency.

  // normalize Radikant
  mpz_t Radikant;
  mpz_init(Radikant); mpz_mod(Radikant,Radikant_bel,Primenumber);
  
  if (mpz_remainder_ui(Primenumber,4)==3)
    {
      // result can be computed easily
      //cout << "sqrtmod: p=3 (mod 4)" << endl;
      mpz_t h; mpz_init(h); mpz_add_ui(h,Primenumber,1); mpz_div_ui(h,h,4);
      mpz_powm(res,Radikant,h,Primenumber);
      mpz_clear(h);
      //return powmod(Radikant,(Primenumber+1)>>2,Primenumber);
    }
  else
    if (mpz_remainder_ui(Primenumber,8)==5)
      {
        // result can be computed easily, too
        //cout << "sqrtmod: p=5 (mod 8)" << endl;
        mpz_t y;
        mpz_init(y); mpz_add_ui(y,Primenumber,3); mpz_div_ui(y,y,8);
        mpz_powm(y,Radikant,y,Primenumber);
        mpz_powm_ui(res,y,2,Primenumber);
        if (mpz_cmp(res,Radikant)==0) mpz_set(res,y);
        else
          {
            mpz_t Pviertel;
            mpz_init(Pviertel); mpz_div_ui(Pviertel,Primenumber,4);
            mpz_set_ui(res,2); mpz_powm(res,res,Pviertel,Primenumber);
            mpz_mul(res,res,y); mpz_mod(res,res,Primenumber);
            mpz_clear(Pviertel);
          }
        mpz_clear(y);
      }
    else
     { 
       mpz_mod(res,Radikant,Primenumber);
       if (mpz_cmp_ui(res,1)>0) // omit special cases: 0^2=0, 1^2=1
        {
          // compute result via Lucas sequences for Primenumber=1 (mod 4)
          //cout << "sqrtmod: p=1 (mod 4) -> Lucassequenzen" << endl;
          Clucas_capsule_mpz capsulated;
           // if thread-safety is no issue, then you can speed-up things a bit
           // in making capsulated a static variable...
          capsulated.lucas(res,Radikant,Primenumber);
        }
     }
  
#if 0
  // final check:
  cout << "mpz_sqrtmod: checking squares" << endl;
  mpz_t x;
  mpz_init(x); mpz_powm_ui(x,res,2,Primenumber);
  if (mpz_cmp(x,Radikant)!=0)
   {
     cerr << "computing square root failed!" << endl;
     cerr << Radikant << "," << x << endl;
     exit(1);
   }
  mpz_clear(x);
#endif

  mpz_clear(Radikant);
}
