[BACK]Return to sntrup4591761.c CVS log [TXT][DIR] Up to [local] / src / usr.bin / ssh

Annotation of src/usr.bin/ssh/sntrup4591761.c, Revision 1.2

1.1       djm         1: #include <string.h>
                      2: #include "crypto_api.h"
                      3:
1.2     ! djm         4: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/int32_sort.h */
        !             5: #ifndef int32_sort_h
        !             6: #define int32_sort_h
1.1       djm         7:
                      8:
1.2     ! djm         9: static void int32_sort(crypto_int32 *,int);
1.1       djm        10:
1.2     ! djm        11: #endif
        !            12:
        !            13: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/int32_sort.c */
        !            14: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
        !            15:
        !            16:
        !            17: static void minmax(crypto_int32 *x,crypto_int32 *y)
1.1       djm        18: {
1.2     ! djm        19:   crypto_uint32 xi = *x;
        !            20:   crypto_uint32 yi = *y;
        !            21:   crypto_uint32 xy = xi ^ yi;
        !            22:   crypto_uint32 c = yi - xi;
        !            23:   c ^= xy & (c ^ yi);
        !            24:   c >>= 31;
        !            25:   c = -c;
        !            26:   c &= xy;
        !            27:   *x = xi ^ c;
        !            28:   *y = yi ^ c;
        !            29: }
        !            30:
        !            31: static void int32_sort(crypto_int32 *x,int n)
        !            32: {
        !            33:   int top,p,q,i;
1.1       djm        34:
                     35:   if (n < 2) return;
                     36:   top = 1;
                     37:   while (top < n - top) top += top;
                     38:
                     39:   for (p = top;p > 0;p >>= 1) {
                     40:     for (i = 0;i < n - p;++i)
                     41:       if (!(i & p))
1.2     ! djm        42:         minmax(x + i,x + i + p);
        !            43:     for (q = top;q > p;q >>= 1)
        !            44:       for (i = 0;i < n - q;++i)
        !            45:         if (!(i & p))
        !            46:           minmax(x + i + p,x + i + q);
1.1       djm        47:   }
                     48: }
                     49:
1.2     ! djm        50: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/small.h */
1.1       djm        51: #ifndef small_h
                     52: #define small_h
                     53:
                     54:
                     55: typedef crypto_int8 small;
                     56:
                     57: static void small_encode(unsigned char *,const small *);
                     58:
                     59: static void small_decode(small *,const unsigned char *);
                     60:
                     61:
                     62: static void small_random(small *);
                     63:
                     64: static void small_random_weightw(small *);
                     65:
                     66: #endif
                     67:
1.2     ! djm        68: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/mod3.h */
1.1       djm        69: #ifndef mod3_h
                     70: #define mod3_h
                     71:
                     72:
                     73: /* -1 if x is nonzero, 0 otherwise */
                     74: static inline int mod3_nonzero_mask(small x)
                     75: {
                     76:   return -x*x;
                     77: }
                     78:
                     79: /* input between -100000 and 100000 */
                     80: /* output between -1 and 1 */
                     81: static inline small mod3_freeze(crypto_int32 a)
                     82: {
                     83:   a -= 3 * ((10923 * a) >> 15);
                     84:   a -= 3 * ((89478485 * a + 134217728) >> 28);
                     85:   return a;
                     86: }
                     87:
                     88: static inline small mod3_minusproduct(small a,small b,small c)
                     89: {
                     90:   crypto_int32 A = a;
                     91:   crypto_int32 B = b;
                     92:   crypto_int32 C = c;
                     93:   return mod3_freeze(A - B * C);
                     94: }
                     95:
                     96: static inline small mod3_plusproduct(small a,small b,small c)
                     97: {
                     98:   crypto_int32 A = a;
                     99:   crypto_int32 B = b;
                    100:   crypto_int32 C = c;
                    101:   return mod3_freeze(A + B * C);
                    102: }
                    103:
                    104: static inline small mod3_product(small a,small b)
                    105: {
                    106:   return a * b;
                    107: }
                    108:
                    109: static inline small mod3_sum(small a,small b)
                    110: {
                    111:   crypto_int32 A = a;
                    112:   crypto_int32 B = b;
                    113:   return mod3_freeze(A + B);
                    114: }
                    115:
                    116: static inline small mod3_reciprocal(small a1)
                    117: {
                    118:   return a1;
                    119: }
                    120:
                    121: static inline small mod3_quotient(small num,small den)
                    122: {
                    123:   return mod3_product(num,mod3_reciprocal(den));
                    124: }
                    125:
                    126: #endif
                    127:
