[BACK]Return to xmss_wots.c CVS log [TXT][DIR] Up to [local] / src / usr.bin / ssh

Annotation of src/usr.bin/ssh/xmss_wots.c, Revision 1.2

1.2     ! dtucker     1: /* $OpenBSD$ */
1.1       markus      2: /*
                      3: wots.c version 20160722
                      4: Andreas Hülsing
                      5: Joost Rijneveld
                      6: Public domain.
                      7: */
                      8:
                      9: #include <stdlib.h>
                     10: #include <stdint.h>
                     11: #include <limits.h>
                     12: #include "xmss_commons.h"
                     13: #include "xmss_hash.h"
                     14: #include "xmss_wots.h"
                     15: #include "xmss_hash_address.h"
                     16:
                     17:
                     18: /* libm-free version of log2() for wots */
                     19: static inline int
                     20: wots_log2(uint32_t v)
                     21: {
                     22:   int      b;
                     23:
                     24:   for (b = sizeof (v) * CHAR_BIT - 1; b >= 0; b--) {
                     25:     if ((1U << b) & v) {
                     26:       return b;
                     27:     }
                     28:   }
                     29:   return 0;
                     30: }
                     31:
                     32: void
                     33: wots_set_params(wots_params *params, int n, int w)
                     34: {
                     35:   params->n = n;
                     36:   params->w = w;
                     37:   params->log_w = wots_log2(params->w);
                     38:   params->len_1 = (CHAR_BIT * n) / params->log_w;
                     39:   params->len_2 = (wots_log2(params->len_1 * (w - 1)) / params->log_w) + 1;
                     40:   params->len = params->len_1 + params->len_2;
                     41:   params->keysize = params->len * params->n;
                     42: }
                     43:
                     44: /**
                     45:  * Helper method for pseudorandom key generation
                     46:  * Expands an n-byte array into a len*n byte array
                     47:  * this is done using PRF
                     48:  */
                     49: static void expand_seed(unsigned char *outseeds, const unsigned char *inseed, const wots_params *params)
                     50: {
                     51:   uint32_t i = 0;
                     52:   unsigned char ctr[32];
                     53:   for(i = 0; i < params->len; i++){
                     54:     to_byte(ctr, i, 32);
                     55:     prf((outseeds + (i*params->n)), ctr, inseed, params->n);
                     56:   }
                     57: }
                     58:
                     59: /**
                     60:  * Computes the chaining function.
                     61:  * out and in have to be n-byte arrays
                     62:  *
                     63:  * interpretes in as start-th value of the chain
                     64:  * addr has to contain the address of the chain
                     65:  */
                     66: static void gen_chain(unsigned char *out, const unsigned char *in, unsigned int start, unsigned int steps, const wots_params *params, const unsigned char *pub_seed, uint32_t addr[8])
                     67: {
                     68:   uint32_t i, j;
                     69:   for (j = 0; j < params->n; j++)
                     70:     out[j] = in[j];
                     71:
                     72:   for (i = start; i < (start+steps) && i < params->w; i++) {
                     73:     setHashADRS(addr, i);
                     74:     hash_f(out, out, pub_seed, addr, params->n);
                     75:   }
                     76: }
                     77:
                     78: /**
                     79:  * base_w algorithm as described in draft.
                     80:  *
                     81:  *
                     82:  */
                     83: static void base_w(int *output, const int out_len, const unsigned char *input, const wots_params *params)
                     84: {
                     85:   int in = 0;
                     86:   int out = 0;
                     87:   uint32_t total = 0;
                     88:   int bits = 0;
                     89:   int consumed = 0;
                     90:
                     91:   for (consumed = 0; consumed < out_len; consumed++) {
                     92:     if (bits == 0) {
                     93:       total = input[in];
                     94:       in++;
                     95:       bits += 8;
                     96:     }
                     97:     bits -= params->log_w;
                     98:     output[out] = (total >> bits) & (params->w - 1);
                     99:     out++;
                    100:   }
                    101: }
                    102:
                    103: void wots_pkgen(unsigned char *pk, const unsigned char *sk, const wots_params *params, const unsigned char *pub_seed, uint32_t addr[8])
                    104: {
                    105:   uint32_t i;
                    106:   expand_seed(pk, sk, params);
                    107:   for (i=0; i < params->len; i++) {
                    108:     setChainADRS(addr, i);
                    109:     gen_chain(pk+i*params->n, pk+i*params->n, 0, params->w-1, params, pub_seed, addr);
                    110:   }
                    111: }
                    112:
                    113:
                    114: int wots_sign(unsigned char *sig, const unsigned char *msg, const unsigned char *sk, const wots_params *params, const unsigned char *pub_seed, uint32_t addr[8])
                    115: {
                    116:   //int basew[params->len];
                    117:   int csum = 0;
                    118:   uint32_t i = 0;
                    119:   int *basew = calloc(params->len, sizeof(int));
                    120:   if (basew == NULL)
                    121:     return -1;
                    122:
                    123:   base_w(basew, params->len_1, msg, params);
                    124:
                    125:   for (i=0; i < params->len_1; i++) {
                    126:     csum += params->w - 1 - basew[i];
                    127:   }
                    128:
                    129:   csum = csum << (8 - ((params->len_2 * params->log_w) % 8));
                    130:
                    131:   int len_2_bytes = ((params->len_2 * params->log_w) + 7) / 8;
                    132:
                    133:   unsigned char csum_bytes[len_2_bytes];
                    134:   to_byte(csum_bytes, csum, len_2_bytes);
                    135:
                    136:   int csum_basew[params->len_2];
                    137:   base_w(csum_basew, params->len_2, csum_bytes, params);
                    138:
                    139:   for (i = 0; i < params->len_2; i++) {
                    140:     basew[params->len_1 + i] = csum_basew[i];
                    141:   }
                    142:
                    143:   expand_seed(sig, sk, params);
                    144:
                    145:   for (i = 0; i < params->len; i++) {
                    146:     setChainADRS(addr, i);
                    147:     gen_chain(sig+i*params->n, sig+i*params->n, 0, basew[i], params, pub_seed, addr);
                    148:   }
                    149:   free(basew);
                    150:   return 0;
                    151: }
                    152:
                    153: int wots_pkFromSig(unsigned char *pk, const unsigned char *sig, const unsigned char *msg, const wots_params *params, const unsigned char *pub_seed, uint32_t addr[8])
                    154: {
                    155:   int csum = 0;
                    156:   uint32_t i = 0;
                    157:   int *basew = calloc(params->len, sizeof(int));
                    158:   if (basew == NULL)
                    159:     return -1;
                    160:
                    161:   base_w(basew, params->len_1, msg, params);
                    162:
                    163:   for (i=0; i < params->len_1; i++) {
                    164:     csum += params->w - 1 - basew[i];
                    165:   }
                    166:
                    167:   csum = csum << (8 - ((params->len_2 * params->log_w) % 8));
                    168:
                    169:   int len_2_bytes = ((params->len_2 * params->log_w) + 7) / 8;
                    170:
                    171:   unsigned char csum_bytes[len_2_bytes];
                    172:   to_byte(csum_bytes, csum, len_2_bytes);
                    173:
                    174:   int csum_basew[params->len_2];
                    175:   base_w(csum_basew, params->len_2, csum_bytes, params);
                    176:
                    177:   for (i = 0; i < params->len_2; i++) {
                    178:     basew[params->len_1 + i] = csum_basew[i];
                    179:   }
                    180:   for (i=0; i < params->len; i++) {
                    181:     setChainADRS(addr, i);
                    182:     gen_chain(pk+i*params->n, sig+i*params->n, basew[i], params->w-1-basew[i], params, pub_seed, addr);
                    183:   }
                    184:   free(basew);
                    185:   return 0;
                    186: }