Path: blob/master/sage/schemes/hyperelliptic_curves/hypellfrob/recurrences_ntl.cpp
4108 views
/* ============================================================================12recurrences_ntl.cpp: recurrences solved via NTL polynomial arithmetic34This file is part of hypellfrob (version 2.1.1).56Copyright (C) 2007, 2008, David Harvey78This program is free software; you can redistribute it and/or modify9it under the terms of the GNU General Public License as published by10the Free Software Foundation; either version 2 of the License, or11(at your option) any later version.1213This program is distributed in the hope that it will be useful,14but WITHOUT ANY WARRANTY; without even the implied warranty of15MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the16GNU General Public License for more details.1718You should have received a copy of the GNU General Public License along19with this program; if not, write to the Free Software Foundation, Inc.,2051 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.2122============================================================================ */232425#include <NTL/ZZ_pX.h>26#include <NTL/mat_ZZ_p.h>27#include <NTL/lzz_pX.h>28#include <NTL/mat_lzz_p.h>29#include <cassert>30#include "recurrences_ntl.h"313233NTL_CLIENT343536namespace hypellfrob {373839/* ============================================================================4041Some template stuff4243The matrix evaluation code is templated so that it can work over either ZZ_p44or zz_p. There are several template parameters, which can have two settings:4546SCALAR: ZZ_p zz_p47POLY: ZZ_pX zz_pX48VECTOR: vec_ZZ_p vec_zz_p49MATRIX: mat_ZZ_p mat_zz_p50POLYMODULUS: ZZ_pXModulus zz_pXModulus51FFTREP: FFTRep fftRep5253For the most part NTL uses the same function names for both columns, which54makes life easy, but there's a few I needed to define explicitly:5556to_scalar() convert a ZZ or int into a SCALAR57forward_fft() runs a forward FFT (either ToFFTRep() or TofftRep())58inverse_fft() runs an inverse FFT (either FromFFTRep() or FromfftRep())5960============================================================================ */616263// to_scalar(ZZ)6465template <typename SCALAR> SCALAR to_scalar(const ZZ& input);6667template<> inline ZZ_p to_scalar<ZZ_p>(const ZZ& input)68{69return to_ZZ_p(input);70}7172template<> inline zz_p to_scalar<zz_p>(const ZZ& input)73{74return to_zz_p(input);75}767778// to_scalar(int)7980template <typename SCALAR> SCALAR to_scalar(int input);8182template<> inline ZZ_p to_scalar<ZZ_p>(int input)83{84return to_ZZ_p(input);85}8687template<> inline zz_p to_scalar<zz_p>(int input)88{89return to_zz_p(input);90}919293// forward_fft9495template <typename POLY, typename FFTREP>96void forward_fft(FFTREP& y, const POLY& x, long k, long lo, long hi);9798template<> inline void99forward_fft<ZZ_pX, FFTRep>(FFTRep& y, const ZZ_pX& x, long k, long lo, long hi)100{101ToFFTRep(y, x, k, lo, hi);102}103104template<> inline void105forward_fft<zz_pX, fftRep>(fftRep& y, const zz_pX& x, long k, long lo, long hi)106{107TofftRep(y, x, k, lo, hi);108}109110111// inverse_fft112113template <typename POLY, typename FFTREP>114void inverse_fft(POLY& x, FFTREP& y, long lo, long hi);115116template<> inline void117inverse_fft<ZZ_pX, FFTRep>(ZZ_pX& x, FFTRep& y, long lo, long hi)118{119FromFFTRep(x, y, lo, hi);120}121122template<> inline void123inverse_fft<zz_pX, fftRep>(zz_pX& x, fftRep& y, long lo, long hi)124{125FromfftRep(x, y, lo, hi);126}127128129130/* ============================================================================131132Dyadic evaluation stuff133134This section essentially implements Theorem 8 of [BGS], for the particular135case of the parameters that are used in Theorem 15.136137============================================================================ */138139140/*141Assume that f has degree d = 2^n and that g has degree 2d.142143This function computes a polynomial h whose x^d through x^{2d} coefficients144(inclusive) are the same as those of f*g. The bottom d coefficients of h145will be junk.146147The parameter g_fft should be the precomputed length 2d FFT of g.148149The algorithm in this function is based on the paper "The middle product150algorithm" by Hanrot, Quercia and Zimmermann. (Many thanks to Victor Shoup151for writing such wonderfully modular FFT code. It really made my day.)152153*/154template <typename SCALAR, typename POLY, typename FFTREP>155void middle_product(POLY& h, const POLY& f,156const POLY& g, const FFTREP& g_fft, int n)157{158int d = 1 << n;159h.rep.SetLength(2*d + 1);160161FFTREP f_fft(INIT_SIZE, n+1);162163// Compute length 2d cyclic convolutions of f and g, letting the top164// third of f*g "wrap around" to the bottom half of the output.165forward_fft<POLY, FFTREP>(f_fft, f, n+1, 0, 2*d);166mul(f_fft, f_fft, g_fft);167inverse_fft<POLY, FFTREP>(h, f_fft, 0, 2*d);168169// Need to correct for the x^{2d} term of g which got wrapped around to the170// constant term.171h.rep[d] -= g.rep[2*d] * f.rep[d];172173// Now h contains terms x^d through x^{2d-1} of f*g.174// To finish off, we just need the x^{2d} term.175SCALAR temp;176SCALAR& sum = h.rep[2*d];177sum = 0;178for (int i = 0; i <= d; i++)179{180mul(temp, f.rep[i], g.rep[2*d-i]);181add(sum, sum, temp);182}183}184185186187/*188This struct stores precomputed information that can then be used to shift189evaluation values of a polynomial F(x) of degree d = 2^n.190191Specifically, given the values192F(0), F(b), F(2*b), ..., F(d*b),193the shift() method computes194F(a), F(a + b), F(a + 2*b), ..., F(a + d*b).195196PRECONDITIONS:197n >= 11981, 2, ..., d + 1 are invertible199a + i*b are invertible for -d <= i <= d200201*/202template <typename SCALAR, typename POLY, typename VECTOR, typename FFTREP>203struct DyadicShifter204{205int d, n;206207// input_twist is a vector of length d/2 + 1.208// The i-th entry is \prod_{0 <= j <= d, j != i} (i-j)^(-1).209VECTOR input_twist;210211// output_twist is a vector of length d+1.212// The i-th entry is b^(-d) \prod_{0 <= j <= d} (a + (i-j)*b).213VECTOR output_twist;214215// kernel is a polynomial of degree 2d.216// The coefficients are (a + k*b)^(-1) for -d <= k <= d.217// We also store its length 2d FFT.218POLY kernel;219FFTREP kernel_fft;220221// Polynomials for scratch space in shift()222POLY scratch, scratch2;223224// Constructor (performs various precomputations)225DyadicShifter(int n, const SCALAR& a, const SCALAR& b)226{227assert(n >= 1);228this->n = n;229d = 1 << n;230231// ------------------------ compute input_twist -------------------------232233input_twist.SetLength(d/2 + 1);234235// prod = (d!)^(-1)236SCALAR prod;237prod = 1;238for (int i = 2; i <= d; i++)239mul(prod, prod, i);240prod = 1 / prod;241242// input_twist[i] = ((d-i)!)^(-1)243input_twist[0] = prod;244for (int i = 1; i <= d/2; i++)245mul(input_twist[i], input_twist[i-1], d-(i-1));246247// input_twist[i] = ((d-i)!*i!)^(-1)248prod = input_twist[d/2];249for (int i = d/2; i >= 0; i--)250{251mul(input_twist[i], input_twist[i], prod);252mul(prod, prod, i);253}254255// input_twist[i] = \prod_{0 <= j <= d, j != i} (i-j)^(-1) :-)256for (int i = 1; i <= d/2; i += 2)257NTL::negate(input_twist[i], input_twist[i]);258259// ----------------- compute output_twist and kernel --------------------260261output_twist.SetLength(d+1);262263// c[i] = c_i = a + (i-d)*b for 0 <= i <= 2d264VECTOR c;265c.SetLength(2*d+1);266c[0] = a - d*b;267for (int i = 1; i <= 2*d; i++)268add(c[i], c[i-1], b);269270// accum[i] = c_0 * c_1 * ... * c_i for 0 <= i <= 2d271VECTOR accum;272accum.SetLength(2*d+1);273accum[0] = c[0];274for (int i = 1; i <= 2*d; i++)275mul(accum[i], accum[i-1], c[i]);276277// accum_inv[i] = (c_0 * c_1 * ... * c_i)^(-1) for 0 <= i <= 2d278VECTOR accum_inv;279accum_inv.SetLength(2*d+1);280accum_inv[2*d] = 1 / accum[2*d];281for (int i = 2*d-1; i >= 0; i--)282mul(accum_inv[i], accum_inv[i+1], c[i+1]);283284// kernel[i] = (c_i)^(-1) for 0 <= i <= 2d285kernel.rep.SetLength(2*d+1);286kernel.rep[0] = accum_inv[0];287for (int i = 1; i <= 2*d; i++)288mul(kernel.rep[i], accum_inv[i], accum[i-1]);289290// precompute transform of kernel291forward_fft<POLY, FFTREP>(kernel_fft, kernel, n+1, 0, 2*d);292293// output_twist[i] = b^{-d} * c_i * c_{i+1} * ... * c_{i+d}294// for 0 <= i <= d295SCALAR factor = power(b, -d);296SCALAR temp;297output_twist.SetLength(d+1);298output_twist[0] = factor * accum[d];299for (int i = 1; i <= d; i++)300{301mul(temp, factor, accum[i+d]);302mul(output_twist[i], temp, accum_inv[i-1]);303}304}305306307// Shifts evaluation values as described above.308// Assumes both output and input have length d + 1.309void shift(VECTOR& output, const VECTOR& input)310{311assert(input.length() == d+1);312assert(output.length() == d+1);313314// multiply inputs pointwise by input_twist315scratch.rep.SetLength(d+1);316for (int i = 0; i <= d/2; i++)317mul(scratch.rep[i], input[i], input_twist[i]);318for (int i = 1; i <= d/2; i++)319mul(scratch.rep[i+d/2], input[i+d/2], input_twist[d/2-i]);320321middle_product<SCALAR, POLY, FFTREP>(scratch2, scratch, kernel,322kernel_fft, n);323324// multiply outputs pointwise by output_twist325for (int i = 0; i <= d; i++)326mul(output[i], scratch2.rep[i+d], output_twist[i]);327}328};329330331332/*333Let M0 and M1 be square matrices of size n*n. Let M(x) = M0 + x*M1; this is a334matrix of linear polys in x. Let P(x) = M(x+1) M(x+2) ... M(x+2^s); this is a335matrix of polynomials of degree 2^s. This function computes the values336P(a), P(a + 2^t), P(a + 2*2^t), ..., P(a + 2^s*2^t).337338The output array should have length n^2. Each entry should be a vector of339length 2^s+1, pre-initialised to all zeroes. The (y*n + x)-th vector will be340the values of the (y, x) entries of the above list of matrices. (This data341format is optimised for the case that 2^s+1 is much larger than n.)342343PRECONDITIONS:3440 <= s <= t3452, 3, ..., 2^t + 1 must be invertible346347*/348template <typename SCALAR, typename POLY, typename VECTOR,349typename MATRIX, typename FFTREP>350void dyadic_evaluation(vector<VECTOR>& output,351const MATRIX& M0, const MATRIX& M1,352int s, int t, const SCALAR& a)353{354int n = M0.NumRows();355356// base cases; just evaluate naively357if (s <= 1)358{359MATRIX X[3];360361if (s == 0)362{363X[0] = M0 + (a+1) * M1;364X[1] = M0 + (a+1 + (1 << t)) * M1;365}366else367{368for (int i = 0; i <= 2; i++)369X[i] = (M0 + (a+1 + (i << t)) * M1) * (M0 + (a+2 + (i << t)) * M1);370}371372for (int x = 0; x < n; x++)373for (int y = 0; y < n; y++)374for (int i = 0; i < output[0].length(); i++)375output[y*n + x][i] = X[i][y][x];376377return;378}379380// General case.381// Let Q(x) = M(x+1) M(x+2) ... M(x+2^(s-1)).382383// Recursively compute Q(a), Q(a + 2^t), ..., Q(a + 2^(s-1)*2^t).384vector<VECTOR> X(n*n);385for (int i = 0; i < n*n; i++)386X[i].SetLength((1 << (s-1)) + 1);387dyadic_evaluation<SCALAR, POLY, VECTOR, MATRIX, FFTREP>388(X, M0, M1, s-1, t, a);389390// Do precomputations for shifting by 2^(s-1) and by (2^(s-1)+1)*2^t391SCALAR c, b;392c = 1 << (s-1);393b = 1 << t;394DyadicShifter<SCALAR, POLY, VECTOR, FFTREP> shifter1(s-1, c, b);395DyadicShifter<SCALAR, POLY, VECTOR, FFTREP> shifter2(s-1, (c + 1) * b, b);396397// Shift by 2^(s-1) to obtain398// Q(a + 2^(s-1)), Q(a + 2^t + 2^(s-1)), ..., Q(a + 2^(s-1)*2^t + 2^(s-1))399vector<VECTOR> Y(n*n);400for (int i = 0; i < n*n; i++)401{402Y[i].SetLength((1 << (s-1)) + 1);403shifter1.shift(Y[i], X[i]);404}405406// Multiply matrices to obtain407// P(a), P(a + 2^t), ..., P(a + 2^(s-1)*2^t).408SCALAR temp;409for (int i = 0; i <= (1 << (s-1)); i++)410for (int x = 0; x < n; x++)411for (int y = 0; y < n; y++)412for (int z = 0; z < n; z++)413{414mul(temp, X[y*n + z][i], Y[z*n + x][i]);415output[y*n + x][i] += temp;416}417418// Shift original sequence by (2^(s-1)+1)*2^t to obtain419// Q(a + (2^(s-1)+1)*2^t), Q(a + (2^(s-1)+2)*2^t), ..., Q(a + (2^s+1)*2^t).420for (int i = 0; i < n*n; i++)421shifter2.shift(Y[i], X[i]);422423// Shift again by 2^(s-1) to obtain424// Q(a + (2^(s-1)+1)*2^t + 2^(s-1)), Q(a + (2^(s-1)+2)*2^t + 2^(s-1)), ...,425// Q(a + (2^s+1)*2^t + 2^(s-1)).426for (int i = 0; i < n*n; i++)427shifter1.shift(X[i], Y[i]);428429// Multiply matrices to obtain430// P(a + (2^(s-1)+1)*2^t), P(a + (2^(s-1)+2)*2^t), ..., P(a + (2^s+1)*2^t).431// (we throw out the last one since it's surplus to requirements)432for (int i = 0; i < (1 << (s-1)); i++)433for (int x = 0; x < n; x++)434for (int y = 0; y < n; y++)435for (int z = 0; z < n; z++)436{437mul(temp, Y[y*n + z][i], X[z*n + x][i]);438output[y*n + x][i + (1 << (s-1)) + 1] += temp;439}440}441442443444/* ============================================================================445446General evaluation stuff447448This section essentially implements Corollary 10 of [BGS].449450============================================================================ */451452453/*454This struct stores the product tree associated to a vector a[0], ..., a[n-1].455456The top node stores the polynomial product457(x - a[0]) ... (x - a[n-1]).458The two children nodes store459(x - a[0]) ... (x - a[m-1])460and461(x - a[m]) ... (x - a[n-1])462where m = floor(n/2). This continues recursively until we reach n = 1,463in which case just the polynomial x - a[0] is stored, and no children.464465*/466template <typename SCALAR, typename POLY, typename VECTOR>467struct ProductTree468{469// polynomial product stored at this node470POLY poly;471472// children for left and right halves, if deg(poly) > 1473ProductTree* child1;474ProductTree* child2;475476// These are temp polys used by the Evaluator and Interpolator classes.477// It's not very hygienic to keep them here... but it makes things more478// efficient, because we need two temps for each node, and this prevent479// unnecessary reallocations. (The lengths will be the same on repeated480// calls to evaluate() and interpolate().)481POLY scratch1, scratch2;482483// Constructs product tree for the supplied vector.484ProductTree(const VECTOR& points)485{486build(points, 0, points.length());487}488489ProductTree(const VECTOR& points, int start, int end)490{491build(points, start, end);492}493494// Constructs product tree recursively for the subset [start, end) of495// the supplied vector.496void build(const VECTOR& points, int start, int end)497{498assert(end - start >= 1);499assert(start >= 0);500assert(end <= points.length());501502if (end - start == 1)503{504SetCoeff(poly, 1, 1);505SetCoeff(poly, 0, -points[start]);506}507else508{509int m = (end - start) / 2;510child1 = new ProductTree(points, start, start + m);511child2 = new ProductTree(points, start + m, end);512mul(poly, child1->poly, child2->poly);513}514}515516~ProductTree()517{518if (deg(poly) > 1)519{520delete child1;521delete child2;522}523}524};525526527528/*529Given a list of evaluation points a[0], ..., a[n-1], this struct stores some530precomputed information to permit evaluating an arbitrary polynomial at those531points.532*/533template <typename SCALAR, typename POLY,534typename POLYMODULUS, typename VECTOR>535struct Evaluator536{537// The product tree for the evaluation points538ProductTree<SCALAR, POLY, VECTOR>* tree;539540// A list of NTL ZZ_pXModulus/zz_pXModulus objects corresponding to the541// polynomials in the product tree, in the order that they get used as the542// tree is traversed in recursive_evaluate().543vector<POLYMODULUS> moduli;544545// Constructs evaluator object for the given list of evaluation points546Evaluator(const VECTOR& points)547{548assert(points.length() >= 1);549tree = new ProductTree<SCALAR, POLY, VECTOR>(points);550moduli.reserve(2*points.length());551build(tree);552assert(moduli.size() <= 2*points.length());553}554555// Compute modulus objects for each polynomial under the supplied node of556// the product tree; appends them in traversal order to "moduli".557void build(const ProductTree<SCALAR, POLY, VECTOR>* node)558{559if (deg(node->poly) > 1)560{561moduli.push_back(POLYMODULUS(node->poly));562build(node->child1);563build(node->child2);564}565}566567~Evaluator()568{569delete tree;570}571572// Evaluates the input polynomial at the evaluation points, writes the573// results to output. The output array must have the correct length.574void evaluate(VECTOR& output, const POLY& input)575{576recursive_evaluate(output, input, tree, 0, 0);577}578579// Evaluates the input polynomial at the subset [start, end) of the580// evaluation points, which should correspond to the supplied product tree581// node. (The length of the interval is implied by the degree of the poly582// at that node.) Writes the output to the subset [start, end) of the583// output array. The index parameter indicates which modulus in "moduli"584// to use for this node of the tree. The return value is the index for585// the modulus that should be used immediately after this call.586int recursive_evaluate(VECTOR& output, const POLY& input,587ProductTree<SCALAR, POLY, VECTOR>* node,588int start, int index)589{590if (deg(node->poly) == 1)591{592eval(output[start], input, -coeff(node->poly, 0));593}594else595{596rem(node->scratch1, input, moduli[index++]);597index = recursive_evaluate(output, node->scratch1, node->child1,598start, index);599index = recursive_evaluate(output, node->scratch1, node->child2,600start + deg(node->child1->poly), index);601}602return index;603}604};605606607/*608Given an integer L >= 1, this struct does some precomputations to permit609interpolating a polynomial whose values at 0, 1, ..., L are known.610611PRECONDITIONS:6121, 2, ..., L must be invertible.613614*/615template <typename SCALAR, typename POLY, typename VECTOR>616struct Interpolator617{618ProductTree<SCALAR, POLY, VECTOR>* tree;619int L;620621// input_twist is a vector of length L+1.622// The i-th entry is \prod_{0 <= j <= L, j != i} (i-j)^(-1).623VECTOR input_twist;624625// vector of length L+1, used in interpolate()626VECTOR temp;627628// Performs various precomputations for the given L.629Interpolator(int L)630{631this->L = L;632temp.SetLength(L+1);633634// Build a product tree for the evaluation points635for (int i = 0; i <= L; i++)636temp[i] = i;637tree = new ProductTree<SCALAR, POLY, VECTOR>(temp);638639// prod = (L!)^(-1)640SCALAR prod;641prod = 1;642for (int i = 2; i <= L; i++)643mul(prod, prod, i);644prod = 1 / prod;645646// input_twist[i] = (i!)^(-1), 0 <= i <= L647input_twist.SetLength(L+1);648input_twist[L] = prod;649for (int i = L; i >= 1; i--)650mul(input_twist[i-1], input_twist[i], i);651652// input_twist[i] = \prod_{0 <= j <= L, j != i} (i-j)^(-1).653for (int i = 0; i <= L/2; i++)654{655mul(input_twist[i], input_twist[i], input_twist[L-i]);656input_twist[L-i] = input_twist[i];657}658for (int i = L-1; i >= 0; i -= 2)659NTL::negate(input_twist[i], input_twist[i]);660}661662~Interpolator()663{664delete tree;665}666667668// Returns the polynomial669// \sum_{i=start}^{end-1} values[i] * (x-start) (x-start+1) ... (x-end-1)670// where [start, end) is the interval associated to the supplied product671// tree node, and where the (x-i) term is omitted in each product.672void combine(POLY& output, const VECTOR& values,673ProductTree<SCALAR, POLY, VECTOR>* node, int start)674{675if (deg(node->poly) == 1)676{677// base case678clear(output);679SetCoeff(output, 0, values[start]);680}681else682{683// recursively build up from two halves684// i.e. if f1, f2 are the results of "combine" for the two halves,685// and if p1, p2 are the associated product tree polys, we compute686// f1*p2 + f2*p1687688combine(node->scratch1, values, node->child1, start);689mul(output, node->scratch1, node->child2->poly);690691combine(node->scratch1, values, node->child2,692start + deg(node->child1->poly));693mul(node->scratch2, node->scratch1, node->child1->poly);694695add(output, output, node->scratch2);696}697}698699// Returns a polynomial F(x) of degree at most L such that F(i) = values[i]700// for each 0 <= i <= L.701void interpolate(POLY& output, const VECTOR& values)702{703assert(values.length() == L+1);704705// multiply input values pointwise by input_twist; this corrects for706// the factor (i-0) (i-1) ... (i-L) (where the i-i factor is omitted).707for (int i = 0; i <= L; i++)708mul(temp[i], values[i], input_twist[i]);709710// do the interpolation711combine(output, temp, tree, 0);712}713};714715716717/* ============================================================================718719Matrix products over arbitrary, relatively short intervals720721This section implements something similar to steps 1, 2, ... and the final722refining step of Theorem 15 of [BGS].723724============================================================================ */725726727/*728Let M0 and M1 be matrices of constants. This function evaluates729M(x) = M0 + x*M1730at x = a.731732The output matrix must already have the correct dimensions.733734*/735template <typename SCALAR, typename MATRIX>736void eval_matrix(MATRIX& output, const MATRIX& M0, const MATRIX& M1,737const SCALAR& a)738{739int n = M0.NumRows();740for (int x = 0; x < n; x++)741for (int y = 0; y < n; y++)742{743mul(output[x][y], a, M1[x][y]);744add(output[x][y], output[x][y], M0[x][y]);745}746}747748749750/*751Similar to ntl_interval_products. This is used as a subroutine of752ntl_interval_products() to handle the smaller "refining" subintervals.753Its asymptotic complexity theoretically has an extra logarithmic factor754over that of ntl_interval_products().755756PRECONDITIONS:757Let d = sum of lengths of intervals. Then 2, 3, ... 1 + floor(sqrt(d)) must758all be invertible.759760*/761template <typename SCALAR, typename POLY, typename POLYMODULUS,762typename VECTOR, typename MATRIX>763void ntl_short_interval_products(vector<MATRIX>& output,764const MATRIX& M0, const MATRIX& M1,765const vector<ZZ>& target)766{767output.clear();768769if (target.size() == 0)770return;771772int dim = M0.NumRows();773int num_intervals = target.size() / 2;774775// Determine maximum target interval length776int max_length = -1;777for (int i = 0; i < target.size(); i += 2)778{779int temp = to_ulong(target[i+1] - target[i]);780if (temp > max_length)781max_length = temp;782}783784// Select an appropriate length for the matrix products we'll use785int L, max_eval_points;786if (max_length > 2*num_intervals)787{788// The intervals are still pretty long relative to the number of789// intervals, so we're only going to do a single multipoint790// evaluation.791L = 1 + to_ulong(SqrRoot(num_intervals * to_ZZ(max_length)));792max_eval_points = L;793}794else795{796// The intervals are getting pretty short, so we probably will need797// to do several shorter multipoint evaluations.798L = 1 + max_length/2;799max_eval_points = num_intervals;800}801802// =========================================================================803// Step 1: compute entries of M(X, X+L) as polynomials in X.804805vector<POLY> polys(dim*dim);806{807// left_accum[i] = M(L-i-1, L) for 0 <= i <= L-1808// right_accum[i] = M(L, L+i+1) for 0 <= i <= L-1809vector<MATRIX> left_accum(L), right_accum(L);810811MATRIX temp;812temp.SetDims(dim, dim);813814left_accum[0].SetDims(dim, dim);815eval_matrix<SCALAR, MATRIX>(left_accum[0], M0, M1, to_scalar<SCALAR>(L));816for (int i = L-1; i >= 1; i--)817{818eval_matrix<SCALAR, MATRIX>(temp, M0, M1, to_scalar<SCALAR>(i));819mul(left_accum[L-i], temp, left_accum[L-i-1]);820}821822right_accum[0].SetDims(dim, dim);823eval_matrix<SCALAR, MATRIX>(right_accum[0], M0, M1,824to_scalar<SCALAR>(L+1));825for (int i = 1; i <= L-1; i++)826{827eval_matrix<SCALAR, MATRIX>(temp, M0, M1, to_scalar<SCALAR>(L+1+i));828mul(right_accum[i], right_accum[i-1], temp);829}830831// Use left_accum and right_accum to compute:832// initial[i] = M(i, L+i) for 0 <= i <= L833// i.e. initial[i] are the values of M(X, X+L) at X = 0, 1, ..., L.834vector<MATRIX> initial(L+1);835initial[0] = left_accum.back();836initial[L] = right_accum.back();837for (int i = 1; i <= L-1; i++)838mul(initial[i], left_accum[L-1-i], right_accum[i-1]);839840// Now interpolate entries of initial[i] to get entries of M(X, X+L)841// as polynomials of degree L.842Interpolator<SCALAR, POLY, VECTOR> interpolator(L);843VECTOR values;844values.SetLength(L+1);845for (int x = 0; x < dim; x++)846for (int y = 0; y < dim; y++)847{848for (int j = 0; j <= L; j++)849values[j] = initial[j][y][x];850interpolator.interpolate(polys[y*dim + x], values);851}852}853854// =========================================================================855// Step 2: decompose intervals into subintervals of length L which we'll856// attack by direct multipoint evaluation, plus leftover pieces that we'll857// handle with a recursive call to ntl_short_interval_products().858859// eval_points holds all the values of X for which we want to860// evaluate M(X, X+L)861VECTOR eval_points;862eval_points.SetMaxLength(max_eval_points);863864// leftover_target is the list of leftover intervals that we're going to865// later do recursively866vector<ZZ> leftover_target;867leftover_target.reserve(target.size());868869ZZ current, next;870for (int i = 0; i < target.size(); i += 2)871{872current = target[i];873next = current + L;874while (next <= target[i+1])875{876// [current, next) fits inside this interval, so peel it off into877// eval_points878append(eval_points, to_scalar<SCALAR>(current));879swap(current, next);880next = current + L;881}882if (current < target[i+1])883{884// the rest of this interval is too short to handle with M(X, X+L),885// so put it in the leftover bin886leftover_target.push_back(current);887leftover_target.push_back(target[i+1]);888}889}890891// =========================================================================892// Step 3: recursively handle leftover pieces893894// leftover_matrices[i] holds the matrix for leftover interval #i895vector<MATRIX> leftover_matrices;896ntl_short_interval_products<SCALAR, POLY, POLYMODULUS, VECTOR, MATRIX>897(leftover_matrices, M0, M1, leftover_target);898899// =========================================================================900// Step 4: evaluate M(X, X+L) at each of the evaluation points. We do this901// by breaking up the list of evaluation points into blocks of length at902// most L+1, and using multipoint evaluation on each block.903904// main_matrices[i] will hold M(X, X+L) for the i-th evaluation point X.905vector<MATRIX> main_matrices(eval_points.length());906for (int i = 0; i < main_matrices.size(); i++)907main_matrices[i].SetDims(dim, dim);908909VECTOR block, values;910block.SetMaxLength(L+1);911values.SetMaxLength(L+1);912913// for each block...914for (int i = 0; i < eval_points.length(); i += (L+1))915{916// determine length of this block, which is at most L+1917int length = eval_points.length() - i;918if (length >= (L+1))919length = (L+1);920block.SetLength(length);921922// construct Evaluator object for evaluating at these points923for (int j = 0; j < length; j++)924block[j] = eval_points[i+j];925Evaluator<SCALAR, POLY, POLYMODULUS, VECTOR> evaluator(block);926927// evaluate each entry of M(X, X+L) at those points928for (int x = 0; x < dim; x++)929for (int y = 0; y < dim; y++)930{931evaluator.evaluate(values, polys[y*dim + x]);932for (int k = 0; k < length; k++)933main_matrices[i+k][y][x] = values[k];934}935}936937// =========================================================================938// Step 5: merge together the matrices obtained from the multipoint939// evaluation step and the recursive leftover interval step.940941output.clear();942output.resize(target.size() / 2);943for (int i = 0; i < target.size()/2; i++)944output[i].SetDims(dim, dim);945946int main_index = 0; // index into main_matrices947int leftover_index = 0; // index into leftover_matrices948949MATRIX temp;950temp.SetDims(dim, dim);951952for (int i = 0; i < target.size(); i += 2)953{954current = target[i];955next = current + L;956ident(output[i/2], dim);957958while (next <= target[i+1])959{960// merge in a matrix from multipoint evaluation step961mul(temp, output[i/2], main_matrices[main_index++]);962swap(temp, output[i/2]);963swap(current, next);964next = current + L;965}966if (current < target[i+1])967{968// merge in a matrix from leftover interval step969mul(temp, output[i/2], leftover_matrices[leftover_index++]);970swap(temp, output[i/2]);971}972}973}974975976/* ============================================================================977978Matrix products over arbitrary, long intervals979980This section implements an algorithm similar to Theorem 15 of [BGS].981982============================================================================ */983984985/*986See interval_products_wrapper().987988NOTE:989This algorithm works best if the intervals are very long and don't have990much space between them. The case where the gaps are relatively large is991best handled by ntl_short_interval_products().992993*/994template <typename SCALAR, typename POLY, typename POLYMODULUS,995typename VECTOR, typename MATRIX, typename FFTREP>996void ntl_interval_products(vector<MATRIX>& output,997const MATRIX& M0, const MATRIX& M1,998const vector<ZZ>& target)999{1000assert(target.size() % 2 == 0);1001output.resize(target.size() / 2);10021003int dim = M0.NumRows();1004assert(dim == M0.NumCols());1005assert(dim == M1.NumRows());1006assert(dim == M1.NumCols());10071008// =========================================================================1009// Step 0: get as many intervals as possible using dyadic_evaluation().10101011// step0_matrix[i] is the transition matrix between step0_index[2*i]1012// and step0_index[2*i+1].1013vector<MATRIX> step0_matrix;1014vector<ZZ> step0_index;1015// preallocate the maximum number of matrices that could arise (plus safety)1016int reserve_size = target.size() +10174*NumBits(target.back() - target.front());1018step0_matrix.reserve(reserve_size);1019step0_index.reserve(2 * reserve_size);10201021ZZ current_index = target.front();1022int next_target = 0; // index into "target" array10231024// This flag indicates whether the last entry of step0_matrix is1025// still accumulating matrices (in which the right endpoint of the1026// corresponding interval hasn't been written to step0_index yet).1027int active = 0;10281029MATRIX temp_mat;1030temp_mat.SetDims(dim, dim);10311032while (current_index < target.back() - 3)1033{1034// find largest t such that 2^t*(2^t + 1) <= remaining distance to go1035ZZ remaining = target.back() - current_index;1036int t = 0;1037while ((to_ZZ(1) << (2*t)) + (1 << t) <= remaining)1038t++;1039t--;10401041// evaluate matrices for 2^t+1 intervals of length 2^t1042vector<VECTOR> dyadic_output(dim*dim);1043for (int i = 0; i < dim*dim; i++)1044dyadic_output[i].SetLength((1 << t) + 1);1045dyadic_evaluation<SCALAR, POLY, VECTOR, MATRIX, FFTREP>1046(dyadic_output, M0, M1, t, t, to_scalar<SCALAR>(current_index));10471048// Walk through the intervals we just computed. Find maximal subsequences1049// of intervals none of which contain any target endpoints. Merge them1050// together (by multiplying the appropriate matrices) and store results1051// in step0_matrix, step0_index.10521053SCALAR scratch;10541055for (int i = 0; i <= (1 << t); i++, current_index += (1 << t))1056{1057assert(next_target == target.size() ||1058target[next_target] >= current_index);10591060// Skip over target endpoints which are exactly at the beginning1061// of this interval1062while ((next_target < target.size()) &&1063(target[next_target] == current_index))1064{1065// if there's an active matrix, don't forget to close it off1066if (active)1067{1068step0_index.push_back(current_index);1069active = 0;1070}1071next_target++;1072}10731074// Test if any target endpoints are strictly within this interval.1075if ((next_target == target.size()) ||1076(target[next_target] >= current_index + (1 << t)))1077{1078// There are no target endpoints in this interval.1079if (active)1080{1081// Merge this matrix with the active one1082MATRIX& active_mat = step0_matrix.back();1083for (int y = 0; y < dim; y++)1084for (int x = 0; x < dim; x++)1085{1086SCALAR& accum = temp_mat[y][x];1087accum = 0;1088for (int z = 0; z < dim; z++)1089{1090mul(scratch, active_mat[y][z],1091dyadic_output[z*dim + x][i]);1092add(accum, accum, scratch);1093}1094}10951096swap(temp_mat, active_mat);1097}1098else1099{1100// Make this matrix into a new active one1101step0_index.push_back(current_index);1102step0_matrix.resize(step0_matrix.size() + 1);1103MATRIX& X = step0_matrix.back();1104X.SetDims(dim, dim);1105for (int y = 0; y < dim; y++)1106for (int x = 0; x < dim; x++)1107X[y][x] = dyadic_output[y*dim + x][i];1108active = 1;1109}1110}1111else1112{1113// There are target endpoints in this interval.1114if (active)1115{1116// If there is still an active matrix, close it off.1117step0_index.push_back(current_index);1118active = 0;1119}11201121// skip over any other endpoints in this interval1122while ((next_target < target.size()) &&1123(target[next_target] < current_index + (1 << t)))1124{1125next_target++;1126}1127}1128}1129}11301131// If there is still an active matrix, close it off.1132if (active)1133step0_index.push_back(current_index);11341135assert(step0_index.size() == 2*step0_matrix.size());11361137// =========================================================================1138// Step 1: Make a list of all subintervals that we are going to need in1139// the refining steps.11401141int next_step0 = 0; // index into step0_index1142vector<ZZ> step1_index; // list of pairs of endpoints of needed intervals1143step1_index.reserve(2*target.size());11441145// add sentinel endpoints to make the next loop simpler:1146step0_index.push_back(target.back() + 10);1147step0_index.push_back(target.back() + 20);11481149for (next_target = 0; next_target < target.size(); next_target += 2)1150{1151// skip dyadic intervals that come before this target interval1152while (step0_index[next_step0+1] <= target[next_target])1153next_step0 += 2;11541155if (step0_index[next_step0] < target[next_target+1])1156{1157// The next dyadic interval starts before the end of this target1158// interval.1159if (step0_index[next_step0] > target[next_target])1160{1161// The next dyadic interval starts strictly within this target1162// interval, so we need to create a refining subinterval for the1163// initial segment of this target interval.1164step1_index.push_back(target[next_target]);1165step1_index.push_back(step0_index[next_step0]);1166}11671168// Skip over dyadic intervals to find the last one still contained1169// within this target interval.1170while (step0_index[next_step0+3] <= target[next_target+1])1171next_step0 += 2;11721173if (step0_index[next_step0+1] < target[next_target+1])1174{1175// The next dyadic interval finishes strictly within this target1176// interval, so we need to create a refining subinterval for the1177// final segment of this target interval.1178step1_index.push_back(step0_index[next_step0+1]);1179step1_index.push_back(target[next_target+1]);1180}11811182// Move on to next dyadic interval1183next_step0 += 2;1184}1185else1186{1187// The next dyadic interval starts beyond (or just at the end of)1188// this target interval, so we need to create a refining subinterval1189// for this *whole* target interval.1190step1_index.push_back(target[next_target]);1191step1_index.push_back(target[next_target+1]);1192}1193}11941195// remove sentinels for my sanity1196step0_index.pop_back();1197step0_index.pop_back();11981199// Step 1b: Compute matrix products over those refining subintervals.1200vector<MATRIX> step1_matrix;1201ntl_short_interval_products<SCALAR, POLY, POLYMODULUS, VECTOR, MATRIX>1202(step1_matrix, M0, M1, step1_index);12031204assert(step1_index.size() == 2 * step1_matrix.size());12051206// =========================================================================1207// Step 2: Merge together the dyadic intervals and refining intervals into1208// a single list, in sorted order.12091210vector<MATRIX> step2_matrix(step0_matrix.size() + step1_matrix.size());1211vector<ZZ> step2_index(step0_index.size() + step1_index.size());12121213// add sentinels to make the next loop simpler1214step0_index.push_back(target.back() + 10);1215step0_index.push_back(target.back() + 20);1216step1_index.push_back(target.back() + 10);1217step1_index.push_back(target.back() + 20);12181219next_step0 = 0; // index into step0_matrix1220int next_step1 = 0; // index into step1_matrix12211222for (int next_step2 = 0; next_step2 < step2_matrix.size(); next_step2++)1223{1224if (step0_index[2*next_step0] < step1_index[2*next_step1])1225{1226// grab a matrix and pair of indices from step01227swap(step2_matrix[next_step2], step0_matrix[next_step0]);1228step2_index[2*next_step2] = step0_index[2*next_step0];1229step2_index[2*next_step2+1] = step0_index[2*next_step0+1];1230next_step0++;1231}1232else1233{1234// grab a matrix and pair of indices from step11235swap(step2_matrix[next_step2], step1_matrix[next_step1]);1236step2_index[2*next_step2] = step1_index[2*next_step1];1237step2_index[2*next_step2+1] = step1_index[2*next_step1+1];1238next_step1++;1239}1240}12411242// remove sentinels for my sanity1243step0_index.pop_back();1244step0_index.pop_back();1245step1_index.pop_back();1246step1_index.pop_back();12471248assert(step2_index.size() == 2*step2_matrix.size());12491250// =========================================================================1251// Step 3: Walk through target intervals, and merge together appropriate1252// intervals from step2 to get those target intervals.12531254int next_step2 = 0; // index into step2_matrix12551256// add sentinels to make the next loop simpler1257step2_index.push_back(target.back() + 1);1258step2_index.push_back(target.back() + 2);12591260for (int next_target = 0; next_target < target.size(); next_target += 2)1261{1262// search for step2 interval matching the start of this target interval1263while (step2_index[2*next_step2] < target[next_target])1264next_step2++;12651266assert(step2_index[2*next_step2] == target[next_target]);12671268// merge together matrices for step2 intervals contained in this target1269// interval1270swap(step2_matrix[next_step2++], output[next_target/2]);1271while (step2_index[2*next_step2+1] <= target[next_target+1])1272{1273mul(temp_mat, output[next_target/2], step2_matrix[next_step2]);1274swap(temp_mat, output[next_target/2]);1275next_step2++;1276}1277}1278}127912801281// explicit instantiations for zz_p and ZZ_p versions:128212831284template void ntl_interval_products1285<ZZ_p, ZZ_pX, ZZ_pXModulus, vec_ZZ_p, mat_ZZ_p, FFTRep>1286(vector<mat_ZZ_p>& output, const mat_ZZ_p& M0, const mat_ZZ_p& M1,1287const vector<ZZ>& target);128812891290template void ntl_interval_products1291<zz_p, zz_pX, zz_pXModulus, vec_zz_p, mat_zz_p, fftRep>1292(vector<mat_zz_p>& output, const mat_zz_p& M0, const mat_zz_p& M1,1293const vector<ZZ>& target);129412951296}; // namespace hypellfrob129712981299// ----------------------- end of file130013011302