1.2     ! djm       128: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/modq.h */
1.1       djm       129: #ifndef modq_h
                    130: #define modq_h
                    131:
                    132:
                    133: typedef crypto_int16 modq;
                    134:
                    135: /* -1 if x is nonzero, 0 otherwise */
                    136: static inline int modq_nonzero_mask(modq x)
                    137: {
                    138:   crypto_int32 r = (crypto_uint16) x;
                    139:   r = -r;
                    140:   r >>= 30;
                    141:   return r;
                    142: }
                    143:
                    144: /* input between -9000000 and 9000000 */
                    145: /* output between -2295 and 2295 */
                    146: static inline modq modq_freeze(crypto_int32 a)
                    147: {
                    148:   a -= 4591 * ((228 * a) >> 20);
                    149:   a -= 4591 * ((58470 * a + 134217728) >> 28);
                    150:   return a;
                    151: }
                    152:
                    153: static inline modq modq_minusproduct(modq a,modq b,modq c)
                    154: {
                    155:   crypto_int32 A = a;
                    156:   crypto_int32 B = b;
                    157:   crypto_int32 C = c;
                    158:   return modq_freeze(A - B * C);
                    159: }
                    160:
                    161: static inline modq modq_plusproduct(modq a,modq b,modq c)
                    162: {
                    163:   crypto_int32 A = a;
                    164:   crypto_int32 B = b;
                    165:   crypto_int32 C = c;
                    166:   return modq_freeze(A + B * C);
                    167: }
                    168:
                    169: static inline modq modq_product(modq a,modq b)
                    170: {
                    171:   crypto_int32 A = a;
                    172:   crypto_int32 B = b;
                    173:   return modq_freeze(A * B);
                    174: }
                    175:
                    176: static inline modq modq_square(modq a)
                    177: {
                    178:   crypto_int32 A = a;
                    179:   return modq_freeze(A * A);
                    180: }
                    181:
                    182: static inline modq modq_sum(modq a,modq b)
                    183: {
                    184:   crypto_int32 A = a;
                    185:   crypto_int32 B = b;
                    186:   return modq_freeze(A + B);
                    187: }
                    188:
                    189: static inline modq modq_reciprocal(modq a1)
                    190: {
                    191:   modq a2 = modq_square(a1);
                    192:   modq a3 = modq_product(a2,a1);
                    193:   modq a4 = modq_square(a2);
                    194:   modq a8 = modq_square(a4);
                    195:   modq a16 = modq_square(a8);
                    196:   modq a32 = modq_square(a16);
                    197:   modq a35 = modq_product(a32,a3);
                    198:   modq a70 = modq_square(a35);
                    199:   modq a140 = modq_square(a70);
                    200:   modq a143 = modq_product(a140,a3);
                    201:   modq a286 = modq_square(a143);
                    202:   modq a572 = modq_square(a286);
                    203:   modq a1144 = modq_square(a572);
                    204:   modq a1147 = modq_product(a1144,a3);
                    205:   modq a2294 = modq_square(a1147);
                    206:   modq a4588 = modq_square(a2294);
                    207:   modq a4589 = modq_product(a4588,a1);
                    208:   return a4589;
                    209: }
                    210:
                    211: static inline modq modq_quotient(modq num,modq den)
                    212: {
                    213:   return modq_product(num,modq_reciprocal(den));
                    214: }
                    215:
                    216: #endif
                    217:
1.2     ! djm       218: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/params.h */
1.1       djm       219: #ifndef params_h
                    220: #define params_h
                    221:
                    222: #define q 4591
                    223: /* XXX: also built into modq in various ways */
                    224:
                    225: #define qshift 2295
                    226: #define p 761
                    227: #define w 286
                    228:
                    229: #define rq_encode_len 1218
                    230: #define small_encode_len 191
                    231:
                    232: #endif
                    233:
1.2     ! djm       234: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/r3.h */
1.1       djm       235: #ifndef r3_h
                    236: #define r3_h
                    237:
                    238:
                    239: static void r3_mult(small *,const small *,const small *);
                    240:
                    241: extern int r3_recip(small *,const small *);
                    242:
                    243: #endif
                    244:
1.2     ! djm       245: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/rq.h */
1.1       djm       246: #ifndef rq_h
                    247: #define rq_h
                    248:
                    249:
                    250: static void rq_encode(unsigned char *,const modq *);
                    251:
                    252: static void rq_decode(modq *,const unsigned char *);
                    253:
                    254: static void rq_encoderounded(unsigned char *,const modq *);
                    255:
                    256: static void rq_decoderounded(modq *,const unsigned char *);
                    257:
                    258: static void rq_round3(modq *,const modq *);
                    259:
                    260: static void rq_mult(modq *,const modq *,const small *);
                    261:
                    262: int rq_recip3(modq *,const small *);
                    263:
                    264: #endif
                    265:
1.2     ! djm       266: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/swap.h */
1.1       djm       267: #ifndef swap_h
                    268: #define swap_h
                    269:
                    270: static void swap(void *,void *,int,int);
                    271:
                    272: #endif
                    273:
1.2     ! djm       274: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/dec.c */
1.1       djm       275: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    276:
                    277: #ifdef KAT
                    278: #endif
                    279:
                    280:
                    281: int crypto_kem_sntrup4591761_dec(
                    282:   unsigned char *k,
                    283:   const unsigned char *cstr,
                    284:   const unsigned char *sk
                    285: )
                    286: {
                    287:   small f[p];
                    288:   modq h[p];
                    289:   small grecip[p];
                    290:   modq c[p];
                    291:   modq t[p];
                    292:   small t3[p];
                    293:   small r[p];
                    294:   modq hr[p];
                    295:   unsigned char rstr[small_encode_len];
                    296:   unsigned char hash[64];
                    297:   int i;
                    298:   int result = 0;
                    299:   int weight;
                    300:
                    301:   small_decode(f,sk);
                    302:   small_decode(grecip,sk + small_encode_len);
                    303:   rq_decode(h,sk + 2 * small_encode_len);
                    304:
                    305:   rq_decoderounded(c,cstr + 32);
                    306:
                    307:   rq_mult(t,c,f);
                    308:   for (i = 0;i < p;++i) t3[i] = mod3_freeze(modq_freeze(3*t[i]));
                    309:
                    310:   r3_mult(r,t3,grecip);
                    311:
                    312: #ifdef KAT
                    313:   {
                    314:     int j;
                    315:     printf("decrypt r:");
                    316:     for (j = 0;j < p;++j)
                    317:       if (r[j] == 1) printf(" +%d",j);
                    318:       else if (r[j] == -1) printf(" -%d",j);
                    319:     printf("\n");
                    320:   }
                    321: #endif
                    322:
                    323:   weight = 0;
                    324:   for (i = 0;i < p;++i) weight += (1 & r[i]);
                    325:   weight -= w;
                    326:   result |= modq_nonzero_mask(weight); /* XXX: puts limit on p */
                    327:
                    328:   rq_mult(hr,h,r);
                    329:   rq_round3(hr,hr);
                    330:   for (i = 0;i < p;++i) result |= modq_nonzero_mask(hr[i] - c[i]);
                    331:
                    332:   small_encode(rstr,r);
                    333:   crypto_hash_sha512(hash,rstr,sizeof rstr);
                    334:   result |= crypto_verify_32(hash,cstr);
                    335:
                    336:   for (i = 0;i < 32;++i) k[i] = (hash[32 + i] & ~result);
                    337:   return result;
                    338: }
                    339:
