#include <gmp.h>
#include <NTL/ZZ.h>
#include <cmath>
#include <vector>
#include <set>
#include "bern_modp_util.h"
#include "bern_modp.h"
#include "bern_rat.h"
#ifdef USE_THREADS
#include <pthread.h>
#endif
using namespace std;
using namespace NTL;
namespace bernmm {
void bern_den(mpz_t res, long k, const PrimeTable& table)
{
mpz_set_ui(res, 1);
for (long f = 1; f*f <= k; f++)
{
if (k % f == 0)
{
if (table.is_prime(f + 1))
mpz_mul_ui(res, res, f + 1);
if (f*f != k)
if (table.is_prime(k/f + 1))
mpz_mul_ui(res, res, k/f + 1);
}
}
}
#define BLOCK_SIZE 1000
struct Item
{
mpz_t modulus;
mpz_t residue;
Item()
{
mpz_init(modulus);
mpz_init(residue);
}
~Item()
{
mpz_clear(residue);
mpz_clear(modulus);
}
};
struct Item_cmp
{
bool operator()(const Item* x, const Item* y)
{
return mpz_cmp(x->modulus, y->modulus) < 0;
}
};
Item* CRT(Item* op1, Item* op2)
{
Item* res = new Item;
mpz_invert(res->modulus, op1->modulus, op2->modulus);
mpz_mul(res->modulus, res->modulus, op1->modulus);
mpz_sub(res->residue, op2->residue, op1->residue);
mpz_mul(res->residue, res->residue, res->modulus);
mpz_add(res->residue, res->residue, op1->residue);
mpz_mul(res->modulus, op1->modulus, op2->modulus);
mpz_mod(res->residue, res->residue, res->modulus);
return res;
}
struct State
{
long k;
long bound;
const PrimeTable* table;
long next;
std::set<Item*, Item_cmp> items;
#ifdef USE_THREADS
pthread_mutex_t lock;
#endif
State(long k, long bound, const PrimeTable& table)
{
this->k = k;
this->bound = bound;
this->next = 0;
this->table = &table;
#ifdef USE_THREADS
pthread_mutex_init(&lock, NULL);
#endif
}
~State()
{
#ifdef USE_THREADS
pthread_mutex_destroy(&lock);
#endif
}
};
void* worker(void* arg)
{
State& state = *((State*) arg);
long k = state.k;
#ifdef USE_THREADS
pthread_mutex_lock(&state.lock);
#endif
while (1)
{
if (state.next * BLOCK_SIZE < state.bound)
{
long next = state.next++;
#ifdef USE_THREADS
pthread_mutex_unlock(&state.lock);
#endif
Item* item = new Item;
mpz_set_ui(item->modulus, 1);
mpz_set_ui(item->residue, 0);
for (long p = max(5, state.table->next_prime(next * BLOCK_SIZE));
p < state.bound && p < (next+1) * BLOCK_SIZE;
p = state.table->next_prime(p))
{
if (k % (p-1) == 0)
continue;
long b = bern_modp(p, k);
long x = MulMod(SubMod(b, mpz_fdiv_ui(item->residue, p), p),
InvMod(mpz_fdiv_ui(item->modulus, p), p), p);
mpz_addmul_ui(item->residue, item->modulus, x);
mpz_mul_ui(item->modulus, item->modulus, p);
}
#ifdef USE_THREADS
pthread_mutex_lock(&state.lock);
#endif
state.items.insert(item);
}
else
{
if (state.items.size() <= 1)
{
#ifdef USE_THREADS
pthread_mutex_unlock(&state.lock);
#endif
return NULL;
}
Item* item1 = *(state.items.begin());
state.items.erase(state.items.begin());
Item* item2 = *(state.items.begin());
state.items.erase(state.items.begin());
#ifdef USE_THREADS
pthread_mutex_unlock(&state.lock);
#endif
Item* item3 = CRT(item1, item2);
delete item1;
delete item2;
#ifdef USE_THREADS
pthread_mutex_lock(&state.lock);
#endif
state.items.insert(item3);
}
}
}
void bern_rat(mpq_t res, long k, int num_threads)
{
if (k == 0)
{
mpq_set_ui(res, 1, 1);
return;
}
if (k == 1)
{
mpq_set_si(res, -1, 2);
return;
}
if (k == 2)
{
mpq_set_si(res, 1, 6);
return;
}
if (k & 1)
{
mpq_set_ui(res, 0, 1);
return;
}
if (num_threads <= 0)
num_threads = 1;
mpz_t num, den;
mpz_init(num);
mpz_init(den);
const double log2 = 0.69314718055994528622676;
const double invlog2 = 1.44269504088896340735992;
long bound1 = (long) max(37.0, ceil((k + 0.5) * log(k) * invlog2));
PrimeTable table(bound1);
bern_den(den, k, table);
long bits = (long) ceil((k + 0.5) * log(k) * invlog2 - 4.094 * k + 2.470
+ log(mpz_get_d(den)) * invlog2);
double prod = 1.0;
long prod_bits = 0;
long p;
for (p = 5; prod_bits < bits + 1; p = table.next_prime(p))
{
if (p >= NTL_SP_BOUND)
abort();
if (k % (p-1) != 0)
prod *= (double) p;
int exp;
prod = frexp(prod, &exp);
prod_bits += exp;
}
long bound2 = p;
State state(k, bound2, table);
#ifdef USE_THREADS
vector<pthread_t> threads(num_threads - 1);
pthread_attr_t attr;
pthread_attr_init(&attr);
#ifdef THREAD_STACK_SIZE
pthread_attr_setstacksize(&attr, THREAD_STACK_SIZE * 1024);
#endif
for (long i = 0; i < num_threads - 1; i++)
pthread_create(&threads[i], &attr, worker, &state);
#endif
worker(&state);
#ifdef USE_THREADS
for (long i = 0; i < num_threads - 1; i++)
pthread_join(threads[i], NULL);
#endif
pthread_attr_destroy (&attr);
Item* item = *(state.items.begin());
mpz_mul(num, item->residue, den);
mpz_mod(num, num, item->modulus);
if (k % 4 == 0)
{
mpz_sub(num, item->modulus, num);
mpz_neg(num, num);
}
delete item;
mpz_swap(num, mpq_numref(res));
mpz_swap(den, mpq_denref(res));
mpz_clear(num);
mpz_clear(den);
}
};