// Discrete Fast Fourier Transform
// written by Thorsten Reinecke, 2003-07-04
// last change: 2003-11-11

/*! @file
 * @brief
 * Discrete Fast Fourier Transform
 */


#include "modulo.cc"

namespace polynomial
{

class CDFT_base : private ForbidAssignment
{
 public:
  const unsigned int max_size;

  // static helper function (for constructor to initialize max_size)
  static const unsigned int calc_max_size(const unsigned int _size)
   {
     unsigned int i=2;
     while (i<_size) i<<=1;
     return i;
   }

  // this procedure calculates "count" valid primes (beginning at "Start") suitable for
  // doing dft (with recursion depth "Depth" and returns in a newly created Polynom "primes",
  // which needs to be initially an empty reference.
  static void get_valid_primes_for(TPolynom &primes, const unsigned int count,
                                   const mpz_t Start, const unsigned Depth);

 private:
  // temporre Hilfsvariablen, die
  // jede Routine nutzen darf...
  mpz_t h;
  mpz_t inverse[32];
  TPolynom w;

 protected:
  int size;
  mpz_t M; // Restklasse: es wird (mod M) gerechnet!

  inline const mpz_t& invpow2(const unsigned int i) const { return inverse[i]; }

  inline const unsigned int use_size(const unsigned int input_size) const
  {
    unsigned int i = 2;
    while (i<input_size) i<<=1;
    if (i>max_size)
     {
       MARK;
       cerr << "input_size is invalid!" << endl;
       exit(1);
     }
    return i;
  }

  void calc_roots_and_inverse();
  void convolute(const TPolynom p, const unsigned int n, const unsigned int k);

 public:
  explicit CDFT_base(const unsigned int _size)
   : max_size(calc_max_size(_size)), size(0)
   {
     MARK; cout << "CDFT_base-constructor IN for maximum size=" << max_size << endl;
     mpz_init(M);
     mpz_init(h);
     w = new mpz_t[max_size];
     for (unsigned int i=0; i<max_size; ++i) mpz_init(w[i]);
     for (unsigned int i=0; i<32; ++i) mpz_init(inverse[i]);
     MARK; cout << "CDFT_base-constructor OUT" << endl;
   }

  CDFT_base(const unsigned int _size, const mpz_t _M)
   : max_size(calc_max_size(_size)), size(0)
   {
     MARK; cout << "CDFT_base-constructor IN for maximum size=" << max_size << endl;
     mpz_init(M);
     mpz_init(h);
     w = new mpz_t[max_size];
     for (unsigned int i=0; i<max_size; ++i) mpz_init(w[i]);
     for (unsigned int i=0; i<32; ++i) mpz_init(inverse[i]);
     mpz_set(M,_M);
     calc_roots_and_inverse();
     MARK; cout << "CDFT_base-constructor OUT" << endl;
   }

  virtual ~CDFT_base()
   {
     for (int i=31; i>=0; --i) mpz_clear(inverse[i]);

     for (int i=max_size-1; i>=0; --i) mpz_clear(w[i]);
     delete [] w;

     mpz_clear(h);
     mpz_clear(M);
   }

  const int dftmul(const TPolynom R, const int kR,
                   const TconstPolynom P1, const int k1,
                   const TconstPolynom P2, const int k2);
};



class CDFT : public CDFT_base
{
 private:
  mpz_t N; // dies ist unsere eigentliche Zahl!

  // helper function for constructor 
  void calc_field_and_roots_and_inverse();

 protected:
  const int _mul(const TPolynom R, const int kR, 
                 const TconstPolynom P1, const int k1,
                 const TconstPolynom P2, const int k2,
                 const bool reduce_result_modN);

 public:

  // DFT (diskrete Fouriertransformation) zur Multiplikation zweier
  // Polynome, deren Produkt nicht mehr als size Koeffizienten
  // umfassen darf.
  // _N gibt die Restklasse an, in der das Ergebnis noch
  // korrekt sein mu.