1.2     ! djm       340: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/enc.c */
1.1       djm       341: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    342:
                    343: #ifdef KAT
                    344: #endif
                    345:
                    346:
                    347: int crypto_kem_sntrup4591761_enc(
                    348:   unsigned char *cstr,
                    349:   unsigned char *k,
                    350:   const unsigned char *pk
                    351: )
                    352: {
                    353:   small r[p];
                    354:   modq h[p];
                    355:   modq c[p];
                    356:   unsigned char rstr[small_encode_len];
                    357:   unsigned char hash[64];
                    358:
                    359:   small_random_weightw(r);
                    360:
                    361: #ifdef KAT
                    362:   {
                    363:     int i;
                    364:     printf("encrypt r:");
                    365:     for (i = 0;i < p;++i)
                    366:       if (r[i] == 1) printf(" +%d",i);
                    367:       else if (r[i] == -1) printf(" -%d",i);
                    368:     printf("\n");
                    369:   }
                    370: #endif
                    371:
                    372:   small_encode(rstr,r);
                    373:   crypto_hash_sha512(hash,rstr,sizeof rstr);
                    374:
                    375:   rq_decode(h,pk);
                    376:   rq_mult(c,h,r);
                    377:   rq_round3(c,c);
                    378:
                    379:   memcpy(k,hash + 32,32);
                    380:   memcpy(cstr,hash,32);
                    381:   rq_encoderounded(cstr + 32,c);
                    382:
                    383:   return 0;
                    384: }
                    385:
1.2     ! djm       386: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/keypair.c */
1.1       djm       387: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    388:
                    389:
                    390: #if crypto_kem_sntrup4591761_PUBLICKEYBYTES != rq_encode_len
                    391: #error "crypto_kem_sntrup4591761_PUBLICKEYBYTES must match rq_encode_len"
                    392: #endif
                    393: #if crypto_kem_sntrup4591761_SECRETKEYBYTES != rq_encode_len + 2 * small_encode_len
                    394: #error "crypto_kem_sntrup4591761_SECRETKEYBYTES must match rq_encode_len + 2 * small_encode_len"
                    395: #endif
                    396:
                    397: int crypto_kem_sntrup4591761_keypair(unsigned char *pk,unsigned char *sk)
                    398: {
                    399:   small g[p];
                    400:   small grecip[p];
                    401:   small f[p];
                    402:   modq f3recip[p];
                    403:   modq h[p];
                    404:
                    405:   do
                    406:     small_random(g);
                    407:   while (r3_recip(grecip,g) != 0);
                    408:
                    409:   small_random_weightw(f);
                    410:   rq_recip3(f3recip,f);
                    411:
                    412:   rq_mult(h,f3recip,g);
                    413:
                    414:   rq_encode(pk,h);
                    415:   small_encode(sk,f);
                    416:   small_encode(sk + small_encode_len,grecip);
                    417:   memcpy(sk + 2 * small_encode_len,pk,rq_encode_len);
                    418:
                    419:   return 0;
                    420: }
                    421:
1.2     ! djm       422: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/r3_mult.c */
1.1       djm       423: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    424:
                    425:
                    426: static void r3_mult(small *h,const small *f,const small *g)
                    427: {
                    428:   small fg[p + p - 1];
                    429:   small result;
                    430:   int i, j;
                    431:
                    432:   for (i = 0;i < p;++i) {
                    433:     result = 0;
                    434:     for (j = 0;j <= i;++j)
                    435:       result = mod3_plusproduct(result,f[j],g[i - j]);
                    436:     fg[i] = result;
                    437:   }
                    438:   for (i = p;i < p + p - 1;++i) {
                    439:     result = 0;
                    440:     for (j = i - p + 1;j < p;++j)
                    441:       result = mod3_plusproduct(result,f[j],g[i - j]);
                    442:     fg[i] = result;
                    443:   }
                    444:
                    445:   for (i = p + p - 2;i >= p;--i) {
                    446:     fg[i - p] = mod3_sum(fg[i - p],fg[i]);
                    447:     fg[i - p + 1] = mod3_sum(fg[i - p + 1],fg[i]);
                    448:   }
                    449:
                    450:   for (i = 0;i < p;++i)
                    451:     h[i] = fg[i];
                    452: }
                    453:
