Annotation of src/usr.bin/ssh/xmss_fast.c, Revision 1.2
1.2 ! dtucker 1: /* $OpenBSD$ */
1.1 markus 2: /*
3: xmss_fast.c version 20160722
4: Andreas Hülsing
5: Joost Rijneveld
6: Public domain.
7: */
8:
9: #include "xmss_fast.h"
10: #include <stdlib.h>
11: #include <string.h>
12: #include <stdint.h>
13:
14: #include "crypto_api.h"
15: #include "xmss_wots.h"
16: #include "xmss_hash.h"
17:
18: #include "xmss_commons.h"
19: #include "xmss_hash_address.h"
20: // For testing
21: #include "stdio.h"
22:
23:
24:
25: /**
26: * Used for pseudorandom keygeneration,
27: * generates the seed for the WOTS keypair at address addr
28: *
29: * takes n byte sk_seed and returns n byte seed using 32 byte address addr.
30: */
31: static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, uint32_t addr[8])
32: {
33: unsigned char bytes[32];
34: // Make sure that chain addr, hash addr, and key bit are 0!
35: setChainADRS(addr,0);
36: setHashADRS(addr,0);
37: setKeyAndMask(addr,0);
38: // Generate pseudorandom value
39: addr_to_byte(bytes, addr);
40: prf(seed, bytes, sk_seed, n);
41: }
42:
43: /**
44: * Initialize xmss params struct
45: * parameter names are the same as in the draft
46: * parameter k is K as used in the BDS algorithm
47: */
48: int xmss_set_params(xmss_params *params, int n, int h, int w, int k)
49: {
50: if (k >= h || k < 2 || (h - k) % 2) {
51: fprintf(stderr, "For BDS traversal, H - K must be even, with H > K >= 2!\n");
52: return 1;
53: }
54: params->h = h;
55: params->n = n;
56: params->k = k;
57: wots_params wots_par;
58: wots_set_params(&wots_par, n, w);
59: params->wots_par = wots_par;
60: return 0;
61: }
62:
63: /**
64: * Initialize BDS state struct
65: * parameter names are the same as used in the description of the BDS traversal
66: */
67: void xmss_set_bds_state(bds_state *state, unsigned char *stack, int stackoffset, unsigned char *stacklevels, unsigned char *auth, unsigned char *keep, treehash_inst *treehash, unsigned char *retain, int next_leaf)
68: {
69: state->stack = stack;
70: state->stackoffset = stackoffset;
71: state->stacklevels = stacklevels;
72: state->auth = auth;
73: state->keep = keep;
74: state->treehash = treehash;
75: state->retain = retain;
76: state->next_leaf = next_leaf;
77: }
78:
79: /**
80: * Initialize xmssmt_params struct
81: * parameter names are the same as in the draft
82: *
83: * Especially h is the total tree height, i.e. the XMSS trees have height h/d
84: */
85: int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k)
86: {
87: if (h % d) {
88: fprintf(stderr, "d must divide h without remainder!\n");
89: return 1;
90: }
91: params->h = h;
92: params->d = d;
93: params->n = n;
94: params->index_len = (h + 7) / 8;
95: xmss_params xmss_par;
96: if (xmss_set_params(&xmss_par, n, (h/d), w, k)) {
97: return 1;
98: }
99: params->xmss_par = xmss_par;
100: return 0;
101: }
102:
103: /**
104: * Computes a leaf from a WOTS public key using an L-tree.
105: */
106: static void l_tree(unsigned char *leaf, unsigned char *wots_pk, const xmss_params *params, const unsigned char *pub_seed, uint32_t addr[8])
107: {
108: unsigned int l = params->wots_par.len;
109: unsigned int n = params->n;
110: uint32_t i = 0;
111: uint32_t height = 0;
112: uint32_t bound;
113:
114: //ADRS.setTreeHeight(0);
115: setTreeHeight(addr, height);
116:
117: while (l > 1) {
118: bound = l >> 1; //floor(l / 2);
119: for (i = 0; i < bound; i++) {
120: //ADRS.setTreeIndex(i);
121: setTreeIndex(addr, i);
122: //wots_pk[i] = RAND_HASH(pk[2i], pk[2i + 1], SEED, ADRS);
123: hash_h(wots_pk+i*n, wots_pk+i*2*n, pub_seed, addr, n);
124: }
125: //if ( l % 2 == 1 ) {
126: if (l & 1) {
127: //pk[floor(l / 2) + 1] = pk[l];
128: memcpy(wots_pk+(l>>1)*n, wots_pk+(l-1)*n, n);
129: //l = ceil(l / 2);
130: l=(l>>1)+1;
131: }
132: else {
133: //l = ceil(l / 2);
134: l=(l>>1);
135: }
136: //ADRS.setTreeHeight(ADRS.getTreeHeight() + 1);
137: height++;
138: setTreeHeight(addr, height);
139: }
140: //return pk[0];
141: memcpy(leaf, wots_pk, n);
142: }
143:
144: /**
145: * Computes the leaf at a given address. First generates the WOTS key pair, then computes leaf using l_tree. As this happens position independent, we only require that addr encodes the right ltree-address.
146: */
147: static void gen_leaf_wots(unsigned char *leaf, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, uint32_t ltree_addr[8], uint32_t ots_addr[8])
148: {
149: unsigned char seed[params->n];
150: unsigned char pk[params->wots_par.keysize];
151:
152: get_seed(seed, sk_seed, params->n, ots_addr);
153: wots_pkgen(pk, seed, &(params->wots_par), pub_seed, ots_addr);
154:
155: l_tree(leaf, pk, params, pub_seed, ltree_addr);
156: }
157:
158: static int treehash_minheight_on_stack(bds_state* state, const xmss_params *params, const treehash_inst *treehash) {
159: unsigned int r = params->h, i;
160: for (i = 0; i < treehash->stackusage; i++) {
161: if (state->stacklevels[state->stackoffset - i - 1] < r) {
162: r = state->stacklevels[state->stackoffset - i - 1];
163: }
164: }
165: return r;
166: }
167:
168: /**
169: * Merkle's TreeHash algorithm. The address only needs to initialize the first 78 bits of addr. Everything else will be set by treehash.
170: * Currently only used for key generation.
171: *
172: */
173: static void treehash_setup(unsigned char *node, int height, int index, bds_state *state, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, const uint32_t addr[8])
174: {
175: unsigned int idx = index;
176: unsigned int n = params->n;
177: unsigned int h = params->h;
178: unsigned int k = params->k;
179: // use three different addresses because at this point we use all three formats in parallel
180: uint32_t ots_addr[8];
181: uint32_t ltree_addr[8];
182: uint32_t node_addr[8];
183: // only copy layer and tree address parts
184: memcpy(ots_addr, addr, 12);
185: // type = ots
186: setType(ots_addr, 0);
187: memcpy(ltree_addr, addr, 12);
188: setType(ltree_addr, 1);
189: memcpy(node_addr, addr, 12);
190: setType(node_addr, 2);
191:
192: uint32_t lastnode, i;
193: unsigned char stack[(height+1)*n];
194: unsigned int stacklevels[height+1];
195: unsigned int stackoffset=0;
196: unsigned int nodeh;
197:
198: lastnode = idx+(1<<height);
199:
200: for (i = 0; i < h-k; i++) {
201: state->treehash[i].h = i;
202: state->treehash[i].completed = 1;
203: state->treehash[i].stackusage = 0;
204: }
205:
206: i = 0;
207: for (; idx < lastnode; idx++) {
208: setLtreeADRS(ltree_addr, idx);
209: setOTSADRS(ots_addr, idx);
210: gen_leaf_wots(stack+stackoffset*n, sk_seed, params, pub_seed, ltree_addr, ots_addr);
211: stacklevels[stackoffset] = 0;
212: stackoffset++;
213: if (h - k > 0 && i == 3) {
214: memcpy(state->treehash[0].node, stack+stackoffset*n, n);
215: }
216: while (stackoffset>1 && stacklevels[stackoffset-1] == stacklevels[stackoffset-2])
217: {
218: nodeh = stacklevels[stackoffset-1];
219: if (i >> nodeh == 1) {
220: memcpy(state->auth + nodeh*n, stack+(stackoffset-1)*n, n);
221: }
222: else {
223: if (nodeh < h - k && i >> nodeh == 3) {
224: memcpy(state->treehash[nodeh].node, stack+(stackoffset-1)*n, n);
225: }
226: else if (nodeh >= h - k) {
227: memcpy(state->retain + ((1 << (h - 1 - nodeh)) + nodeh - h + (((i >> nodeh) - 3) >> 1)) * n, stack+(stackoffset-1)*n, n);
228: }
229: }
230: setTreeHeight(node_addr, stacklevels[stackoffset-1]);
231: setTreeIndex(node_addr, (idx >> (stacklevels[stackoffset-1]+1)));
232: hash_h(stack+(stackoffset-2)*n, stack+(stackoffset-2)*n, pub_seed,
233: node_addr, n);
234: stacklevels[stackoffset-2]++;
235: stackoffset--;
236: }
237: i++;
238: }
239:
240: for (i = 0; i < n; i++)
241: node[i] = stack[i];
242: }
243:
244: static void treehash_update(treehash_inst *treehash, bds_state *state, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, const uint32_t addr[8]) {
245: int n = params->n;
246:
247: uint32_t ots_addr[8];
248: uint32_t ltree_addr[8];
249: uint32_t node_addr[8];
250: // only copy layer and tree address parts
251: memcpy(ots_addr, addr, 12);
252: // type = ots
253: setType(ots_addr, 0);
254: memcpy(ltree_addr, addr, 12);
255: setType(ltree_addr, 1);
256: memcpy(node_addr, addr, 12);
257: setType(node_addr, 2);
258:
259: setLtreeADRS(ltree_addr, treehash->next_idx);
260: setOTSADRS(ots_addr, treehash->next_idx);
261:
262: unsigned char nodebuffer[2 * n];
263: unsigned int nodeheight = 0;
264: gen_leaf_wots(nodebuffer, sk_seed, params, pub_seed, ltree_addr, ots_addr);
265: while (treehash->stackusage > 0 && state->stacklevels[state->stackoffset-1] == nodeheight) {
266: memcpy(nodebuffer + n, nodebuffer, n);
267: memcpy(nodebuffer, state->stack + (state->stackoffset-1)*n, n);
268: setTreeHeight(node_addr, nodeheight);
269: setTreeIndex(node_addr, (treehash->next_idx >> (nodeheight+1)));
270: hash_h(nodebuffer, nodebuffer, pub_seed, node_addr, n);
271: nodeheight++;
272: treehash->stackusage--;
273: state->stackoffset--;
274: }
275: if (nodeheight == treehash->h) { // this also implies stackusage == 0
276: memcpy(treehash->node, nodebuffer, n);
277: treehash->completed = 1;
278: }
279: else {
280: memcpy(state->stack + state->stackoffset*n, nodebuffer, n);
281: treehash->stackusage++;
282: state->stacklevels[state->stackoffset] = nodeheight;
283: state->stackoffset++;
284: treehash->next_idx++;
285: }
286: }
287:
288: /**
289: * Computes a root node given a leaf and an authapth
290: */
291: static void validate_authpath(unsigned char *root, const unsigned char *leaf, unsigned long leafidx, const unsigned char *authpath, const xmss_params *params, const unsigned char *pub_seed, uint32_t addr[8])
292: {
293: unsigned int n = params->n;
294:
295: uint32_t i, j;
296: unsigned char buffer[2*n];
297:
298: // If leafidx is odd (last bit = 1), current path element is a right child and authpath has to go to the left.
299: // Otherwise, it is the other way around
300: if (leafidx & 1) {
301: for (j = 0; j < n; j++)
302: buffer[n+j] = leaf[j];
303: for (j = 0; j < n; j++)
304: buffer[j] = authpath[j];
305: }
306: else {
307: for (j = 0; j < n; j++)
308: buffer[j] = leaf[j];
309: for (j = 0; j < n; j++)
310: buffer[n+j] = authpath[j];
311: }
312: authpath += n;
313:
314: for (i=0; i < params->h-1; i++) {
315: setTreeHeight(addr, i);
316: leafidx >>= 1;
317: setTreeIndex(addr, leafidx);
318: if (leafidx&1) {
319: hash_h(buffer+n, buffer, pub_seed, addr, n);
320: for (j = 0; j < n; j++)
321: buffer[j] = authpath[j];
322: }
323: else {
324: hash_h(buffer, buffer, pub_seed, addr, n);
325: for (j = 0; j < n; j++)
326: buffer[j+n] = authpath[j];
327: }
328: authpath += n;
329: }
330: setTreeHeight(addr, (params->h-1));
331: leafidx >>= 1;
332: setTreeIndex(addr, leafidx);
333: hash_h(root, buffer, pub_seed, addr, n);
334: }
335:
336: /**
337: * Performs one treehash update on the instance that needs it the most.
338: * Returns 1 if such an instance was not found
339: **/
340: static char bds_treehash_update(bds_state *state, unsigned int updates, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, const uint32_t addr[8]) {
341: uint32_t i, j;
342: unsigned int level, l_min, low;
343: unsigned int h = params->h;
344: unsigned int k = params->k;
345: unsigned int used = 0;
346:
347: for (j = 0; j < updates; j++) {
348: l_min = h;
349: level = h - k;
350: for (i = 0; i < h - k; i++) {
351: if (state->treehash[i].completed) {
352: low = h;
353: }
354: else if (state->treehash[i].stackusage == 0) {
355: low = i;
356: }
357: else {
358: low = treehash_minheight_on_stack(state, params, &(state->treehash[i]));
359: }
360: if (low < l_min) {
361: level = i;
362: l_min = low;
363: }
364: }
365: if (level == h - k) {
366: break;
367: }
368: treehash_update(&(state->treehash[level]), state, sk_seed, params, pub_seed, addr);
369: used++;
370: }
371: return updates - used;
372: }
373:
374: /**
375: * Updates the state (typically NEXT_i) by adding a leaf and updating the stack
376: * Returns 1 if all leaf nodes have already been processed
377: **/
378: static char bds_state_update(bds_state *state, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, const uint32_t addr[8]) {
379: uint32_t ltree_addr[8];
380: uint32_t node_addr[8];
381: uint32_t ots_addr[8];
382:
383: int n = params->n;
384: int h = params->h;
385: int k = params->k;
386:
387: int nodeh;
388: int idx = state->next_leaf;
389: if (idx == 1 << h) {
390: return 1;
391: }
392:
393: // only copy layer and tree address parts
394: memcpy(ots_addr, addr, 12);
395: // type = ots
396: setType(ots_addr, 0);
397: memcpy(ltree_addr, addr, 12);
398: setType(ltree_addr, 1);
399: memcpy(node_addr, addr, 12);
400: setType(node_addr, 2);
401:
402: setOTSADRS(ots_addr, idx);
403: setLtreeADRS(ltree_addr, idx);
404:
405: gen_leaf_wots(state->stack+state->stackoffset*n, sk_seed, params, pub_seed, ltree_addr, ots_addr);
406:
407: state->stacklevels[state->stackoffset] = 0;
408: state->stackoffset++;
409: if (h - k > 0 && idx == 3) {
410: memcpy(state->treehash[0].node, state->stack+state->stackoffset*n, n);
411: }
412: while (state->stackoffset>1 && state->stacklevels[state->stackoffset-1] == state->stacklevels[state->stackoffset-2]) {
413: nodeh = state->stacklevels[state->stackoffset-1];
414: if (idx >> nodeh == 1) {
415: memcpy(state->auth + nodeh*n, state->stack+(state->stackoffset-1)*n, n);
416: }
417: else {
418: if (nodeh < h - k && idx >> nodeh == 3) {
419: memcpy(state->treehash[nodeh].node, state->stack+(state->stackoffset-1)*n, n);
420: }
421: else if (nodeh >= h - k) {
422: memcpy(state->retain + ((1 << (h - 1 - nodeh)) + nodeh - h + (((idx >> nodeh) - 3) >> 1)) * n, state->stack+(state->stackoffset-1)*n, n);
423: }
424: }
425: setTreeHeight(node_addr, state->stacklevels[state->stackoffset-1]);
426: setTreeIndex(node_addr, (idx >> (state->stacklevels[state->stackoffset-1]+1)));
427: hash_h(state->stack+(state->stackoffset-2)*n, state->stack+(state->stackoffset-2)*n, pub_seed, node_addr, n);
428:
429: state->stacklevels[state->stackoffset-2]++;
430: state->stackoffset--;
431: }
432: state->next_leaf++;
433: return 0;
434: }
435:
436: /**
437: * Returns the auth path for node leaf_idx and computes the auth path for the
438: * next leaf node, using the algorithm described by Buchmann, Dahmen and Szydlo
439: * in "Post Quantum Cryptography", Springer 2009.
440: */
441: static void bds_round(bds_state *state, const unsigned long leaf_idx, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, uint32_t addr[8])
442: {
443: unsigned int i;
444: unsigned int n = params->n;
445: unsigned int h = params->h;
446: unsigned int k = params->k;
447:
448: unsigned int tau = h;
449: unsigned int startidx;
450: unsigned int offset, rowidx;
451: unsigned char buf[2 * n];
452:
453: uint32_t ots_addr[8];
454: uint32_t ltree_addr[8];
455: uint32_t node_addr[8];
456: // only copy layer and tree address parts
457: memcpy(ots_addr, addr, 12);
458: // type = ots
459: setType(ots_addr, 0);
460: memcpy(ltree_addr, addr, 12);
461: setType(ltree_addr, 1);
462: memcpy(node_addr, addr, 12);
463: setType(node_addr, 2);
464:
465: for (i = 0; i < h; i++) {
466: if (! ((leaf_idx >> i) & 1)) {
467: tau = i;
468: break;
469: }
470: }
471:
472: if (tau > 0) {
473: memcpy(buf, state->auth + (tau-1) * n, n);
474: // we need to do this before refreshing state->keep to prevent overwriting
475: memcpy(buf + n, state->keep + ((tau-1) >> 1) * n, n);
476: }
477: if (!((leaf_idx >> (tau + 1)) & 1) && (tau < h - 1)) {
478: memcpy(state->keep + (tau >> 1)*n, state->auth + tau*n, n);
479: }
480: if (tau == 0) {
481: setLtreeADRS(ltree_addr, leaf_idx);
482: setOTSADRS(ots_addr, leaf_idx);
483: gen_leaf_wots(state->auth, sk_seed, params, pub_seed, ltree_addr, ots_addr);
484: }
485: else {
486: setTreeHeight(node_addr, (tau-1));
487: setTreeIndex(node_addr, leaf_idx >> tau);
488: hash_h(state->auth + tau * n, buf, pub_seed, node_addr, n);
489: for (i = 0; i < tau; i++) {
490: if (i < h - k) {
491: memcpy(state->auth + i * n, state->treehash[i].node, n);
492: }
493: else {
494: offset = (1 << (h - 1 - i)) + i - h;
495: rowidx = ((leaf_idx >> i) - 1) >> 1;
496: memcpy(state->auth + i * n, state->retain + (offset + rowidx) * n, n);
497: }
498: }
499:
500: for (i = 0; i < ((tau < h - k) ? tau : (h - k)); i++) {
501: startidx = leaf_idx + 1 + 3 * (1 << i);
502: if (startidx < 1U << h) {
503: state->treehash[i].h = i;
504: state->treehash[i].next_idx = startidx;
505: state->treehash[i].completed = 0;
506: state->treehash[i].stackusage = 0;
507: }
508: }
509: }
510: }
511:
512: /*
513: * Generates a XMSS key pair for a given parameter set.
514: * Format sk: [(32bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]
515: * Format pk: [root || PUB_SEED] omitting algo oid.
516: */
517: int xmss_keypair(unsigned char *pk, unsigned char *sk, bds_state *state, xmss_params *params)
518: {
519: unsigned int n = params->n;
520: // Set idx = 0
521: sk[0] = 0;
522: sk[1] = 0;
523: sk[2] = 0;
524: sk[3] = 0;
525: // Init SK_SEED (n byte), SK_PRF (n byte), and PUB_SEED (n byte)
526: randombytes(sk+4, 3*n);
527: // Copy PUB_SEED to public key
528: memcpy(pk+n, sk+4+2*n, n);
529:
530: uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
531:
532: // Compute root
533: treehash_setup(pk, params->h, 0, state, sk+4, params, sk+4+2*n, addr);
534: // copy root to sk
535: memcpy(sk+4+3*n, pk, n);
536: return 0;
537: }
538:
539: /**
540: * Signs a message.
541: * Returns
542: * 1. an array containing the signature followed by the message AND
543: * 2. an updated secret key!
544: *
545: */
546: int xmss_sign(unsigned char *sk, bds_state *state, unsigned char *sig_msg, unsigned long long *sig_msg_len, const unsigned char *msg, unsigned long long msglen, const xmss_params *params)
547: {
548: unsigned int h = params->h;
549: unsigned int n = params->n;
550: unsigned int k = params->k;
551: uint16_t i = 0;
552:
553: // Extract SK
554: unsigned long idx = ((unsigned long)sk[0] << 24) | ((unsigned long)sk[1] << 16) | ((unsigned long)sk[2] << 8) | sk[3];
555: unsigned char sk_seed[n];
556: memcpy(sk_seed, sk+4, n);
557: unsigned char sk_prf[n];
558: memcpy(sk_prf, sk+4+n, n);
559: unsigned char pub_seed[n];
560: memcpy(pub_seed, sk+4+2*n, n);
561:
562: // index as 32 bytes string
563: unsigned char idx_bytes_32[32];
564: to_byte(idx_bytes_32, idx, 32);
565:
566: unsigned char hash_key[3*n];
567:
568: // Update SK
569: sk[0] = ((idx + 1) >> 24) & 255;
570: sk[1] = ((idx + 1) >> 16) & 255;
571: sk[2] = ((idx + 1) >> 8) & 255;
572: sk[3] = (idx + 1) & 255;
573: // -- Secret key for this non-forward-secure version is now updated.
574: // -- A productive implementation should use a file handle instead and write the updated secret key at this point!
575:
576: // Init working params
577: unsigned char R[n];
578: unsigned char msg_h[n];
579: unsigned char ots_seed[n];
580: uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
581:
582: // ---------------------------------
583: // Message Hashing
584: // ---------------------------------
585:
586: // Message Hash:
587: // First compute pseudorandom value
588: prf(R, idx_bytes_32, sk_prf, n);
589: // Generate hash key (R || root || idx)
590: memcpy(hash_key, R, n);
591: memcpy(hash_key+n, sk+4+3*n, n);
592: to_byte(hash_key+2*n, idx, n);
593: // Then use it for message digest
594: h_msg(msg_h, msg, msglen, hash_key, 3*n, n);
595:
596: // Start collecting signature
597: *sig_msg_len = 0;
598:
599: // Copy index to signature
600: sig_msg[0] = (idx >> 24) & 255;
601: sig_msg[1] = (idx >> 16) & 255;
602: sig_msg[2] = (idx >> 8) & 255;
603: sig_msg[3] = idx & 255;
604:
605: sig_msg += 4;
606: *sig_msg_len += 4;
607:
608: // Copy R to signature
609: for (i = 0; i < n; i++)
610: sig_msg[i] = R[i];
611:
612: sig_msg += n;
613: *sig_msg_len += n;
614:
615: // ----------------------------------
616: // Now we start to "really sign"
617: // ----------------------------------
618:
619: // Prepare Address
620: setType(ots_addr, 0);
621: setOTSADRS(ots_addr, idx);
622:
623: // Compute seed for OTS key pair
624: get_seed(ots_seed, sk_seed, n, ots_addr);
625:
626: // Compute WOTS signature
627: wots_sign(sig_msg, msg_h, ots_seed, &(params->wots_par), pub_seed, ots_addr);
628:
629: sig_msg += params->wots_par.keysize;
630: *sig_msg_len += params->wots_par.keysize;
631:
632: // the auth path was already computed during the previous round
633: memcpy(sig_msg, state->auth, h*n);
634:
635: if (idx < (1U << h) - 1) {
636: bds_round(state, idx, sk_seed, params, pub_seed, ots_addr);
637: bds_treehash_update(state, (h - k) >> 1, sk_seed, params, pub_seed, ots_addr);
638: }
639:
640: /* TODO: save key/bds state here! */
641:
642: sig_msg += params->h*n;
643: *sig_msg_len += params->h*n;
644:
645: //Whipe secret elements?
646: //zerobytes(tsk, CRYPTO_SECRETKEYBYTES);
647:
648:
649: memcpy(sig_msg, msg, msglen);
650: *sig_msg_len += msglen;
651:
652: return 0;
653: }
654:
655: /**
656: * Verifies a given message signature pair under a given public key.
657: */
658: int xmss_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigned char *sig_msg, unsigned long long sig_msg_len, const unsigned char *pk, const xmss_params *params)
659: {
660: unsigned int n = params->n;
661:
662: unsigned long long i, m_len;
663: unsigned long idx=0;
664: unsigned char wots_pk[params->wots_par.keysize];
665: unsigned char pkhash[n];
666: unsigned char root[n];
667: unsigned char msg_h[n];
668: unsigned char hash_key[3*n];
669:
670: unsigned char pub_seed[n];
671: memcpy(pub_seed, pk+n, n);
672:
673: // Init addresses
674: uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
675: uint32_t ltree_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
676: uint32_t node_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
677:
678: setType(ots_addr, 0);
679: setType(ltree_addr, 1);
680: setType(node_addr, 2);
681:
682: // Extract index
683: idx = ((unsigned long)sig_msg[0] << 24) | ((unsigned long)sig_msg[1] << 16) | ((unsigned long)sig_msg[2] << 8) | sig_msg[3];
684: printf("verify:: idx = %lu\n", idx);
685:
686: // Generate hash key (R || root || idx)
687: memcpy(hash_key, sig_msg+4,n);
688: memcpy(hash_key+n, pk, n);
689: to_byte(hash_key+2*n, idx, n);
690:
691: sig_msg += (n+4);
692: sig_msg_len -= (n+4);
693:
694: // hash message
695: unsigned long long tmp_sig_len = params->wots_par.keysize+params->h*n;
696: m_len = sig_msg_len - tmp_sig_len;
697: h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n);
698:
699: //-----------------------
700: // Verify signature
701: //-----------------------
702:
703: // Prepare Address
704: setOTSADRS(ots_addr, idx);
705: // Check WOTS signature
706: wots_pkFromSig(wots_pk, sig_msg, msg_h, &(params->wots_par), pub_seed, ots_addr);
707:
708: sig_msg += params->wots_par.keysize;
709: sig_msg_len -= params->wots_par.keysize;
710:
711: // Compute Ltree
712: setLtreeADRS(ltree_addr, idx);
713: l_tree(pkhash, wots_pk, params, pub_seed, ltree_addr);
714:
715: // Compute root
716: validate_authpath(root, pkhash, idx, sig_msg, params, pub_seed, node_addr);
717:
718: sig_msg += params->h*n;
719: sig_msg_len -= params->h*n;
720:
721: for (i = 0; i < n; i++)
722: if (root[i] != pk[i])
723: goto fail;
724:
725: *msglen = sig_msg_len;
726: for (i = 0; i < *msglen; i++)
727: msg[i] = sig_msg[i];
728:
729: return 0;
730:
731:
732: fail:
733: *msglen = sig_msg_len;
734: for (i = 0; i < *msglen; i++)
735: msg[i] = 0;
736: *msglen = -1;
737: return -1;
738: }
739:
740: /*
741: * Generates a XMSSMT key pair for a given parameter set.
742: * Format sk: [(ceil(h/8) bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]
743: * Format pk: [root || PUB_SEED] omitting algo oid.
744: */
745: int xmssmt_keypair(unsigned char *pk, unsigned char *sk, bds_state *states, unsigned char *wots_sigs, xmssmt_params *params)
746: {
747: unsigned int n = params->n;
748: unsigned int i;
749: unsigned char ots_seed[params->n];
750: // Set idx = 0
751: for (i = 0; i < params->index_len; i++) {
752: sk[i] = 0;
753: }
754: // Init SK_SEED (n byte), SK_PRF (n byte), and PUB_SEED (n byte)
755: randombytes(sk+params->index_len, 3*n);
756: // Copy PUB_SEED to public key
757: memcpy(pk+n, sk+params->index_len+2*n, n);
758:
759: // Set address to point on the single tree on layer d-1
760: uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
761: setLayerADRS(addr, (params->d-1));
762: // Set up state and compute wots signatures for all but topmost tree root
763: for (i = 0; i < params->d - 1; i++) {
764: // Compute seed for OTS key pair
765: treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr);
766: setLayerADRS(addr, (i+1));
767: get_seed(ots_seed, sk+params->index_len, n, addr);
768: wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, pk, ots_seed, &(params->xmss_par.wots_par), pk+n, addr);
769: }
770: treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr);
771: memcpy(sk+params->index_len+3*n, pk, n);
772: return 0;
773: }
774:
775: /**
776: * Signs a message.
777: * Returns
778: * 1. an array containing the signature followed by the message AND
779: * 2. an updated secret key!
780: *
781: */
782: int xmssmt_sign(unsigned char *sk, bds_state *states, unsigned char *wots_sigs, unsigned char *sig_msg, unsigned long long *sig_msg_len, const unsigned char *msg, unsigned long long msglen, const xmssmt_params *params)
783: {
784: unsigned int n = params->n;
785:
786: unsigned int tree_h = params->xmss_par.h;
787: unsigned int h = params->h;
788: unsigned int k = params->xmss_par.k;
789: unsigned int idx_len = params->index_len;
790: uint64_t idx_tree;
791: uint32_t idx_leaf;
792: uint64_t i, j;
793: int needswap_upto = -1;
794: unsigned int updates;
795:
796: unsigned char sk_seed[n];
797: unsigned char sk_prf[n];
798: unsigned char pub_seed[n];
799: // Init working params
800: unsigned char R[n];
801: unsigned char msg_h[n];
802: unsigned char hash_key[3*n];
803: unsigned char ots_seed[n];
804: uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
805: uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
806: unsigned char idx_bytes_32[32];
807: bds_state tmp;
808:
809: // Extract SK
810: unsigned long long idx = 0;
811: for (i = 0; i < idx_len; i++) {
812: idx |= ((unsigned long long)sk[i]) << 8*(idx_len - 1 - i);
813: }
814:
815: memcpy(sk_seed, sk+idx_len, n);
816: memcpy(sk_prf, sk+idx_len+n, n);
817: memcpy(pub_seed, sk+idx_len+2*n, n);
818:
819: // Update SK
820: for (i = 0; i < idx_len; i++) {
821: sk[i] = ((idx + 1) >> 8*(idx_len - 1 - i)) & 255;
822: }
823: // -- Secret key for this non-forward-secure version is now updated.
824: // -- A productive implementation should use a file handle instead and write the updated secret key at this point!
825:
826:
827: // ---------------------------------
828: // Message Hashing
829: // ---------------------------------
830:
831: // Message Hash:
832: // First compute pseudorandom value
833: to_byte(idx_bytes_32, idx, 32);
834: prf(R, idx_bytes_32, sk_prf, n);
835: // Generate hash key (R || root || idx)
836: memcpy(hash_key, R, n);
837: memcpy(hash_key+n, sk+idx_len+3*n, n);
838: to_byte(hash_key+2*n, idx, n);
839:
840: // Then use it for message digest
841: h_msg(msg_h, msg, msglen, hash_key, 3*n, n);
842:
843: // Start collecting signature
844: *sig_msg_len = 0;
845:
846: // Copy index to signature
847: for (i = 0; i < idx_len; i++) {
848: sig_msg[i] = (idx >> 8*(idx_len - 1 - i)) & 255;
849: }
850:
851: sig_msg += idx_len;
852: *sig_msg_len += idx_len;
853:
854: // Copy R to signature
855: for (i = 0; i < n; i++)
856: sig_msg[i] = R[i];
857:
858: sig_msg += n;
859: *sig_msg_len += n;
860:
861: // ----------------------------------
862: // Now we start to "really sign"
863: // ----------------------------------
864:
865: // Handle lowest layer separately as it is slightly different...
866:
867: // Prepare Address
868: setType(ots_addr, 0);
869: idx_tree = idx >> tree_h;
870: idx_leaf = (idx & ((1 << tree_h)-1));
871: setLayerADRS(ots_addr, 0);
872: setTreeADRS(ots_addr, idx_tree);
873: setOTSADRS(ots_addr, idx_leaf);
874:
875: // Compute seed for OTS key pair
876: get_seed(ots_seed, sk_seed, n, ots_addr);
877:
878: // Compute WOTS signature
879: wots_sign(sig_msg, msg_h, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr);
880:
881: sig_msg += params->xmss_par.wots_par.keysize;
882: *sig_msg_len += params->xmss_par.wots_par.keysize;
883:
884: memcpy(sig_msg, states[0].auth, tree_h*n);
885: sig_msg += tree_h*n;
886: *sig_msg_len += tree_h*n;
887:
888: // prepare signature of remaining layers
889: for (i = 1; i < params->d; i++) {
890: // put WOTS signature in place
891: memcpy(sig_msg, wots_sigs + (i-1)*params->xmss_par.wots_par.keysize, params->xmss_par.wots_par.keysize);
892:
893: sig_msg += params->xmss_par.wots_par.keysize;
894: *sig_msg_len += params->xmss_par.wots_par.keysize;
895:
896: // put AUTH nodes in place
897: memcpy(sig_msg, states[i].auth, tree_h*n);
898: sig_msg += tree_h*n;
899: *sig_msg_len += tree_h*n;
900: }
901:
902: updates = (tree_h - k) >> 1;
903:
904: setTreeADRS(addr, (idx_tree + 1));
905: // mandatory update for NEXT_0 (does not count towards h-k/2) if NEXT_0 exists
906: if ((1 + idx_tree) * (1 << tree_h) + idx_leaf < (1ULL << h)) {
907: bds_state_update(&states[params->d], sk_seed, &(params->xmss_par), pub_seed, addr);
908: }
909:
910: for (i = 0; i < params->d; i++) {
911: // check if we're not at the end of a tree
912: if (! (((idx + 1) & ((1ULL << ((i+1)*tree_h)) - 1)) == 0)) {
913: idx_leaf = (idx >> (tree_h * i)) & ((1 << tree_h)-1);
914: idx_tree = (idx >> (tree_h * (i+1)));
915: setLayerADRS(addr, i);
916: setTreeADRS(addr, idx_tree);
917: if (i == (unsigned int) (needswap_upto + 1)) {
918: bds_round(&states[i], idx_leaf, sk_seed, &(params->xmss_par), pub_seed, addr);
919: }
920: updates = bds_treehash_update(&states[i], updates, sk_seed, &(params->xmss_par), pub_seed, addr);
921: setTreeADRS(addr, (idx_tree + 1));
922: // if a NEXT-tree exists for this level;
923: if ((1 + idx_tree) * (1 << tree_h) + idx_leaf < (1ULL << (h - tree_h * i))) {
924: if (i > 0 && updates > 0 && states[params->d + i].next_leaf < (1ULL << h)) {
925: bds_state_update(&states[params->d + i], sk_seed, &(params->xmss_par), pub_seed, addr);
926: updates--;
927: }
928: }
929: }
930: else if (idx < (1ULL << h) - 1) {
931: memcpy(&tmp, states+params->d + i, sizeof(bds_state));
932: memcpy(states+params->d + i, states + i, sizeof(bds_state));
933: memcpy(states + i, &tmp, sizeof(bds_state));
934:
935: setLayerADRS(ots_addr, (i+1));
936: setTreeADRS(ots_addr, ((idx + 1) >> ((i+2) * tree_h)));
937: setOTSADRS(ots_addr, (((idx >> ((i+1) * tree_h)) + 1) & ((1 << tree_h)-1)));
938:
939: get_seed(ots_seed, sk+params->index_len, n, ots_addr);
940: wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, states[i].stack, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr);
941:
942: states[params->d + i].stackoffset = 0;
943: states[params->d + i].next_leaf = 0;
944:
945: updates--; // WOTS-signing counts as one update
946: needswap_upto = i;
947: for (j = 0; j < tree_h-k; j++) {
948: states[i].treehash[j].completed = 1;
949: }
950: }
951: }
952:
953: //Whipe secret elements?
954: //zerobytes(tsk, CRYPTO_SECRETKEYBYTES);
955:
956: memcpy(sig_msg, msg, msglen);
957: *sig_msg_len += msglen;
958:
959: return 0;
960: }
961:
962: /**
963: * Verifies a given message signature pair under a given public key.
964: */
965: int xmssmt_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigned char *sig_msg, unsigned long long sig_msg_len, const unsigned char *pk, const xmssmt_params *params)
966: {
967: unsigned int n = params->n;
968:
969: unsigned int tree_h = params->xmss_par.h;
970: unsigned int idx_len = params->index_len;
971: uint64_t idx_tree;
972: uint32_t idx_leaf;
973:
974: unsigned long long i, m_len;
975: unsigned long long idx=0;
976: unsigned char wots_pk[params->xmss_par.wots_par.keysize];
977: unsigned char pkhash[n];
978: unsigned char root[n];
979: unsigned char msg_h[n];
980: unsigned char hash_key[3*n];
981:
982: unsigned char pub_seed[n];
983: memcpy(pub_seed, pk+n, n);
984:
985: // Init addresses
986: uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
987: uint32_t ltree_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
988: uint32_t node_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
989:
990: // Extract index
991: for (i = 0; i < idx_len; i++) {
992: idx |= ((unsigned long long)sig_msg[i]) << (8*(idx_len - 1 - i));
993: }
994: printf("verify:: idx = %llu\n", idx);
995: sig_msg += idx_len;
996: sig_msg_len -= idx_len;
997:
998: // Generate hash key (R || root || idx)
999: memcpy(hash_key, sig_msg,n);
1000: memcpy(hash_key+n, pk, n);
1001: to_byte(hash_key+2*n, idx, n);
1002:
1003: sig_msg += n;
1004: sig_msg_len -= n;
1005:
1006:
1007: // hash message (recall, R is now on pole position at sig_msg
1008: unsigned long long tmp_sig_len = (params->d * params->xmss_par.wots_par.keysize) + (params->h * n);
1009: m_len = sig_msg_len - tmp_sig_len;
1010: h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n);
1011:
1012:
1013: //-----------------------
1014: // Verify signature
1015: //-----------------------
1016:
1017: // Prepare Address
1018: idx_tree = idx >> tree_h;
1019: idx_leaf = (idx & ((1 << tree_h)-1));
1020: setLayerADRS(ots_addr, 0);
1021: setTreeADRS(ots_addr, idx_tree);
1022: setType(ots_addr, 0);
1023:
1024: memcpy(ltree_addr, ots_addr, 12);
1025: setType(ltree_addr, 1);
1026:
1027: memcpy(node_addr, ltree_addr, 12);
1028: setType(node_addr, 2);
1029:
1030: setOTSADRS(ots_addr, idx_leaf);
1031:
1032: // Check WOTS signature
1033: wots_pkFromSig(wots_pk, sig_msg, msg_h, &(params->xmss_par.wots_par), pub_seed, ots_addr);
1034:
1035: sig_msg += params->xmss_par.wots_par.keysize;
1036: sig_msg_len -= params->xmss_par.wots_par.keysize;
1037:
1038: // Compute Ltree
1039: setLtreeADRS(ltree_addr, idx_leaf);
1040: l_tree(pkhash, wots_pk, &(params->xmss_par), pub_seed, ltree_addr);
1041:
1042: // Compute root
1043: validate_authpath(root, pkhash, idx_leaf, sig_msg, &(params->xmss_par), pub_seed, node_addr);
1044:
1045: sig_msg += tree_h*n;
1046: sig_msg_len -= tree_h*n;
1047:
1048: for (i = 1; i < params->d; i++) {
1049: // Prepare Address
1050: idx_leaf = (idx_tree & ((1 << tree_h)-1));
1051: idx_tree = idx_tree >> tree_h;
1052:
1053: setLayerADRS(ots_addr, i);
1054: setTreeADRS(ots_addr, idx_tree);
1055: setType(ots_addr, 0);
1056:
1057: memcpy(ltree_addr, ots_addr, 12);
1058: setType(ltree_addr, 1);
1059:
1060: memcpy(node_addr, ltree_addr, 12);
1061: setType(node_addr, 2);
1062:
1063: setOTSADRS(ots_addr, idx_leaf);
1064:
1065: // Check WOTS signature
1066: wots_pkFromSig(wots_pk, sig_msg, root, &(params->xmss_par.wots_par), pub_seed, ots_addr);
1067:
1068: sig_msg += params->xmss_par.wots_par.keysize;
1069: sig_msg_len -= params->xmss_par.wots_par.keysize;
1070:
1071: // Compute Ltree
1072: setLtreeADRS(ltree_addr, idx_leaf);
1073: l_tree(pkhash, wots_pk, &(params->xmss_par), pub_seed, ltree_addr);
1074:
1075: // Compute root
1076: validate_authpath(root, pkhash, idx_leaf, sig_msg, &(params->xmss_par), pub_seed, node_addr);
1077:
1078: sig_msg += tree_h*n;
1079: sig_msg_len -= tree_h*n;
1080:
1081: }
1082:
1083: for (i = 0; i < n; i++)
1084: if (root[i] != pk[i])
1085: goto fail;
1086:
1087: *msglen = sig_msg_len;
1088: for (i = 0; i < *msglen; i++)
1089: msg[i] = sig_msg[i];
1090:
1091: return 0;
1092:
1093:
1094: fail:
1095: *msglen = sig_msg_len;
1096: for (i = 0; i < *msglen; i++)
1097: msg[i] = 0;
1098: *msglen = -1;
1099: return -1;
1100: }