  CDFT(const unsigned int _size, const mpz_t _N)
   : CDFT_base(_size)
   {
     MARK; cout << "CDFT-constructor IN for maximum size=" << max_size << endl;
     mpz_init_set(N,_N);
     calc_field_and_roots_and_inverse();
     MARK; cout << "CDFT-constructor OUT" << endl;
   }

  virtual ~CDFT()
   {
     mpz_clear(N);
   }

 inline const mpz_t& get_N(void) const { return N; }

 inline const int mul(const TPolynom R, const int kR, 
                      const TconstPolynom P1, const int k1,
                      const TconstPolynom P2, const int k2)
  {
    return _mul(R,kR,P1,k1,P2,k2,false);
  }

 inline const int mulmod(const TPolynom R, const int kR,
                         const TconstPolynom P1, const int k1,
                         const TconstPolynom P2, const int k2)
  {
    return _mul(R,kR,P1,k1,P2,k2,true);
  }

 inline const int square(const TPolynom R, const int kR,
                         const TconstPolynom P, const int k)
  {
    return _mul(R,kR,P,k,P,k,false);
  }

 inline const int squaremod(const TPolynom R, const int kR,
                            const TconstPolynom P, const int k)
  {
    return _mul(R,kR,P,k,P,k,true);
  }

};



// ----------------- Implementation ---------------------------------------


void CDFT_base::get_valid_primes_for(TPolynom &primes, const unsigned int count,
                                     const mpz_t Start, const unsigned Depth)
{
  // this procedure calculates "count" valid primes (beginning at "Start") suitable for
  // doing dft (with recursion depth "Depth" and returns in a newly created Polynom "primes",
  // which needs to be initially an empty reference.

 if (primes!=NULL)
  {
    cout << __FILE__ << ", " << __FUNCTION__ << ": line " <<  __LINE__ << endl;
    cerr << "First parameter is a call by reference," << endl;
    cerr << "it should initilly point to NULL (to avoid memory-leaks)," << endl;
    cerr << "because a new pointer to new data will be generated and" << endl;
    cerr << "there is no need for initially data pointed by \"primes!\"" << endl;
    exit(1);
  }

 MARK;
 primes = new mpz_t[count];
 for (unsigned int i=0; i<count; ++i) mpz_init(primes[i]);

  mpz_t x,M;
  mpz_init(x); mpz_init(M);

  const size_t bits = mpz_sizeinbase(Start,2)+1; // Grenordng ld(Start)+1
  // + "Sicherheitsbits", falls spter bei Polynomen einige Koeffizienten zu gro sein sollten...
  mpz_set_ui(M,1); mpz_mul_2exp(M,M,bits); mpz_mul_ui(M,M,Depth);
  mpz_add_ui(M,M,1);

  const unsigned int interval = 10000;

  for (unsigned int bisher=0; bisher<count; ++bisher)
   {
    cerr << bisher+1 << "/" << count << ": ";
    do
     {
      // sieve[i] -> true, if M+i*Depth is composite
      // sieve[i] -> false: unknown
      bool sieve[interval] = { false };
      for (unsigned int p=3; p<1000; p+=2) if (numtheory::is_prime(p))
       {
         unsigned int i = 0;
         unsigned int r = mpz_fdiv_ui(M,p);
         while (r) { r=(r+Depth)%p; ++i; }
         while (i<interval) { sieve[i]=true; i+=p; }
       }
      sieve[interval-1]=false;
      unsigned int i=0;
      while(i<interval)
       {
         while(sieve[i]) ++i;
         mpz_set_ui(x,Depth); mpz_mul_ui(x,x,i); mpz_add(x,x,M);
         cerr << i << " ";
         if (mpz_probab_prime_p(x,10)) break;
         ++i;
       }
      mpz_set_ui(x,Depth); mpz_mul_ui(x,x,i); mpz_add(M,M,x);
      cerr << " # +" << i << endl;
     } while (mpz_probab_prime_p(M,10)==0);
    mpz_set(primes[bisher],M);
   }

  mpz_clear(M); mpz_clear(x);
}


void CDFT::calc_field_and_roots_and_inverse()
{
  mpz_mul(M,N,N); mpz_mul_ui(M,M,max_size);
  mpz_mul_ui(M,M,4); // + "Sicherheitsbits", falls spter bei Polynomen einige Koeffizienten zu gro sein sollten...

  TPolynom MyField = NULL;
  get_valid_primes_for(MyField,1,M,max_size);
  mpz_set(M,MyField[0]);
  mpz_clear(MyField[0]); delete [] MyField;
  calc_roots_and_inverse();
}

void CDFT_base::calc_roots_and_inverse()
{
  mpz_t x,e;
  mpz_init(x); mpz_init(e);

  if (!mpz_probab_prime_p(M,10))
   {
     MARK;
     cerr << "invalid M for dft!" << endl;
     exit(1);
   }

  mpz_sub_ui(e,M,1);
  if (mpz_div_ui(e,e,max_size)!=0)
   {
     MARK;
     cerr << "invalid M for dft!" << endl;
     exit(1);
   }

  unsigned int r=911;
try_r:
  mpz_set_ui(x,r); mpz_powm(w[1],x,e,M); mpz_powm_ui(x,w[1],max_size/2,M);
  mpz_add_ui(x,x,1); mpz_mod(x,x,M);
  //cout << "Restklasse " << "M="; mpz_out_str(stdout,10,M); cout << endl;
  //cout << "-1? ";  mpz_out_str(stdout,10,x); cout << endl << endl;
  if (mpz_cmp_ui(x,0)!=0)
   {
     r+=2; if (r<2000) goto try_r;
     cerr << "unable to find valid roots..." << endl;
     exit(1);
   }

  // ansonsten ist w[1] die erste Hauptwurzel...
  mpz_set_ui(w[0],1); // w^0 = 1
  for (unsigned int i=2; i<max_size; ++i)
   {
     mpz_mul(x,w[i-1],w[1]); mpz_mod(w[i],x,M);
     if (mpz_cmp_ui(w[i],1)==0)
      {
        MARK;
        cerr << "invalid roots..." << endl;
        exit(1);
      }
   }

  mpz_clear(x); mpz_clear(e);

  // finally precalculate inverse of 2^k (mod M), k=0..31
  mpz_set_ui(inverse[0],1);
  mpz_set_ui(h,2);
  if (!mpz_invert(h,h,M)) { MARK; cerr << "inverse of 2 does not exist!" << endl; exit(1); }
  mpz_mod(h,h,M); mpz_set(inverse[1],h);
  for (int i=2; i<32; ++i)
   {
     mpz_mul(h,inverse[i-1],inverse[1]); mpz_mod(h,h,M);
     mpz_init_set(inverse[i],h);
   }
}

void CDFT_base::convolute(const TPolynom p, const unsigned int n, const unsigned int k)
{
  const unsigned int nh = n>>1;
  if (n==2)
   {
#if 0
     mpz_set(h,p[k]);
     mpz_add(p[k],p[k],p[k+1]); mpz_sub(p[k+1],h,p[k+1]);
#else
     mpz_add(p[k],p[k],p[k+1]); mpz_mul_2exp(p[k+1],p[k+1],1); mpz_sub(p[k+1],p[k],p[k+1]);
#endif
   }
  else
   {
     {
       mpz_t* const temp = new mpz_t[nh]; // will be only used for swapping, so no initialization is needed...
       for (unsigned int i=0, j=k; i<nh; ++i)
        {
	  mpz_swap(p[i+k],p[j++]);
	  mpz_swap(temp[i],p[j++]);
        }
       for (unsigned int i=0; i<nh; ++i) mpz_swap(p[i+k+nh],temp[i]);
       delete [] temp;
     }

     convolute(p,nh,k);
     convolute(p,nh,nh+k);

     const unsigned int dj = max_size/n;
     for (unsigned int i=k,j=0; i<nh+k; ++i,j+=dj)
      {
	mpz_mul(h,w[j],p[i+nh]); mpz_mod(h,h,M);
        mpz_sub(p[i+nh],p[i],h);
	mpz_add(p[i],p[i],h);
      }
   }
}


const int CDFT_base::dftmul(const TPolynom R, const int kR,
                            const TconstPolynom P1, const int k1,
                            const TconstPolynom P2, const int k2)
{
  const size_t ld_M = mpz_sizeinbase(M,2); // ld(any input coefficient)>ld_N -> result could be wrong!!
  const unsigned int estimated_memusage_in_bits = mpz_sizeinbase(M,2)*2+5; // for optimizing mpz-heap allocation
  const int result_size = k1+k2-1;

  size = use_size(result_size);

  // sanity checks
  if (result_size>size)
   {
     MARK; cerr << "(result_size>size)" << endl;
     exit(1);
   }
  if (kR<result_size)
   {
     MARK; cerr << "bereitgestelltes Resultatpolynom ist zu klein!" << endl;
     exit(1);
   }

  // wenn in R bereitgestellter Platz fr convolution ausreicht, R auch nehmen,
  // ansonsten temporren Speicher verwenden...
  const TPolynom p = (kR>=size && mpz_sizeinbase(R[0],2)>=estimated_memusage_in_bits) ? R : new mpz_t[size];

  for (int i=0; i<size; ++i) mpz_init2(p[i],estimated_memusage_in_bits);
  for (int i=0; i<k1; ++i) mpz_set(p[i],P1[i]); // get first multiplicant

  // and this is done by mpz_init already :-)
  // for (int i=k1; i<size; ++i) mpz_set_ui(p[i],0); // padding with zeros


#if 0
  for (int i=0; i<k1; ++i) mpz_mod(p[i],p[i],M); // just to be on the safe side
#else
  {
   int j=0;
   for (int i=0; i<k1; ++i)
    {
     if (mpz_sgn(p[i])<0)
      {
        ++j; mpz_mod(p[i],p[i],M); // just to be on the safe side
      }
     else
      if (mpz_sizeinbase(p[i],2)>ld_M)
       {
         ++j; mpz_mod(p[i],p[i],M); // just to be on the safe side
       }
    }
#if defined(VERBOSE)
   if (j) cout << "P1: " << j << " out of " << k1 << " coefficients corrected." << endl;
#endif
  }
#endif

  convolute(p,size,0); // do fft

  const TPolynom q = (P1==P2 && k1==k2) ? p : new mpz_t[size]; // fr Spezialfall p*q = p^2
  if (p!=q)
   {
     for (int i=0; i<size; ++i) mpz_init2(q[i],estimated_memusage_in_bits);
     for (int i=0; i<k2; ++i) mpz_set(q[i],P2[i]); // get second multiplicant
     // is done already by init2!! for (int i=k2; i<size; ++i) mpz_init(q[i]); // padding with zeros

#if 0
     for (int i=0; i<k2; ++i) mpz_mod(q[i],q[i],M); // just to be on the safe side
#else
     {
      int j=0;
      for (int i=0; i<k2; ++i)
       {
        if (mpz_sgn(q[i])<0)
         {
           ++j; mpz_mod(q[i],q[i],M); // just to be on the safe side
         }
        else
         if (mpz_sizeinbase(q[i],2)>ld_M)
          {
            ++j; mpz_mod(q[i],q[i],M); // just to be on the safe side
          }
       }
#if defined(VERBOSE)
      if (j) cout << "P2: " << j << " out of " << k2 << " coefficients corrected." << endl;
#endif
     }
#endif

      convolute(q,size,0); // do fft
    }

  // IMPORTANT: store result for last fft in p (to save memory space)
  for (int i=0; i<size; ++i)
   {
     mpz_mul(p[i],p[i],q[i]);
     mpz_mod(p[i],p[i],M);
   }
   // the result will be in p now!!!

  if (q!=p)
   {
     // we can delete the temporary polynom q
     for (int i=0; i<size; ++i) mpz_clear(q[i]);
     delete [] q;
   }

  convolute(p,size,0); // do fft
  for (int i=1; i<size/2; ++i) mpz_swap(p[i],p[size-i]);

  int inv_index=0;
  for (int i=1; i<size; i<<=1) ++inv_index;

  for (int i=0; i<result_size; ++i)
   {
     mpz_mul(p[i],p[i],invpow2(inv_index));
     mpz_mod(p[i],p[i],M);
   }

  if (p!=R)
   {
     // Resultat kopieren und temporres Polynom freigeben
     for (int i=result_size-1; i>=0; --i) mpz_set(R[i],p[i]); // schneller wre mpz_swap, aber: Speicherfragmentierung?
     if (size-result_size>10) cout << size-result_size << " Auswertungen gespart..." << endl;
     // release temporary polynom
     for (int i=0; i<size; ++i) mpz_clear(p[i]);
     delete [] p;
   }
  //for (int i=result_size; i<kR; ++i) mpz_set_ui(R[i],0); // fhrende Nullen nicht notwendig (wegen result_size-Rckgabe)
  return result_size; // return size of result
}


const int CDFT::_mul(const TPolynom R, const int kR,
               const TconstPolynom P1, const int k1,
               const TconstPolynom P2, const int k2,
               const bool reduce_result_modN)
{
  const size_t ld_N = mpz_sizeinbase(N,2); // ld(any input coefficient)>ld_N -> result could be wrong!!
  const unsigned int estimated_memusage_in_bits = mpz_sizeinbase(M,2)*2+5; // for optimizing mpz-heap allocation
  const int result_size = k1+k2-1;

  size = use_size(result_size);

  // sanity checks
  if (result_size>size)
   {
     MARK; cerr << "(result_size>size)" << endl;
     exit(1);
   }
  if (kR<result_size)
   {
     MARK; cerr << "bereitgestelltes Resultatpolynom ist zu klein!" << endl;
     exit(1);
   }

  // wenn in R bereitgestellter Platz fr convolution ausreicht, R auch nehmen,
  // ansonsten temporren Speicher verwenden...
  const TPolynom p = (kR>=size && mpz_sizeinbase(R[0],2)>=estimated_memusage_in_bits) ? R : new mpz_t[size];

  // IMPORTANT!!! all coefficients of P1 and P2 must be
  // between 0 and N-1!
  // This is *very* important for DFFT!!

  for (int i=0; i<size; ++i) mpz_init2(p[i],estimated_memusage_in_bits);
  for (int i=0; i<k1; ++i) mpz_set(p[i],P1[i]); // get first multiplicant

  // and this is done by mpz_init already :-)
  // for (int i=k1; i<size; ++i) mpz_set_ui(p[i],0); // padding with zeros


#if 0
  for (int i=0; i<k1; ++i) mpz_mod(p[i],p[i],N); // just to be on the safe side
#else
  {
   int j=0;
   for (int i=0; i<k1; ++i)
    {
     if ( mpz_sgn(p[i])<0 || mpz_sizeinbase(p[i],2)>ld_N )
      {
        ++j; mpz_mod(p[i],p[i],N); // just to be on the safe side
      }
    }
#if defined(VERBOSE)
   if (j) cout << "P1: " << j << " out of " << k1 << " coefficients corrected." << endl;
#endif
  }
#endif

  convolute(p,size,0); // do fft

  const TPolynom q = (P1==P2 && k1==k2) ? p : new mpz_t[size]; // fr Spezialfall p*q = p^2
  if (p!=q)
   {
     // IMPORTANT!!! all input coefficients of P1 and P2 must be
     // between 0 and N-1!
     // This is *very* important for DFFT!!

     for (int i=0; i<size; ++i) mpz_init2(q[i],estimated_memusage_in_bits);
     for (int i=0; i<k2; ++i) mpz_set(q[i],P2[i]); // get second multiplicant
     // is done already by init2!! for (int i=k2; i<size; ++i) mpz_init(q[i]); // padding with zeros

#if 0
     for (int i=0; i<k2; ++i) mpz_mod(q[i],q[i],N); // just to be on the safe side
#else
     {
      int j=0;
      for (int i=0; i<k2; ++i)
       {
        if (mpz_sgn(q[i])<0 || mpz_sizeinbase(q[i],2)>ld_N )
         {
           ++j; mpz_mod(q[i],q[i],N); // just to be on the safe side
         }
       }
#if defined(VERBOSE)
      if (j) cout << "P2: " << j << " out of " << k2 << " coefficients corrected." << endl;
#endif
     }
#endif

      convolute(q,size,0); // do fft
    }

  // IMPORTANT: store result for last fft in p (to save memory space)
  for (int i=0; i<size; ++i)
   {
     mpz_mul(p[i],p[i],q[i]);
     mpz_mod(p[i],p[i],M);
   }
   // the result will be in p now!!!

  if (q!=p)
   {
     // we can delete the temporary polynom q
     for (int i=0; i<size; ++i) mpz_clear(q[i]);
     delete [] q;
   }

  convolute(p,size,0); // do fft
  for (int i=1; i<size/2; ++i) mpz_swap(p[i],p[size-i]);

  int inv_index=0;
  for (int i=1; i<size; i<<=1) ++inv_index;

  for (int i=0; i<result_size; ++i)
   {
     mpz_mul(p[i],p[i],invpow2(inv_index));
     mpz_mod(p[i],p[i],M);
     if (reduce_result_modN) mpz_mod(p[i],p[i],N);
   }

  if (p!=R)
   {
     // Resultat kopieren und temporres Polynom freigeben
     for (int i=result_size-1; i>=0; --i) mpz_set(R[i],p[i]); // schneller wre mpz_swap, aber: Speicherfragmentierung?
     if (size-result_size>10) cout << size-result_size << " Auswertungen gespart..." << endl;
#ifdef DEBUG
     // sanity check
     for (int i=result_size; i<size; ++i)
      {
        mpz_mul(p[i],p[i],invpov2(inv_index));
        mpz_mod(p[i],p[i],M);
        if (mpz_cmp_ui(p[i],0)!=0)
         {
           MARK;
           cerr << "These values should be ZERO!" << endl;
         }
      }
#endif
     // release temporary polynom
     for (int i=0; i<size; ++i) mpz_clear(p[i]);
     delete [] p;
   }
  //for (int i=result_size; i<kR; ++i) mpz_set_ui(R[i],0); // fhrende Nullen nicht notwendig (wegen result_size-Rckgabe)
  return result_size; // return size of result
}



// ------------------------------------------------------------------------



// die nachfolgenden Programmteile sind nun spezifischer
// fr unsere Anwendung...

typedef CDFT TDFT;
typedef TDFT* PDFT;


#if 1
inline bool dft_mul_is_recommended(const int k1, const int k2)
{
  // tune...
  if (k1<14000 || k2<14000) return false;
  if (k1>25000 && k2>25000) return true;
  if (k1<=16384 && k1<=16384) return true;
  return false;
}

inline bool dft_square_is_recommended(const int k)
{
  // tune...
  return k>=8192;
}

#else

inline bool dft_mul_is_recommended(const int k1, const int k2)
{
  return k1+k2>=4;
}

inline bool dft_square_is_recommended(const int k)
{
  return k>=4;
}

#endif



const PDFT get_dft(const unsigned int n, const mpz_t m)
{
 static PDFT pdft = NULL;
 if (!pdft)
  {
    if (n<=0) return NULL;
    pdft = new TDFT(n>32768 ? n : 32768,m);
  }
 
 if ( n > pdft->max_size // resize is necessary!
       ||
      mpz_cmp(pdft->get_N(),m)!=0 // modulo-base has changed!
    ) 
  {
    delete pdft; pdft=NULL;
    if (n>0)
     {
       cout << "renewing dft-object..." << endl;
       pdft = new TDFT(n,m);
     }
    else
     {
       // if n<=0, then no new dft-object will be created...
       cout << "dft-obejct is released..." << endl;
     }
    return pdft;
  }
 return pdft;
}

} // namespace polynomial