1.2     ! djm       454: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/r3_recip.c */
1.1       djm       455: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    456:
                    457:
                    458: /* caller must ensure that x-y does not overflow */
                    459: static int smaller_mask_r3_recip(int x,int y)
                    460: {
                    461:   return (x - y) >> 31;
                    462: }
                    463:
                    464: static void vectormod3_product(small *z,int len,const small *x,const small c)
                    465: {
                    466:   int i;
                    467:   for (i = 0;i < len;++i) z[i] = mod3_product(x[i],c);
                    468: }
                    469:
                    470: static void vectormod3_minusproduct(small *z,int len,const small *x,const small *y,const small c)
                    471: {
                    472:   int i;
                    473:   for (i = 0;i < len;++i) z[i] = mod3_minusproduct(x[i],y[i],c);
                    474: }
                    475:
                    476: static void vectormod3_shift(small *z,int len)
                    477: {
                    478:   int i;
                    479:   for (i = len - 1;i > 0;--i) z[i] = z[i - 1];
                    480:   z[0] = 0;
                    481: }
                    482:
                    483: /*
                    484: r = s^(-1) mod m, returning 0, if s is invertible mod m
                    485: or returning -1 if s is not invertible mod m
                    486: r,s are polys of degree <p
                    487: m is x^p-x-1
                    488: */
                    489: int r3_recip(small *r,const small *s)
                    490: {
                    491:   const int loops = 2*p + 1;
                    492:   int loop;
                    493:   small f[p + 1];
                    494:   small g[p + 1];
                    495:   small u[loops + 1];
                    496:   small v[loops + 1];
                    497:   small c;
                    498:   int i;
                    499:   int d = p;
                    500:   int e = p;
                    501:   int swapmask;
                    502:
                    503:   for (i = 2;i < p;++i) f[i] = 0;
                    504:   f[0] = -1;
                    505:   f[1] = -1;
                    506:   f[p] = 1;
                    507:   /* generalization: can initialize f to any polynomial m */
                    508:   /* requirements: m has degree exactly p, nonzero constant coefficient */
                    509:
                    510:   for (i = 0;i < p;++i) g[i] = s[i];
                    511:   g[p] = 0;
                    512:
                    513:   for (i = 0;i <= loops;++i) u[i] = 0;
                    514:
                    515:   v[0] = 1;
                    516:   for (i = 1;i <= loops;++i) v[i] = 0;
                    517:
                    518:   loop = 0;
                    519:   for (;;) {
                    520:     /* e == -1 or d + e + loop <= 2*p */
                    521:
                    522:     /* f has degree p: i.e., f[p]!=0 */
                    523:     /* f[i]==0 for i < p-d */
                    524:
                    525:     /* g has degree <=p (so it fits in p+1 coefficients) */
                    526:     /* g[i]==0 for i < p-e */
                    527:
                    528:     /* u has degree <=loop (so it fits in loop+1 coefficients) */
                    529:     /* u[i]==0 for i < p-d */
                    530:     /* if invertible: u[i]==0 for i < loop-p (so can look at just p+1 coefficients) */
                    531:
                    532:     /* v has degree <=loop (so it fits in loop+1 coefficients) */
                    533:     /* v[i]==0 for i < p-e */
                    534:     /* v[i]==0 for i < loop-p (so can look at just p+1 coefficients) */
                    535:
                    536:     if (loop >= loops) break;
                    537:
                    538:     c = mod3_quotient(g[p],f[p]);
                    539:
                    540:     vectormod3_minusproduct(g,p + 1,g,f,c);
                    541:     vectormod3_shift(g,p + 1);
                    542:
                    543: #ifdef SIMPLER
                    544:     vectormod3_minusproduct(v,loops + 1,v,u,c);
                    545:     vectormod3_shift(v,loops + 1);
                    546: #else
                    547:     if (loop < p) {
                    548:       vectormod3_minusproduct(v,loop + 1,v,u,c);
                    549:       vectormod3_shift(v,loop + 2);
                    550:     } else {
                    551:       vectormod3_minusproduct(v + loop - p,p + 1,v + loop - p,u + loop - p,c);
                    552:       vectormod3_shift(v + loop - p,p + 2);
                    553:     }
                    554: #endif
                    555:
                    556:     e -= 1;
                    557:
                    558:     ++loop;
                    559:
                    560:     swapmask = smaller_mask_r3_recip(e,d) & mod3_nonzero_mask(g[p]);
                    561:     swap(&e,&d,sizeof e,swapmask);
                    562:     swap(f,g,(p + 1) * sizeof(small),swapmask);
                    563:
                    564: #ifdef SIMPLER
                    565:     swap(u,v,(loops + 1) * sizeof(small),swapmask);
                    566: #else
                    567:     if (loop < p) {
                    568:       swap(u,v,(loop + 1) * sizeof(small),swapmask);
                    569:     } else {
                    570:       swap(u + loop - p,v + loop - p,(p + 1) * sizeof(small),swapmask);
                    571:     }
                    572: #endif
                    573:   }
                    574:
                    575:   c = mod3_reciprocal(f[p]);
                    576:   vectormod3_product(r,p,u + p,c);
                    577:   return smaller_mask_r3_recip(0,d);
                    578: }
                    579:
1.2     ! djm       580: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/randomsmall.c */
1.1       djm       581: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    582:
                    583:
                    584: static void small_random(small *g)
                    585: {
                    586:   int i;
                    587:
                    588:   for (i = 0;i < p;++i) {
                    589:     crypto_uint32 r = small_random32();
                    590:     g[i] = (small) (((1073741823 & r) * 3) >> 30) - 1;
                    591:   }
                    592: }
                    593:
1.2     ! djm       594: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/randomweightw.c */
1.1       djm       595: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    596:
                    597:
                    598: static void small_random_weightw(small *f)
                    599: {
                    600:   crypto_int32 r[p];
                    601:   int i;
                    602:
                    603:   for (i = 0;i < p;++i) r[i] = small_random32();
                    604:   for (i = 0;i < w;++i) r[i] &= -2;
                    605:   for (i = w;i < p;++i) r[i] = (r[i] & -3) | 1;
1.2     ! djm       606:   int32_sort(r,p);
1.1       djm       607:   for (i = 0;i < p;++i) f[i] = ((small) (r[i] & 3)) - 1;
                    608: }
                    609:
