[BACK]Return to sntrup761.c CVS log [TXT][DIR] Up to [local] / src / usr.bin / ssh

Annotation of src/usr.bin/ssh/sntrup761.c, Revision 1.2

1.2     ! tobhe       1: /*  $OpenBSD: sntrup761.c,v 1.1 2020/12/29 00:59:15 djm Exp $ */
1.1       djm         2:
                      3: /*
                      4:  * Public Domain, Authors:
                      5:  * - Daniel J. Bernstein
                      6:  * - Chitchanok Chuengsatiansup
                      7:  * - Tanja Lange
                      8:  * - Christine van Vredendaal
                      9:  */
                     10:
                     11: #include <string.h>
                     12: #include "crypto_api.h"
1.2     ! tobhe      13: #include "int32_minmax.inc"
1.1       djm        14:
                     15: #define CRYPTO_NAMESPACE(s) s
                     16:
                     17: /* from supercop-20201130/crypto_sort/int32/portable4/sort.c */
                     18: #define int32 crypto_int32
                     19:
                     20:
                     21: static void crypto_sort_int32(void *array,long long n)
                     22: {
                     23:   long long top,p,q,r,i,j;
                     24:   int32 *x = array;
                     25:
                     26:   if (n < 2) return;
                     27:   top = 1;
                     28:   while (top < n - top) top += top;
                     29:
                     30:   for (p = top;p >= 1;p >>= 1) {
                     31:     i = 0;
                     32:     while (i + 2 * p <= n) {
                     33:       for (j = i;j < i + p;++j)
                     34:         int32_MINMAX(x[j],x[j+p]);
                     35:       i += 2 * p;
                     36:     }
                     37:     for (j = i;j < n - p;++j)
                     38:       int32_MINMAX(x[j],x[j+p]);
                     39:
                     40:     i = 0;
                     41:     j = 0;
                     42:     for (q = top;q > p;q >>= 1) {
                     43:       if (j != i) for (;;) {
                     44:         if (j == n - q) goto done;
                     45:         int32 a = x[j + p];
                     46:         for (r = q;r > p;r >>= 1)
                     47:           int32_MINMAX(a,x[j + r]);
                     48:         x[j + p] = a;
                     49:         ++j;
                     50:         if (j == i + p) {
                     51:           i += 2 * p;
                     52:           break;
                     53:         }
                     54:       }
                     55:       while (i + p <= n - q) {
                     56:         for (j = i;j < i + p;++j) {
                     57:           int32 a = x[j + p];
                     58:           for (r = q;r > p;r >>= 1)
                     59:             int32_MINMAX(a,x[j+r]);
                     60:           x[j + p] = a;
                     61:         }
                     62:         i += 2 * p;
                     63:       }
                     64:       /* now i + p > n - q */
                     65:       j = i;
                     66:       while (j < n - q) {
                     67:         int32 a = x[j + p];
                     68:         for (r = q;r > p;r >>= 1)
                     69:           int32_MINMAX(a,x[j+r]);
                     70:         x[j + p] = a;
                     71:         ++j;
                     72:       }
                     73:
                     74:       done: ;
                     75:     }
                     76:   }
                     77: }
                     78:
                     79: /* from supercop-20201130/crypto_sort/uint32/useint32/sort.c */
                     80:
                     81: /* can save time by vectorizing xor loops */
                     82: /* can save time by integrating xor loops with int32_sort */
                     83:
                     84: static void crypto_sort_uint32(void *array,long long n)
                     85: {
                     86:   crypto_uint32 *x = array;
                     87:   long long j;
                     88:   for (j = 0;j < n;++j) x[j] ^= 0x80000000;
                     89:   crypto_sort_int32(array,n);
                     90:   for (j = 0;j < n;++j) x[j] ^= 0x80000000;
                     91: }
                     92:
                     93: /* from supercop-20201130/crypto_kem/sntrup761/ref/uint64.h */
                     94: #ifndef UINT64_H
                     95: #define UINT64_H
                     96:
                     97:
                     98: typedef uint64_t uint64;
                     99:
                    100: #endif
                    101:
                    102: /* from supercop-20201130/crypto_kem/sntrup761/ref/uint16.h */
                    103: #ifndef UINT16_H
                    104: #define UINT16_H
                    105:
                    106: typedef uint16_t uint16;
                    107:
                    108: #endif
                    109:
                    110: /* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.h */
                    111: #ifndef UINT32_H
                    112: #define UINT32_H
                    113:
                    114: #define uint32_div_uint14 CRYPTO_NAMESPACE(uint32_div_uint14)
                    115: #define uint32_mod_uint14 CRYPTO_NAMESPACE(uint32_mod_uint14)
                    116: #define uint32_divmod_uint14 CRYPTO_NAMESPACE(uint32_divmod_uint14)
                    117:
                    118:
                    119: typedef uint32_t uint32;
                    120:
                    121: /*
                    122: assuming 1 <= m < 16384:
                    123: q = uint32_div_uint14(x,m) means q = x/m
                    124: r = uint32_mod_uint14(x,m) means r = x/m
                    125: uint32_moddiv_uint14(&q,&r,x,m) means q = x/m, r = x%m
                    126: */
                    127:
                    128: extern uint32 uint32_div_uint14(uint32,uint16);
                    129: extern uint16 uint32_mod_uint14(uint32,uint16);
                    130: static void uint32_divmod_uint14(uint32 *,uint16 *,uint32,uint16);
                    131:
                    132: #endif
                    133:
                    134: /* from supercop-20201130/crypto_kem/sntrup761/ref/int8.h */
                    135: #ifndef INT8_H
                    136: #define INT8_H
                    137:
                    138: typedef int8_t int8;
                    139:
                    140: #endif
                    141:
                    142: /* from supercop-20201130/crypto_kem/sntrup761/ref/int16.h */
                    143: #ifndef INT16_H
                    144: #define INT16_H
                    145:
                    146: typedef int16_t int16;
                    147:
                    148: #endif
                    149:
                    150: /* from supercop-20201130/crypto_kem/sntrup761/ref/int32.h */
                    151: #ifndef INT32_H
                    152: #define INT32_H
                    153:
                    154: #define int32_div_uint14 CRYPTO_NAMESPACE(int32_div_uint14)
                    155: #define int32_mod_uint14 CRYPTO_NAMESPACE(int32_mod_uint14)
                    156: #define int32_divmod_uint14 CRYPTO_NAMESPACE(int32_divmod_uint14)
                    157:
                    158:
                    159: typedef int32_t int32;
                    160:
                    161: /*
                    162: assuming 1 <= m < 16384:
                    163: q = int32_div_uint14(x,m) means q = x/m
                    164: r = int32_mod_uint14(x,m) means r = x/m
                    165: int32_moddiv_uint14(&q,&r,x,m) means q = x/m, r = x%m
                    166: */
                    167:
                    168: extern int32 int32_div_uint14(int32,uint16);
                    169: extern uint16 int32_mod_uint14(int32,uint16);
                    170: static void int32_divmod_uint14(int32 *,uint16 *,int32,uint16);
                    171:
                    172: #endif
                    173:
                    174: /* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.c */
                    175:
                    176: /*
                    177: CPU division instruction typically takes time depending on x.
                    178: This software is designed to take time independent of x.
                    179: Time still varies depending on m; user must ensure that m is constant.
                    180: Time also varies on CPUs where multiplication is variable-time.
                    181: There could be more CPU issues.
                    182: There could also be compiler issues.
                    183: */
                    184:
                    185: static void uint32_divmod_uint14(uint32 *q,uint16 *r,uint32 x,uint16 m)
                    186: {
                    187:   uint32 v = 0x80000000;
                    188:   uint32 qpart;
                    189:   uint32 mask;
                    190:
                    191:   v /= m;
                    192:
                    193:   /* caller guarantees m > 0 */
                    194:   /* caller guarantees m < 16384 */
                    195:   /* vm <= 2^31 <= vm+m-1 */
                    196:   /* xvm <= 2^31 x <= xvm+x(m-1) */
                    197:
                    198:   *q = 0;
                    199:
                    200:   qpart = (x*(uint64)v)>>31;
                    201:   /* 2^31 qpart <= xv <= 2^31 qpart + 2^31-1 */
                    202:   /* 2^31 qpart m <= xvm <= 2^31 qpart m + (2^31-1)m */
                    203:   /* 2^31 qpart m <= 2^31 x <= 2^31 qpart m + (2^31-1)m + x(m-1) */
                    204:   /* 0 <= 2^31 newx <= (2^31-1)m + x(m-1) */
                    205:   /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
                    206:   /* 0 <= newx <= (1-1/2^31)(2^14-1) + (2^32-1)((2^14-1)-1)/2^31 */
                    207:
                    208:   x -= qpart*m; *q += qpart;
                    209:   /* x <= 49146 */
                    210:
                    211:   qpart = (x*(uint64)v)>>31;
                    212:   /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
                    213:   /* 0 <= newx <= m + 49146(2^14-1)/2^31 */
                    214:   /* 0 <= newx <= m + 0.4 */
                    215:   /* 0 <= newx <= m */
                    216:
                    217:   x -= qpart*m; *q += qpart;
                    218:   /* x <= m */
                    219:
                    220:   x -= m; *q += 1;
                    221:   mask = -(x>>31);
                    222:   x += mask&(uint32)m; *q += mask;
                    223:   /* x < m */
                    224:
                    225:   *r = x;
                    226: }
                    227:
                    228: uint32 uint32_div_uint14(uint32 x,uint16 m)
                    229: {
                    230:   uint32 q;
                    231:   uint16 r;
                    232:   uint32_divmod_uint14(&q,&r,x,m);
                    233:   return q;
                    234: }
                    235:
                    236: uint16 uint32_mod_uint14(uint32 x,uint16 m)
                    237: {
                    238:   uint32 q;
                    239:   uint16 r;
                    240:   uint32_divmod_uint14(&q,&r,x,m);
                    241:   return r;
                    242: }
                    243:
                    244: /* from supercop-20201130/crypto_kem/sntrup761/ref/int32.c */
                    245:
                    246: static void int32_divmod_uint14(int32 *q,uint16 *r,int32 x,uint16 m)
                    247: {
                    248:   uint32 uq,uq2;
                    249:   uint16 ur,ur2;
                    250:   uint32 mask;
                    251:
                    252:   uint32_divmod_uint14(&uq,&ur,0x80000000+(uint32)x,m);
                    253:   uint32_divmod_uint14(&uq2,&ur2,0x80000000,m);
                    254:   ur -= ur2; uq -= uq2;
                    255:   mask = -(uint32)(ur>>15);
                    256:   ur += mask&m; uq += mask;
                    257:   *r = ur; *q = uq;
                    258: }
                    259:
                    260: int32 int32_div_uint14(int32 x,uint16 m)
                    261: {
                    262:   int32 q;
                    263:   uint16 r;
                    264:   int32_divmod_uint14(&q,&r,x,m);
                    265:   return q;
                    266: }
                    267:
                    268: uint16 int32_mod_uint14(int32 x,uint16 m)
                    269: {
                    270:   int32 q;
                    271:   uint16 r;
                    272:   int32_divmod_uint14(&q,&r,x,m);
                    273:   return r;
                    274: }
                    275:
                    276: /* from supercop-20201130/crypto_kem/sntrup761/ref/paramsmenu.h */
                    277: /* pick one of these three: */
                    278: #define SIZE761
                    279: #undef SIZE653
                    280: #undef SIZE857
                    281:
                    282: /* pick one of these two: */
                    283: #define SNTRUP /* Streamlined NTRU Prime */
                    284: #undef LPR /* NTRU LPRime */
                    285:
                    286: /* from supercop-20201130/crypto_kem/sntrup761/ref/params.h */
                    287: #ifndef params_H
                    288: #define params_H
                    289:
                    290: /* menu of parameter choices: */
                    291:
                    292:
                    293: /* what the menu means: */
                    294:
                    295: #if defined(SIZE761)
                    296: #define p 761
                    297: #define q 4591
                    298: #define Rounded_bytes 1007
                    299: #ifndef LPR
                    300: #define Rq_bytes 1158
                    301: #define w 286
                    302: #else
                    303: #define w 250
                    304: #define tau0 2156
                    305: #define tau1 114
                    306: #define tau2 2007
                    307: #define tau3 287
                    308: #endif
                    309:
                    310: #elif defined(SIZE653)
                    311: #define p 653
                    312: #define q 4621
                    313: #define Rounded_bytes 865
                    314: #ifndef LPR
                    315: #define Rq_bytes 994
                    316: #define w 288
                    317: #else
                    318: #define w 252
                    319: #define tau0 2175
                    320: #define tau1 113
                    321: #define tau2 2031
                    322: #define tau3 290
                    323: #endif
                    324:
                    325: #elif defined(SIZE857)
                    326: #define p 857
                    327: #define q 5167
                    328: #define Rounded_bytes 1152
                    329: #ifndef LPR
                    330: #define Rq_bytes 1322
                    331: #define w 322
                    332: #else
                    333: #define w 281
                    334: #define tau0 2433
                    335: #define tau1 101
                    336: #define tau2 2265
                    337: #define tau3 324
                    338: #endif
                    339:
                    340: #else
                    341: #error "no parameter set defined"
                    342: #endif
                    343:
                    344: #ifdef LPR
                    345: #define I 256
                    346: #endif
                    347:
                    348: #endif
                    349:
                    350: /* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.h */
                    351: #ifndef Decode_H
                    352: #define Decode_H
                    353:
                    354: #define Decode CRYPTO_NAMESPACE(Decode)
                    355:
                    356: /* Decode(R,s,M,len) */
                    357: /* assumes 0 < M[i] < 16384 */
                    358: /* produces 0 <= R[i] < M[i] */
                    359: static void Decode(uint16 *,const unsigned char *,const uint16 *,long long);
                    360:
                    361: #endif
                    362:
                    363: /* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.c */
                    364:
                    365: static void Decode(uint16 *out,const unsigned char *S,const uint16 *M,long long len)
                    366: {
                    367:   if (len == 1) {
                    368:     if (M[0] == 1)
                    369:       *out = 0;
                    370:     else if (M[0] <= 256)
                    371:       *out = uint32_mod_uint14(S[0],M[0]);
                    372:     else
                    373:       *out = uint32_mod_uint14(S[0]+(((uint16)S[1])<<8),M[0]);
                    374:   }
                    375:   if (len > 1) {
                    376:     uint16 R2[(len+1)/2];
                    377:     uint16 M2[(len+1)/2];
                    378:     uint16 bottomr[len/2];
                    379:     uint32 bottomt[len/2];
                    380:     long long i;
                    381:     for (i = 0;i < len-1;i += 2) {
                    382:       uint32 m = M[i]*(uint32) M[i+1];
                    383:       if (m > 256*16383) {
                    384:         bottomt[i/2] = 256*256;
                    385:         bottomr[i/2] = S[0]+256*S[1];
                    386:         S += 2;
                    387:         M2[i/2] = (((m+255)>>8)+255)>>8;
                    388:       } else if (m >= 16384) {
                    389:         bottomt[i/2] = 256;
                    390:         bottomr[i/2] = S[0];
                    391:         S += 1;
                    392:         M2[i/2] = (m+255)>>8;
                    393:       } else {
                    394:         bottomt[i/2] = 1;
                    395:         bottomr[i/2] = 0;
                    396:         M2[i/2] = m;
                    397:       }
                    398:     }
                    399:     if (i < len)
                    400:       M2[i/2] = M[i];
                    401:     Decode(R2,S,M2,(len+1)/2);
                    402:     for (i = 0;i < len-1;i += 2) {
                    403:       uint32 r = bottomr[i/2];
                    404:       uint32 r1;
                    405:       uint16 r0;
                    406:       r += bottomt[i/2]*R2[i/2];
                    407:       uint32_divmod_uint14(&r1,&r0,r,M[i]);
                    408:       r1 = uint32_mod_uint14(r1,M[i+1]); /* only needed for invalid inputs */
                    409:       *out++ = r0;
                    410:       *out++ = r1;
                    411:     }
                    412:     if (i < len)
                    413:       *out++ = R2[i/2];
                    414:   }
                    415: }
                    416:
                    417: /* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.h */
                    418: #ifndef Encode_H
                    419: #define Encode_H
                    420:
                    421: #define Encode CRYPTO_NAMESPACE(Encode)
                    422:
                    423: /* Encode(s,R,M,len) */
                    424: /* assumes 0 <= R[i] < M[i] < 16384 */
                    425: static void Encode(unsigned char *,const uint16 *,const uint16 *,long long);
                    426:
                    427: #endif
                    428:
                    429: /* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.c */
                    430:
                    431: /* 0 <= R[i] < M[i] < 16384 */
                    432: static void Encode(unsigned char *out,const uint16 *R,const uint16 *M,long long len)
                    433: {
                    434:   if (len == 1) {
                    435:     uint16 r = R[0];
                    436:     uint16 m = M[0];
                    437:     while (m > 1) {
                    438:       *out++ = r;
                    439:       r >>= 8;
                    440:       m = (m+255)>>8;
                    441:     }
                    442:   }
                    443:   if (len > 1) {
                    444:     uint16 R2[(len+1)/2];
                    445:     uint16 M2[(len+1)/2];
                    446:     long long i;
                    447:     for (i = 0;i < len-1;i += 2) {
                    448:       uint32 m0 = M[i];
                    449:       uint32 r = R[i]+R[i+1]*m0;
                    450:       uint32 m = M[i+1]*m0;
                    451:       while (m >= 16384) {
                    452:         *out++ = r;
                    453:         r >>= 8;
                    454:         m = (m+255)>>8;
                    455:       }
                    456:       R2[i/2] = r;
                    457:       M2[i/2] = m;
                    458:     }
                    459:     if (i < len) {
                    460:       R2[i/2] = R[i];
                    461:       M2[i/2] = M[i];
                    462:     }
                    463:     Encode(out,R2,M2,(len+1)/2);
                    464:   }
                    465: }
                    466:
                    467: /* from supercop-20201130/crypto_kem/sntrup761/ref/kem.c */
                    468:
                    469: #ifdef LPR
                    470: #endif
                    471:
                    472:
                    473: /* ----- masks */
                    474:
                    475: #ifndef LPR
                    476:
                    477: /* return -1 if x!=0; else return 0 */
                    478: static int int16_nonzero_mask(int16 x)
                    479: {
                    480:   uint16 u = x; /* 0, else 1...65535 */
                    481:   uint32 v = u; /* 0, else 1...65535 */
                    482:   v = -v; /* 0, else 2^32-65535...2^32-1 */
                    483:   v >>= 31; /* 0, else 1 */
                    484:   return -v; /* 0, else -1 */
                    485: }
                    486:
                    487: #endif
                    488:
                    489: /* return -1 if x<0; otherwise return 0 */
                    490: static int int16_negative_mask(int16 x)
                    491: {
                    492:   uint16 u = x;
                    493:   u >>= 15;
                    494:   return -(int) u;
                    495:   /* alternative with gcc -fwrapv: */
                    496:   /* x>>15 compiles to CPU's arithmetic right shift */
                    497: }
                    498:
                    499: /* ----- arithmetic mod 3 */
                    500:
                    501: typedef int8 small;
                    502:
                    503: /* F3 is always represented as -1,0,1 */
                    504: /* so ZZ_fromF3 is a no-op */
                    505:
                    506: /* x must not be close to top int16 */
                    507: static small F3_freeze(int16 x)
                    508: {
                    509:   return int32_mod_uint14(x+1,3)-1;
                    510: }
                    511:
                    512: /* ----- arithmetic mod q */
                    513:
                    514: #define q12 ((q-1)/2)
                    515: typedef int16 Fq;
                    516: /* always represented as -q12...q12 */
                    517: /* so ZZ_fromFq is a no-op */
                    518:
                    519: /* x must not be close to top int32 */
                    520: static Fq Fq_freeze(int32 x)
                    521: {
                    522:   return int32_mod_uint14(x+q12,q)-q12;
                    523: }
                    524:
                    525: #ifndef LPR
                    526:
                    527: static Fq Fq_recip(Fq a1)
                    528: {
                    529:   int i = 1;
                    530:   Fq ai = a1;
                    531:
                    532:   while (i < q-2) {
                    533:     ai = Fq_freeze(a1*(int32)ai);
                    534:     i += 1;
                    535:   }
                    536:   return ai;
                    537: }
                    538:
                    539: #endif
                    540:
                    541: /* ----- Top and Right */
                    542:
                    543: #ifdef LPR
                    544: #define tau 16
                    545:
                    546: static int8 Top(Fq C)
                    547: {
                    548:   return (tau1*(int32)(C+tau0)+16384)>>15;
                    549: }
                    550:
                    551: static Fq Right(int8 T)
                    552: {
                    553:   return Fq_freeze(tau3*(int32)T-tau2);
                    554: }
                    555: #endif
                    556:
                    557: /* ----- small polynomials */
                    558:
                    559: #ifndef LPR
                    560:
                    561: /* 0 if Weightw_is(r), else -1 */
                    562: static int Weightw_mask(small *r)
                    563: {
                    564:   int weight = 0;
                    565:   int i;
                    566:
                    567:   for (i = 0;i < p;++i) weight += r[i]&1;
                    568:   return int16_nonzero_mask(weight-w);
                    569: }
                    570:
                    571: /* R3_fromR(R_fromRq(r)) */
                    572: static void R3_fromRq(small *out,const Fq *r)
                    573: {
                    574:   int i;
                    575:   for (i = 0;i < p;++i) out[i] = F3_freeze(r[i]);
                    576: }
                    577:
                    578: /* h = f*g in the ring R3 */
                    579: static void R3_mult(small *h,const small *f,const small *g)
                    580: {
                    581:   small fg[p+p-1];
                    582:   small result;
                    583:   int i,j;
                    584:
                    585:   for (i = 0;i < p;++i) {
                    586:     result = 0;
                    587:     for (j = 0;j <= i;++j) result = F3_freeze(result+f[j]*g[i-j]);
                    588:     fg[i] = result;
                    589:   }
                    590:   for (i = p;i < p+p-1;++i) {
                    591:     result = 0;
                    592:     for (j = i-p+1;j < p;++j) result = F3_freeze(result+f[j]*g[i-j]);
                    593:     fg[i] = result;
                    594:   }
                    595:
                    596:   for (i = p+p-2;i >= p;--i) {
                    597:     fg[i-p] = F3_freeze(fg[i-p]+fg[i]);
                    598:     fg[i-p+1] = F3_freeze(fg[i-p+1]+fg[i]);
                    599:   }
                    600:
                    601:   for (i = 0;i < p;++i) h[i] = fg[i];
                    602: }
                    603:
                    604: /* returns 0 if recip succeeded; else -1 */
                    605: static int R3_recip(small *out,const small *in)
                    606: {
                    607:   small f[p+1],g[p+1],v[p+1],r[p+1];
                    608:   int i,loop,delta;
                    609:   int sign,swap,t;
                    610:
                    611:   for (i = 0;i < p+1;++i) v[i] = 0;
                    612:   for (i = 0;i < p+1;++i) r[i] = 0;
                    613:   r[0] = 1;
                    614:   for (i = 0;i < p;++i) f[i] = 0;
                    615:   f[0] = 1; f[p-1] = f[p] = -1;
                    616:   for (i = 0;i < p;++i) g[p-1-i] = in[i];
                    617:   g[p] = 0;
                    618:
                    619:   delta = 1;
                    620:
                    621:   for (loop = 0;loop < 2*p-1;++loop) {
                    622:     for (i = p;i > 0;--i) v[i] = v[i-1];
                    623:     v[0] = 0;
                    624:
                    625:     sign = -g[0]*f[0];
                    626:     swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
                    627:     delta ^= swap&(delta^-delta);
                    628:     delta += 1;
                    629:
                    630:     for (i = 0;i < p+1;++i) {
                    631:       t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
                    632:       t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
                    633:     }
                    634:
                    635:     for (i = 0;i < p+1;++i) g[i] = F3_freeze(g[i]+sign*f[i]);
                    636:     for (i = 0;i < p+1;++i) r[i] = F3_freeze(r[i]+sign*v[i]);
                    637:
                    638:     for (i = 0;i < p;++i) g[i] = g[i+1];
                    639:     g[p] = 0;
                    640:   }
                    641:
                    642:   sign = f[0];
                    643:   for (i = 0;i < p;++i) out[i] = sign*v[p-1-i];
                    644:
                    645:   return int16_nonzero_mask(delta);
                    646: }
                    647:
                    648: #endif
                    649:
                    650: /* ----- polynomials mod q */
                    651:
                    652: /* h = f*g in the ring Rq */
                    653: static void Rq_mult_small(Fq *h,const Fq *f,const small *g)
                    654: {
                    655:   Fq fg[p+p-1];
                    656:   Fq result;
                    657:   int i,j;
                    658:
                    659:   for (i = 0;i < p;++i) {
                    660:     result = 0;
                    661:     for (j = 0;j <= i;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
                    662:     fg[i] = result;
                    663:   }
                    664:   for (i = p;i < p+p-1;++i) {
                    665:     result = 0;
                    666:     for (j = i-p+1;j < p;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
                    667:     fg[i] = result;
                    668:   }
                    669:
                    670:   for (i = p+p-2;i >= p;--i) {
                    671:     fg[i-p] = Fq_freeze(fg[i-p]+fg[i]);
                    672:     fg[i-p+1] = Fq_freeze(fg[i-p+1]+fg[i]);
                    673:   }
                    674:
                    675:   for (i = 0;i < p;++i) h[i] = fg[i];
                    676: }
                    677:
                    678: #ifndef LPR
                    679:
                    680: /* h = 3f in Rq */
                    681: static void Rq_mult3(Fq *h,const Fq *f)
                    682: {
                    683:   int i;
                    684:
                    685:   for (i = 0;i < p;++i) h[i] = Fq_freeze(3*f[i]);
                    686: }
                    687:
                    688: /* out = 1/(3*in) in Rq */
                    689: /* returns 0 if recip succeeded; else -1 */
                    690: static int Rq_recip3(Fq *out,const small *in)
                    691: {
                    692:   Fq f[p+1],g[p+1],v[p+1],r[p+1];
                    693:   int i,loop,delta;
                    694:   int swap,t;
                    695:   int32 f0,g0;
                    696:   Fq scale;
                    697:
                    698:   for (i = 0;i < p+1;++i) v[i] = 0;
                    699:   for (i = 0;i < p+1;++i) r[i] = 0;
                    700:   r[0] = Fq_recip(3);
                    701:   for (i = 0;i < p;++i) f[i] = 0;
                    702:   f[0] = 1; f[p-1] = f[p] = -1;
                    703:   for (i = 0;i < p;++i) g[p-1-i] = in[i];
                    704:   g[p] = 0;
                    705:
                    706:   delta = 1;
                    707:
                    708:   for (loop = 0;loop < 2*p-1;++loop) {
                    709:     for (i = p;i > 0;--i) v[i] = v[i-1];
                    710:     v[0] = 0;
                    711:
                    712:     swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
                    713:     delta ^= swap&(delta^-delta);
                    714:     delta += 1;
                    715:
                    716:     for (i = 0;i < p+1;++i) {
                    717:       t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
                    718:       t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
                    719:     }
                    720:
                    721:     f0 = f[0];
                    722:     g0 = g[0];
                    723:     for (i = 0;i < p+1;++i) g[i] = Fq_freeze(f0*g[i]-g0*f[i]);
                    724:     for (i = 0;i < p+1;++i) r[i] = Fq_freeze(f0*r[i]-g0*v[i]);
                    725:
                    726:     for (i = 0;i < p;++i) g[i] = g[i+1];
                    727:     g[p] = 0;
                    728:   }
                    729:
                    730:   scale = Fq_recip(f[0]);
                    731:   for (i = 0;i < p;++i) out[i] = Fq_freeze(scale*(int32)v[p-1-i]);
                    732:
                    733:   return int16_nonzero_mask(delta);
                    734: }
                    735:
                    736: #endif
                    737:
                    738: /* ----- rounded polynomials mod q */
                    739:
                    740: static void Round(Fq *out,const Fq *a)
                    741: {
                    742:   int i;
                    743:   for (i = 0;i < p;++i) out[i] = a[i]-F3_freeze(a[i]);
                    744: }
                    745:
                    746: /* ----- sorting to generate short polynomial */
                    747:
                    748: static void Short_fromlist(small *out,const uint32 *in)
                    749: {
                    750:   uint32 L[p];
                    751:   int i;
                    752:
                    753:   for (i = 0;i < w;++i) L[i] = in[i]&(uint32)-2;
                    754:   for (i = w;i < p;++i) L[i] = (in[i]&(uint32)-3)|1;
                    755:   crypto_sort_uint32(L,p);
                    756:   for (i = 0;i < p;++i) out[i] = (L[i]&3)-1;
                    757: }
                    758:
                    759: /* ----- underlying hash function */
                    760:
                    761: #define Hash_bytes 32
                    762:
                    763: /* e.g., b = 0 means out = Hash0(in) */
                    764: static void Hash_prefix(unsigned char *out,int b,const unsigned char *in,int inlen)
                    765: {
                    766:   unsigned char x[inlen+1];
                    767:   unsigned char h[64];
                    768:   int i;
                    769:
                    770:   x[0] = b;
                    771:   for (i = 0;i < inlen;++i) x[i+1] = in[i];
                    772:   crypto_hash_sha512(h,x,inlen+1);
                    773:   for (i = 0;i < 32;++i) out[i] = h[i];
                    774: }
                    775:
                    776: /* ----- higher-level randomness */
                    777:
                    778: static uint32 urandom32(void)
                    779: {
                    780:   unsigned char c[4];
                    781:   uint32 out[4];
                    782:
                    783:   randombytes(c,4);
                    784:   out[0] = (uint32)c[0];
                    785:   out[1] = ((uint32)c[1])<<8;
                    786:   out[2] = ((uint32)c[2])<<16;
                    787:   out[3] = ((uint32)c[3])<<24;
                    788:   return out[0]+out[1]+out[2]+out[3];
                    789: }
                    790:
                    791: static void Short_random(small *out)
                    792: {
                    793:   uint32 L[p];
                    794:   int i;
                    795:
                    796:   for (i = 0;i < p;++i) L[i] = urandom32();
                    797:   Short_fromlist(out,L);
                    798: }
                    799:
                    800: #ifndef LPR
                    801:
                    802: static void Small_random(small *out)
                    803: {
                    804:   int i;
                    805:
                    806:   for (i = 0;i < p;++i) out[i] = (((urandom32()&0x3fffffff)*3)>>30)-1;
                    807: }
                    808:
                    809: #endif
                    810:
                    811: /* ----- Streamlined NTRU Prime Core */
                    812:
                    813: #ifndef LPR
                    814:
                    815: /* h,(f,ginv) = KeyGen() */
                    816: static void KeyGen(Fq *h,small *f,small *ginv)
                    817: {
                    818:   small g[p];
                    819:   Fq finv[p];
                    820:
                    821:   for (;;) {
                    822:     Small_random(g);
                    823:     if (R3_recip(ginv,g) == 0) break;
                    824:   }
                    825:   Short_random(f);
                    826:   Rq_recip3(finv,f); /* always works */
                    827:   Rq_mult_small(h,finv,g);
                    828: }
                    829:
                    830: /* c = Encrypt(r,h) */
                    831: static void Encrypt(Fq *c,const small *r,const Fq *h)
                    832: {
                    833:   Fq hr[p];
                    834:
                    835:   Rq_mult_small(hr,h,r);
                    836:   Round(c,hr);
                    837: }
                    838:
                    839: /* r = Decrypt(c,(f,ginv)) */
                    840: static void Decrypt(small *r,const Fq *c,const small *f,const small *ginv)
                    841: {
                    842:   Fq cf[p];
                    843:   Fq cf3[p];
                    844:   small e[p];
                    845:   small ev[p];
                    846:   int mask;
                    847:   int i;
                    848:
                    849:   Rq_mult_small(cf,c,f);
                    850:   Rq_mult3(cf3,cf);
                    851:   R3_fromRq(e,cf3);
                    852:   R3_mult(ev,e,ginv);
                    853:
                    854:   mask = Weightw_mask(ev); /* 0 if weight w, else -1 */
                    855:   for (i = 0;i < w;++i) r[i] = ((ev[i]^1)&~mask)^1;
                    856:   for (i = w;i < p;++i) r[i] = ev[i]&~mask;
                    857: }
                    858:
                    859: #endif
                    860:
                    861: /* ----- NTRU LPRime Core */
                    862:
                    863: #ifdef LPR
                    864:
                    865: /* (G,A),a = KeyGen(G); leaves G unchanged */
                    866: static void KeyGen(Fq *A,small *a,const Fq *G)
                    867: {
                    868:   Fq aG[p];
                    869:
                    870:   Short_random(a);
                    871:   Rq_mult_small(aG,G,a);
                    872:   Round(A,aG);
                    873: }
                    874:
                    875: /* B,T = Encrypt(r,(G,A),b) */
                    876: static void Encrypt(Fq *B,int8 *T,const int8 *r,const Fq *G,const Fq *A,const small *b)
                    877: {
                    878:   Fq bG[p];
                    879:   Fq bA[p];
                    880:   int i;
                    881:
                    882:   Rq_mult_small(bG,G,b);
                    883:   Round(B,bG);
                    884:   Rq_mult_small(bA,A,b);
                    885:   for (i = 0;i < I;++i) T[i] = Top(Fq_freeze(bA[i]+r[i]*q12));
                    886: }
                    887:
                    888: /* r = Decrypt((B,T),a) */
                    889: static void Decrypt(int8 *r,const Fq *B,const int8 *T,const small *a)
                    890: {
                    891:   Fq aB[p];
                    892:   int i;
                    893:
                    894:   Rq_mult_small(aB,B,a);
                    895:   for (i = 0;i < I;++i)
                    896:     r[i] = -int16_negative_mask(Fq_freeze(Right(T[i])-aB[i]+4*w+1));
                    897: }
                    898:
                    899: #endif
                    900:
                    901: /* ----- encoding I-bit inputs */
                    902:
                    903: #ifdef LPR
                    904:
                    905: #define Inputs_bytes (I/8)
                    906: typedef int8 Inputs[I]; /* passed by reference */
                    907:
                    908: static void Inputs_encode(unsigned char *s,const Inputs r)
                    909: {
                    910:   int i;
                    911:   for (i = 0;i < Inputs_bytes;++i) s[i] = 0;
                    912:   for (i = 0;i < I;++i) s[i>>3] |= r[i]<<(i&7);
                    913: }
                    914:
                    915: #endif
                    916:
                    917: /* ----- Expand */
                    918:
                    919: #ifdef LPR
                    920:
                    921: static const unsigned char aes_nonce[16] = {0};
                    922:
                    923: static void Expand(uint32 *L,const unsigned char *k)
                    924: {
                    925:   int i;
                    926:   crypto_stream_aes256ctr((unsigned char *) L,4*p,aes_nonce,k);
                    927:   for (i = 0;i < p;++i) {
                    928:     uint32 L0 = ((unsigned char *) L)[4*i];
                    929:     uint32 L1 = ((unsigned char *) L)[4*i+1];
                    930:     uint32 L2 = ((unsigned char *) L)[4*i+2];
                    931:     uint32 L3 = ((unsigned char *) L)[4*i+3];
                    932:     L[i] = L0+(L1<<8)+(L2<<16)+(L3<<24);
                    933:   }
                    934: }
                    935:
                    936: #endif
                    937:
                    938: /* ----- Seeds */
                    939:
                    940: #ifdef LPR
                    941:
                    942: #define Seeds_bytes 32
                    943:
                    944: static void Seeds_random(unsigned char *s)
                    945: {
                    946:   randombytes(s,Seeds_bytes);
                    947: }
                    948:
                    949: #endif
                    950:
                    951: /* ----- Generator, HashShort */
                    952:
                    953: #ifdef LPR
                    954:
                    955: /* G = Generator(k) */
                    956: static void Generator(Fq *G,const unsigned char *k)
                    957: {
                    958:   uint32 L[p];
                    959:   int i;
                    960:
                    961:   Expand(L,k);
                    962:   for (i = 0;i < p;++i) G[i] = uint32_mod_uint14(L[i],q)-q12;
                    963: }
                    964:
                    965: /* out = HashShort(r) */
                    966: static void HashShort(small *out,const Inputs r)
                    967: {
                    968:   unsigned char s[Inputs_bytes];
                    969:   unsigned char h[Hash_bytes];
                    970:   uint32 L[p];
                    971:
                    972:   Inputs_encode(s,r);
                    973:   Hash_prefix(h,5,s,sizeof s);
                    974:   Expand(L,h);
                    975:   Short_fromlist(out,L);
                    976: }
                    977:
                    978: #endif
                    979:
                    980: /* ----- NTRU LPRime Expand */
                    981:
                    982: #ifdef LPR
                    983:
                    984: /* (S,A),a = XKeyGen() */
                    985: static void XKeyGen(unsigned char *S,Fq *A,small *a)
                    986: {
                    987:   Fq G[p];
                    988:
                    989:   Seeds_random(S);
                    990:   Generator(G,S);
                    991:   KeyGen(A,a,G);
                    992: }
                    993:
                    994: /* B,T = XEncrypt(r,(S,A)) */
                    995: static void XEncrypt(Fq *B,int8 *T,const int8 *r,const unsigned char *S,const Fq *A)
                    996: {
                    997:   Fq G[p];
                    998:   small b[p];
                    999:
                   1000:   Generator(G,S);
                   1001:   HashShort(b,r);
                   1002:   Encrypt(B,T,r,G,A,b);
                   1003: }
                   1004:
                   1005: #define XDecrypt Decrypt
                   1006:
                   1007: #endif
                   1008:
                   1009: /* ----- encoding small polynomials (including short polynomials) */
                   1010:
                   1011: #define Small_bytes ((p+3)/4)
                   1012:
                   1013: /* these are the only functions that rely on p mod 4 = 1 */
                   1014:
                   1015: static void Small_encode(unsigned char *s,const small *f)
                   1016: {
                   1017:   small x;
                   1018:   int i;
                   1019:
                   1020:   for (i = 0;i < p/4;++i) {
                   1021:     x = *f++ + 1;
                   1022:     x += (*f++ + 1)<<2;
                   1023:     x += (*f++ + 1)<<4;
                   1024:     x += (*f++ + 1)<<6;
                   1025:     *s++ = x;
                   1026:   }
                   1027:   x = *f++ + 1;
                   1028:   *s++ = x;
                   1029: }
                   1030:
                   1031: static void Small_decode(small *f,const unsigned char *s)
                   1032: {
                   1033:   unsigned char x;
                   1034:   int i;
                   1035:
                   1036:   for (i = 0;i < p/4;++i) {
                   1037:     x = *s++;
                   1038:     *f++ = ((small)(x&3))-1; x >>= 2;
                   1039:     *f++ = ((small)(x&3))-1; x >>= 2;
                   1040:     *f++ = ((small)(x&3))-1; x >>= 2;
                   1041:     *f++ = ((small)(x&3))-1;
                   1042:   }
                   1043:   x = *s++;
                   1044:   *f++ = ((small)(x&3))-1;
                   1045: }
                   1046:
                   1047: /* ----- encoding general polynomials */
                   1048:
                   1049: #ifndef LPR
                   1050:
                   1051: static void Rq_encode(unsigned char *s,const Fq *r)
                   1052: {
                   1053:   uint16 R[p],M[p];
                   1054:   int i;
                   1055:
                   1056:   for (i = 0;i < p;++i) R[i] = r[i]+q12;
                   1057:   for (i = 0;i < p;++i) M[i] = q;
                   1058:   Encode(s,R,M,p);
                   1059: }
                   1060:
                   1061: static void Rq_decode(Fq *r,const unsigned char *s)
                   1062: {
                   1063:   uint16 R[p],M[p];
                   1064:   int i;
                   1065:
                   1066:   for (i = 0;i < p;++i) M[i] = q;
                   1067:   Decode(R,s,M,p);
                   1068:   for (i = 0;i < p;++i) r[i] = ((Fq)R[i])-q12;
                   1069: }
                   1070:
                   1071: #endif
                   1072:
                   1073: /* ----- encoding rounded polynomials */
                   1074:
                   1075: static void Rounded_encode(unsigned char *s,const Fq *r)
                   1076: {
                   1077:   uint16 R[p],M[p];
                   1078:   int i;
                   1079:
                   1080:   for (i = 0;i < p;++i) R[i] = ((r[i]+q12)*10923)>>15;
                   1081:   for (i = 0;i < p;++i) M[i] = (q+2)/3;
                   1082:   Encode(s,R,M,p);
                   1083: }
                   1084:
                   1085: static void Rounded_decode(Fq *r,const unsigned char *s)
                   1086: {
                   1087:   uint16 R[p],M[p];
                   1088:   int i;
                   1089:
                   1090:   for (i = 0;i < p;++i) M[i] = (q+2)/3;
                   1091:   Decode(R,s,M,p);
                   1092:   for (i = 0;i < p;++i) r[i] = R[i]*3-q12;
                   1093: }
                   1094:
                   1095: /* ----- encoding top polynomials */
                   1096:
                   1097: #ifdef LPR
                   1098:
                   1099: #define Top_bytes (I/2)
                   1100:
                   1101: static void Top_encode(unsigned char *s,const int8 *T)
                   1102: {
                   1103:   int i;
                   1104:   for (i = 0;i < Top_bytes;++i)
                   1105:     s[i] = T[2*i]+(T[2*i+1]<<4);
                   1106: }
                   1107:
                   1108: static void Top_decode(int8 *T,const unsigned char *s)
                   1109: {
                   1110:   int i;
                   1111:   for (i = 0;i < Top_bytes;++i) {
                   1112:     T[2*i] = s[i]&15;
                   1113:     T[2*i+1] = s[i]>>4;
                   1114:   }
                   1115: }
                   1116:
                   1117: #endif
                   1118:
                   1119: /* ----- Streamlined NTRU Prime Core plus encoding */
                   1120:
                   1121: #ifndef LPR
                   1122:
                   1123: typedef small Inputs[p]; /* passed by reference */
                   1124: #define Inputs_random Short_random
                   1125: #define Inputs_encode Small_encode
                   1126: #define Inputs_bytes Small_bytes
                   1127:
                   1128: #define Ciphertexts_bytes Rounded_bytes
                   1129: #define SecretKeys_bytes (2*Small_bytes)
                   1130: #define PublicKeys_bytes Rq_bytes
                   1131:
                   1132: /* pk,sk = ZKeyGen() */
                   1133: static void ZKeyGen(unsigned char *pk,unsigned char *sk)
                   1134: {
                   1135:   Fq h[p];
                   1136:   small f[p],v[p];
                   1137:
                   1138:   KeyGen(h,f,v);
                   1139:   Rq_encode(pk,h);
                   1140:   Small_encode(sk,f); sk += Small_bytes;
                   1141:   Small_encode(sk,v);
                   1142: }
                   1143:
                   1144: /* C = ZEncrypt(r,pk) */
                   1145: static void ZEncrypt(unsigned char *C,const Inputs r,const unsigned char *pk)
                   1146: {
                   1147:   Fq h[p];
                   1148:   Fq c[p];
                   1149:   Rq_decode(h,pk);
                   1150:   Encrypt(c,r,h);
                   1151:   Rounded_encode(C,c);
                   1152: }
                   1153:
                   1154: /* r = ZDecrypt(C,sk) */
                   1155: static void ZDecrypt(Inputs r,const unsigned char *C,const unsigned char *sk)
                   1156: {
                   1157:   small f[p],v[p];
                   1158:   Fq c[p];
                   1159:
                   1160:   Small_decode(f,sk); sk += Small_bytes;
                   1161:   Small_decode(v,sk);
                   1162:   Rounded_decode(c,C);
                   1163:   Decrypt(r,c,f,v);
                   1164: }
                   1165:
                   1166: #endif
                   1167:
                   1168: /* ----- NTRU LPRime Expand plus encoding */
                   1169:
                   1170: #ifdef LPR
                   1171:
                   1172: #define Ciphertexts_bytes (Rounded_bytes+Top_bytes)
                   1173: #define SecretKeys_bytes Small_bytes
                   1174: #define PublicKeys_bytes (Seeds_bytes+Rounded_bytes)
                   1175:
                   1176: static void Inputs_random(Inputs r)
                   1177: {
                   1178:   unsigned char s[Inputs_bytes];
                   1179:   int i;
                   1180:
                   1181:   randombytes(s,sizeof s);
                   1182:   for (i = 0;i < I;++i) r[i] = 1&(s[i>>3]>>(i&7));
                   1183: }
                   1184:
                   1185: /* pk,sk = ZKeyGen() */
                   1186: static void ZKeyGen(unsigned char *pk,unsigned char *sk)
                   1187: {
                   1188:   Fq A[p];
                   1189:   small a[p];
                   1190:
                   1191:   XKeyGen(pk,A,a); pk += Seeds_bytes;
                   1192:   Rounded_encode(pk,A);
                   1193:   Small_encode(sk,a);
                   1194: }
                   1195:
                   1196: /* c = ZEncrypt(r,pk) */
                   1197: static void ZEncrypt(unsigned char *c,const Inputs r,const unsigned char *pk)
                   1198: {
                   1199:   Fq A[p];
                   1200:   Fq B[p];
                   1201:   int8 T[I];
                   1202:
                   1203:   Rounded_decode(A,pk+Seeds_bytes);
                   1204:   XEncrypt(B,T,r,pk,A);
                   1205:   Rounded_encode(c,B); c += Rounded_bytes;
                   1206:   Top_encode(c,T);
                   1207: }
                   1208:
                   1209: /* r = ZDecrypt(C,sk) */
                   1210: static void ZDecrypt(Inputs r,const unsigned char *c,const unsigned char *sk)
                   1211: {
                   1212:   small a[p];
                   1213:   Fq B[p];
                   1214:   int8 T[I];
                   1215:
                   1216:   Small_decode(a,sk);
                   1217:   Rounded_decode(B,c);
                   1218:   Top_decode(T,c+Rounded_bytes);
                   1219:   XDecrypt(r,B,T,a);
                   1220: }
                   1221:
                   1222: #endif
                   1223:
                   1224: /* ----- confirmation hash */
                   1225:
                   1226: #define Confirm_bytes 32
                   1227:
                   1228: /* h = HashConfirm(r,pk,cache); cache is Hash4(pk) */
                   1229: static void HashConfirm(unsigned char *h,const unsigned char *r,const unsigned char *pk,const unsigned char *cache)
                   1230: {
                   1231: #ifndef LPR
                   1232:   unsigned char x[Hash_bytes*2];
                   1233:   int i;
                   1234:
                   1235:   Hash_prefix(x,3,r,Inputs_bytes);
                   1236:   for (i = 0;i < Hash_bytes;++i) x[Hash_bytes+i] = cache[i];
                   1237: #else
                   1238:   unsigned char x[Inputs_bytes+Hash_bytes];
                   1239:   int i;
                   1240:
                   1241:   for (i = 0;i < Inputs_bytes;++i) x[i] = r[i];
                   1242:   for (i = 0;i < Hash_bytes;++i) x[Inputs_bytes+i] = cache[i];
                   1243: #endif
                   1244:   Hash_prefix(h,2,x,sizeof x);
                   1245: }
                   1246:
                   1247: /* ----- session-key hash */
                   1248:
                   1249: /* k = HashSession(b,y,z) */
                   1250: static void HashSession(unsigned char *k,int b,const unsigned char *y,const unsigned char *z)
                   1251: {
                   1252: #ifndef LPR
                   1253:   unsigned char x[Hash_bytes+Ciphertexts_bytes+Confirm_bytes];
                   1254:   int i;
                   1255:
                   1256:   Hash_prefix(x,3,y,Inputs_bytes);
                   1257:   for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Hash_bytes+i] = z[i];
                   1258: #else
                   1259:   unsigned char x[Inputs_bytes+Ciphertexts_bytes+Confirm_bytes];
                   1260:   int i;
                   1261:
                   1262:   for (i = 0;i < Inputs_bytes;++i) x[i] = y[i];
                   1263:   for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Inputs_bytes+i] = z[i];
                   1264: #endif
                   1265:   Hash_prefix(k,b,x,sizeof x);
                   1266: }
                   1267:
                   1268: /* ----- Streamlined NTRU Prime and NTRU LPRime */
                   1269:
                   1270: /* pk,sk = KEM_KeyGen() */
                   1271: static void KEM_KeyGen(unsigned char *pk,unsigned char *sk)
                   1272: {
                   1273:   int i;
                   1274:
                   1275:   ZKeyGen(pk,sk); sk += SecretKeys_bytes;
                   1276:   for (i = 0;i < PublicKeys_bytes;++i) *sk++ = pk[i];
                   1277:   randombytes(sk,Inputs_bytes); sk += Inputs_bytes;
                   1278:   Hash_prefix(sk,4,pk,PublicKeys_bytes);
                   1279: }
                   1280:
                   1281: /* c,r_enc = Hide(r,pk,cache); cache is Hash4(pk) */
                   1282: static void Hide(unsigned char *c,unsigned char *r_enc,const Inputs r,const unsigned char *pk,const unsigned char *cache)
                   1283: {
                   1284:   Inputs_encode(r_enc,r);
                   1285:   ZEncrypt(c,r,pk); c += Ciphertexts_bytes;
                   1286:   HashConfirm(c,r_enc,pk,cache);
                   1287: }
                   1288:
                   1289: /* c,k = Encap(pk) */
                   1290: static void Encap(unsigned char *c,unsigned char *k,const unsigned char *pk)
                   1291: {
                   1292:   Inputs r;
                   1293:   unsigned char r_enc[Inputs_bytes];
                   1294:   unsigned char cache[Hash_bytes];
                   1295:
                   1296:   Hash_prefix(cache,4,pk,PublicKeys_bytes);
                   1297:   Inputs_random(r);
                   1298:   Hide(c,r_enc,r,pk,cache);
                   1299:   HashSession(k,1,r_enc,c);
                   1300: }
                   1301:
                   1302: /* 0 if matching ciphertext+confirm, else -1 */
                   1303: static int Ciphertexts_diff_mask(const unsigned char *c,const unsigned char *c2)
                   1304: {
                   1305:   uint16 differentbits = 0;
                   1306:   int len = Ciphertexts_bytes+Confirm_bytes;
                   1307:
                   1308:   while (len-- > 0) differentbits |= (*c++)^(*c2++);
                   1309:   return (1&((differentbits-1)>>8))-1;
                   1310: }
                   1311:
                   1312: /* k = Decap(c,sk) */
                   1313: static void Decap(unsigned char *k,const unsigned char *c,const unsigned char *sk)
                   1314: {
                   1315:   const unsigned char *pk = sk + SecretKeys_bytes;
                   1316:   const unsigned char *rho = pk + PublicKeys_bytes;
                   1317:   const unsigned char *cache = rho + Inputs_bytes;
                   1318:   Inputs r;
                   1319:   unsigned char r_enc[Inputs_bytes];
                   1320:   unsigned char cnew[Ciphertexts_bytes+Confirm_bytes];
                   1321:   int mask;
                   1322:   int i;
                   1323:
                   1324:   ZDecrypt(r,c,sk);
                   1325:   Hide(cnew,r_enc,r,pk,cache);
                   1326:   mask = Ciphertexts_diff_mask(c,cnew);
                   1327:   for (i = 0;i < Inputs_bytes;++i) r_enc[i] ^= mask&(r_enc[i]^rho[i]);
                   1328:   HashSession(k,1+mask,r_enc,c);
                   1329: }
                   1330:
                   1331: /* ----- crypto_kem API */
                   1332:
                   1333:
                   1334: int crypto_kem_sntrup761_keypair(unsigned char *pk,unsigned char *sk)
                   1335: {
                   1336:   KEM_KeyGen(pk,sk);
                   1337:   return 0;
                   1338: }
                   1339:
                   1340: int crypto_kem_sntrup761_enc(unsigned char *c,unsigned char *k,const unsigned char *pk)
                   1341: {
                   1342:   Encap(c,k,pk);
                   1343:   return 0;
                   1344: }
                   1345:
                   1346: int crypto_kem_sntrup761_dec(unsigned char *k,const unsigned char *c,const unsigned char *sk)
                   1347: {
                   1348:   Decap(k,c,sk);
                   1349:   return 0;
                   1350: }
                   1351: