diff --git a/content/number-theory/ModularArithmetic.h b/content/number-theory/ModularArithmetic.h index 4a62aae5f..91b7c718c 100644 --- a/content/number-theory/ModularArithmetic.h +++ b/content/number-theory/ModularArithmetic.h @@ -8,23 +8,21 @@ */ #pragma once -#include "euclid.h" - const ll mod = 17; // change to something else struct Mod { - ll x; - Mod(ll xx) : x(xx) {} - Mod operator+(Mod b) { return Mod((x + b.x) % mod); } - Mod operator-(Mod b) { return Mod((x - b.x + mod) % mod); } - Mod operator*(Mod b) { return Mod((x * b.x) % mod); } + ll v; + Mod() : v(0) {} + Mod(ll vv) : v(vv % mod) {} + Mod operator+(Mod b) { return Mod((v + b.v) % mod); } + Mod operator-(Mod b) { return Mod(v - b.v + mod); } + Mod operator*(Mod b) { return Mod(v * b.v); } Mod operator/(Mod b) { return *this * invert(b); } - Mod invert(Mod a) { - ll x, y, g = euclid(a.x, mod, x, y); - assert(g == 1); return Mod((x + mod) % mod); - } + Mod invert(Mod a) { return a^(mod-2); } Mod operator^(ll e) { - if (!e) return Mod(1); - Mod r = *this ^ (e / 2); r = r * r; - return e&1 ? *this * r : r; + ll ans = 1, b = (*this).v; + for (; e; b = b * b % mod, e /= 2) + if (e & 1) ans = ans * b % mod; + return ans; } -}; + explicit operator ll() const { return v; } +}; \ No newline at end of file diff --git a/content/numerical/FastFourierTransform.h b/content/numerical/FastFourierTransform.h index 0238f69db..7ed1d86a0 100644 --- a/content/numerical/FastFourierTransform.h +++ b/content/numerical/FastFourierTransform.h @@ -3,22 +3,31 @@ * Date: 2019-01-09 * License: CC0 * Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf (do read, it's excellent) - Papers about accuracy: http://www.daemonology.net/papers/fft.pdf, http://www.cs.berkeley.edu/~fateman/papers/fftvsothers.pdf - For integers rounding works if $(|a| + |b|)\max(a, b) < \mathtt{\sim} 10^9$, or in theory maybe $10^6$. + Accuracy bound from http://www.daemonology.net/papers/fft.pdf * Description: fft(a, ...) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. Useful for convolution: \texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$. For convolution of complex numbers or more than two vectors: FFT, multiply pointwise, divide by n, reverse(start+1, end), FFT back. - For integers, consider using a number-theoretic transform instead, to avoid rounding issues. + Let N be $\max(|a|,|b|)$. Is guaranteed safe as long as $N\log_2{N}\max(a)\max(b) < \mathtt{\sim} 10^{16}$ . + Consider using number-theoretic transform or FFTMod instead if precision is an issue. * Time: O(N \log N), where $N = |A|+|B|-1$ ($\tilde 1s$ for $N=2^{22}$) * Status: somewhat tested */ #pragma once typedef complex C; +typedef complex Cd; typedef vector vd; - -void fft(vector &a, vector &rt, vi& rev, int n) { +void fft(vector &a, int n, int L, vector &rt) { + vi rev(n); + rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2; + if (rt.empty()) { + rt.assign(n, 1); + for (int k = 2; k < n; k *= 2) { + Cd z[] = {1, polar(1.0, M_PI / k)}; + rep(i, k, 2 * k) rt[i] = Cd(rt[i / 2]) * z[i & 1]; + } + } rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]); for (int k = 1; k < n; k *= 2) for (int i = 0; i < n; i += 2 * k) rep(j,0,k) { @@ -27,25 +36,19 @@ void fft(vector &a, vector &rt, vi& rev, int n) { C z(x[0]*y[0] - x[1]*y[1], x[0]*y[1] + x[1]*y[0]); /// exclude-line a[i + j + k] = a[i + j] - z; a[i + j] += z; - } + } } - -vd conv(const vd& a, const vd& b) { +vd conv(const vd &a, const vd &b) { if (a.empty() || b.empty()) return {}; vd res(sz(a) + sz(b) - 1); int L = 32 - __builtin_clz(sz(res)), n = 1 << L; - vector in(n), out(n), rt(n, 1); vi rev(n); - rep(i,0,n) rev[i] = (rev[i/2] | (i&1) << L) / 2; - for (int k = 2; k < n; k *= 2) { - C z[] = {1, polar(1.0, M_PI / k)}; - rep(i,k,2*k) rt[i] = rt[i/2] * z[i&1]; - } + vector in(n), out(n), rt; copy(all(a), begin(in)); rep(i,0,sz(b)) in[i].imag(b[i]); - fft(in, rt, rev, n); + fft(in, n, L, rt); trav(x, in) x *= x; rep(i,0,n) out[i] = in[-i & (n - 1)] - conj(in[i]); - fft(out, rt, rev, n); - rep(i,0,sz(res)) res[i] = imag(out[i]) / (4*n); + fft(out, n, L, rt); + rep(i,0,sz(res)) res[i] = imag(out[i]) / (4 * n); return res; } diff --git a/content/numerical/FastFourierTransformMod.h b/content/numerical/FastFourierTransformMod.h new file mode 100644 index 000000000..7ee520ef1 --- /dev/null +++ b/content/numerical/FastFourierTransformMod.h @@ -0,0 +1,36 @@ +/** + * Author: chilli + * Date: 2019-04-25 + * License: CC0 + * Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf + * Description: Higher precision FFT, can be used for convolutions modulo arbitrary integers. + * Let N be $\max(|a|,|b|)$. Is guaranteed safe as long as $N\log_2{N}\sqrt{\max(a)\max(b)} < \mathtt{\sim} 10^{16}$ . + * Time: O(N \log N), where $N = |A|+|B|-1$ (twice as slow as NTT or FFT) + * Status: somewhat tested + */ +#pragma once + +#include "FastFourierTransform.h" + +typedef vector vl; +template vl convMod(const vl &a, const vl &b) { + if (a.empty() || b.empty()) return {}; + vl res(sz(a) + sz(b) - 1); + int B=32-__builtin_clz(sz(res)), n = 1< L(n), R(n), outs(n), outl(n), rt; + rep(i,0,sz(a)) L[i] = Cd(a[i] / cut, a[i] % cut); + rep(i,0,sz(b)) R[i] = Cd(b[i] / cut, b[i] % cut); + fft(L, n, B, rt), fft(R, n, B, rt); + rep(i,0,n) { + int j = -i & (n - 1); + outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * n); + outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * n) / 1i; + } + fft(outl, n, B, rt), fft(outs, n, B, rt); + rep(i,0,sz(res)) { + ll av = ll(outl[i].real()+.5), cv = ll(outs[i].imag()+.5); + ll bv = ll(outl[i].imag()+.5) + ll(outs[i].real()+.5); + res[i] = ((av % M * cut + bv % M) * cut + cv % M) % M; + } + return res; +} diff --git a/content/numerical/PolyBase.h b/content/numerical/PolyBase.h new file mode 100644 index 000000000..b70b9dff1 --- /dev/null +++ b/content/numerical/PolyBase.h @@ -0,0 +1,47 @@ +/** + * Author: chilli, Andrew He, Adamant + * Date: 2019-04-27 + * Description: A FFT based Polynomial class. + */ +#pragma once + +#include "../number-theory/ModularArithmetic.h" +#include "FastFourierTransform.h" +#include "FastFourierTransformMod.h" +#include "NumberTheoreticTransform.h" + +typedef Mod num; +typedef vector poly; +poly &operator+=(poly &a, const poly &b) { + a.resize(max(sz(a), sz(b))); + rep(i, 0, sz(b)) a[i] = a[i] + b[i]; + return a; +} +poly &operator-=(poly &a, const poly &b) { + a.resize(max(sz(a), sz(b))); + rep(i, 0, sz(b)) a[i] = a[i] - b[i]; + return a; +} + +poly &operator*=(poly &a, const poly &b) { + if (sz(a) + sz(b) < 100){ + poly res(sz(a) + sz(b) - 1); + rep(i,0,sz(a)) rep(j,0,sz(b)) + res[i + j] = (res[i + j] + a[i] * b[j]); + return (a = res); + } + // auto res = convMod(vl(all(a)), vl(all(b))); + auto res = conv(vl(all(a)), vl(all(b))); + return (a = poly(all(res))); +} +poly operator*(poly a, const num b) { + poly c = a; + trav(i, c) i = i * b; + return c; +} +#define OP(o, oe) \ + poly operator o(poly a, poly b) { \ + poly c = a; \ + return c o##= b; \ + } +OP(*, *=) OP(+, +=) OP(-, -=); \ No newline at end of file diff --git a/content/numerical/PolyEvaluate.h b/content/numerical/PolyEvaluate.h new file mode 100644 index 000000000..f9aedeb7c --- /dev/null +++ b/content/numerical/PolyEvaluate.h @@ -0,0 +1,25 @@ +/** + * Author: chilli, Andrew He, Adamant + * Date: 2019-04-27 + * Description: Multi-point evaluation. Evaluates a given polynomial A at $A(x_0), ... A(x_n)$. + * Time: O(n \log^2 n) + */ +#pragma once + +#include "PolyBase.h" +#include "PolyMod.h" + +vector eval(const poly &a, const vector &x) { + int n = sz(x); + if (!n) return {}; + vector up(2 * n); + rep(i, 0, n) up[i + n] = poly({num(0) - x[i], 1}); + for (int i = n - 1; i > 0; i--) + up[i] = up[2 * i] * up[2 * i + 1]; + vector down(2 * n); + down[1] = a % up[1]; + rep(i, 2, 2 * n) down[i] = down[i / 2] % up[i]; + vector y(n); + rep(i, 0, n) y[i] = down[i + n][0]; + return y; +} diff --git a/content/numerical/PolyIntegDeriv.h b/content/numerical/PolyIntegDeriv.h new file mode 100644 index 000000000..ec0c71263 --- /dev/null +++ b/content/numerical/PolyIntegDeriv.h @@ -0,0 +1,22 @@ +/** + * Author: chilli, Andrew He, Adamant + * Date: 2019-04-27 + * Description: A FFT based Polynomial class. + */ +#pragma once +#include "PolyBase.h" + +poly deriv(poly a) { + if (a.empty()) return {}; + poly b(sz(a) - 1); + rep(i, 1, sz(a)) b[i - 1] = a[i] * num(i); + return b; +} +poly integr(poly a) { + if (a.empty()) return {0}; + poly b(sz(a) + 1); + b[1] = num(1); + rep(i, 2, sz(b)) b[i] = b[mod%i]*Mod(-mod/i+mod); + rep(i, 1 ,sz(b)) b[i] = a[i-1] * b[i]; + return b; +} diff --git a/content/numerical/PolyInterpolate.h b/content/numerical/PolyInterpolate.h index 9343edbc3..0a71def06 100644 --- a/content/numerical/PolyInterpolate.h +++ b/content/numerical/PolyInterpolate.h @@ -1,25 +1,24 @@ /** - * Author: Simon Lindholm - * Date: 2017-05-10 - * License: CC0 - * Source: Wikipedia + * Author: chilli, Andrew He, Adamant + * Date: 2019-04-27 * Description: Given $n$ points (x[i], y[i]), computes an n-1-degree polynomial $p$ that * passes through them: $p(x) = a[0]*x^0 + ... + a[n-1]*x^{n-1}$. - * For numerical precision, pick $x[k] = c*\cos(k/(n-1)*\pi), k=0 \dots n-1$. - * Time: O(n^2) + * Time: O(n \log^2 n) */ #pragma once -typedef vector vd; -vd interpolate(vd x, vd y, int n) { - vd res(n), temp(n); - rep(k,0,n-1) rep(i,k+1,n) - y[i] = (y[i] - y[k]) / (x[i] - x[k]); - double last = 0; temp[0] = 1; - rep(k,0,n) rep(i,0,n) { - res[i] += y[k] * temp[i]; - swap(last, temp[i]); - temp[i] -= last * x[k]; - } - return res; +#include "PolyBase.h" +#include "PolyIntegDeriv.h" +#include "PolyEvaluate.h" + +poly interp(vector x, vector y) { + int n=sz(x); + vector up(n*2); + rep(i,0,n) up[i+n] = poly({num(0)-x[i], num(1)}); + for(int i=n-1; i>0;i--) up[i] = up[2*i]*up[2*i+1]; + vector a = eval(deriv(up[1]), x); + vector down(2*n); + rep(i,0,n) down[i+n] = poly({y[i]*(num(1)/a[i])}); + for(int i=n-1;i>0;i--) down[i] = down[i*2] * up[i*2+1] + down[i*2+1] * up[i*2]; + return down[1]; } diff --git a/content/numerical/PolyInterpolateSlow.h b/content/numerical/PolyInterpolateSlow.h new file mode 100644 index 000000000..9343edbc3 --- /dev/null +++ b/content/numerical/PolyInterpolateSlow.h @@ -0,0 +1,25 @@ +/** + * Author: Simon Lindholm + * Date: 2017-05-10 + * License: CC0 + * Source: Wikipedia + * Description: Given $n$ points (x[i], y[i]), computes an n-1-degree polynomial $p$ that + * passes through them: $p(x) = a[0]*x^0 + ... + a[n-1]*x^{n-1}$. + * For numerical precision, pick $x[k] = c*\cos(k/(n-1)*\pi), k=0 \dots n-1$. + * Time: O(n^2) + */ +#pragma once + +typedef vector vd; +vd interpolate(vd x, vd y, int n) { + vd res(n), temp(n); + rep(k,0,n-1) rep(i,k+1,n) + y[i] = (y[i] - y[k]) / (x[i] - x[k]); + double last = 0; temp[0] = 1; + rep(k,0,n) rep(i,0,n) { + res[i] += y[k] * temp[i]; + swap(last, temp[i]); + temp[i] -= last * x[k]; + } + return res; +} diff --git a/content/numerical/PolyInverse.h b/content/numerical/PolyInverse.h new file mode 100644 index 000000000..bcc93b8dd --- /dev/null +++ b/content/numerical/PolyInverse.h @@ -0,0 +1,16 @@ +/** + * Author: chilli, Andrew He, Adamant + * Date: 2019-04-27 + * Description: A FFT based Polynomial class. + */ +#pragma once + +#include "PolyBase.h" + +poly modK(poly a, int k) { return {a.begin(), a.begin() + min(k, sz(a))}; } +poly inverse(poly A) { + poly B = poly({num(1) / A[0]}); + while (sz(B) < sz(A)) + B = modK(B * (poly({num(2)}) - modK(A, 2*sz(B)) * B), 2 * sz(B)); + return modK(B, sz(A)); +} \ No newline at end of file diff --git a/content/numerical/PolyLogExp.h b/content/numerical/PolyLogExp.h new file mode 100644 index 000000000..cb4bbb9c2 --- /dev/null +++ b/content/numerical/PolyLogExp.h @@ -0,0 +1,25 @@ +/** + * Author: chilli, Andrew He, Adamant + * Date: 2019-04-27 + * Description: A FFT based Polynomial class. + */ +#pragma once + +#include "PolyBase.h" +#include "PolyInverse.h" +#include "PolyIntegDeriv.h" + +poly log(poly a) { + return modK(integr(deriv(a) * inverse(a)), sz(a)); +} +poly exp(poly a) { + poly b(1, num(1)); + if (a.empty()) + return b; + while (sz(b) < sz(a)) { + b.resize(sz(b) * 2); + b *= (poly({num(1)}) + modK(a, sz(b)) - log(b)); + b.resize(sz(b) / 2 + 1); + } + return modK(b, sz(a)); +} \ No newline at end of file diff --git a/content/numerical/PolyMod.h b/content/numerical/PolyMod.h new file mode 100644 index 000000000..35c650491 --- /dev/null +++ b/content/numerical/PolyMod.h @@ -0,0 +1,30 @@ +/** + * Author: chilli, Andrew He, Adamant + * Date: 2019-04-27 + * Description: A FFT based Polynomial class. + */ +#pragma once + +#include "PolyBase.h" +#include "PolyInverse.h" + +poly &operator/=(poly &a, poly b) { + if (sz(a) < sz(b)) + return a = {}; + int s = sz(a) - sz(b) + 1; + reverse(all(a)), reverse(all(b)); + a.resize(s), b.resize(s); + a = a * inverse(b); + a.resize(s), reverse(all(a)); + return a; +} +OP(/, /=) +poly &operator%=(poly &a, poly &b) { + if (sz(a) < sz(b)) + return a; + poly c = (a / b) * b; + a.resize(sz(b) - 1); + rep(i, 0, sz(a)) a[i] = a[i] - c[i]; + return a; +} +OP(%, %=) \ No newline at end of file diff --git a/content/numerical/PolyPow.h b/content/numerical/PolyPow.h new file mode 100644 index 000000000..b259e7b3d --- /dev/null +++ b/content/numerical/PolyPow.h @@ -0,0 +1,23 @@ +/** + * Author: chilli, Andrew He, Adamant + * Date: 2019-04-27 + * Description: A FFT based Polynomial class. + */ +#pragma once + +#include "PolyBase.h" +#include "PolyLogExp.h" + +poly pow(poly a, ll m) { + int p = 0, n = sz(a); + while (p < sz(a) && a[p].v == 0) + ++p; + if (ll(m)*p >= sz(a)) return poly(sz(a)); + num j = a[p]; + a = {a.begin() + p, a.end()}; + a = a * (num(1) / j); + a.resize(n); + auto res = exp(log(a) * num(m)) * (j ^ m); + res.insert(res.begin(), p*m, 0); + return {res.begin(), res.begin()+n}; +} diff --git a/content/numerical/Polynomial.h b/content/numerical/Polynomial.h deleted file mode 100644 index f0569eb65..000000000 --- a/content/numerical/Polynomial.h +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Author: David Rydh, Per Austrin - * Date: 2003-03-16 - * Description: - */ -#pragma once - -struct Poly { - vector a; - double operator()(double x) const { - double val = 0; - for(int i = sz(a); i--;) (val *= x) += a[i]; - return val; - } - void diff() { - rep(i,1,sz(a)) a[i-1] = i*a[i]; - a.pop_back(); - } - void divroot(double x0) { - double b = a.back(), c; a.back() = 0; - for(int i=sz(a)-1; i--;) c = a[i], a[i] = a[i+1]*x0+b, b=c; - a.pop_back(); - } -}; diff --git a/content/numerical/chapter.tex b/content/numerical/chapter.tex index aab4b5699..13b354611 100644 --- a/content/numerical/chapter.tex +++ b/content/numerical/chapter.tex @@ -1,9 +1,15 @@ \chapter{Numerical} \kactlimport{GoldenSectionSearch.h} -\kactlimport{Polynomial.h} -\kactlimport{PolyRoots.h} +\kactlimport{PolyBase.h} +\kactlimport{PolyInverse.h} +\kactlimport{PolyMod.h} +\kactlimport{PolyIntegDeriv.h} +\kactlimport{PolyLogExp.h} +\kactlimport{PolyPow.h} \kactlimport{PolyInterpolate.h} +\kactlimport{PolyEvaluate.h} +\kactlimport{PolyRoots.h} \kactlimport{BerlekampMassey.h} \kactlimport{LinearRecurrence.h} \kactlimport{HillClimbing.h} @@ -19,5 +25,6 @@ \chapter{Numerical} \kactlimport{Tridiagonal.h} \section{Fourier transforms} \kactlimport{FastFourierTransform.h} + \kactlimport{FastFourierTransformMod.h} \kactlimport{NumberTheoreticTransform.h} \kactlimport{FastSubsetTransform.h} diff --git a/fuzz-tests/numerical/Polynomial.cpp b/fuzz-tests/numerical/Polynomial.cpp new file mode 100644 index 000000000..1b53166e0 --- /dev/null +++ b/fuzz-tests/numerical/Polynomial.cpp @@ -0,0 +1,630 @@ +#include + +#define all(x) begin(x), end(x) +typedef long long ll; +using namespace std; + +#define per(i, a, b) for (int i = (b)-1; i >= (a); --i) +#define rep(i, a, b) for (int i = a; i < (b); ++i) +#define trav(a, x) for (auto &a : x) +#define sz(x) (int)(x).size() +typedef long long ll; +typedef pair pii; +typedef vector vi; + +struct timeit { + decltype(chrono::high_resolution_clock::now()) begin; + const string label; + timeit(string label = "???") : label(label) { begin = chrono::high_resolution_clock::now(); } + ~timeit() { + auto end = chrono::high_resolution_clock::now(); + auto duration = chrono::duration_cast(end - begin).count(); + cerr << duration << "ms elapsed [" << label << "]" << endl; + } +}; +namespace MIT { +namespace fft { +#if FFT +// FFT +using dbl = double; +struct num { /// start-hash + dbl x, y; + num(dbl x_ = 0, dbl y_ = 0) : x(x_), y(y_) {} +}; +inline num operator+(num a, num b) { return num(a.x + b.x, a.y + b.y); } +inline num operator-(num a, num b) { return num(a.x - b.x, a.y - b.y); } +inline num operator*(num a, num b) { return num(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); } +inline num conj(num a) { return num(a.x, -a.y); } +inline num inv(num a) { + dbl n = (a.x * a.x + a.y * a.y); + return num(a.x / n, -a.y / n); +} +/// end-hash +#else +// NTT +const int mod = 998244353, g = 3; +// For p < 2^30 there is also (5 << 25, 3), (7 << 26, 3), +// (479 << 21, 3) and (483 << 21, 5). Last two are > 10^9. +struct num { /// start-hash + int v; + num(ll v_ = 0) : v(int(v_ % mod)) { + if (v < 0) + v += mod; + } + explicit operator int() const { return v; } +}; +inline num operator+(num a, num b) { return num(a.v + b.v); } +inline num operator-(num a, num b) { return num(a.v + mod - b.v); } +inline num operator*(num a, num b) { return num(1ll * a.v * b.v); } +inline num pow(num a, int b) { + num r = 1; + do { + if (b & 1) + r = r * a; + a = a * a; + } while (b >>= 1); + return r; +} +inline num inv(num a) { return pow(a, mod - 2); } +/// end-hash +#endif + +using vn = vector; +vi rev({0, 1}); +vn rt(2, num(1)), fa, fb; + +inline void init(int n) { /// start-hash + if (n <= sz(rt)) + return; + rev.resize(n); + rep(i, 0, n) rev[i] = (rev[i >> 1] | ((i & 1) * n)) >> 1; + rt.reserve(n); + for (int k = sz(rt); k < n; k *= 2) { + rt.resize(2 * k); +#if FFT + double a = M_PI / k; + num z(cos(a), sin(a)); // FFT +#else + num z = pow(num(g), (mod - 1) / (2 * k)); // NTT +#endif + rep(i, k / 2, k) rt[2 * i] = rt[i], rt[2 * i + 1] = rt[i] * z; + } +} /// end-hash + +inline void fft(vector &a, int n) { /// start-hash + init(n); + int s = __builtin_ctz(sz(rev) / n); + rep(i, 0, n) if (i> s) swap(a[i], a[rev[i] >> s]); + for (int k = 1; k < n; k *= 2) + for (int i = 0; i < n; i += 2 * k) + rep(j, 0, k) { + num t = rt[j + k] * a[i + j + k]; + a[i + j + k] = a[i + j] - t; + a[i + j] = a[i + j] + t; + } +} /// end-hash + +// Complex/NTT +vn multiply(vn a, vn b) { /// start-hash + int s = sz(a) + sz(b) - 1; + if (s <= 0) + return {}; + int L = s > 1 ? 32 - __builtin_clz(s - 1) : 0, n = 1 << L; + a.resize(n), b.resize(n); + fft(a, n); + fft(b, n); + num d = inv(num(n)); + rep(i, 0, n) a[i] = a[i] * b[i] * d; + reverse(a.begin() + 1, a.end()); + fft(a, n); + a.resize(s); + return a; +} /// end-hash + +// Complex/NTT power-series inverse +// Doubles b as b[:n] = (2 - a[:n] * b[:n/2]) * b[:n/2] +vn inverse(const vn &a) { /// start-hash + if (a.empty()) + return {}; + vn b({inv(a[0])}); + b.reserve(2 * a.size()); + while (sz(b) < sz(a)) { + int n = 2 * sz(b); + b.resize(2 * n, 0); + if (sz(fa) < 2 * n) + fa.resize(2 * n); + fill(fa.begin(), fa.begin() + 2 * n, 0); + copy(a.begin(), a.begin() + min(n, sz(a)), fa.begin()); + fft(b, 2 * n); + fft(fa, 2 * n); + num d = inv(num(2 * n)); + rep(i, 0, 2 * n) b[i] = b[i] * (2 - fa[i] * b[i]) * d; + reverse(b.begin() + 1, b.end()); + fft(b, 2 * n); + b.resize(n); + } + b.resize(a.size()); + return b; +} /// end-hash + +#if FFT +// Double multiply (num = complex) +using vd = vector; +vd multiply(const vd &a, const vd &b) { /// start-hash + int s = sz(a) + sz(b) - 1; + if (s <= 0) + return {}; + int L = s > 1 ? 32 - __builtin_clz(s - 1) : 0, n = 1 << L; + if (sz(fa) < n) + fa.resize(n); + if (sz(fb) < n) + fb.resize(n); + + fill(fa.begin(), fa.begin() + n, 0); + rep(i, 0, sz(a)) fa[i].x = a[i]; + rep(i, 0, sz(b)) fa[i].y = b[i]; + fft(fa, n); + trav(x, fa) x = x * x; + rep(i, 0, n) fb[i] = fa[(n - i) & (n - 1)] - conj(fa[i]); + fft(fb, n); + vd r(s); + rep(i, 0, s) r[i] = fb[i].y / (4 * n); + return r; +} /// end-hash + +// Integer multiply mod m (num = complex) /// start-hash +vi multiply_mod(const vi &a, const vi &b, int m) { + int s = sz(a) + sz(b) - 1; + if (s <= 0) + return {}; + int L = s > 1 ? 32 - __builtin_clz(s - 1) : 0, n = 1 << L; + if (sz(fa) < n) + fa.resize(n); + if (sz(fb) < n) + fb.resize(n); + + rep(i, 0, sz(a)) fa[i] = num(a[i] & ((1 << 15) - 1), a[i] >> 15); + fill(fa.begin() + sz(a), fa.begin() + n, 0); + rep(i, 0, sz(b)) fb[i] = num(b[i] & ((1 << 15) - 1), b[i] >> 15); + fill(fb.begin() + sz(b), fb.begin() + n, 0); + + fft(fa, n); + fft(fb, n); + double r0 = 0.5 / n; // 1/2n + rep(i, 0, n / 2 + 1) { + int j = (n - i) & (n - 1); + num g0 = (fb[i] + conj(fb[j])) * r0; + num g1 = (fb[i] - conj(fb[j])) * r0; + swap(g1.x, g1.y); + g1.y *= -1; + if (j != i) { + swap(fa[j], fa[i]); + fb[j] = fa[j] * g1; + fa[j] = fa[j] * g0; + } + fb[i] = fa[i] * conj(g1); + fa[i] = fa[i] * conj(g0); + } + fft(fa, n); + fft(fb, n); + vi r(s); + rep(i, 0, s) r[i] = int((ll(fa[i].x + 0.5) + (ll(fa[i].y + 0.5) % m << 15) + (ll(fb[i].x + 0.5) % m << 15) + + (ll(fb[i].y + 0.5) % m << 30)) % + m); + return r; +} /// end-hash +#endif + +} // namespace fft + +// For multiply_mod, use num = modnum, poly = vector +using fft::num; +using poly = fft::vn; +using fft::inverse; +using fft::multiply; +/// start-hash +poly &operator+=(poly &a, const poly &b) { + if (sz(a) < sz(b)) + a.resize(b.size()); + rep(i, 0, sz(b)) a[i] = a[i] + b[i]; + return a; +} +poly operator+(const poly &a, const poly &b) { + poly r = a; + r += b; + return r; +} +poly &operator-=(poly &a, const poly &b) { + if (sz(a) < sz(b)) + a.resize(b.size()); + rep(i, 0, sz(b)) a[i] = a[i] - b[i]; + return a; +} +poly operator-(const poly &a, const poly &b) { + poly r = a; + r -= b; + return r; +} +poly operator*(const poly &a, const poly &b) { + // TODO: small-case? + return multiply(a, b); +} +poly &operator*=(poly &a, const poly &b) { return a = a * b; } +/// end-hash +poly &operator*=(poly &a, const num &b) { // Optional + trav(x, a) x = x * b; + return a; +} +poly operator*(const poly &a, const num &b) { + poly r = a; + r *= b; + return r; +} + +// Polynomial floor division; no leading 0's plz +poly operator/(poly a, poly b) { /// start-hash + if (sz(a) < sz(b)) + return {}; + int s = sz(a) - sz(b) + 1; + reverse(a.begin(), a.end()); + reverse(b.begin(), b.end()); + a.resize(s); + b.resize(s); + a = a * inverse(move(b)); + a.resize(s); + reverse(a.begin(), a.end()); + return a; +} /// end-hash +poly &operator/=(poly &a, const poly &b) { return a = a / b; } +poly &operator%=(poly &a, const poly &b) { /// start-hash + if (sz(a) >= sz(b)) { + poly c = (a / b) * b; + a.resize(sz(b) - 1); + rep(i, 0, sz(a)) a[i] = a[i] - c[i]; + } + return a; +} /// end-hash +poly operator%(const poly &a, const poly &b) { + poly r = a; + r %= b; + return r; +} + +// Log/exp/pow +poly deriv(const poly &a) { /// start-hash + if (a.empty()) + return {}; + poly b(sz(a) - 1); + rep(i, 1, sz(a)) b[i - 1] = a[i] * i; + return b; +} /// end-hash +poly integ(const poly &a) { /// start-hash + if (a.empty()) + return {0}; + poly b(sz(a) + 1); + b[1] = 1; // mod p + rep(i, 2, sz(b)) b[i] = b[fft::mod % i] * (-fft::mod / i); // mod p + rep(i, 1, sz(b)) b[i] = a[i - 1] * b[i]; // mod p + // rep(i,1,sz(b)) b[i]=a[i-1]*inv(num(i)); // else + return b; +} /// end-hash +poly log(const poly &a) { // a[0] == 1 /// start-hash + poly b = integ(deriv(a) * inverse(a)); + b.resize(a.size()); + return b; +} /// end-hash +poly exp(const poly &a) { // a[0] == 0 /// start-hash + poly b(1, num(1)); + if (a.empty()) + return b; + while (sz(b) < sz(a)) { + int n = min(sz(b) * 2, sz(a)); + b.resize(n); + poly v = poly(a.begin(), a.begin() + n) - log(b); + v[0] = v[0] + num(1); + b *= v; + b.resize(n); + } + return b; +} /// end-hash +poly pow(const poly &a, int m) { // m >= 0 /// start-hash + poly b(a.size()); + if (!m) { + b[0] = 1; + return b; + } + int p = 0; + while (p < sz(a) && a[p].v == 0) + ++p; + if (1ll * m * p >= sz(a)) + return b; + num mu = pow(a[p], m), di = inv(a[p]); + poly c(sz(a) - m * p); + rep(i, 0, sz(c)) c[i] = a[i + p] * di; + c = log(c); + trav(v, c) v = v * m; + c = exp(c); + rep(i, 0, sz(c)) b[i + m * p] = c[i] * mu; + return b; +} /// end-hash + +// Multipoint evaluation/interpolation +/// start-hash +vector eval(const poly &a, const vector &x) { + int n = sz(x); + if (!n) + return {}; + vector up(2 * n); + rep(i, 0, n) up[i + n] = poly({0 - x[i], 1}); + per(i, 1, n) up[i] = up[2 * i] * up[2 * i + 1]; + vector down(2 * n); + down[1] = a % up[1]; + { + rep(i, 2, 2 * n) { + down[i] = down[i / 2] % up[i]; + } + } + vector y(n); + rep(i, 0, n) y[i] = down[i + n][0]; + return y; +} /// end-hash +/// start-hash +poly interp(const vector &x, const vector &y) { + int n = sz(x); + assert(n); + vector up(n * 2); + rep(i, 0, n) up[i + n] = poly({0 - x[i], 1}); + per(i, 1, n) up[i] = up[2 * i] * up[2 * i + 1]; + vector a = eval(deriv(up[1]), x); + vector down(2 * n); + rep(i, 0, n) down[i + n] = poly({y[i] * inv(a[i])}); + per(i, 1, n) down[i] = down[i * 2] * up[i * 2 + 1] + down[i * 2 + 1] * up[i * 2]; + return down[1]; +} /// end-hash +} // namespace MIT + +namespace mine { +namespace ignore1 { +#include "../../content/number-theory/ModPow.h" +} +namespace ignore2 { +#include "../../content/number-theory/ModularArithmetic.h" +} +ll modpow(ll a, ll e); +#include "../../content/numerical/NumberTheoreticTransform.h" +ll modpow(ll a, ll e) { + if (e == 0) return 1; + ll x = modpow(a * a % mod, e >> 1); + return e & 1 ? x * a % mod : x; +} +#include "../../content/numerical/FastFourierTransformMod.h" +struct Mod { + ll v; + Mod() : v(0) {} + Mod(ll vv) : v(vv % mod) {} + Mod operator+(Mod b) { return Mod((v + b.v) % mod); } + Mod operator-(Mod b) { return Mod(v < b.v ? v - b.v + mod : v - b.v); } + Mod operator*(Mod b) { return Mod(v * b.v); } + Mod operator/(Mod b) { return *this * invert(b); } + Mod invert(Mod a) { return a^(mod-2); } + Mod operator^(ll e) { + ll ans = 1, b = (*this).v; + for (; e; b = b * b % mod, e /= 2) + if (e & 1) ans = ans * b % mod; + return ans; + } + explicit operator ll() const { return v; } +}; + +typedef Mod num; +typedef vector poly; + +#include "../../content/numerical/PolyBase.h" +#include "../../content/numerical/PolyMod.h" +#include "../../content/numerical/PolyIntegDeriv.h" +#include "../../content/numerical/PolyLogExp.h" +#include "../../content/numerical/PolyPow.h" +#include "../../content/numerical/PolyEvaluate.h" +#include "../../content/numerical/PolyInterpolate.h" +} // namespace mine + +pair genVec(int sz) { + mine::poly a; + MIT::poly am; + for (int i = 0; i < sz; i++) { + int val = rand(); + a.push_back(val); + am.push_back(val); + } + return {a, am}; +} +bool checkEqual(mine::poly a, MIT::poly b) { + if (sz(a) != sz(b)) + return false; + int ml = min(sz(a), sz(b)); + for (int i = 0; i < ml; i++) + if (a[i].v != b[i].v) + return false; + // for (int i = ml; i < sz(a); i++) + // if (a[i].v != 0) + // return false; + // for (int i = ml; i < sz(b); i++) + // if (b[i].v != 0) + // return false; + return true; +} + +template void fail(A mine, B mit) { + cout<<"mine: "; + for (auto i : mine) + cout << i.v << ' '; + cout << endl; + cout<<"MIT: "; + for (auto i : mit) + cout << i.v << ' '; + cout << endl; + +} + +const int NUMITERS=10; +template void testBinary(string name, A f1, B f2, int mxSz = 5) { + for (int it = 0; it < NUMITERS; it++) { + auto a = genVec((rand() % mxSz) + 1); + auto b = genVec((rand() % mxSz) + 1); + mine::poly res; + res = f1(a.first, b.first); + auto t = f2(a.second, b.second); + if (!checkEqual(res, t)) + fail(res, t); + + assert(checkEqual(res, t)); + } + cout << name << " tests passed!" << endl; + auto a = genVec(mxSz); + auto b = genVec(mxSz/2); + { + timeit x("mine"); + for (int it=0; it void testUnary(string name, A f1, B f2, int mxSz = 5) { + for (int it = 0; it < NUMITERS; it++) { + auto a = genVec((rand() % mxSz) + 1); + auto res = f1(a.first); + auto t = f2(a.second); + if (!checkEqual(res, t)) + fail(res, t); + assert(checkEqual(res, t)); + } + cout << name + " tests passed!" << endl; + auto a = genVec(mxSz); + { + timeit x("mine"); + for (int it=0; it void testPow(string name, A f1, B f2, int mxSz = 5, int mxPref=5) { + for (int it = 0; it < NUMITERS; it++) { + auto a = genVec((rand() % mxSz) + 1); + int pref = rand()%mxSz; + for (int j=0; j void testEval(string name, A f1, B f2, int mxSz = 5) { + for (int it = 0; it < NUMITERS; it++) { + break; + auto a = genVec((rand() % mxSz) + 1); + auto b = genVec((rand() % mxSz)+1); + auto res = f1(a.first, b.first); + auto t = f2(a.second, b.second); + if (!checkEqual(res, t)) + fail(res, t); + assert(checkEqual(res, t)); + } + cout << name + " tests passed!" << endl; + auto a = genVec(mxSz); + auto b = genVec(mxSz); + { + timeit x("mine"); + for (int it = 0; it < NUMITERS; it++) { + f1(a.first, b.first); + } + } + { + timeit x("MIT"); + for (int it = 0; it < NUMITERS; it++) { + f2(a.second, b.second); + } + } + cout< void testInterp(string name, A f1, B f2, int mxSz = 5) { + for (int it = 0; it < NUMITERS; it++) { + int s = (rand()%mxSz) + 1; + auto a = genVec(s); + auto b = genVec(s); + auto res = f1(a.first, b.first); + auto t = f2(a.second, b.second); + if (!checkEqual(res, t)) + fail(res, t); + assert(checkEqual(res, t)); + } + cout << name + " tests passed!" << endl; + auto a = genVec(mxSz); + auto b = genVec(mxSz); + { + timeit x("mine"); + for (int it = 0; it < NUMITERS; it++) { + f1(a.first, b.first); + } + } + { + timeit x("MIT"); + for (int it = 0; it < NUMITERS; it++) { + f2(a.second, b.second); + } + } + cout<