1.2     ! djm       610: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/rq.c */
1.1       djm       611: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    612:
                    613:
                    614: static void rq_encode(unsigned char *c,const modq *f)
                    615: {
                    616:   crypto_int32 f0, f1, f2, f3, f4;
                    617:   int i;
                    618:
                    619:   for (i = 0;i < p/5;++i) {
                    620:     f0 = *f++ + qshift;
                    621:     f1 = *f++ + qshift;
                    622:     f2 = *f++ + qshift;
                    623:     f3 = *f++ + qshift;
                    624:     f4 = *f++ + qshift;
                    625:     /* now want f0 + 6144*f1 + ... as a 64-bit integer */
                    626:     f1 *= 3;
                    627:     f2 *= 9;
                    628:     f3 *= 27;
                    629:     f4 *= 81;
                    630:     /* now want f0 + f1<<11 + f2<<22 + f3<<33 + f4<<44 */
                    631:     f0 += f1 << 11;
                    632:     *c++ = f0; f0 >>= 8;
                    633:     *c++ = f0; f0 >>= 8;
                    634:     f0 += f2 << 6;
                    635:     *c++ = f0; f0 >>= 8;
                    636:     *c++ = f0; f0 >>= 8;
                    637:     f0 += f3 << 1;
                    638:     *c++ = f0; f0 >>= 8;
                    639:     f0 += f4 << 4;
                    640:     *c++ = f0; f0 >>= 8;
                    641:     *c++ = f0; f0 >>= 8;
                    642:     *c++ = f0;
                    643:   }
                    644:   /* XXX: using p mod 5 = 1 */
                    645:   f0 = *f++ + qshift;
                    646:   *c++ = f0; f0 >>= 8;
                    647:   *c++ = f0;
                    648: }
                    649:
                    650: static void rq_decode(modq *f,const unsigned char *c)
                    651: {
                    652:   crypto_uint32 c0, c1, c2, c3, c4, c5, c6, c7;
                    653:   crypto_uint32 f0, f1, f2, f3, f4;
                    654:   int i;
                    655:
                    656:   for (i = 0;i < p/5;++i) {
                    657:     c0 = *c++;
                    658:     c1 = *c++;
                    659:     c2 = *c++;
                    660:     c3 = *c++;
                    661:     c4 = *c++;
                    662:     c5 = *c++;
                    663:     c6 = *c++;
                    664:     c7 = *c++;
                    665:
                    666:     /* f0 + f1*6144 + f2*6144^2 + f3*6144^3 + f4*6144^4 */
                    667:     /* = c0 + c1*256 + ... + c6*256^6 + c7*256^7 */
                    668:     /* with each f between 0 and 4590 */
                    669:
                    670:     c6 += c7 << 8;
                    671:     /* c6 <= 23241 = floor(4591*6144^4/2^48) */
                    672:     /* f4 = (16/81)c6 + (1/1296)(c5+[0,1]) - [0,0.75] */
                    673:     /* claim: 2^19 f4 < x < 2^19(f4+1) */
                    674:     /* where x = 103564 c6 + 405(c5+1) */
                    675:     /* proof: x - 2^19 f4 = (76/81)c6 + (37/81)c5 + 405 - (32768/81)[0,1] + 2^19[0,0.75] */
                    676:     /* at least 405 - 32768/81 > 0 */
                    677:     /* at most (76/81)23241 + (37/81)255 + 405 + 2^19 0.75 < 2^19 */
                    678:     f4 = (103564*c6 + 405*(c5+1)) >> 19;
                    679:
                    680:     c5 += c6 << 8;
                    681:     c5 -= (f4 * 81) << 4;
                    682:     c4 += c5 << 8;
                    683:
                    684:     /* f0 + f1*6144 + f2*6144^2 + f3*6144^3 */
                    685:     /* = c0 + c1*256 + c2*256^2 + c3*256^3 + c4*256^4 */
                    686:     /* c4 <= 247914 = floor(4591*6144^3/2^32) */
                    687:     /* f3 = (1/54)(c4+[0,1]) - [0,0.75] */
                    688:     /* claim: 2^19 f3 < x < 2^19(f3+1) */
                    689:     /* where x = 9709(c4+2) */
                    690:     /* proof: x - 2^19 f3 = 19418 - (1/27)c4 - (262144/27)[0,1] + 2^19[0,0.75] */
                    691:     /* at least 19418 - 247914/27 - 262144/27 > 0 */
                    692:     /* at most 19418 + 2^19 0.75 < 2^19 */
                    693:     f3 = (9709*(c4+2)) >> 19;
                    694:
                    695:     c4 -= (f3 * 27) << 1;
                    696:     c3 += c4 << 8;
                    697:     /* f0 + f1*6144 + f2*6144^2 */
                    698:     /* = c0 + c1*256 + c2*256^2 + c3*256^3 */
                    699:     /* c3 <= 10329 = floor(4591*6144^2/2^24) */
                    700:     /* f2 = (4/9)c3 + (1/576)c2 + (1/147456)c1 + (1/37748736)c0 - [0,0.75] */
                    701:     /* claim: 2^19 f2 < x < 2^19(f2+1) */
                    702:     /* where x = 233017 c3 + 910(c2+2) */
                    703:     /* proof: x - 2^19 f2 = 1820 + (1/9)c3 - (2/9)c2 - (32/9)c1 - (1/72)c0 + 2^19[0,0.75] */
                    704:     /* at least 1820 - (2/9)255 - (32/9)255 - (1/72)255 > 0 */
                    705:     /* at most 1820 + (1/9)10329 + 2^19 0.75 < 2^19 */
                    706:     f2 = (233017*c3 + 910*(c2+2)) >> 19;
                    707:
                    708:     c2 += c3 << 8;
                    709:     c2 -= (f2 * 9) << 6;
                    710:     c1 += c2 << 8;
                    711:     /* f0 + f1*6144 */
                    712:     /* = c0 + c1*256 */
                    713:     /* c1 <= 110184 = floor(4591*6144/2^8) */
                    714:     /* f1 = (1/24)c1 + (1/6144)c0 - (1/6144)f0 */
                    715:     /* claim: 2^19 f1 < x < 2^19(f1+1) */
                    716:     /* where x = 21845(c1+2) + 85 c0 */
                    717:     /* proof: x - 2^19 f1 = 43690 - (1/3)c1 - (1/3)c0 + 2^19 [0,0.75] */
                    718:     /* at least 43690 - (1/3)110184 - (1/3)255 > 0 */
                    719:     /* at most 43690 + 2^19 0.75 < 2^19 */
                    720:     f1 = (21845*(c1+2) + 85*c0) >> 19;
                    721:
                    722:     c1 -= (f1 * 3) << 3;
                    723:     c0 += c1 << 8;
                    724:     f0 = c0;
                    725:
                    726:     *f++ = modq_freeze(f0 + q - qshift);
                    727:     *f++ = modq_freeze(f1 + q - qshift);
                    728:     *f++ = modq_freeze(f2 + q - qshift);
                    729:     *f++ = modq_freeze(f3 + q - qshift);
                    730:     *f++ = modq_freeze(f4 + q - qshift);
                    731:   }
                    732:
                    733:   c0 = *c++;
                    734:   c1 = *c++;
                    735:   c0 += c1 << 8;
                    736:   *f++ = modq_freeze(c0 + q - qshift);
                    737: }
                    738:
