Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/pkg
Path: blob/main/external/libecc/scripts/expand_libecc.py
2065 views
1
#/*
2
# * Copyright (C) 2017 - This file is part of libecc project
3
# *
4
# * Authors:
5
# * Ryad BENADJILA <[email protected]>
6
# * Arnaud EBALARD <[email protected]>
7
# * Jean-Pierre FLORI <[email protected]>
8
# *
9
# * Contributors:
10
# * Nicolas VIVET <[email protected]>
11
# * Karim KHALFALLAH <[email protected]>
12
# *
13
# * This software is licensed under a dual BSD and GPL v2 license.
14
# * See LICENSE file at the root folder of the project.
15
# */
16
#! /usr/bin/env python
17
18
import random, sys, re, math, os, getopt, glob, copy, hashlib, binascii, string, signal, base64
19
20
# External dependecy for SHA-3
21
# It is an independent module, since hashlib has no support
22
# for SHA-3 functions for now
23
import sha3
24
25
# Handle Python 2/3 issues
26
def is_python_2():
27
if sys.version_info[0] < 3:
28
return True
29
else:
30
return False
31
32
### Ctrl-C handler
33
def handler(signal, frame):
34
print("\nSIGINT caught: exiting ...")
35
exit(0)
36
37
# Helper to ask the user for something
38
def get_user_input(prompt):
39
# Handle the Python 2/3 issue
40
if is_python_2() == False:
41
return input(prompt)
42
else:
43
return raw_input(prompt)
44
45
##########################################################
46
#### Math helpers
47
def egcd(b, n):
48
x0, x1, y0, y1 = 1, 0, 0, 1
49
while n != 0:
50
q, b, n = b // n, n, b % n
51
x0, x1 = x1, x0 - q * x1
52
y0, y1 = y1, y0 - q * y1
53
return b, x0, y0
54
55
def modinv(a, m):
56
g, x, y = egcd(a, m)
57
if g != 1:
58
raise Exception("Error: modular inverse does not exist")
59
else:
60
return x % m
61
62
def compute_monty_coef(prime, pbitlen, wlen):
63
"""
64
Compute montgomery coeff r, r^2 and mpinv. pbitlen is the size
65
of p in bits. It is expected to be a multiple of word
66
bit size.
67
"""
68
r = (1 << int(pbitlen)) % prime
69
r_square = (1 << (2 * int(pbitlen))) % prime
70
mpinv = 2**wlen - (modinv(prime, 2**wlen))
71
return r, r_square, mpinv
72
73
def compute_div_coef(prime, pbitlen, wlen):
74
"""
75
Compute division coeffs p_normalized, p_shift and p_reciprocal.
76
"""
77
tmp = prime
78
cnt = 0
79
while tmp != 0:
80
tmp = tmp >> 1
81
cnt += 1
82
pshift = int(pbitlen - cnt)
83
primenorm = prime << pshift
84
B = 2**wlen
85
prec = B**3 // ((primenorm >> int(pbitlen - 2*wlen)) + 1) - B
86
return pshift, primenorm, prec
87
88
def is_probprime(n):
89
# ensure n is odd
90
if n % 2 == 0:
91
return False
92
# write n-1 as 2**s * d
93
# repeatedly try to divide n-1 by 2
94
s = 0
95
d = n-1
96
while True:
97
quotient, remainder = divmod(d, 2)
98
if remainder == 1:
99
break
100
s += 1
101
d = quotient
102
assert(2**s * d == n-1)
103
# test the base a to see whether it is a witness for the compositeness of n
104
def try_composite(a):
105
if pow(a, d, n) == 1:
106
return False
107
for i in range(s):
108
if pow(a, 2**i * d, n) == n-1:
109
return False
110
return True # n is definitely composite
111
for i in range(5):
112
a = random.randrange(2, n)
113
if try_composite(a):
114
return False
115
return True # no base tested showed n as composite
116
117
def legendre_symbol(a, p):
118
ls = pow(a, (p - 1) // 2, p)
119
return -1 if ls == p - 1 else ls
120
121
# Tonelli-Shanks algorithm to find square roots
122
# over prime fields
123
def mod_sqrt(a, p):
124
# Square root of 0 is 0
125
if a == 0:
126
return 0
127
# Simple cases
128
if legendre_symbol(a, p) != 1:
129
# No square residue
130
return None
131
elif p == 2:
132
return a
133
elif p % 4 == 3:
134
return pow(a, (p + 1) // 4, p)
135
s = p - 1
136
e = 0
137
while s % 2 == 0:
138
s = s // 2
139
e += 1
140
n = 2
141
while legendre_symbol(n, p) != -1:
142
n += 1
143
x = pow(a, (s + 1) // 2, p)
144
b = pow(a, s, p)
145
g = pow(n, s, p)
146
r = e
147
while True:
148
t = b
149
m = 0
150
if is_python_2():
151
for m in xrange(r):
152
if t == 1:
153
break
154
t = pow(t, 2, p)
155
else:
156
for m in range(r):
157
if t == 1:
158
break
159
t = pow(t, 2, p)
160
if m == 0:
161
return x
162
gs = pow(g, 2 ** (r - m - 1), p)
163
g = (gs * gs) % p
164
x = (x * gs) % p
165
b = (b * g) % p
166
r = m
167
168
##########################################################
169
### Math elliptic curves basic blocks
170
171
# WARNING: these blocks are only here for testing purpose and
172
# are not intended to be used in a security oriented library!
173
# This explains the usage of naive affine coordinates fomulas
174
class Curve(object):
175
def __init__(self, a, b, prime, order, cofactor, gx, gy, npoints, name, oid):
176
self.a = a
177
self.b = b
178
self.p = prime
179
self.q = order
180
self.c = cofactor
181
self.gx = gx
182
self.gy = gy
183
self.n = npoints
184
self.name = name
185
self.oid = oid
186
# Equality testing
187
def __eq__(self, other):
188
return self.__dict__ == other.__dict__
189
# Deep copy is implemented using the ~X operator
190
def __invert__(self):
191
return copy.deepcopy(self)
192
193
194
class Point(object):
195
# Affine coordinates (x, y), infinity point is (None, None)
196
def __init__(self, curve, x, y):
197
self.curve = curve
198
if x != None:
199
self.x = (x % curve.p)
200
else:
201
self.x = None
202
if y != None:
203
self.y = (y % curve.p)
204
else:
205
self.y = None
206
# Check that the point is indeed on the curve
207
if (x != None):
208
if (pow(y, 2, curve.p) != ((pow(x, 3, curve.p) + (curve.a * x) + curve.b ) % curve.p)):
209
raise Exception("Error: point is not on curve!")
210
# Addition
211
def __add__(self, Q):
212
x1 = self.x
213
y1 = self.y
214
x2 = Q.x
215
y2 = Q.y
216
curve = self.curve
217
# Check that we are on the same curve
218
if Q.curve != curve:
219
raise Exception("Point add error: two point don't have the same curve")
220
# If Q is infinity point, return ourself
221
if Q.x == None:
222
return Point(self.curve, self.x, self.y)
223
# If we are the infinity point return Q
224
if self.x == None:
225
return Q
226
# Infinity point or Doubling
227
if (x1 == x2):
228
if (((y1 + y2) % curve.p) == 0):
229
# Return infinity point
230
return Point(self.curve, None, None)
231
else:
232
# Doubling
233
L = ((3*pow(x1, 2, curve.p) + curve.a) * modinv(2*y1, curve.p)) % curve.p
234
# Addition
235
else:
236
L = ((y2 - y1) * modinv((x2 - x1) % curve.p, curve.p)) % curve.p
237
resx = (pow(L, 2, curve.p) - x1 - x2) % curve.p
238
resy = ((L * (x1 - resx)) - y1) % curve.p
239
# Return the point
240
return Point(self.curve, resx, resy)
241
# Negation
242
def __neg__(self):
243
if (self.x == None):
244
return Point(self.curve, None, None)
245
else:
246
return Point(self.curve, self.x, -self.y)
247
# Subtraction
248
def __sub__(self, other):
249
return self + (-other)
250
# Scalar mul
251
def __rmul__(self, scalar):
252
# Implement simple double and add algorithm
253
P = self
254
Q = Point(P.curve, None, None)
255
for i in range(getbitlen(scalar), 0, -1):
256
Q = Q + Q
257
if (scalar >> (i-1)) & 0x1 == 0x1:
258
Q = Q + P
259
return Q
260
# Equality testing
261
def __eq__(self, other):
262
return self.__dict__ == other.__dict__
263
# Deep copy is implemented using the ~X operator
264
def __invert__(self):
265
return copy.deepcopy(self)
266
def __str__(self):
267
if self.x == None:
268
return "Inf"
269
else:
270
return ("(x = %s, y = %s)" % (hex(self.x), hex(self.y)))
271
272
##########################################################
273
### Private and public keys structures
274
class PrivKey(object):
275
def __init__(self, curve, x):
276
self.curve = curve
277
self.x = x
278
279
class PubKey(object):
280
def __init__(self, curve, Y):
281
# Sanity check
282
if Y.curve != curve:
283
raise Exception("Error: curve and point curve differ in public key!")
284
self.curve = curve
285
self.Y = Y
286
287
class KeyPair(object):
288
def __init__(self, pubkey, privkey):
289
self.pubkey = pubkey
290
self.privkey = privkey
291
292
293
def fromprivkey(privkey, is_eckcdsa=False):
294
curve = privkey.curve
295
q = curve.q
296
gx = curve.gx
297
gy = curve.gy
298
G = Point(curve, gx, gy)
299
if is_eckcdsa == False:
300
return PubKey(curve, privkey.x * G)
301
else:
302
return PubKey(curve, modinv(privkey.x, q) * G)
303
304
def genKeyPair(curve, is_eckcdsa=False):
305
p = curve.p
306
q = curve.q
307
gx = curve.gx
308
gy = curve.gy
309
G = Point(curve, gx, gy)
310
OK = False
311
while OK == False:
312
x = getrandomint(q)
313
if x == 0:
314
continue
315
OK = True
316
privkey = PrivKey(curve, x)
317
pubkey = fromprivkey(privkey, is_eckcdsa)
318
return KeyPair(pubkey, privkey)
319
320
##########################################################
321
### Signature algorithms helpers
322
def getrandomint(modulo):
323
return random.randrange(0, modulo+1)
324
325
def getbitlen(bint):
326
"""
327
Returns the number of bits encoding an integer
328
"""
329
if bint == None:
330
return 0
331
if bint == 0:
332
# Zero is encoded on one bit
333
return 1
334
else:
335
return int(bint).bit_length()
336
337
def getbytelen(bint):
338
"""
339
Returns the number of bytes encoding an integer
340
"""
341
bitsize = getbitlen(bint)
342
bytesize = int(bitsize // 8)
343
if bitsize % 8 != 0:
344
bytesize += 1
345
return bytesize
346
347
def stringtoint(bitstring):
348
acc = 0
349
size = len(bitstring)
350
for i in range(0, size):
351
acc = acc + (ord(bitstring[i]) * (2**(8*(size - 1 - i))))
352
return acc
353
354
def inttostring(a):
355
size = int(getbytelen(a))
356
outstr = ""
357
for i in range(0, size):
358
outstr = outstr + chr((a >> (8*(size - 1 - i))) & 0xFF)
359
return outstr
360
361
def expand(bitstring, bitlen, direction):
362
bytelen = int(math.ceil(bitlen / 8.))
363
if len(bitstring) >= bytelen:
364
return bitstring
365
else:
366
if direction == "LEFT":
367
return ((bytelen-len(bitstring))*"\x00") + bitstring
368
elif direction == "RIGHT":
369
return bitstring + ((bytelen-len(bitstring))*"\x00")
370
else:
371
raise Exception("Error: unknown direction "+direction+" in expand")
372
373
def truncate(bitstring, bitlen, keep):
374
"""
375
Takes a bit string and truncates it to keep the left
376
most or the right most bits
377
"""
378
strbitlen = 8*len(bitstring)
379
# Check if truncation is needed
380
if strbitlen > bitlen:
381
if keep == "LEFT":
382
return expand(inttostring(stringtoint(bitstring) >> int(strbitlen - bitlen)), bitlen, "LEFT")
383
elif keep == "RIGHT":
384
mask = (2**bitlen)-1
385
return expand(inttostring(stringtoint(bitstring) & mask), bitlen, "LEFT")
386
else:
387
raise Exception("Error: unknown direction "+keep+" in truncate")
388
else:
389
# No need to truncate!
390
return bitstring
391
392
##########################################################
393
### Hash algorithms
394
def sha224(message):
395
ctx = hashlib.sha224()
396
if(is_python_2() == True):
397
ctx.update(message)
398
digest = ctx.digest()
399
else:
400
ctx.update(message.encode('latin-1'))
401
digest = ctx.digest().decode('latin-1')
402
return (digest, ctx.digest_size, ctx.block_size)
403
404
def sha256(message):
405
ctx = hashlib.sha256()
406
if(is_python_2() == True):
407
ctx.update(message)
408
digest = ctx.digest()
409
else:
410
ctx.update(message.encode('latin-1'))
411
digest = ctx.digest().decode('latin-1')
412
return (digest, ctx.digest_size, ctx.block_size)
413
414
def sha384(message):
415
ctx = hashlib.sha384()
416
if(is_python_2() == True):
417
ctx.update(message)
418
digest = ctx.digest()
419
else:
420
ctx.update(message.encode('latin-1'))
421
digest = ctx.digest().decode('latin-1')
422
return (digest, ctx.digest_size, ctx.block_size)
423
424
def sha512(message):
425
ctx = hashlib.sha512()
426
if(is_python_2() == True):
427
ctx.update(message)
428
digest = ctx.digest()
429
else:
430
ctx.update(message.encode('latin-1'))
431
digest = ctx.digest().decode('latin-1')
432
return (digest, ctx.digest_size, ctx.block_size)
433
434
def sha3_224(message):
435
ctx = sha3.Sha3_ctx(224)
436
if(is_python_2() == True):
437
ctx.update(message)
438
digest = ctx.digest()
439
else:
440
ctx.update(message.encode('latin-1'))
441
digest = ctx.digest().decode('latin-1')
442
return (digest, ctx.digest_size, ctx.block_size)
443
444
def sha3_256(message):
445
ctx = sha3.Sha3_ctx(256)
446
if(is_python_2() == True):
447
ctx.update(message)
448
digest = ctx.digest()
449
else:
450
ctx.update(message.encode('latin-1'))
451
digest = ctx.digest().decode('latin-1')
452
return (digest, ctx.digest_size, ctx.block_size)
453
454
def sha3_384(message):
455
ctx = sha3.Sha3_ctx(384)
456
if(is_python_2() == True):
457
ctx.update(message)
458
digest = ctx.digest()
459
else:
460
ctx.update(message.encode('latin-1'))
461
digest = ctx.digest().decode('latin-1')
462
return (digest, ctx.digest_size, ctx.block_size)
463
464
def sha3_512(message):
465
ctx = sha3.Sha3_ctx(512)
466
if(is_python_2() == True):
467
ctx.update(message)
468
digest = ctx.digest()
469
else:
470
ctx.update(message.encode('latin-1'))
471
digest = ctx.digest().decode('latin-1')
472
return (digest, ctx.digest_size, ctx.block_size)
473
474
##########################################################
475
### Signature algorithms
476
477
# *| IUF - ECDSA signature
478
# *|
479
# *| UF 1. Compute h = H(m)
480
# *| F 2. If |h| > bitlen(q), set h to bitlen(q)
481
# *| leftmost (most significant) bits of h
482
# *| F 3. e = OS2I(h) mod q
483
# *| F 4. Get a random value k in ]0,q[
484
# *| F 5. Compute W = (W_x,W_y) = kG
485
# *| F 6. Compute r = W_x mod q
486
# *| F 7. If r is 0, restart the process at step 4.
487
# *| F 8. If e == rx, restart the process at step 4.
488
# *| F 9. Compute s = k^-1 * (xr + e) mod q
489
# *| F 10. If s is 0, restart the process at step 4.
490
# *| F 11. Return (r,s)
491
def ecdsa_sign(hashfunc, keypair, message, k=None):
492
privkey = keypair.privkey
493
# Get important parameters from the curve
494
p = privkey.curve.p
495
q = privkey.curve.q
496
gx = privkey.curve.gx
497
gy = privkey.curve.gy
498
G = Point(privkey.curve, gx, gy)
499
q_limit_len = getbitlen(q)
500
# Compute the hash
501
(h, _, _) = hashfunc(message)
502
# Truncate hash value
503
h = truncate(h, q_limit_len, "LEFT")
504
# Convert the hash value to an int
505
e = stringtoint(h) % q
506
OK = False
507
while OK == False:
508
if k == None:
509
k = getrandomint(q)
510
if k == 0:
511
continue
512
W = k * G
513
r = W.x % q
514
if r == 0:
515
continue
516
if e == r * privkey.x:
517
continue
518
s = (modinv(k, q) * ((privkey.x * r) + e)) % q
519
if s == 0:
520
continue
521
OK = True
522
return ((expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT")), k)
523
524
# *| IUF - ECDSA verification
525
# *|
526
# *| I 1. Reject the signature if r or s is 0.
527
# *| UF 2. Compute h = H(m)
528
# *| F 3. If |h| > bitlen(q), set h to bitlen(q)
529
# *| leftmost (most significant) bits of h
530
# *| F 4. Compute e = OS2I(h) mod q
531
# *| F 5. Compute u = (s^-1)e mod q
532
# *| F 6. Compute v = (s^-1)r mod q
533
# *| F 7. Compute W' = uG + vY
534
# *| F 8. If W' is the point at infinity, reject the signature.
535
# *| F 9. Compute r' = W'_x mod q
536
# *| F 10. Accept the signature if and only if r equals r'
537
def ecdsa_verify(hashfunc, keypair, message, sig):
538
pubkey = keypair.pubkey
539
# Get important parameters from the curve
540
p = pubkey.curve.p
541
q = pubkey.curve.q
542
gx = pubkey.curve.gx
543
gy = pubkey.curve.gy
544
q_limit_len = getbitlen(q)
545
G = Point(pubkey.curve, gx, gy)
546
# Extract r and s
547
if len(sig) != 2*getbytelen(q):
548
raise Exception("ECDSA verify: bad signature length!")
549
r = stringtoint(sig[0:int(len(sig)/2)])
550
s = stringtoint(sig[int(len(sig)/2):])
551
if r == 0 or s == 0:
552
return False
553
# Compute the hash
554
(h, _, _) = hashfunc(message)
555
# Truncate hash value
556
h = truncate(h, q_limit_len, "LEFT")
557
# Convert the hash value to an int
558
e = stringtoint(h) % q
559
u = (modinv(s, q) * e) % q
560
v = (modinv(s, q) * r) % q
561
W_ = (u * G) + (v * pubkey.Y)
562
if W_.x == None:
563
return False
564
r_ = W_.x % q
565
if r == r_:
566
return True
567
else:
568
return False
569
570
def eckcdsa_genKeyPair(curve):
571
return genKeyPair(curve, True)
572
573
# *| IUF - ECKCDSA signature
574
# *|
575
# *| IUF 1. Compute h = H(z||m)
576
# *| F 2. If hsize > bitlen(q), set h to bitlen(q)
577
# *| rightmost (less significant) bits of h.
578
# *| F 3. Get a random value k in ]0,q[
579
# *| F 4. Compute W = (W_x,W_y) = kG
580
# *| F 5. Compute r = h(FE2OS(W_x)).
581
# *| F 6. If hsize > bitlen(q), set r to bitlen(q)
582
# *| rightmost (less significant) bits of r.
583
# *| F 7. Compute e = OS2I(r XOR h) mod q
584
# *| F 8. Compute s = x(k - e) mod q
585
# *| F 9. if s == 0, restart at step 3.
586
# *| F 10. return (r,s)
587
def eckcdsa_sign(hashfunc, keypair, message, k=None):
588
privkey = keypair.privkey
589
# Get important parameters from the curve
590
p = privkey.curve.p
591
q = privkey.curve.q
592
gx = privkey.curve.gx
593
gy = privkey.curve.gy
594
G = Point(privkey.curve, gx, gy)
595
q_limit_len = getbitlen(q)
596
# Compute the certificate data
597
(_, _, hblocksize) = hashfunc("")
598
z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT")
599
z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT")
600
if len(z) > hblocksize:
601
# Truncate
602
z = truncate(z, 8*hblocksize, "LEFT")
603
else:
604
# Expand
605
z = expand(z, 8*hblocksize, "RIGHT")
606
# Compute the hash
607
(h, _, _) = hashfunc(z + message)
608
# Truncate hash value
609
h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
610
OK = False
611
while OK == False:
612
if k == None:
613
k = getrandomint(q)
614
if k == 0:
615
continue
616
W = k * G
617
(r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT"))
618
r = truncate(r, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
619
e = (stringtoint(r) ^ stringtoint(h)) % q
620
s = (privkey.x * (k - e)) % q
621
if s == 0:
622
continue
623
OK = True
624
return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
625
626
# *| IUF - ECKCDSA verification
627
# *|
628
# *| I 1. Check the length of r:
629
# *| - if hsize > bitlen(q), r must be of
630
# *| length bitlen(q)
631
# *| - if hsize <= bitlen(q), r must be of
632
# *| length hsize
633
# *| I 2. Check that s is in ]0,q[
634
# *| IUF 3. Compute h = H(z||m)
635
# *| F 4. If hsize > bitlen(q), set h to bitlen(q)
636
# *| rightmost (less significant) bits of h.
637
# *| F 5. Compute e = OS2I(r XOR h) mod q
638
# *| F 6. Compute W' = sY + eG, where Y is the public key
639
# *| F 7. Compute r' = h(FE2OS(W'x))
640
# *| F 8. If hsize > bitlen(q), set r' to bitlen(q)
641
# *| rightmost (less significant) bits of r'.
642
# *| F 9. Check if r == r'
643
def eckcdsa_verify(hashfunc, keypair, message, sig):
644
pubkey = keypair.pubkey
645
# Get important parameters from the curve
646
p = pubkey.curve.p
647
q = pubkey.curve.q
648
gx = pubkey.curve.gx
649
gy = pubkey.curve.gy
650
G = Point(pubkey.curve, gx, gy)
651
q_limit_len = getbitlen(q)
652
(_, hsize, hblocksize) = hashfunc("")
653
# Extract r and s
654
if (8*hsize) > q_limit_len:
655
r_len = int(math.ceil(q_limit_len / 8.))
656
else:
657
r_len = hsize
658
r = stringtoint(sig[0:int(r_len)])
659
s = stringtoint(sig[int(r_len):])
660
if (s >= q) or (s < 0):
661
return False
662
# Compute the certificate data
663
z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT")
664
z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT")
665
if len(z) > hblocksize:
666
# Truncate
667
z = truncate(z, 8*hblocksize, "LEFT")
668
else:
669
# Expand
670
z = expand(z, 8*hblocksize, "RIGHT")
671
# Compute the hash
672
(h, _, _) = hashfunc(z + message)
673
# Truncate hash value
674
h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
675
e = (r ^ stringtoint(h)) % q
676
W_ = (s * pubkey.Y) + (e * G)
677
(h, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT"))
678
r_ = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
679
if stringtoint(r_) == r:
680
return True
681
else:
682
return False
683
684
# *| IUF - ECFSDSA signature
685
# *|
686
# *| I 1. Get a random value k in ]0,q[
687
# *| I 2. Compute W = (W_x,W_y) = kG
688
# *| I 3. Compute r = FE2OS(W_x)||FE2OS(W_y)
689
# *| I 4. If r is an all zero string, restart the process at step 1.
690
# *| IUF 5. Compute h = H(r||m)
691
# *| F 6. Compute e = OS2I(h) mod q
692
# *| F 7. Compute s = (k + ex) mod q
693
# *| F 8. If s is 0, restart the process at step 1 (see c. below)
694
# *| F 9. Return (r,s)
695
def ecfsdsa_sign(hashfunc, keypair, message, k=None):
696
privkey = keypair.privkey
697
# Get important parameters from the curve
698
p = privkey.curve.p
699
q = privkey.curve.q
700
gx = privkey.curve.gx
701
gy = privkey.curve.gy
702
G = Point(privkey.curve, gx, gy)
703
OK = False
704
while OK == False:
705
if k == None:
706
k = getrandomint(q)
707
if k == 0:
708
continue
709
W = k * G
710
r = expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT")
711
if stringtoint(r) == 0:
712
continue
713
(h, _, _) = hashfunc(r + message)
714
e = stringtoint(h) % q
715
s = (k + e * privkey.x) % q
716
if s == 0:
717
continue
718
OK = True
719
return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
720
721
722
# *| IUF - ECFSDSA verification
723
# *|
724
# *| I 1. Reject the signature if r is not a valid point on the curve.
725
# *| I 2. Reject the signature if s is not in ]0,q[
726
# *| IUF 3. Compute h = H(r||m)
727
# *| F 4. Convert h to an integer and then compute e = -h mod q
728
# *| F 5. compute W' = sG + eY, where Y is the public key
729
# *| F 6. Compute r' = FE2OS(W'_x)||FE2OS(W'_y)
730
# *| F 7. Accept the signature if and only if r equals r'
731
def ecfsdsa_verify(hashfunc, keypair, message, sig):
732
pubkey = keypair.pubkey
733
# Get important parameters from the curve
734
p = pubkey.curve.p
735
q = pubkey.curve.q
736
gx = pubkey.curve.gx
737
gy = pubkey.curve.gy
738
G = Point(pubkey.curve, gx, gy)
739
# Extract coordinates from r and s from signature
740
if len(sig) != (2*getbytelen(p)) + getbytelen(q):
741
raise Exception("ECFSDSA verify: bad signature length!")
742
wx = sig[:int(getbytelen(p))]
743
wy = sig[int(getbytelen(p)):int(2*getbytelen(p))]
744
r = wx + wy
745
s = stringtoint(sig[int(2*getbytelen(p)):int((2*getbytelen(p))+getbytelen(q))])
746
# Check r is on the curve
747
W = Point(pubkey.curve, stringtoint(wx), stringtoint(wy))
748
# Check s is in ]0,q[
749
if s == 0 or s > q:
750
raise Exception("ECFSDSA verify: s not in ]0,q[")
751
(h, _, _) = hashfunc(r + message)
752
e = (-stringtoint(h)) % q
753
W_ = s * G + e * pubkey.Y
754
r_ = expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT")
755
if r == r_:
756
return True
757
else:
758
return False
759
760
761
# NOTE: ISO/IEC 14888-3 standard seems to diverge from the existing implementations
762
# of ECRDSA when treating the message hash, and from the examples of certificates provided
763
# in RFC 7091 and draft-deremin-rfc4491-bis. While in ISO/IEC 14888-3 it is explicitely asked
764
# to proceed with the hash of the message as big endian, the RFCs derived from the Russian
765
# standard expect the hash value to be treated as little endian when importing it as an integer
766
# (this discrepancy is exhibited and confirmed by test vectors present in ISO/IEC 14888-3, and
767
# by X.509 certificates present in the RFCs). This seems (to be confirmed) to be a discrepancy of
768
# ISO/IEC 14888-3 algorithm description that must be fixed there.
769
#
770
# In order to be conservative, libecc uses the Russian standard behavior as expected to be in line with
771
# other implemetations, but keeps the ISO/IEC 14888-3 behavior if forced/asked by the user using
772
# the USE_ISO14888_3_ECRDSA toggle. This allows to keep backward compatibility with previous versions of the
773
# library if needed.
774
775
# *| IUF - ECRDSA signature
776
# *|
777
# *| UF 1. Compute h = H(m)
778
# *| F 2. Get a random value k in ]0,q[
779
# *| F 3. Compute W = (W_x,W_y) = kG
780
# *| F 4. Compute r = W_x mod q
781
# *| F 5. If r is 0, restart the process at step 2.
782
# *| F 6. Compute e = OS2I(h) mod q. If e is 0, set e to 1.
783
# *| NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated.
784
# *| e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h
785
# *| is reversed for RFCs.
786
# *| F 7. Compute s = (rx + ke) mod q
787
# *| F 8. If s is 0, restart the process at step 2.
788
# *| F 11. Return (r,s)
789
def ecrdsa_sign(hashfunc, keypair, message, k=None, use_iso14888_divergence=False):
790
privkey = keypair.privkey
791
# Get important parameters from the curve
792
p = privkey.curve.p
793
q = privkey.curve.q
794
gx = privkey.curve.gx
795
gy = privkey.curve.gy
796
G = Point(privkey.curve, gx, gy)
797
(h, _, _) = hashfunc(message)
798
if use_iso14888_divergence == False:
799
# Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case)
800
h = h[::-1]
801
OK = False
802
while OK == False:
803
if k == None:
804
k = getrandomint(q)
805
if k == 0:
806
continue
807
W = k * G
808
r = W.x % q
809
if r == 0:
810
continue
811
e = stringtoint(h) % q
812
if e == 0:
813
e = 1
814
s = ((r * privkey.x) + (k * e)) % q
815
if s == 0:
816
continue
817
OK = True
818
return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
819
820
# *| IUF - ECRDSA verification
821
# *|
822
# *| UF 1. Check that r and s are both in ]0,q[
823
# *| F 2. Compute h = H(m)
824
# *| F 3. Compute e = OS2I(h)^-1 mod q
825
# *| NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated.
826
# *| e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h
827
# *| is reversed for RFCs.
828
# *| F 4. Compute u = es mod q
829
# *| F 4. Compute v = -er mod q
830
# *| F 5. Compute W' = uG + vY = (W'_x, W'_y)
831
# *| F 6. Let's now compute r' = W'_x mod q
832
# *| F 7. Check r and r' are the same
833
def ecrdsa_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=False):
834
pubkey = keypair.pubkey
835
# Get important parameters from the curve
836
p = pubkey.curve.p
837
q = pubkey.curve.q
838
gx = pubkey.curve.gx
839
gy = pubkey.curve.gy
840
G = Point(pubkey.curve, gx, gy)
841
# Extract coordinates from r and s from signature
842
if len(sig) != 2*getbytelen(q):
843
raise Exception("ECRDSA verify: bad signature length!")
844
r = stringtoint(sig[:int(getbytelen(q))])
845
s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))])
846
if r == 0 or r > q:
847
raise Exception("ECRDSA verify: r not in ]0,q[")
848
if s == 0 or s > q:
849
raise Exception("ECRDSA verify: s not in ]0,q[")
850
(h, _, _) = hashfunc(message)
851
if use_iso14888_divergence == False:
852
# Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case)
853
h = h[::-1]
854
e = modinv(stringtoint(h) % q, q)
855
u = (e * s) % q
856
v = (-e * r) % q
857
W_ = u * G + v * pubkey.Y
858
r_ = W_.x % q
859
if r == r_:
860
return True
861
else:
862
return False
863
864
865
# *| IUF - ECGDSA signature
866
# *|
867
# *| UF 1. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q)
868
# *| leftmost (most significant) bits of h
869
# *| F 2. Convert e = - OS2I(h) mod q
870
# *| F 3. Get a random value k in ]0,q[
871
# *| F 4. Compute W = (W_x,W_y) = kG
872
# *| F 5. Compute r = W_x mod q
873
# *| F 6. If r is 0, restart the process at step 4.
874
# *| F 7. Compute s = x(kr + e) mod q
875
# *| F 8. If s is 0, restart the process at step 4.
876
# *| F 9. Return (r,s)
877
def ecgdsa_sign(hashfunc, keypair, message, k=None):
878
privkey = keypair.privkey
879
# Get important parameters from the curve
880
p = privkey.curve.p
881
q = privkey.curve.q
882
gx = privkey.curve.gx
883
gy = privkey.curve.gy
884
G = Point(privkey.curve, gx, gy)
885
(h, _, _) = hashfunc(message)
886
q_limit_len = getbitlen(q)
887
# Truncate hash value
888
h = truncate(h, q_limit_len, "LEFT")
889
e = (-stringtoint(h)) % q
890
OK = False
891
while OK == False:
892
if k == None:
893
k = getrandomint(q)
894
if k == 0:
895
continue
896
W = k * G
897
r = W.x % q
898
if r == 0:
899
continue
900
s = (privkey.x * ((k * r) + e)) % q
901
if s == 0:
902
continue
903
OK = True
904
return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
905
906
# *| IUF - ECGDSA verification
907
# *|
908
# *| I 1. Reject the signature if r or s is 0.
909
# *| UF 2. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q)
910
# *| leftmost (most significant) bits of h
911
# *| F 3. Compute e = OS2I(h) mod q
912
# *| F 4. Compute u = ((r^-1)e mod q)
913
# *| F 5. Compute v = ((r^-1)s mod q)
914
# *| F 6. Compute W' = uG + vY
915
# *| F 7. Compute r' = W'_x mod q
916
# *| F 8. Accept the signature if and only if r equals r'
917
def ecgdsa_verify(hashfunc, keypair, message, sig):
918
pubkey = keypair.pubkey
919
# Get important parameters from the curve
920
p = pubkey.curve.p
921
q = pubkey.curve.q
922
gx = pubkey.curve.gx
923
gy = pubkey.curve.gy
924
G = Point(pubkey.curve, gx, gy)
925
# Extract coordinates from r and s from signature
926
if len(sig) != 2*getbytelen(q):
927
raise Exception("ECGDSA verify: bad signature length!")
928
r = stringtoint(sig[:int(getbytelen(q))])
929
s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))])
930
if r == 0 or r > q:
931
raise Exception("ECGDSA verify: r not in ]0,q[")
932
if s == 0 or s > q:
933
raise Exception("ECGDSA verify: s not in ]0,q[")
934
(h, _, _) = hashfunc(message)
935
q_limit_len = getbitlen(q)
936
# Truncate hash value
937
h = truncate(h, q_limit_len, "LEFT")
938
e = stringtoint(h) % q
939
r_inv = modinv(r, q)
940
u = (r_inv * e) % q
941
v = (r_inv * s) % q
942
W_ = u * G + v * pubkey.Y
943
r_ = W_.x % q
944
if r == r_:
945
return True
946
else:
947
return False
948
949
# *| IUF - ECSDSA/ECOSDSA signature
950
# *|
951
# *| I 1. Get a random value k in ]0, q[
952
# *| I 2. Compute W = kG = (Wx, Wy)
953
# *| IUF 3. Compute r = H(Wx [|| Wy] || m)
954
# *| - In the normal version (ECSDSA), r = h(Wx || Wy || m).
955
# *| - In the optimized version (ECOSDSA), r = h(Wx || m).
956
# *| F 4. Compute e = OS2I(r) mod q
957
# *| F 5. if e == 0, restart at step 1.
958
# *| F 6. Compute s = (k + ex) mod q.
959
# *| F 7. if s == 0, restart at step 1.
960
# *| F 8. Return (r, s)
961
def ecsdsa_common_sign(hashfunc, keypair, message, optimized, k=None):
962
privkey = keypair.privkey
963
# Get important parameters from the curve
964
p = privkey.curve.p
965
q = privkey.curve.q
966
gx = privkey.curve.gx
967
gy = privkey.curve.gy
968
G = Point(privkey.curve, gx, gy)
969
OK = False
970
while OK == False:
971
if k == None:
972
k = getrandomint(q)
973
if k == 0:
974
continue
975
W = k * G
976
if optimized == False:
977
(r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT") + message)
978
else:
979
(r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + message)
980
e = stringtoint(r) % q
981
if e == 0:
982
continue
983
s = (k + (e * privkey.x)) % q
984
if s == 0:
985
continue
986
OK = True
987
return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
988
989
def ecsdsa_sign(hashfunc, keypair, message, k=None):
990
return ecsdsa_common_sign(hashfunc, keypair, message, False, k)
991
992
def ecosdsa_sign(hashfunc, keypair, message, k=None):
993
return ecsdsa_common_sign(hashfunc, keypair, message, True, k)
994
995
# *| IUF - ECSDSA/ECOSDSA verification
996
# *|
997
# *| I 1. if s is not in ]0,q[, reject the signature.x
998
# *| I 2. Compute e = -r mod q
999
# *| I 3. If e == 0, reject the signature.
1000
# *| I 4. Compute W' = sG + eY
1001
# *| IUF 5. Compute r' = H(W'x [|| W'y] || m)
1002
# *| - In the normal version (ECSDSA), r = h(W'x || W'y || m).
1003
# *| - In the optimized version (ECOSDSA), r = h(W'x || m).
1004
# *| F 6. Accept the signature if and only if r and r' are the same
1005
def ecsdsa_common_verify(hashfunc, keypair, message, sig, optimized):
1006
pubkey = keypair.pubkey
1007
# Get important parameters from the curve
1008
p = pubkey.curve.p
1009
q = pubkey.curve.q
1010
gx = pubkey.curve.gx
1011
gy = pubkey.curve.gy
1012
G = Point(pubkey.curve, gx, gy)
1013
(_, hlen, _) = hashfunc("")
1014
# Extract coordinates from r and s from signature
1015
if len(sig) != hlen + getbytelen(q):
1016
raise Exception("EC[O]SDSA verify: bad signature length!")
1017
r = stringtoint(sig[:int(hlen)])
1018
s = stringtoint(sig[int(hlen):int(hlen+getbytelen(q))])
1019
if s == 0 or s > q:
1020
raise Exception("EC[O]DSA verify: s not in ]0,q[")
1021
e = (-r) % q
1022
if e == 0:
1023
raise Exception("EC[O]DSA verify: e is null")
1024
W_ = s * G + e * pubkey.Y
1025
if optimized == False:
1026
(r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT") + message)
1027
else:
1028
(r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + message)
1029
if sig[:int(hlen)] == r_:
1030
return True
1031
else:
1032
return False
1033
1034
def ecsdsa_verify(hashfunc, keypair, message, sig):
1035
return ecsdsa_common_verify(hashfunc, keypair, message, sig, False)
1036
1037
def ecosdsa_verify(hashfunc, keypair, message, sig):
1038
return ecsdsa_common_verify(hashfunc, keypair, message, sig, True)
1039
1040
1041
##########################################################
1042
### Generate self-tests for all the algorithms
1043
1044
all_hash_funcs = [ (sha224, "SHA224"), (sha256, "SHA256"), (sha384, "SHA384"), (sha512, "SHA512"), (sha3_224, "SHA3_224"), (sha3_256, "SHA3_256"), (sha3_384, "SHA3_384"), (sha3_512, "SHA3_512") ]
1045
1046
all_sig_algs = [ (ecdsa_sign, ecdsa_verify, genKeyPair, "ECDSA"),
1047
(eckcdsa_sign, eckcdsa_verify, eckcdsa_genKeyPair, "ECKCDSA"),
1048
(ecfsdsa_sign, ecfsdsa_verify, genKeyPair, "ECFSDSA"),
1049
(ecrdsa_sign, ecrdsa_verify, genKeyPair, "ECRDSA"),
1050
(ecgdsa_sign, ecgdsa_verify, eckcdsa_genKeyPair, "ECGDSA"),
1051
(ecsdsa_sign, ecsdsa_verify, genKeyPair, "ECSDSA"),
1052
(ecosdsa_sign, ecosdsa_verify, genKeyPair, "ECOSDSA"), ]
1053
1054
1055
curr_test = 0
1056
def pretty_print_curr_test(num_test, total_gen_tests):
1057
num_decimal = int(math.log10(total_gen_tests))+1
1058
format_buf = "%0"+str(num_decimal)+"d/%0"+str(num_decimal)+"d"
1059
sys.stdout.write('\b'*((2*num_decimal)+1))
1060
sys.stdout.flush()
1061
sys.stdout.write(format_buf % (num_test, total_gen_tests))
1062
if num_test == total_gen_tests:
1063
print("")
1064
return
1065
1066
def gen_self_test(curve, hashfunc, sig_alg_sign, sig_alg_verify, sig_alg_genkeypair, num, hashfunc_name, sig_alg_name, total_gen_tests):
1067
global curr_test
1068
curr_test = curr_test + 1
1069
if num != 0:
1070
pretty_print_curr_test(curr_test, total_gen_tests)
1071
output_list = []
1072
for test_num in range(0, num):
1073
out_vectors = ""
1074
# Generate a random key pair
1075
keypair = sig_alg_genkeypair(curve)
1076
# Generate a random message with a random size
1077
size = getrandomint(256)
1078
if is_python_2():
1079
message = ''.join([random.choice(string.ascii_letters + string.digits) for n in xrange(size)])
1080
else:
1081
message = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(size)])
1082
test_name = sig_alg_name + "_" + hashfunc_name + "_" + curve.name.upper() + "_" + str(test_num)
1083
# Sign the message
1084
(sig, k) = sig_alg_sign(hashfunc, keypair, message)
1085
# Check that everything is OK with a verify
1086
if sig_alg_verify(hashfunc, keypair, message, sig) != True:
1087
raise Exception("Error during self test generation: sig verify failed! "+test_name+ " / msg="+message+" / sig="+binascii.hexlify(sig)+" / k="+hex(k)+" / privkey.x="+hex(keypair.privkey.x))
1088
if sig_alg_name == "ECRDSA":
1089
out_vectors += "#ifndef USE_ISO14888_3_ECRDSA\n"
1090
# Now generate the test vector
1091
out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n"
1092
out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n"
1093
out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n"
1094
out_vectors += "/* "+test_name+" known test vectors */\n"
1095
out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n"
1096
# k_buf MUST be exported padded to the length of q
1097
out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q))
1098
out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n"
1099
out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x))
1100
out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig))
1101
out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n"
1102
out_vectors += "\t.name = \""+test_name+"\",\n"
1103
out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n"
1104
out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n"
1105
out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n"
1106
out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n"
1107
out_vectors += "\t.hash_type = "+hashfunc_name+",\n"
1108
out_vectors += "\t.msg = \""+message+"\",\n"
1109
out_vectors += "\t.msglen = "+str(len(message))+",\n"
1110
out_vectors += "\t.sig_type = "+sig_alg_name+",\n"
1111
out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n"
1112
out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n"
1113
out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n"
1114
out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n"
1115
out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n"
1116
if sig_alg_name == "ECRDSA":
1117
out_vectors += "#endif /* !USE_ISO14888_3_ECRDSA */\n"
1118
out_name = ""
1119
if sig_alg_name == "ECRDSA":
1120
out_name += "#ifndef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n"
1121
out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n"
1122
out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n"
1123
out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n"
1124
out_name += "\t&"+test_name+"_test_case,\n"
1125
out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n"
1126
out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n"
1127
out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */"
1128
if sig_alg_name == "ECRDSA":
1129
out_name += "\n#endif /* !USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */"
1130
output_list.append((out_name, out_vectors))
1131
# In the specific case of ECRDSA, we also generate an ISO/IEC compatible test vector
1132
if sig_alg_name == "ECRDSA":
1133
out_vectors = ""
1134
(sig, k) = sig_alg_sign(hashfunc, keypair, message, use_iso14888_divergence=True)
1135
# Check that everything is OK with a verify
1136
if sig_alg_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=True) != True:
1137
raise Exception("Error during self test generation: sig verify failed! "+test_name+ " / msg="+message+" / sig="+binascii.hexlify(sig)+" / k="+hex(k)+" / privkey.x="+hex(keypair.privkey.x))
1138
out_vectors += "#ifdef USE_ISO14888_3_ECRDSA\n"
1139
# Now generate the test vector
1140
out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n"
1141
out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n"
1142
out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n"
1143
out_vectors += "/* "+test_name+" known test vectors */\n"
1144
out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n"
1145
# k_buf MUST be exported padded to the length of q
1146
out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q))
1147
out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n"
1148
out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x))
1149
out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig))
1150
out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n"
1151
out_vectors += "\t.name = \""+test_name+"\",\n"
1152
out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n"
1153
out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n"
1154
out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n"
1155
out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n"
1156
out_vectors += "\t.hash_type = "+hashfunc_name+",\n"
1157
out_vectors += "\t.msg = \""+message+"\",\n"
1158
out_vectors += "\t.msglen = "+str(len(message))+",\n"
1159
out_vectors += "\t.sig_type = "+sig_alg_name+",\n"
1160
out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n"
1161
out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n"
1162
out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n"
1163
out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n"
1164
out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n"
1165
out_vectors += "#endif /* USE_ISO14888_3_ECRDSA */\n"
1166
out_name = ""
1167
out_name += "#ifdef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n"
1168
out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n"
1169
out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n"
1170
out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n"
1171
out_name += "\t&"+test_name+"_test_case,\n"
1172
out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n"
1173
out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n"
1174
out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */\n"
1175
out_name += "#endif /* USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */"
1176
output_list.append((out_name, out_vectors))
1177
1178
return output_list
1179
1180
def gen_self_tests(curve, num):
1181
global curr_test
1182
curr_test = 0
1183
total_gen_tests = len(all_hash_funcs) * len(all_sig_algs)
1184
vectors = [[ gen_self_test(curve, hashf, sign, verify, genkp, num, hash_name, sig_alg_name, total_gen_tests)
1185
for (hashf, hash_name) in all_hash_funcs ] for (sign, verify, genkp, sig_alg_name) in all_sig_algs ]
1186
return vectors
1187
1188
##########################################################
1189
### ASN.1 stuff
1190
def parse_DER_extract_size(derbuf):
1191
# Extract the size
1192
if ord(derbuf[0]) & 0x80 != 0:
1193
encoding_len_bytes = ord(derbuf[0]) & ~0x80
1194
# Skip
1195
base = 1
1196
else:
1197
encoding_len_bytes = 1
1198
base = 0
1199
if len(derbuf) < encoding_len_bytes+1:
1200
return (False, 0, 0)
1201
else:
1202
length = stringtoint(derbuf[base:base+encoding_len_bytes])
1203
if len(derbuf) < length+encoding_len_bytes:
1204
return (False, 0, 0)
1205
else:
1206
return (True, encoding_len_bytes+base, length)
1207
1208
def extract_DER_object(derbuf, object_tag):
1209
# Check type
1210
if ord(derbuf[0]) != object_tag:
1211
# Not the type we expect ...
1212
return (False, 0, "")
1213
else:
1214
derbuf = derbuf[1:]
1215
# Extract the size
1216
(check, encoding_len, size) = parse_DER_extract_size(derbuf)
1217
if check == False:
1218
return (False, 0, "")
1219
else:
1220
if len(derbuf) < encoding_len + size:
1221
return (False, 0, "")
1222
else:
1223
return (True, size+encoding_len+1, derbuf[encoding_len:encoding_len+size])
1224
1225
def extract_DER_sequence(derbuf):
1226
return extract_DER_object(derbuf, 0x30)
1227
1228
def extract_DER_integer(derbuf):
1229
return extract_DER_object(derbuf, 0x02)
1230
1231
def extract_DER_octetstring(derbuf):
1232
return extract_DER_object(derbuf, 0x04)
1233
1234
def extract_DER_bitstring(derbuf):
1235
return extract_DER_object(derbuf, 0x03)
1236
1237
def extract_DER_oid(derbuf):
1238
return extract_DER_object(derbuf, 0x06)
1239
1240
# See ECParameters sequence in RFC 3279
1241
def parse_DER_ECParameters(derbuf):
1242
# XXX: this is a very ugly way of extracting the information
1243
# regarding an EC curve, but since the ASN.1 structure is quite
1244
# "static", this might be sufficient without embedding a full
1245
# ASN.1 parser ...
1246
# Default return (a, b, prime, order, cofactor, gx, gy)
1247
default_ret = (0, 0, 0, 0, 0, 0, 0)
1248
# Get ECParameters wrapping sequence
1249
(check, size_ECParameters, ECParameters) = extract_DER_sequence(derbuf)
1250
if check == False:
1251
return (False, default_ret)
1252
# Get integer
1253
(check, size_ECPVer, ECPVer) = extract_DER_integer(ECParameters)
1254
if check == False:
1255
return (False, default_ret)
1256
# Get sequence
1257
(check, size_FieldID, FieldID) = extract_DER_sequence(ECParameters[size_ECPVer:])
1258
if check == False:
1259
return (False, default_ret)
1260
# Get OID
1261
(check, size_Oid, Oid) = extract_DER_oid(FieldID)
1262
if check == False:
1263
return (False, default_ret)
1264
# Does the OID correspond to a prime field?
1265
if(Oid != "\x2A\x86\x48\xCE\x3D\x01\x01"):
1266
print("DER parse error: only prime fields are supported ...")
1267
return (False, default_ret)
1268
# Get prime p of prime field
1269
(check, size_P, P) = extract_DER_integer(FieldID[size_Oid:])
1270
if check == False:
1271
return (False, default_ret)
1272
# Get curve (sequence)
1273
(check, size_Curve, Curve) = extract_DER_sequence(ECParameters[size_ECPVer+size_FieldID:])
1274
if check == False:
1275
return (False, default_ret)
1276
# Get A in curve
1277
(check, size_A, A) = extract_DER_octetstring(Curve)
1278
if check == False:
1279
return (False, default_ret)
1280
# Get B in curve
1281
(check, size_B, B) = extract_DER_octetstring(Curve[size_A:])
1282
if check == False:
1283
return (False, default_ret)
1284
# Get ECPoint
1285
(check, size_ECPoint, ECPoint) = extract_DER_octetstring(ECParameters[size_ECPVer+size_FieldID+size_Curve:])
1286
if check == False:
1287
return (False, default_ret)
1288
# Get Order
1289
(check, size_Order, Order) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint:])
1290
if check == False:
1291
return (False, default_ret)
1292
# Get Cofactor
1293
(check, size_Cofactor, Cofactor) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint+size_Order:])
1294
if check == False:
1295
return (False, default_ret)
1296
# If we end up here, everything is OK, we can extract all our elements
1297
prime = stringtoint(P)
1298
a = stringtoint(A)
1299
b = stringtoint(B)
1300
order = stringtoint(Order)
1301
cofactor = stringtoint(Cofactor)
1302
# Extract Gx and Gy, see X9.62-1998
1303
if len(ECPoint) < 1:
1304
return (False, default_ret)
1305
ECPoint_type = ord(ECPoint[0])
1306
if (ECPoint_type == 0x04) or (ECPoint_type == 0x06) or (ECPoint_type == 0x07):
1307
# Uncompressed and hybrid points
1308
if len(ECPoint[1:]) % 2 != 0:
1309
return (False, default_ret)
1310
ECPoint = ECPoint[1:]
1311
gx = stringtoint(ECPoint[:int(len(ECPoint)/2)])
1312
gy = stringtoint(ECPoint[int(len(ECPoint)/2):])
1313
elif (ECPoint_type == 0x02) or (ECPoint_type == 0x03):
1314
# Compressed point: uncompress it, see X9.62-1998 section 4.2.1
1315
ECPoint = ECPoint[1:]
1316
gx = stringtoint(ECPoint)
1317
alpha = (pow(gx, 3, prime) + (a * gx) + b) % prime
1318
beta = mod_sqrt(alpha, prime)
1319
if (beta == None) or ((beta == 0) and (alpha != 0)):
1320
return (False, 0)
1321
if (beta & 0x1) == (ECPoint_type & 0x1):
1322
gy = beta
1323
else:
1324
gy = prime - beta
1325
else:
1326
print("DER parse error: hybrid points are unsupported!")
1327
return (False, default_ret)
1328
return (True, (a, b, prime, order, cofactor, gx, gy))
1329
1330
##########################################################
1331
### Text and format helpers
1332
def bigint_to_C_array(bint, size):
1333
"""
1334
Format a python big int to a C hex array
1335
"""
1336
hexstr = format(int(bint), 'x')
1337
# Left pad to the size!
1338
hexstr = ("0"*int((2*size)-len(hexstr)))+hexstr
1339
hexstr = ("0"*(len(hexstr) % 2))+hexstr
1340
out_str = "{\n"
1341
for i in range(0, len(hexstr) - 1, 2):
1342
if (i%16 == 0):
1343
if(i!=0):
1344
out_str += "\n"
1345
out_str += "\t"
1346
out_str += "0x"+hexstr[i:i+2]+", "
1347
out_str += "\n};\n"
1348
return out_str
1349
1350
def check_in_file(fname, pat):
1351
# See if the pattern is in the file.
1352
with open(fname) as f:
1353
if not any(re.search(pat, line) for line in f):
1354
return False # pattern does not occur in file so we are done.
1355
else:
1356
return True
1357
1358
def num_patterns_in_file(fname, pat):
1359
num_pat = 0
1360
with open(fname) as f:
1361
for line in f:
1362
if re.search(pat, line):
1363
num_pat = num_pat+1
1364
return num_pat
1365
1366
def file_replace_pattern(fname, pat, s_after):
1367
# first, see if the pattern is even in the file.
1368
with open(fname) as f:
1369
if not any(re.search(pat, line) for line in f):
1370
return # pattern does not occur in file so we are done.
1371
1372
# pattern is in the file, so perform replace operation.
1373
with open(fname) as f:
1374
out_fname = fname + ".tmp"
1375
out = open(out_fname, "w")
1376
for line in f:
1377
out.write(re.sub(pat, s_after, line))
1378
out.close()
1379
os.rename(out_fname, fname)
1380
1381
def file_remove_pattern(fname, pat):
1382
# first, see if the pattern is even in the file.
1383
with open(fname) as f:
1384
if not any(re.search(pat, line) for line in f):
1385
return # pattern does not occur in file so we are done.
1386
1387
# pattern is in the file, so perform remove operation.
1388
with open(fname) as f:
1389
out_fname = fname + ".tmp"
1390
out = open(out_fname, "w")
1391
for line in f:
1392
if not re.search(pat, line):
1393
out.write(line)
1394
out.close()
1395
1396
if os.path.exists(fname):
1397
remove_file(fname)
1398
os.rename(out_fname, fname)
1399
1400
def remove_file(fname):
1401
# Remove file
1402
os.remove(fname)
1403
1404
def remove_files_pattern(fpattern):
1405
[remove_file(x) for x in glob.glob(fpattern)]
1406
1407
def buffer_remove_pattern(buff, pat):
1408
if is_python_2() == False:
1409
buff = buff.decode('latin-1')
1410
if re.search(pat, buff) == None:
1411
return (False, buff) # pattern does not occur in file so we are done.
1412
# Remove the pattern
1413
buff = re.sub(pat, "", buff)
1414
return (True, buff)
1415
1416
def is_base64(s):
1417
s = ''.join([s.strip() for s in s.split("\n")])
1418
try:
1419
enc = base64.b64encode(base64.b64decode(s)).strip()
1420
if type(enc) is bytes:
1421
return enc == s.encode('latin-1')
1422
else:
1423
return enc == s
1424
except TypeError:
1425
return False
1426
1427
### Curve helpers
1428
def export_curve_int(curvename, intname, bigint, size):
1429
if bigint == None:
1430
out = "static const u8 "+curvename+"_"+intname+"[] = {\n\t0x00,\n};\n"
1431
out += "TO_EC_STR_PARAM_FIXED_SIZE("+curvename+"_"+intname+", 0);\n\n"
1432
else:
1433
out = "static const u8 "+curvename+"_"+intname+"[] = "+bigint_to_C_array(bigint, size)+"\n"
1434
out += "TO_EC_STR_PARAM("+curvename+"_"+intname+");\n\n"
1435
return out
1436
1437
def export_curve_string(curvename, stringname, stringvalue):
1438
out = "static const u8 "+curvename+"_"+stringname+"[] = \""+stringvalue+"\";\n"
1439
out += "TO_EC_STR_PARAM("+curvename+"_"+stringname+");\n\n"
1440
return out
1441
1442
def export_curve_struct(curvename, paramname, paramnamestr):
1443
return "\t."+paramname+" = &"+curvename+"_"+paramnamestr+"_str_param, \n"
1444
1445
def curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards):
1446
"""
1447
Take as input some elliptic curve parameters and generate the
1448
C parameters in a string
1449
"""
1450
bytesize = int(pbitlen / 8)
1451
if pbitlen % 8 != 0:
1452
bytesize += 1
1453
# Compute the rounded word size for each word size
1454
if bytesize % 8 != 0:
1455
wordsbitsize64 = 8*((int(bytesize/8)+1)*8)
1456
else:
1457
wordsbitsize64 = 8*bytesize
1458
if bytesize % 4 != 0:
1459
wordsbitsize32 = 8*((int(bytesize/4)+1)*4)
1460
else:
1461
wordsbitsize32 = 8*bytesize
1462
if bytesize % 2 != 0:
1463
wordsbitsize16 = 8*((int(bytesize/2)+1)*2)
1464
else:
1465
wordsbitsize16 = 8*bytesize
1466
# Compute some parameters
1467
(r64, r_square64, mpinv64) = compute_monty_coef(prime, wordsbitsize64, 64)
1468
(r32, r_square32, mpinv32) = compute_monty_coef(prime, wordsbitsize32, 32)
1469
(r16, r_square16, mpinv16) = compute_monty_coef(prime, wordsbitsize16, 16)
1470
# Compute p_reciprocal for each word size
1471
(pshift64, primenorm64, p_reciprocal64) = compute_div_coef(prime, wordsbitsize64, 64)
1472
(pshift32, primenorm32, p_reciprocal32) = compute_div_coef(prime, wordsbitsize32, 32)
1473
(pshift16, primenorm16, p_reciprocal16) = compute_div_coef(prime, wordsbitsize16, 16)
1474
# Compute the number of points on the curve
1475
npoints = order * cofactor
1476
1477
# Now output the parameters
1478
ec_params_string = "#include <libecc/lib_ecc_config.h>\n"
1479
ec_params_string += "#ifdef WITH_CURVE_"+name.upper()+"\n\n"
1480
ec_params_string += "#ifndef __EC_PARAMS_"+name.upper()+"_H__\n"
1481
ec_params_string += "#define __EC_PARAMS_"+name.upper()+"_H__\n"
1482
ec_params_string += "#include <libecc/curves/known/ec_params_external.h>\n"
1483
ec_params_string += export_curve_int(name, "p", prime, bytesize)
1484
1485
ec_params_string += "#define CURVE_"+name.upper()+"_P_BITLEN "+str(pbitlen)+"\n"
1486
ec_params_string += export_curve_int(name, "p_bitlen", pbitlen, getbytelen(pbitlen))
1487
1488
ec_params_string += "#if (WORD_BYTES == 8) /* 64-bit words */\n"
1489
ec_params_string += export_curve_int(name, "r", r64, getbytelen(r64))
1490
ec_params_string += export_curve_int(name, "r_square", r_square64, getbytelen(r_square64))
1491
ec_params_string += export_curve_int(name, "mpinv", mpinv64, getbytelen(mpinv64))
1492
ec_params_string += export_curve_int(name, "p_shift", pshift64, getbytelen(pshift64))
1493
ec_params_string += export_curve_int(name, "p_normalized", primenorm64, getbytelen(primenorm64))
1494
ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal64, getbytelen(p_reciprocal64))
1495
ec_params_string += "#elif (WORD_BYTES == 4) /* 32-bit words */\n"
1496
ec_params_string += export_curve_int(name, "r", r32, getbytelen(r32))
1497
ec_params_string += export_curve_int(name, "r_square", r_square32, getbytelen(r_square32))
1498
ec_params_string += export_curve_int(name, "mpinv", mpinv32, getbytelen(mpinv32))
1499
ec_params_string += export_curve_int(name, "p_shift", pshift32, getbytelen(pshift32))
1500
ec_params_string += export_curve_int(name, "p_normalized", primenorm32, getbytelen(primenorm32))
1501
ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal32, getbytelen(p_reciprocal32))
1502
ec_params_string += "#elif (WORD_BYTES == 2) /* 16-bit words */\n"
1503
ec_params_string += export_curve_int(name, "r", r16, getbytelen(r16))
1504
ec_params_string += export_curve_int(name, "r_square", r_square16, getbytelen(r_square16))
1505
ec_params_string += export_curve_int(name, "mpinv", mpinv16, getbytelen(mpinv16))
1506
ec_params_string += export_curve_int(name, "p_shift", pshift16, getbytelen(pshift16))
1507
ec_params_string += export_curve_int(name, "p_normalized", primenorm16, getbytelen(primenorm16))
1508
ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal16, getbytelen(p_reciprocal16))
1509
ec_params_string += "#else /* unknown word size */\n"
1510
ec_params_string += "#error \"Unsupported word size\"\n"
1511
ec_params_string += "#endif\n\n"
1512
1513
ec_params_string += export_curve_int(name, "a", a, bytesize)
1514
ec_params_string += export_curve_int(name, "b", b, bytesize)
1515
1516
curve_order_bitlen = getbitlen(npoints)
1517
ec_params_string += "#define CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN "+str(curve_order_bitlen)+"\n"
1518
ec_params_string += export_curve_int(name, "curve_order", npoints, getbytelen(npoints))
1519
1520
ec_params_string += export_curve_int(name, "gx", gx, bytesize)
1521
ec_params_string += export_curve_int(name, "gy", gy, bytesize)
1522
ec_params_string += export_curve_int(name, "gz", 0x01, bytesize)
1523
1524
qbitlen = getbitlen(order)
1525
1526
ec_params_string += export_curve_int(name, "gen_order", order, getbytelen(order))
1527
ec_params_string += "#define CURVE_"+name.upper()+"_Q_BITLEN "+str(qbitlen)+"\n"
1528
ec_params_string += export_curve_int(name, "gen_order_bitlen", qbitlen, getbytelen(qbitlen))
1529
1530
ec_params_string += export_curve_int(name, "cofactor", cofactor, getbytelen(cofactor))
1531
1532
ec_params_string += export_curve_int(name, "alpha_montgomery", alpha_montgomery, getbytelen(alpha_montgomery))
1533
ec_params_string += export_curve_int(name, "gamma_montgomery", gamma_montgomery, getbytelen(gamma_montgomery))
1534
ec_params_string += export_curve_int(name, "alpha_edwards", alpha_edwards, getbytelen(alpha_edwards))
1535
1536
ec_params_string += export_curve_string(name, "name", name.upper());
1537
1538
if oid == None:
1539
oid = ""
1540
ec_params_string += export_curve_string(name, "oid", oid);
1541
1542
ec_params_string += "static const ec_str_params "+name+"_str_params = {\n"+\
1543
export_curve_struct(name, "p", "p") +\
1544
export_curve_struct(name, "p_bitlen", "p_bitlen") +\
1545
export_curve_struct(name, "r", "r") +\
1546
export_curve_struct(name, "r_square", "r_square") +\
1547
export_curve_struct(name, "mpinv", "mpinv") +\
1548
export_curve_struct(name, "p_shift", "p_shift") +\
1549
export_curve_struct(name, "p_normalized", "p_normalized") +\
1550
export_curve_struct(name, "p_reciprocal", "p_reciprocal") +\
1551
export_curve_struct(name, "a", "a") +\
1552
export_curve_struct(name, "b", "b") +\
1553
export_curve_struct(name, "curve_order", "curve_order") +\
1554
export_curve_struct(name, "gx", "gx") +\
1555
export_curve_struct(name, "gy", "gy") +\
1556
export_curve_struct(name, "gz", "gz") +\
1557
export_curve_struct(name, "gen_order", "gen_order") +\
1558
export_curve_struct(name, "gen_order_bitlen", "gen_order_bitlen") +\
1559
export_curve_struct(name, "cofactor", "cofactor") +\
1560
export_curve_struct(name, "alpha_montgomery", "alpha_montgomery") +\
1561
export_curve_struct(name, "gamma_montgomery", "gamma_montgomery") +\
1562
export_curve_struct(name, "alpha_edwards", "alpha_edwards") +\
1563
export_curve_struct(name, "oid", "oid") +\
1564
export_curve_struct(name, "name", "name")
1565
ec_params_string += "};\n\n"
1566
1567
ec_params_string += "/*\n"+\
1568
" * Compute max bit length of all curves for p and q\n"+\
1569
" */\n"+\
1570
"#ifndef CURVES_MAX_P_BIT_LEN\n"+\
1571
"#define CURVES_MAX_P_BIT_LEN 0\n"+\
1572
"#endif\n"+\
1573
"#if (CURVES_MAX_P_BIT_LEN < CURVE_"+name.upper()+"_P_BITLEN)\n"+\
1574
"#undef CURVES_MAX_P_BIT_LEN\n"+\
1575
"#define CURVES_MAX_P_BIT_LEN CURVE_"+name.upper()+"_P_BITLEN\n"+\
1576
"#endif\n"+\
1577
"#ifndef CURVES_MAX_Q_BIT_LEN\n"+\
1578
"#define CURVES_MAX_Q_BIT_LEN 0\n"+\
1579
"#endif\n"+\
1580
"#if (CURVES_MAX_Q_BIT_LEN < CURVE_"+name.upper()+"_Q_BITLEN)\n"+\
1581
"#undef CURVES_MAX_Q_BIT_LEN\n"+\
1582
"#define CURVES_MAX_Q_BIT_LEN CURVE_"+name.upper()+"_Q_BITLEN\n"+\
1583
"#endif\n"+\
1584
"#ifndef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\
1585
"#define CURVES_MAX_CURVE_ORDER_BIT_LEN 0\n"+\
1586
"#endif\n"+\
1587
"#if (CURVES_MAX_CURVE_ORDER_BIT_LEN < CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN)\n"+\
1588
"#undef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\
1589
"#define CURVES_MAX_CURVE_ORDER_BIT_LEN CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN\n"+\
1590
"#endif\n\n"
1591
1592
ec_params_string += "/*\n"+\
1593
" * Compute and adapt max name and oid length\n"+\
1594
" */\n"+\
1595
"#ifndef MAX_CURVE_OID_LEN\n"+\
1596
"#define MAX_CURVE_OID_LEN 0\n"+\
1597
"#endif\n"+\
1598
"#ifndef MAX_CURVE_NAME_LEN\n"+\
1599
"#define MAX_CURVE_NAME_LEN 0\n"+\
1600
"#endif\n"+\
1601
"#if (MAX_CURVE_OID_LEN < "+str(len(oid)+1)+")\n"+\
1602
"#undef MAX_CURVE_OID_LEN\n"+\
1603
"#define MAX_CURVE_OID_LEN "+str(len(oid)+1)+"\n"+\
1604
"#endif\n"+\
1605
"#if (MAX_CURVE_NAME_LEN < "+str(len(name.upper())+1)+")\n"+\
1606
"#undef MAX_CURVE_NAME_LEN\n"+\
1607
"#define MAX_CURVE_NAME_LEN "+str(len(name.upper())+1)+"\n"+\
1608
"#endif\n\n"
1609
1610
ec_params_string += "#endif /* __EC_PARAMS_"+name.upper()+"_H__ */\n\n"+"#endif /* WITH_CURVE_"+name.upper()+" */\n"
1611
1612
return ec_params_string
1613
1614
def usage():
1615
print("This script is intented to *statically* expand the ECC library with user defined curves.")
1616
print("By statically we mean that the source code of libecc is expanded with new curves parameters through")
1617
print("automatic code generation filling place holders in the existing code base of the library. Though the")
1618
print("choice of static code generation versus dynamic curves import (such as what OpenSSL does) might be")
1619
print("argued, this choice has been driven by simplicity and security design decisions: we want libecc to have")
1620
print("all its parameters (such as memory consumption) set at compile time and statically adapted to the curves.")
1621
print("Since libecc only supports curves over prime fields, the script can only add this kind of curves.")
1622
print("This script implements elliptic curves and ISO signature algorithms from scratch over Python's multi-precision")
1623
print("big numbers library. Addition and doubling over curves use naive formulas. Please DO NOT use the functions of this")
1624
print("script for production code: they are not securely implemented and are very inefficient. Their only purpose is to expand")
1625
print("libecc and produce test vectors.")
1626
print("")
1627
print("In order to add a curve, there are two ways:")
1628
print("Adding a user defined curve with explicit parameters:")
1629
print("-----------------------------------------------------")
1630
print(sys.argv[0]+" --name=\"YOURCURVENAME\" --prime=... --order=... --a=... --b=... --gx=... --gy=... --cofactor=... --oid=THEOID")
1631
print("\t> name: name of the curve in the form of a string")
1632
print("\t> prime: prime number representing the curve prime field")
1633
print("\t> order: prime number representing the generator order")
1634
print("\t> cofactor: cofactor of the curve")
1635
print("\t> a: 'a' coefficient of the short Weierstrass equation of the curve")
1636
print("\t> b: 'b' coefficient of the short Weierstrass equation of the curve")
1637
print("\t> gx: x coordinate of the generator G")
1638
print("\t> gy: y coordinate of the generator G")
1639
print("\t> oid: optional OID of the curve")
1640
print(" Notes:")
1641
print(" ******")
1642
print("\t1) These elements are verified to indeed satisfy the curve equation.")
1643
print("\t2) All the numbers can be given either in decimal or hexadecimal format with a prepending '0x'.")
1644
print("\t3) The script automatically generates all the necessary files for the curve to be included in the library." )
1645
print("\tYou will find the new curve definition in the usual 'lib_ecc_config.h' file (one can activate it or not at compile time).")
1646
print("")
1647
print("Adding a user defined curve through RFC3279 ASN.1 parameters:")
1648
print("-------------------------------------------------------------")
1649
print(sys.argv[0]+" --name=\"YOURCURVENAME\" --ECfile=... --oid=THEOID")
1650
print("\t> ECfile: the DER or PEM encoded file containing the curve parameters (see RFC3279)")
1651
print(" Notes:")
1652
print("\tCurve parameters encoded in DER or PEM format can be generated with tools like OpenSSL (among others). As an illustrative example,")
1653
print("\tone can list all the supported curves under OpenSSL with:")
1654
print("\t $ openssl ecparam -list_curves")
1655
print("\tOnly the listed so called \"prime\" curves are supported. Then, one can extract an explicit curve representation in ASN.1")
1656
print("\tas defined in RFC3279, for example for BRAINPOOLP320R1:")
1657
print("\t $ openssl ecparam -param_enc explicit -outform DER -name brainpoolP320r1 -out brainpoolP320r1.der")
1658
print("")
1659
print("Removing user defined curves:")
1660
print("-----------------------------")
1661
print("\t*All the user defined curves can be removed with the --remove-all toggle.")
1662
print("\t*A specific named user define curve can be removed with the --remove toggle: in this case the --name option is used to ")
1663
print("\tlocate which named curve must be deleted.")
1664
print("")
1665
print("Test vectors:")
1666
print("-------------")
1667
print("\tTest vectors can be automatically generated and added to the library self tests when providing the --add-test-vectors=X toggle.")
1668
print("\tIn this case, X test vectors will be generated for *each* (curve, sign algorithm, hash algorithm) 3-uplet (beware of combinatorial")
1669
print("\tissues when X is big!). These tests are transparently added and compiled with the self tests.")
1670
return
1671
1672
def get_int(instring):
1673
if len(instring) == 0:
1674
return 0
1675
if len(instring) >= 2:
1676
if instring[:2] == "0x":
1677
return int(instring, 16)
1678
return int(instring)
1679
1680
def parse_cmd_line(args):
1681
"""
1682
Get elliptic curve parameters from command line
1683
"""
1684
name = oid = prime = a = b = gx = gy = g = order = cofactor = ECfile = remove = remove_all = add_test_vectors = None
1685
alpha_montgomery = gamma_montgomery = alpha_edwards = None
1686
try:
1687
opts, args = getopt.getopt(sys.argv[1:], ":h", ["help", "remove", "remove-all", "name=", "prime=", "a=", "b=", "generator=", "gx=", "gy=", "order=", "cofactor=", "alpha_montgomery=","gamma_montgomery=", "alpha_edwards=", "ECfile=", "oid=", "add-test-vectors="])
1688
except getopt.GetoptError as err:
1689
# print help information and exit:
1690
print(err) # will print something like "option -a not recognized"
1691
usage()
1692
return False
1693
for o, arg in opts:
1694
if o in ("-h", "--help"):
1695
usage()
1696
return True
1697
elif o in ("--name"):
1698
name = arg
1699
# Prepend the custom string before name to avoid any collision
1700
name = "user_defined_"+name
1701
# Replace any unwanted name char
1702
name = re.sub("\-", "_", name)
1703
elif o in ("--oid="):
1704
oid = arg
1705
elif o in ("--prime"):
1706
prime = get_int(arg.replace(' ', ''))
1707
elif o in ("--a"):
1708
a = get_int(arg.replace(' ', ''))
1709
elif o in ("--b"):
1710
b = get_int(arg.replace(' ', ''))
1711
elif o in ("--gx"):
1712
gx = get_int(arg.replace(' ', ''))
1713
elif o in ("--gy"):
1714
gy = get_int(arg.replace(' ', ''))
1715
elif o in ("--generator"):
1716
g = arg.replace(' ', '')
1717
elif o in ("--order"):
1718
order = get_int(arg.replace(' ', ''))
1719
elif o in ("--cofactor"):
1720
cofactor = get_int(arg.replace(' ', ''))
1721
elif o in ("--alpha_montgomery"):
1722
alpha_montgomery = get_int(arg.replace(' ', ''))
1723
elif o in ("--gamma_montgomery"):
1724
gamma_montgomery = get_int(arg.replace(' ', ''))
1725
elif o in ("--alpha_edwards"):
1726
alpha_edwards = get_int(arg.replace(' ', ''))
1727
elif o in ("--remove"):
1728
remove = True
1729
elif o in ("--remove-all"):
1730
remove_all = True
1731
elif o in ("--add-test-vectors"):
1732
add_test_vectors = get_int(arg.replace(' ', ''))
1733
elif o in ("--ECfile"):
1734
ECfile = arg
1735
else:
1736
print("unhandled option")
1737
usage()
1738
return False
1739
1740
# File paths
1741
script_path = os.path.abspath(os.path.dirname(sys.argv[0])) + "/"
1742
ec_params_path = script_path + "../include/libecc/curves/user_defined/"
1743
curves_list_path = script_path + "../include/libecc/curves/"
1744
lib_ecc_types_path = script_path + "../include/libecc/"
1745
lib_ecc_config_path = script_path + "../include/libecc/"
1746
ec_self_tests_path = script_path + "../src/tests/"
1747
meson_options_path = script_path + "../"
1748
1749
# If remove is True, we have been asked to remove already existing user defined curves
1750
if remove == True:
1751
if name == None:
1752
print("--remove option expects a curve name provided with --name")
1753
return False
1754
asked = ""
1755
while asked != "y" and asked != "n":
1756
asked = get_user_input("You asked to remove everything related to user defined "+name.replace("user_defined_", "")+" curve. Enter y to confirm, n to cancel [y/n]. ")
1757
if asked == "n":
1758
print("NOT removing curve "+name.replace("user_defined_", "")+" (cancelled).")
1759
return True
1760
# Remove any user defined stuff with given name
1761
print("Removing user defined curve "+name.replace("user_defined_", "")+" ...")
1762
if name == None:
1763
print("Error: you must provide a curve name with --remove")
1764
return False
1765
file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name+".*")
1766
file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name.upper()+".*")
1767
file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*"+name.upper()+".*")
1768
file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*"+name.upper()+".*")
1769
file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name+".*")
1770
file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name.upper()+".*")
1771
file_remove_pattern(meson_options_path + "meson.options", ".*"+name.lower()+".*")
1772
try:
1773
remove_file(ec_params_path + "ec_params_"+name+".h")
1774
except:
1775
print("Error: curve name "+name+" does not seem to be present in the sources!")
1776
return False
1777
try:
1778
remove_file(ec_self_tests_path + "ec_self_tests_core_"+name+".h")
1779
except:
1780
print("Warning: curve name "+name+" self tests do not seem to be present ...")
1781
return True
1782
return True
1783
if remove_all == True:
1784
asked = ""
1785
while asked != "y" and asked != "n":
1786
asked = get_user_input("You asked to remove everything related to ALL user defined curves. Enter y to confirm, n to cancel [y/n]. ")
1787
if asked == "n":
1788
print("NOT removing user defined curves (cancelled).")
1789
return True
1790
# Remove any user defined stuff with given name
1791
print("Removing ALL user defined curves ...")
1792
# Remove any user defined stuff (whatever name)
1793
file_remove_pattern(curves_list_path + "curves_list.h", ".*user_defined.*")
1794
file_remove_pattern(curves_list_path + "curves_list.h", ".*USER_DEFINED.*")
1795
file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*USER_DEFINED.*")
1796
file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*USER_DEFINED.*")
1797
file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*USER_DEFINED.*")
1798
file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*user_defined.*")
1799
file_remove_pattern(meson_options_path + "meson.options", ".*user_defined.*")
1800
remove_files_pattern(ec_params_path + "ec_params_user_defined_*.h")
1801
remove_files_pattern(ec_self_tests_path + "ec_self_tests_core_user_defined_*.h")
1802
return True
1803
1804
# If a g is provided, split it in two gx and gy
1805
if g != None:
1806
if (len(g)/2)%2 == 0:
1807
gx = get_int(g[:len(g)/2])
1808
gy = get_int(g[len(g)/2:])
1809
else:
1810
# This is probably a generator encapsulated in a bit string
1811
if g[0:2] != "04":
1812
print("Error: provided generator g is not conforming!")
1813
return False
1814
else:
1815
g = g[2:]
1816
gx = get_int(g[:len(g)/2])
1817
gy = get_int(g[len(g)/2:])
1818
if ECfile != None:
1819
# ASN.1 DER input incompatible with other options
1820
if (prime != None) or (a != None) or (b != None) or (gx != None) or (gy != None) or (order != None) or (cofactor != None):
1821
print("Error: option ECfile incompatible with explicit (prime, a, b, gx, gy, order, cofactor) options!")
1822
return False
1823
# We need at least a name
1824
if (name == None):
1825
print("Error: option ECfile needs a curve name!")
1826
return False
1827
# Open the file
1828
try:
1829
buf = open(ECfile, 'rb').read()
1830
except:
1831
print("Error: cannot open ECfile file "+ECfile)
1832
return False
1833
# Check if we have a PEM or a DER file
1834
(check, derbuf) = buffer_remove_pattern(buf, "-----.*-----")
1835
if (check == True):
1836
# This a PEM file, proceed with base64 decoding
1837
if(is_base64(derbuf) == False):
1838
print("Error: error when decoding ECfile file "+ECfile+" (seems to be PEM, but failed to decode)")
1839
return False
1840
derbuf = base64.b64decode(derbuf)
1841
(check, (a, b, prime, order, cofactor, gx, gy)) = parse_DER_ECParameters(derbuf)
1842
if (check == False):
1843
print("Error: error when parsing ECfile file "+ECfile+" (malformed or unsupported ASN.1)")
1844
return False
1845
1846
else:
1847
if (prime == None) or (a == None) or (b == None) or (gx == None) or (gy == None) or (order == None) or (cofactor == None) or (name == None):
1848
err_string = (prime == None)*"prime "+(a == None)*"a "+(b == None)*"b "+(gx == None)*"gx "+(gy == None)*"gy "+(order == None)*"order "+(cofactor == None)*"cofactor "+(name == None)*"name "
1849
print("Error: missing "+err_string+" in explicit curve definition (name, prime, a, b, gx, gy, order, cofactor)!")
1850
print("See the help with -h or --help")
1851
return False
1852
1853
# Some sanity checks here
1854
# Check that prime is indeed a prime
1855
if is_probprime(prime) == False:
1856
print("Error: given prime is *NOT* prime!")
1857
return False
1858
if is_probprime(order) == False:
1859
print("Error: given order is *NOT* prime!")
1860
return False
1861
if (a > prime) or (b > prime) or (gx > prime) or (gy > prime):
1862
err_string = (a > prime)*"a "+(b > prime)*"b "+(gx > prime)*"gx "+(gy > prime)*"gy "
1863
print("Error: "+err_string+"is > prime")
1864
return False
1865
# Check that the provided generator is on the curve
1866
if pow(gy, 2, prime) != ((pow(gx, 3, prime) + (a*gx) + b) % prime):
1867
print("Error: the given parameters (prime, a, b, gx, gy) do not verify the elliptic curve equation!")
1868
return False
1869
1870
# Check Montgomery and Edwards transfer coefficients
1871
if ((alpha_montgomery != None) and (gamma_montgomery == None)) or ((alpha_montgomery == None) and (gamma_montgomery != None)):
1872
print("Error: alpha_montgomery and gamma_montgomery must be both defined if used!")
1873
return False
1874
if (alpha_edwards != None):
1875
if (alpha_montgomery == None) or (gamma_montgomery == None):
1876
print("Error: alpha_edwards needs alpha_montgomery and gamma_montgomery to be both defined if used!")
1877
return False
1878
1879
# Now that we have our parameters, call the function to get bitlen
1880
pbitlen = getbitlen(prime)
1881
ec_params = curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards)
1882
# Check if there is a name collision somewhere
1883
if os.path.exists(ec_params_path + "ec_params_"+name+".h") == True :
1884
print("Error: file %s already exists!" % (ec_params_path + "ec_params_"+name+".h"))
1885
return False
1886
if (check_in_file(curves_list_path + "curves_list.h", "ec_params_"+name+"_str_params") == True) or (check_in_file(curves_list_path + "curves_list.h", "WITH_CURVE_"+name.upper()+"\n") == True) or (check_in_file(lib_ecc_types_path + "lib_ecc_types.h", "WITH_CURVE_"+name.upper()+"\n") == True):
1887
print("Error: name %s already exists in files" % ("ec_params_"+name))
1888
return False
1889
# Create a new file with the parameters
1890
if not os.path.exists(ec_params_path):
1891
# Create the "user_defined" folder if it does not exist
1892
os.mkdir(ec_params_path)
1893
f = open(ec_params_path + "ec_params_"+name+".h", 'w')
1894
f.write(ec_params)
1895
f.close()
1896
# Include the file in curves_list.h
1897
magic = "ADD curves header here"
1898
magic_re = "\/\* "+magic+" \*\/"
1899
magic_back = "/* "+magic+" */"
1900
file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#include <libecc/curves/user_defined/ec_params_"+name+".h>\n"+magic_back)
1901
# Add the curve mapping
1902
magic = "ADD curves mapping here"
1903
magic_re = "\/\* "+magic+" \*\/"
1904
magic_back = "/* "+magic+" */"
1905
file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t{ .type = "+name.upper()+", .params = &"+name+"_str_params },\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back)
1906
# Add the new curve type in the enum
1907
# First we get the number of already defined curves so that we increment the enum counter
1908
num_with_curve = num_patterns_in_file(lib_ecc_types_path + "lib_ecc_types.h", "#ifdef WITH_CURVE_")
1909
magic = "ADD curves type here"
1910
magic_re = "\/\* "+magic+" \*\/"
1911
magic_back = "/* "+magic+" */"
1912
file_replace_pattern(lib_ecc_types_path + "lib_ecc_types.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t"+name.upper()+" = "+str(num_with_curve+1)+",\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back)
1913
# Add the new curve define in the config
1914
magic = "ADD curves define here"
1915
magic_re = "\/\* "+magic+" \*\/"
1916
magic_back = "/* "+magic+" */"
1917
file_replace_pattern(lib_ecc_config_path + "lib_ecc_config.h", magic_re, "#define WITH_CURVE_"+name.upper()+"\n"+magic_back)
1918
# Add the new curve meson option in the meson.options file
1919
magic = "ADD curves meson option here"
1920
magic_re = "# " + magic
1921
magic_back = "# " + magic
1922
file_replace_pattern(meson_options_path + "meson.options", magic_re, "\t'"+name.lower()+"',\n"+magic_back)
1923
1924
# Do we need to add some test vectors?
1925
if add_test_vectors != None:
1926
print("Test vectors generation asked: this can take some time! Please wait ...")
1927
# Create curve
1928
c = Curve(a, b, prime, order, cofactor, gx, gy, cofactor * order, name, oid)
1929
# Generate key pair for the algorithm
1930
vectors = gen_self_tests(c, add_test_vectors)
1931
# Iterate through all the tests
1932
f = open(ec_self_tests_path + "ec_self_tests_core_"+name+".h", 'w')
1933
for l in vectors:
1934
for v in l:
1935
for case in v:
1936
(case_name, case_vector) = case
1937
# Add the new test case
1938
magic = "ADD curve test case here"
1939
magic_re = "\/\* "+magic+" \*\/"
1940
magic_back = "/* "+magic+" */"
1941
file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, case_name+"\n"+magic_back)
1942
# Create/Increment the header file
1943
f.write(case_vector)
1944
f.close()
1945
# Add the new test cases header
1946
magic = "ADD curve test vectors header here"
1947
magic_re = "\/\* "+magic+" \*\/"
1948
magic_back = "/* "+magic+" */"
1949
file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, "#include \"ec_self_tests_core_"+name+".h\"\n"+magic_back)
1950
return True
1951
1952
1953
#### Main
1954
if __name__ == "__main__":
1955
signal.signal(signal.SIGINT, handler)
1956
parse_cmd_line(sys.argv[1:])
1957
1958