#include <limits.h>
#include <signal.h>
#include <cstring>
#include <gmp.h>
#include <NTL/ZZ.h>
#include "bern_modp_util.h"
#include "bern_modp.h"
NTL_CLIENT;
using namespace std;
namespace bernmm {
long bernsum_powg(long p, double pinv, long k, long g)
{
long half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;
long g_to_jm1 = 1;
long g_to_km1 = PowerMod(g, k-1, p, pinv);
long g_to_km1_to_j = g_to_km1;
long sum = 0;
double g_pinv = ((double) g) / ((double) p);
mulmod_precon_t g_to_km1_pinv = PrepMulModPrecon(g_to_km1, p, pinv);
for (long j = 1; j <= (p-1)/2; j++)
{
long q;
g_to_jm1 = MulDivRem(q, g_to_jm1, g, p, g_pinv);
long h = SubMod(q, half_gm1, p);
sum = SubMod(sum, MulMod(h, g_to_km1_to_j, p, pinv), p);
g_to_km1_to_j = MulModPrecon(g_to_km1_to_j, g_to_km1, p, g_to_km1_pinv);
}
return sum;
}
#define MAX_INV 256
#if (GMP_NAIL_BITS == 0) && (GMP_LIMB_BITS >= ULONG_BITS)
typedef mp_limb_t word_t;
#define WORD_BITS GMP_LIMB_BITS
class Expander
{
private:
mp_limb_t pinv[MAX_INV + 2];
mp_limb_t p;
int max_words;
public:
Expander(long p, int max_words)
{
assert(max_words >= 1);
assert(max_words <= MAX_INV);
this->max_words = max_words;
this->p = p;
mp_limb_t one = 1;
mpn_divrem_1(pinv, max_words + 1, &one, 1, p);
}
void expand(word_t* res, long s, int n)
{
assert(s > 0 && s < p);
assert(n >= 1);
assert(n <= max_words);
if (s == 1)
{
for (int i = 1; i <= n; i++)
res[i] = pinv[max_words - n + i];
}
else
{
mpn_mul_1(res, pinv + max_words - n, n + 1, (mp_limb_t) s);
if (res[0] > -((mp_limb_t) s))
{
mp_limb_t ss = s;
mpn_divrem_1(res, n + 1, &ss, 1, p);
}
}
}
};
#else
typedef unsigned long word_t;
#define WORD_BITS ULONG_BITS
class Expander
{
private:
mp_limb_t p;
mpz_t temp;
public:
Expander(long p, int max_words)
{
this->p = p;
mpz_init(temp);
}
~Expander()
{
mpz_clear(temp);
}
void expand(word_t* res, long s, int n)
{
assert(s > 0 && s < p);
assert(n >= 1);
mpz_set_ui(temp, s);
mpz_mul_2exp(temp, temp, WORD_BITS * n);
mpz_fdiv_q_ui(temp, temp, p);
mpz_export(res + 1, NULL, -1, sizeof(word_t), 0, 0, temp);
}
};
#endif
#define TABLE_LG_SIZE 8
#define TABLE_SIZE (((word_t) 1) << TABLE_LG_SIZE)
#define TABLE_MASK (TABLE_SIZE - 1)
#define NUM_TABLES (WORD_BITS / TABLE_LG_SIZE)
#if WORD_BITS % TABLE_LG_SIZE != 0
#error Number of bits in a long must be divisible by TABLE_LG_SIZE
#endif
long bernsum_pow2(long p, double pinv, long k, long g, long n)
{
long tables[NUM_TABLES][TABLE_SIZE];
memset(tables, 0, sizeof(long) * NUM_TABLES * TABLE_SIZE);
long m = (p-1) / n;
if (n & 1)
m >>= 1;
else
n >>= 1;
long g_to_km1 = PowerMod(g, k-1, p, pinv);
long two_to_km1 = PowerMod(2, k-1, p, pinv);
long B_to_km1 = PowerMod(two_to_km1, WORD_BITS, p, pinv);
long s_jump = PowerMod(2, MAX_INV * WORD_BITS, p, pinv);
mulmod_precon_t g_pinv = PrepMulModPrecon(g, p, pinv);
mulmod_precon_t g_to_km1_pinv = PrepMulModPrecon(g_to_km1, p, pinv);
mulmod_precon_t two_to_km1_pinv = PrepMulModPrecon(two_to_km1, p, pinv);
mulmod_precon_t B_to_km1_pinv = PrepMulModPrecon(B_to_km1, p, pinv);
mulmod_precon_t s_jump_pinv = PrepMulModPrecon(s_jump, p, pinv);
long g_to_km1_to_i = 1;
long g_to_i = 1;
long sum = 0;
Expander expander(p, (n >= MAX_INV * WORD_BITS)
? MAX_INV : ((n - 1) / WORD_BITS + 1));
for (long i = 0; i < m; i++)
{
long s = g_to_i;
long x = g_to_km1_to_i;
for (long nn = n; nn > 0; nn -= MAX_INV * WORD_BITS)
{
word_t s_over_p[MAX_INV + 2];
long bits, words;
if (nn >= MAX_INV * WORD_BITS)
{
bits = MAX_INV * WORD_BITS;
words = MAX_INV;
}
else
{
bits = nn;
words = (nn - 1) / WORD_BITS + 1;
}
expander.expand(s_over_p, s, words);
word_t* next = s_over_p + words;
for (; bits >= WORD_BITS; bits -= WORD_BITS, next--)
{
word_t y = *next;
#if NUM_TABLES != 8 && NUM_TABLES != 4
for (long h = 0; h < NUM_TABLES; h++)
{
long& target = tables[h][y & TABLE_MASK];
target = SubMod(target, x, p);
y >>= TABLE_LG_SIZE;
}
#else
long& target0 = tables[0][y & TABLE_MASK];
target0 = SubMod(target0, x, p);
long& target1 = tables[1][(y >> TABLE_LG_SIZE) & TABLE_MASK];
target1 = SubMod(target1, x, p);
long& target2 = tables[2][(y >> (2*TABLE_LG_SIZE)) & TABLE_MASK];
target2 = SubMod(target2, x, p);
long& target3 = tables[3][(y >> (3*TABLE_LG_SIZE)) & TABLE_MASK];
target3 = SubMod(target3, x, p);
#if NUM_TABLES == 8
long& target4 = tables[4][(y >> (4*TABLE_LG_SIZE)) & TABLE_MASK];
target4 = SubMod(target4, x, p);
long& target5 = tables[5][(y >> (5*TABLE_LG_SIZE)) & TABLE_MASK];
target5 = SubMod(target5, x, p);
long& target6 = tables[6][(y >> (6*TABLE_LG_SIZE)) & TABLE_MASK];
target6 = SubMod(target6, x, p);
long& target7 = tables[7][(y >> (7*TABLE_LG_SIZE)) & TABLE_MASK];
target7 = SubMod(target7, x, p);
#endif
#endif
x = MulModPrecon(x, B_to_km1, p, B_to_km1_pinv);
}
word_t y = *next;
for (; bits > 0; bits--)
{
if (y & (((word_t) 1) << (WORD_BITS - 1)))
sum = SubMod(sum, x, p);
else
sum = AddMod(sum, x, p);
x = MulModPrecon(x, two_to_km1, p, two_to_km1_pinv);
y <<= 1;
}
s = MulModPrecon(s, s_jump, p, s_jump_pinv);
}
g_to_i = MulModPrecon(g_to_i, g, p, g_pinv);
g_to_km1_to_i = MulModPrecon(g_to_km1_to_i, g_to_km1, p, g_to_km1_pinv);
}
long weights[TABLE_SIZE];
weights[0] = 0;
for (long h = 0, x = 1; h < TABLE_LG_SIZE;
h++, x = MulModPrecon(x, two_to_km1, p, two_to_km1_pinv))
{
for (long i = (1L << h) - 1; i >= 0; i--)
{
weights[2*i+1] = SubMod(weights[i], x, p);
weights[2*i] = AddMod(weights[i], x, p);
}
}
long x_jump = PowerMod(two_to_km1, TABLE_LG_SIZE, p, pinv);
for (long h = NUM_TABLES - 1, x = 1; h >= 0; h--)
{
mulmod_precon_t x_pinv = PrepMulModPrecon(x, p, pinv);
for (long i = 0; i < TABLE_SIZE; i++)
{
long y = MulMod(tables[h][i], weights[i], p, pinv);
y = MulModPrecon(y, x, p, x_pinv);
sum = SubMod(sum, y, p);
}
x = MulModPrecon(x_jump, x, p, x_pinv);
}
return sum;
}
#define LOW_MASK ((1L << (ULONG_BITS / 2)) - 1)
static inline long RedcFast(long x, long n, long ninv2)
{
unsigned long y = (x * ninv2) & LOW_MASK;
unsigned long z = x + (n * y);
return z >> (ULONG_BITS / 2);
}
static inline long Redc(long x, long n, long ninv2)
{
long y = RedcFast(x, n, ninv2);
if (y >= n)
y -= n;
return y;
}
long PrepRedc(long n)
{
long ninv2 = -n;
for (long bits = 3; bits < ULONG_BITS/2; bits *= 2)
ninv2 = 2*ninv2 + n * ninv2 * ninv2;
return ninv2 & LOW_MASK;
}
long bernsum_pow2_redc(long p, double pinv, long k, long g, long n)
{
long pinv2 = PrepRedc(p);
long F = (1L << (ULONG_BITS/2)) % p;
long tables[NUM_TABLES][TABLE_SIZE];
memset(tables, 0, sizeof(long) * NUM_TABLES * TABLE_SIZE);
long m = (p-1) / n;
if (n & 1)
m >>= 1;
else
n >>= 1;
long g_to_km1 = PowerMod(g, k-1, p, pinv);
long two_to_km1 = PowerMod(2, k-1, p, pinv);
long B_to_km1 = PowerMod(two_to_km1, WORD_BITS, p, pinv);
long s_jump = PowerMod(2, MAX_INV * WORD_BITS, p, pinv);
long g_redc = MulMod(g, F, p, pinv);
long g_to_km1_redc = MulMod(g_to_km1, F, p, pinv);
long two_to_km1_redc = MulMod(two_to_km1, F, p, pinv);
long B_to_km1_redc = MulMod(B_to_km1, F, p, pinv);
long s_jump_redc = MulMod(s_jump, F, p, pinv);
long g_to_km1_to_i = 1;
long g_to_i = 1;
long sum = 0;
Expander expander(p, (n >= MAX_INV * WORD_BITS)
? MAX_INV : ((n - 1) / WORD_BITS + 1));
for (long i = 0; i < m; i++)
{
long s = g_to_i;
if (s >= p)
s -= p;
long x = g_to_km1_to_i;
for (long nn = n; nn > 0; nn -= MAX_INV * WORD_BITS)
{
word_t s_over_p[MAX_INV + 2];
long bits, words;
if (nn >= MAX_INV * WORD_BITS)
{
bits = MAX_INV * WORD_BITS;
words = MAX_INV;
}
else
{
bits = nn;
words = (nn - 1) / WORD_BITS + 1;
}
expander.expand(s_over_p, s, words);
word_t* next = s_over_p + words;
for (; bits >= WORD_BITS; bits -= WORD_BITS, next--)
{
word_t y = *next;
#if NUM_TABLES != 8 && NUM_TABLES != 4
for (long h = 0; h < NUM_TABLES; h++)
{
tables[h][y & TABLE_MASK] += x;
y >>= TABLE_LG_SIZE;
}
#else
tables[0][ y & TABLE_MASK] += x;
tables[1][(y >> TABLE_LG_SIZE ) & TABLE_MASK] += x;
tables[2][(y >> (2*TABLE_LG_SIZE)) & TABLE_MASK] += x;
tables[3][(y >> (3*TABLE_LG_SIZE)) & TABLE_MASK] += x;
#if NUM_TABLES == 8
tables[4][(y >> (4*TABLE_LG_SIZE)) & TABLE_MASK] += x;
tables[5][(y >> (5*TABLE_LG_SIZE)) & TABLE_MASK] += x;
tables[6][(y >> (6*TABLE_LG_SIZE)) & TABLE_MASK] += x;
tables[7][(y >> (7*TABLE_LG_SIZE)) & TABLE_MASK] += x;
#endif
#endif
x = RedcFast(x * B_to_km1_redc, p, pinv2);
}
if (x >= p)
x -= p;
word_t y = *next;
for (; bits > 0; bits--)
{
if (y & (((word_t) 1) << (WORD_BITS - 1)))
sum = SubMod(sum, x, p);
else
sum = AddMod(sum, x, p);
x = Redc(x * two_to_km1_redc, p, pinv2);
y <<= 1;
}
s = Redc(s * s_jump_redc, p, pinv2);
}
g_to_i = RedcFast(g_to_i * g_redc, p, pinv2);
g_to_km1_to_i = RedcFast(g_to_km1_to_i * g_to_km1_redc, p, pinv2);
}
long weights[TABLE_SIZE];
weights[0] = 0;
for (long h = 0, x = PowerMod(2, 3*ULONG_BITS/2, p, pinv);
h < TABLE_LG_SIZE; h++, x = Redc(x * two_to_km1_redc, p, pinv2))
{
for (long i = (1L << h) - 1; i >= 0; i--)
{
weights[2*i+1] = SubMod(weights[i], x, p);
weights[2*i] = AddMod(weights[i], x, p);
}
}
long x_jump = PowerMod(two_to_km1, TABLE_LG_SIZE, p, pinv);
long x_jump_redc = MulMod(x_jump, F, p, pinv);
for (long h = NUM_TABLES - 1, x = 1; h >= 0; h--)
{
for (long i = 0; i < TABLE_SIZE; i++)
{
long y;
y = RedcFast(tables[h][i], p, pinv2);
y = RedcFast(y * weights[i], p, pinv2);
y = RedcFast(y * x, p, pinv2);
sum += y;
}
x = Redc(x * x_jump_redc, p, pinv2);
}
return sum % p;
}
long _bern_modp_powg(long p, double pinv, long k)
{
Factorisation F(p-1);
long g = primitive_root(p, pinv, F);
long x = bernsum_powg(p, pinv, k, g);
long g_to_k = PowerMod(g, k, p, pinv);
long t = InvMod(p + 1 - g_to_k, p);
x = MulMod(x, t, p, pinv);
x = AddMod(x, x, p);
return x;
}
long _bern_modp_pow2(long p, double pinv, long k)
{
Factorisation F(p-1);
long g = primitive_root(p, pinv, F);
long n = order(2, p, pinv, F);
long x;
if (p < (1L << (ULONG_BITS/2 - 1)))
x = bernsum_pow2_redc(p, pinv, k, g, n);
else
x = bernsum_pow2(p, pinv, k, g, n);
long t = PowerMod(2, -k, p, pinv) - 1;
t = AddMod(t, t, p);
t = InvMod(t, p);
x = MulMod(x, t, p, pinv);
return x;
}
long _bern_modp(long p, double pinv, long k)
{
if (PowerMod(2, k, p, pinv) != 1)
return _bern_modp_pow2(p, pinv, k);
else
return _bern_modp_powg(p, pinv, k);
}
long bern_modp(long p, long k)
{
assert(k >= 0);
assert(2 <= p && p < NTL_SP_BOUND);
if (k == 0)
return 1;
if (k == 1)
{
if (p == 2)
return -1;
return (p-1)/2;
}
if (k & 1)
return 0;
if (p <= 3)
return -1;
long m = k % (p-1);
if (m == 0)
return -1;
double pinv = 1 / ((double) p);
long x = _bern_modp(p, pinv, m);
return MulMod(x, k, p, pinv);
}
};