1.2     ! djm       739: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/rq_mult.c */
1.1       djm       740: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    741:
                    742:
                    743: static void rq_mult(modq *h,const modq *f,const small *g)
                    744: {
                    745:   modq fg[p + p - 1];
                    746:   modq result;
                    747:   int i, j;
                    748:
                    749:   for (i = 0;i < p;++i) {
                    750:     result = 0;
                    751:     for (j = 0;j <= i;++j)
                    752:       result = modq_plusproduct(result,f[j],g[i - j]);
                    753:     fg[i] = result;
                    754:   }
                    755:   for (i = p;i < p + p - 1;++i) {
                    756:     result = 0;
                    757:     for (j = i - p + 1;j < p;++j)
                    758:       result = modq_plusproduct(result,f[j],g[i - j]);
                    759:     fg[i] = result;
                    760:   }
                    761:
                    762:   for (i = p + p - 2;i >= p;--i) {
                    763:     fg[i - p] = modq_sum(fg[i - p],fg[i]);
                    764:     fg[i - p + 1] = modq_sum(fg[i - p + 1],fg[i]);
                    765:   }
                    766:
                    767:   for (i = 0;i < p;++i)
                    768:     h[i] = fg[i];
                    769: }
                    770:
1.2     ! djm       771: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/rq_recip3.c */
1.1       djm       772: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    773:
                    774:
                    775: /* caller must ensure that x-y does not overflow */
                    776: static int smaller_mask_rq_recip3(int x,int y)
                    777: {
                    778:   return (x - y) >> 31;
                    779: }
                    780:
                    781: static void vectormodq_product(modq *z,int len,const modq *x,const modq c)
                    782: {
                    783:   int i;
                    784:   for (i = 0;i < len;++i) z[i] = modq_product(x[i],c);
                    785: }
                    786:
                    787: static void vectormodq_minusproduct(modq *z,int len,const modq *x,const modq *y,const modq c)
                    788: {
                    789:   int i;
                    790:   for (i = 0;i < len;++i) z[i] = modq_minusproduct(x[i],y[i],c);
                    791: }
                    792:
                    793: static void vectormodq_shift(modq *z,int len)
                    794: {
                    795:   int i;
                    796:   for (i = len - 1;i > 0;--i) z[i] = z[i - 1];
                    797:   z[0] = 0;
                    798: }
                    799:
                    800: /*
                    801: r = (3s)^(-1) mod m, returning 0, if s is invertible mod m
                    802: or returning -1 if s is not invertible mod m
                    803: r,s are polys of degree <p
                    804: m is x^p-x-1
                    805: */
                    806: int rq_recip3(modq *r,const small *s)
                    807: {
                    808:   const int loops = 2*p + 1;
                    809:   int loop;
                    810:   modq f[p + 1];
                    811:   modq g[p + 1];
                    812:   modq u[loops + 1];
                    813:   modq v[loops + 1];
                    814:   modq c;
                    815:   int i;
                    816:   int d = p;
                    817:   int e = p;
                    818:   int swapmask;
                    819:
                    820:   for (i = 2;i < p;++i) f[i] = 0;
                    821:   f[0] = -1;
                    822:   f[1] = -1;
                    823:   f[p] = 1;
                    824:   /* generalization: can initialize f to any polynomial m */
                    825:   /* requirements: m has degree exactly p, nonzero constant coefficient */
                    826:
                    827:   for (i = 0;i < p;++i) g[i] = 3 * s[i];
                    828:   g[p] = 0;
                    829:
                    830:   for (i = 0;i <= loops;++i) u[i] = 0;
                    831:
                    832:   v[0] = 1;
                    833:   for (i = 1;i <= loops;++i) v[i] = 0;
                    834:
                    835:   loop = 0;
                    836:   for (;;) {
                    837:     /* e == -1 or d + e + loop <= 2*p */
                    838:
                    839:     /* f has degree p: i.e., f[p]!=0 */
                    840:     /* f[i]==0 for i < p-d */
                    841:
                    842:     /* g has degree <=p (so it fits in p+1 coefficients) */
                    843:     /* g[i]==0 for i < p-e */
                    844:
                    845:     /* u has degree <=loop (so it fits in loop+1 coefficients) */
                    846:     /* u[i]==0 for i < p-d */
                    847:     /* if invertible: u[i]==0 for i < loop-p (so can look at just p+1 coefficients) */
                    848:
                    849:     /* v has degree <=loop (so it fits in loop+1 coefficients) */
                    850:     /* v[i]==0 for i < p-e */
                    851:     /* v[i]==0 for i < loop-p (so can look at just p+1 coefficients) */
                    852:
                    853:     if (loop >= loops) break;
                    854:
                    855:     c = modq_quotient(g[p],f[p]);
                    856:
                    857:     vectormodq_minusproduct(g,p + 1,g,f,c);
                    858:     vectormodq_shift(g,p + 1);
                    859:
                    860: #ifdef SIMPLER
                    861:     vectormodq_minusproduct(v,loops + 1,v,u,c);
                    862:     vectormodq_shift(v,loops + 1);
                    863: #else
                    864:     if (loop < p) {
                    865:       vectormodq_minusproduct(v,loop + 1,v,u,c);
                    866:       vectormodq_shift(v,loop + 2);
                    867:     } else {
                    868:       vectormodq_minusproduct(v + loop - p,p + 1,v + loop - p,u + loop - p,c);
                    869:       vectormodq_shift(v + loop - p,p + 2);
                    870:     }
                    871: #endif
                    872:
                    873:     e -= 1;
                    874:
                    875:     ++loop;
                    876:
                    877:     swapmask = smaller_mask_rq_recip3(e,d) & modq_nonzero_mask(g[p]);
                    878:     swap(&e,&d,sizeof e,swapmask);
                    879:     swap(f,g,(p + 1) * sizeof(modq),swapmask);
                    880:
                    881: #ifdef SIMPLER
                    882:     swap(u,v,(loops + 1) * sizeof(modq),swapmask);
                    883: #else
                    884:     if (loop < p) {
                    885:       swap(u,v,(loop + 1) * sizeof(modq),swapmask);
                    886:     } else {
                    887:       swap(u + loop - p,v + loop - p,(p + 1) * sizeof(modq),swapmask);
                    888:     }
                    889: #endif
                    890:   }
                    891:
                    892:   c = modq_reciprocal(f[p]);
                    893:   vectormodq_product(r,p,u + p,c);
                    894:   return smaller_mask_rq_recip3(0,d);
                    895: }
                    896:
