Path: blob/master/sage/schemes/hyperelliptic_curves/hypellfrob/recurrences_zn_poly.cpp
4108 views
/* ============================================================================12recurrences_zn_poly.cpp: recurrences solved via zn_poly arithmetic34This file is part of hypellfrob (version 2.1.1).56Copyright (C) 2007, 2008, David Harvey78This program is free software; you can redistribute it and/or modify9it under the terms of the GNU General Public License as published by10the Free Software Foundation; either version 2 of the License, or11(at your option) any later version.1213This program is distributed in the hope that it will be useful,14but WITHOUT ANY WARRANTY; without even the implied warranty of15MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the16GNU General Public License for more details.1718You should have received a copy of the GNU General Public License along19with this program; if not, write to the Free Software Foundation, Inc.,2051 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.2122============================================================================ */232425#include "recurrences_zn_poly.h"262728NTL_CLIENT293031namespace hypellfrob {323334Shifter::~Shifter()35{36zn_array_mulmid_precomp1_clear(kernel_precomp);37free(input_twist);38}394041Shifter::Shifter(ulong d, ulong a, ulong b, const zn_mod_t mod)42{43this->d = d;44this->mod = mod;4546input_twist = (ulong*) malloc(sizeof(ulong) * (3*d + 3));47output_twist = input_twist + d + 1;48scratch = output_twist + d + 1;4950ZZ modulus;51modulus = zn_mod_get (mod);5253// ------------------------ compute input_twist -------------------------5455// prod = (d!)^(-1)56ulong prod = 1;57for (ulong i = 2; i <= d; i++)58prod = zn_mod_mul(prod, i, mod);59prod = to_ulong(InvMod(to_ZZ(prod), modulus));6061// input_twist[i] = ((d-i)!)^(-1)62input_twist[0] = prod;63for (ulong i = 1; i <= d; i++)64input_twist[i] = zn_mod_mul(input_twist[i-1], d - (i-1), mod);6566// input_twist[i] = ((d-i)!*i!)^(-1)67for (ulong i = 0; i <= d/2; i++)68{69input_twist[i] = zn_mod_mul(input_twist[i], input_twist[d-i], mod);70input_twist[d-i] = input_twist[i];71}7273// input_twist[i] = \prod_{0 <= j <= d, j != i} (i-j)^(-1)74// = (-1)^(d-i) ((d-i)!*i!)^(-1)75for (long i = d - 1; i >= 0; i -= 2)76input_twist[i] = zn_mod_neg(input_twist[i], mod);7778// ----------------- compute output_twist and kernel --------------------7980// need some temp space:81// c, accum, accum_inv each of length 2d+182ulong* c = (ulong*) malloc(sizeof(ulong) * (6*d+3));83ulong* accum = c + 2*d + 1;84ulong* accum_inv = accum + 2*d + 1;85ulong* kernel = c; // overwrites c8687// c[i] = c_i = a + (i-d)*b for 0 <= i <= 2d88c[0] = zn_mod_sub(a, zn_mod_mul(zn_mod_reduce(d, mod), b, mod), mod);89for (ulong i = 1; i <= 2*d; i++)90c[i] = zn_mod_add(c[i-1], b, mod);9192// accum[i] = c_0 * c_1 * ... * c_i for 0 <= i <= 2d93accum[0] = c[0];94for (ulong i = 1; i <= 2*d; i++)95accum[i] = zn_mod_mul(accum[i-1], c[i], mod);9697// accum_inv[i] = (c_0 * c_1 * ... * c_i)^(-1) for 0 <= i <= 2d98accum_inv[2*d] = to_ulong(InvMod(to_ZZ(accum[2*d]), modulus));99100for (long i = 2*d - 1; i >= 0; i--)101accum_inv[i] = zn_mod_mul(accum_inv[i+1], c[i+1], mod);102103// output_twist[i] = b^{-d} * c_i * c_{i+1} * ... * c_{i+d}104// for 0 <= i <= d105ulong factor = to_long(PowerMod(to_ZZ(b), -((long)d), modulus));106output_twist[0] = zn_mod_mul(factor, accum[d], mod);107for (ulong i = 1; i <= d; i++)108output_twist[i] = zn_mod_mul(zn_mod_mul(factor, accum[i+d], mod),109accum_inv[i-1], mod);110111// kernel[i] = (c_i)^(-1) for 0 <= i <= 2d112kernel[0] = accum_inv[0];113for (ulong i = 1; i <= 2*d; i++)114kernel[i] = zn_mod_mul(accum_inv[i], accum[i-1], mod);115116// precompute FFT of kernel117zn_array_mulmid_precomp1_init(kernel_precomp, kernel, 2*d+1, d+1, mod);118119free(c);120}121122123void Shifter::shift(ulong* output, const ulong* input)124{125// multiply inputs pointwise by input_twist126for (ulong i = 0; i <= d; i++)127scratch[i] = zn_mod_mul(input[i], input_twist[i], mod);128129// do middle product130zn_array_mulmid_precomp1_execute(output, scratch, kernel_precomp);131132// multiply outputs pointwise by output_twist133for (ulong i = 0; i <= d; i++)134output[i] = zn_mod_mul(output[i], output_twist[i], mod);135}136137138/*139Checks whether large_evaluate() will succeed for these choices of k and u.140Returns 1 if okay, otherwise 0.141*/142int check_params(ulong k, ulong u, const zn_mod_t mod)143{144ulong n = zn_mod_get (mod);145146if (k >= n || u >= n)147return 0;148149if (k <= 1)150return 1;151152if (k == n - 1)153return 0;154155ulong k2 = k / 2;156157// need the following elements to be invertible:158// u159// 1, 2, ..., k + 1160// k2 + i*u for -k2 <= i <= k2161ulong prod = u;162for (ulong i = 2; i <= k; i++)163prod = zn_mod_mul(prod, i, mod);164ulong temp = zn_mod_mul(k2, zn_mod_sub(1, u, mod), mod);165for (ulong i = 0; i <= 2*k2; i++)166{167prod = zn_mod_mul(prod, temp, mod);168temp = zn_mod_add(temp, u, mod);169}170171ZZ x, y;172x = prod;173y = n;174if (GCD(x, y) != 1)175return 0;176177// check recursively below178return check_params(k2, u, mod);179}180181182183/*184Let M0 and M1 be square matrices of size r*r. Let M(x) = M0 + x*M1; this185is a matrix of linear polys in x. The matrices M0 and M1 are passed in186row-major order.187188Let P(x) = M(x+1) M(x+2) ... M(x+k); this is a matrix of polynomials of189degree k.190191This class attempts to compute the matrices192P(0), P(u), P(2u), ..., P(ku).193194Usage:195196* Call the constructor.197198* Call evaluate() with half == 0. This computes the first half of the199outputs, specifically P(mu) for 0 <= m <= k2, where k2 = floor(k/2).200The (i, j) entry of P(mu) is stored in output[i*r+j][m+offset]. Each array201output[i*r+j] must be preallocated to length k2 + 3. (The extra two matrices202at the end are used for scratch space.)203204* Call evaluate() with half == 1. This computes the second half, namely205P(mu) for k2 + 1 <= m <= k. The (i, j) entry of P(mu) is stored in206output[i*r+j][m-(k2+1)+offset]. Each array output[i*r+j] must be207preallocated to length k2 + 3. It's okay for this second half to overwrite208the first half generated earlier (in fact this property is used by209zn_poly_interval_products to conserve memory).210211The computation may fail for certain bad choices of k and u. Let p = residue212characteristic (i.e. assume mod is p^m for some m). Typically this function213gets called for k ~ u and k*u ~ M*p, for some M much smaller than p. In this214situation, I expect most choices of (k, u) are not bad. Failure is very215bad: the program will crash. Luckily, you can call check_params() to test216whether (k, u) is bad. If it's bad, you probably should just increment u217until you find one that's not bad. (Sorry, I haven't done a proper analysis218of the situation yet.) The main routines fall back on the zz_pX version if219they can't find a good parameter choice.220221Must have 0 <= k < n and 0 < u < n.222223*/224LargeEvaluator::LargeEvaluator(int r, ulong k, ulong u,225const vector<vector<ulong> >& M0,226const vector<vector<ulong> >& M1,227const zn_mod_t& mod) : M0(M0), M1(M1), mod(mod)228{229assert(k < zn_mod_get(mod));230assert(u < zn_mod_get(mod));231assert(r >= 1);232assert(k >= 1);233234this->r = r;235this->k = k;236this->k2 = k / 2;237this->odd = k & 1;238this->u = u;239this->shifter = NULL;240}241242243LargeEvaluator::~LargeEvaluator()244{245if (this->shifter)246delete shifter;247}248249250void LargeEvaluator::evaluate(int half, vector<ulong_array>& output,251ulong offset)252{253// base cases254255if (k == 0)256{257// identity matrix258for (int x = 0; x < r; x++)259for (int y = 0; y < r; y++)260output[y*r + x].data[offset] = (x == y);261return;262}263264if (k == 1)265{266// evaluate M(1) (for half == 0) or M(u+1) (for half == 1)267for (int x = 0; x < r; x++)268for (int y = 0; y < r; y++)269{270ulong temp = zn_mod_add(M0[y][x], M1[y][x], mod);271output[y*r + x].data[offset] = half ?272zn_mod_add(temp, zn_mod_mul(u, M1[y][x], mod), mod) : temp;273}274return;275}276277// recursive case278279// Let Q(x) = M(x+1) ... M(x+k2), where k2 = floor(k/2).280// Then we have either281// P(x) = Q(x) Q(x+k2) if k is even282// P(x) = Q(x) Q(x+k2) M(x+k) if k is odd283284if (half == 0)285{286// This is the first time through this function.287// Allocate scratch space.288scratch.resize(r*r);289for (int i = 0; i < r*r; i++)290scratch[i].resize(k2 + 3);291292{293// Recursively compute Q(0), Q(u), ..., Q(k2*u).294// These are stored in scratch.295LargeEvaluator recurse(r, k2, u, M0, M1, mod);296recurse.evaluate_all(scratch);297}298299// Precomputations for value-shifting300shifter = new Shifter(k2, k2, u, mod);301}302else // half == 1303{304// Shift original sequence by (k2+1)*u to obtain305// Q((k2+1)*u), Q((k2+2)*u), ..., Q((2*k2+1)*u)306Shifter big_shifter(k2, zn_mod_mul(k2+1, u, mod), u, mod);307for (int i = 0; i < r*r; i++)308big_shifter.shift(scratch[i].data, scratch[i].data);309}310311// Let H = (k2+1)*u*half, so now scratch contains312// Q(H), Q(H + u), ..., Q(H + k2*u).313314// Shift by k2 to obtain Q(H + k2), Q(H + u + k2), ..., Q(H + k2*u + k2).315// Results are stored directly in output array. We put them one slot316// to the right, to make room for inplace matrix multiplies later on.317// (If k is odd, they're shifted right by yet another slot, to make room318// for another round of multiplies.)319for (int i = 0; i < r*r; i++)320shifter->shift(output[i].data + offset + odd + 1, scratch[i].data);321322// If k is odd, right-multiply each Q(H + i*u + k2) by M(H + i*u + k)323// (results are stored in output array, shifted one entry to the left)324if (odd)325{326ulong_array cruft(r*r); // for storing M(H + i*u + k)327ulong point = k; // evaluation point328if (half)329point = zn_mod_add(point, zn_mod_mul(k2 + 1, u, mod), mod);330331for (ulong i = 0; i <= k2; i++, point = zn_mod_add(point, u, mod))332{333// compute M(H + i*u + k) = M0 + M1*point334for (int x = 0; x < r; x++)335for (int y = 0; y < r; y++)336cruft.data[y*r + x] = zn_mod_add(M0[y][x],337zn_mod_mul(M1[y][x], point, mod), mod);338339// multiply340for (int x = 0; x < r; x++)341for (int y = 0; y < r; y++)342{343ulong accum = 0;344for (int z = 0; z < r; z++)345accum = zn_mod_add(accum,346zn_mod_mul(output[y*r + z].data[offset + i + 2],347cruft.data[z*r + x], mod), mod);348output[y*r + x].data[offset + i + 1] = accum;349}350}351}352353ulong n = zn_mod_get(mod);354355// Multiply to obtain P(H), P(H + u), ..., P(H + k2*u)356// (except for the last one, in the second half, if k is even)357// Store results directly in output array.358for (ulong i = 0; i + (half && !odd) <= k2; i++)359for (int x = 0; x < r; x++)360for (int y = 0; y < r; y++)361{362ulong sum_hi = 0;363ulong sum_lo = 0;364for (int z = 0; z < r; z++)365{366ulong hi, lo;367ZNP_MUL_WIDE(hi, lo, scratch[y*r + z].data[i],368output[z*r + x].data[offset + i + 1]);369ZNP_ADD_WIDE(sum_hi, sum_lo, sum_hi, sum_lo, hi, lo);370if (sum_hi >= n)371sum_hi -= n;372}373output[y*r + x].data[offset + i] =374zn_mod_reduce_wide(sum_hi, sum_lo, mod);375}376}377378379/*380Evaluates both halves, storing results in output array.381*/382void LargeEvaluator::evaluate_all(vector<ulong_array>& output)383{384evaluate(0, output, 0);385evaluate(1, output, k/2 + 1);386}387388389390/*391See interval_products_wrapper().392393NOTE 1:394I haven't proved that the algorithm here always succeeds, although I expect395that it almost always will, especially when p is very large. If it396succeeds, it returns 1. If it fails, it returns 0 (practically instantly).397In the latter case, at least the caller can fall back on398ntl_interval_products().399400NOTE 2:401The algorithm here is similar to ntl_interval_products(). However, it402doesn't do the "refining step" -- it just handles the smaller intervals403in the naive fashion. Also, instead of breaking intervals into power-of-four404lengths, it just does the whole thing in one chunk. The performance ends up405being smoother, but it's harder to prove anything about avoiding406non-invertible elements. Hence the caveat in Note 1.407408*/409int zn_poly_interval_products(vector<vector<vector<ulong> > >& output,410const vector<vector<ulong> >& M0,411const vector<vector<ulong> >& M1,412const vector<ZZ>& target, const zn_mod_t& mod)413{414output.resize(target.size() / 2);415416// select k such that k*(k+1) >= total length of intervals417ulong k;418{419ZZ len = target.back() - target.front();420ZZ kk = SqrRoot(len);421k = to_ulong(kk);422if (kk * (kk + 1) < len)423k++;424}425426int r = M0.size();427428// try to find good parameters u and k429ulong u = k;430for (int trial = 0; ; trial++, u++)431{432if (check_params(k, u, mod))433break; // found some good parameters434if (trial == 5)435return 0; // too many failures, give up436}437438ulong n = zn_mod_get(mod);439440// shift M0 over to account for starting index441vector<vector<ulong> > M0_shifted = M0;442ulong shift = target.front() % n;443for (int x = 0; x < r; x++)444for (int y = 0; y < r; y++)445M0_shifted[y][x] = zn_mod_add(M0[y][x],446zn_mod_mul(shift, M1[y][x], mod), mod);447448// prepare for evaluating products over the big intervals449// [0, k), [u, u+k), ..., [ku, ku+k)450LargeEvaluator evaluator(r, k, u, M0_shifted, M1, mod);451452vector<ulong_array> big(r*r);453for (int i = 0; i < r*r; i++)454big[i].resize(k/2 + 3); // space for half the products, plus two more455456// evaluate the first half of the products457evaluator.evaluate(0, big, 0);458// flag indicating which half we currently have stored in the "big" array459int half = 0;460461vector<vector<ulong> > accum(r, vector<ulong>(r));462vector<vector<ulong> > temp1(r, vector<ulong>(r));463vector<vector<ulong> > temp2(r, vector<ulong>(r));464465// for each target interval....466for (int i = 0; i < target.size()/2; i++)467{468// doing interval [s0, s1)469ZZ s0 = target[2*i];470ZZ s1 = target[2*i+1];471472// product accumulated so far is [t0, t1).473ZZ t0 = s0;474ZZ t1 = s0;475for (int x = 0; x < r; x++)476for (int y = 0; y < r; y++)477accum[x][y] = (x == y); // identity matrix478479while (t1 < s1)480{481// if we are exactly on the left end of a big interval, and the482// big interval fits inside the target interval, then roll it in483if ((t1 - target[0]) % u == 0 && t1 + k <= s1)484{485// compute which "big" interval we are rolling in486int index = to_ulong((t1 - target[0]) / u);487488if (index >= k/2 + 1)489{490// if the "big" interval is in the second half, and we haven't491// computed the second half of the intervals yet, then go and492// compute them (overwriting the first half)493if (half == 0)494{495evaluator.evaluate(1, big, 0);496half = 1;497}498index -= (k/2 + 1);499}500501for (int x = 0; x < r; x++)502for (int y = 0; y < r; y++)503temp1[y][x] = big[y*r + x].data[index];504505t1 += k;506}507else508{509// otherwise just multiply by the single matrix M(t1 + 1)510ulong e = (t1 + 1) % n;511512for (int x = 0; x < r; x++)513for (int y = 0; y < r; y++)514temp1[y][x] = zn_mod_add(M0[y][x],515zn_mod_mul(M1[y][x], e, mod), mod);516517t1++;518}519520// multiply by whichever matrix we picked up above521for (int x = 0; x < r; x++)522for (int y = 0; y < r; y++)523{524ulong sum_hi = 0;525ulong sum_lo = 0;526for (int z = 0; z < r; z++)527{528ulong hi, lo;529ZNP_MUL_WIDE(hi, lo, accum[y][z], temp1[z][x]);530ZNP_ADD_WIDE(sum_hi, sum_lo, sum_hi, sum_lo, hi, lo);531if (sum_hi >= n)532sum_hi -= n;533}534temp2[y][x] = zn_mod_reduce_wide(sum_hi, sum_lo, mod);535}536537accum.swap(temp2);538}539540// store result in output array541output[i] = accum;542}543544return 1;545}546547548}; // namespace hypellfrob549550551// ----------------------- end of file552553554