1.2     ! djm       897: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/rq_round3.c */
1.1       djm       898: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    899:
                    900:
                    901: static void rq_round3(modq *h,const modq *f)
                    902: {
                    903:   int i;
                    904:
                    905:   for (i = 0;i < p;++i)
                    906:     h[i] = ((21846 * (f[i] + 2295) + 32768) >> 16) * 3 - 2295;
                    907: }
                    908:
1.2     ! djm       909: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/rq_rounded.c */
1.1       djm       910: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                    911:
                    912:
                    913: static void rq_encoderounded(unsigned char *c,const modq *f)
                    914: {
                    915:   crypto_int32 f0, f1, f2;
                    916:   int i;
                    917:
                    918:   for (i = 0;i < p/3;++i) {
                    919:     f0 = *f++ + qshift;
                    920:     f1 = *f++ + qshift;
                    921:     f2 = *f++ + qshift;
                    922:     f0 = (21846 * f0) >> 16;
                    923:     f1 = (21846 * f1) >> 16;
                    924:     f2 = (21846 * f2) >> 16;
                    925:     /* now want f0 + f1*1536 + f2*1536^2 as a 32-bit integer */
                    926:     f2 *= 3;
                    927:     f1 += f2 << 9;
                    928:     f1 *= 3;
                    929:     f0 += f1 << 9;
                    930:     *c++ = f0; f0 >>= 8;
                    931:     *c++ = f0; f0 >>= 8;
                    932:     *c++ = f0; f0 >>= 8;
                    933:     *c++ = f0;
                    934:   }
                    935:   /* XXX: using p mod 3 = 2 */
                    936:   f0 = *f++ + qshift;
                    937:   f1 = *f++ + qshift;
                    938:   f0 = (21846 * f0) >> 16;
                    939:   f1 = (21846 * f1) >> 16;
                    940:   f1 *= 3;
                    941:   f0 += f1 << 9;
                    942:   *c++ = f0; f0 >>= 8;
                    943:   *c++ = f0; f0 >>= 8;
                    944:   *c++ = f0;
                    945: }
                    946:
                    947: static void rq_decoderounded(modq *f,const unsigned char *c)
                    948: {
                    949:   crypto_uint32 c0, c1, c2, c3;
                    950:   crypto_uint32 f0, f1, f2;
                    951:   int i;
                    952:
                    953:   for (i = 0;i < p/3;++i) {
                    954:     c0 = *c++;
                    955:     c1 = *c++;
                    956:     c2 = *c++;
                    957:     c3 = *c++;
                    958:
                    959:     /* f0 + f1*1536 + f2*1536^2 */
                    960:     /* = c0 + c1*256 + c2*256^2 + c3*256^3 */
                    961:     /* with each f between 0 and 1530 */
                    962:
                    963:     /* f2 = (64/9)c3 + (1/36)c2 + (1/9216)c1 + (1/2359296)c0 - [0,0.99675] */
                    964:     /* claim: 2^21 f2 < x < 2^21(f2+1) */
                    965:     /* where x = 14913081*c3 + 58254*c2 + 228*(c1+2) */
                    966:     /* proof: x - 2^21 f2 = 456 - (8/9)c0 + (4/9)c1 - (2/9)c2 + (1/9)c3 + 2^21 [0,0.99675] */
                    967:     /* at least 456 - (8/9)255 - (2/9)255 > 0 */
                    968:     /* at most 456 + (4/9)255 + (1/9)255 + 2^21 0.99675 < 2^21 */
                    969:     f2 = (14913081*c3 + 58254*c2 + 228*(c1+2)) >> 21;
                    970:
                    971:     c2 += c3 << 8;
                    972:     c2 -= (f2 * 9) << 2;
                    973:     /* f0 + f1*1536 */
                    974:     /* = c0 + c1*256 + c2*256^2 */
                    975:     /* c2 <= 35 = floor((1530+1530*1536)/256^2) */
                    976:     /* f1 = (128/3)c2 + (1/6)c1 + (1/1536)c0 - (1/1536)f0 */
                    977:     /* claim: 2^21 f1 < x < 2^21(f1+1) */
                    978:     /* where x = 89478485*c2 + 349525*c1 + 1365*(c0+1) */
                    979:     /* proof: x - 2^21 f1 = 1365 - (1/3)c2 - (1/3)c1 - (1/3)c0 + (4096/3)f0 */
                    980:     /* at least 1365 - (1/3)35 - (1/3)255 - (1/3)255 > 0 */
                    981:     /* at most 1365 + (4096/3)1530 < 2^21 */
                    982:     f1 = (89478485*c2 + 349525*c1 + 1365*(c0+1)) >> 21;
                    983:
                    984:     c1 += c2 << 8;
                    985:     c1 -= (f1 * 3) << 1;
                    986:
                    987:     c0 += c1 << 8;
                    988:     f0 = c0;
                    989:
                    990:     *f++ = modq_freeze(f0 * 3 + q - qshift);
                    991:     *f++ = modq_freeze(f1 * 3 + q - qshift);
                    992:     *f++ = modq_freeze(f2 * 3 + q - qshift);
                    993:   }
                    994:
                    995:   c0 = *c++;
                    996:   c1 = *c++;
                    997:   c2 = *c++;
                    998:
                    999:   f1 = (89478485*c2 + 349525*c1 + 1365*(c0+1)) >> 21;
                   1000:
                   1001:   c1 += c2 << 8;
                   1002:   c1 -= (f1 * 3) << 1;
                   1003:
                   1004:   c0 += c1 << 8;
                   1005:   f0 = c0;
                   1006:
                   1007:   *f++ = modq_freeze(f0 * 3 + q - qshift);
                   1008:   *f++ = modq_freeze(f1 * 3 + q - qshift);
                   1009: }
                   1010:
1.2     ! djm      1011: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/small.c */
1.1       djm      1012: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                   1013:
                   1014:
                   1015: /* XXX: these functions rely on p mod 4 = 1 */
                   1016:
                   1017: /* all coefficients in -1, 0, 1 */
                   1018: static void small_encode(unsigned char *c,const small *f)
                   1019: {
                   1020:   small c0;
                   1021:   int i;
                   1022:
                   1023:   for (i = 0;i < p/4;++i) {
                   1024:     c0 = *f++ + 1;
                   1025:     c0 += (*f++ + 1) << 2;
                   1026:     c0 += (*f++ + 1) << 4;
                   1027:     c0 += (*f++ + 1) << 6;
                   1028:     *c++ = c0;
                   1029:   }
                   1030:   c0 = *f++ + 1;
                   1031:   *c++ = c0;
                   1032: }
                   1033:
                   1034: static void small_decode(small *f,const unsigned char *c)
                   1035: {
                   1036:   unsigned char c0;
                   1037:   int i;
                   1038:
                   1039:   for (i = 0;i < p/4;++i) {
                   1040:     c0 = *c++;
                   1041:     *f++ = ((small) (c0 & 3)) - 1; c0 >>= 2;
                   1042:     *f++ = ((small) (c0 & 3)) - 1; c0 >>= 2;
                   1043:     *f++ = ((small) (c0 & 3)) - 1; c0 >>= 2;
                   1044:     *f++ = ((small) (c0 & 3)) - 1;
                   1045:   }
                   1046:   c0 = *c++;
                   1047:   *f++ = ((small) (c0 & 3)) - 1;
                   1048: }
                   1049:
1.2     ! djm      1050: /* from libpqcrypto-20180314/crypto_kem/sntrup4591761/ref/swap.c */
1.1       djm      1051: /* See https://ntruprime.cr.yp.to/software.html for detailed documentation. */
                   1052:
                   1053:
                   1054: static void swap(void *x,void *y,int bytes,int mask)
                   1055: {
                   1056:   int i;
                   1057:   char xi, yi, c, t;
                   1058:
                   1059:   c = mask;
                   1060:
                   1061:   for (i = 0;i < bytes;++i) {
                   1062:     xi = i[(char *) x];
                   1063:     yi = i[(char *) y];
                   1064:     t = c & (xi ^ yi);
                   1065:     xi ^= t;
                   1066:     yi ^= t;
                   1067:     i[(char *) x] = xi;
                   1068:     i[(char *) y] = yi;
                   1069:   }
                   1070: }
                   1071: