[BACK]Return to sshkey-xmss.c CVS log [TXT][DIR] Up to [local] / src / usr.bin / ssh

Annotation of src/usr.bin/ssh/sshkey-xmss.c, Revision 1.11

1.11    ! djm         1: /* $OpenBSD: sshkey-xmss.c,v 1.10 2021/03/06 20:36:31 millert Exp $ */
1.1       markus      2: /*
                      3:  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
                      4:  *
                      5:  * Redistribution and use in source and binary forms, with or without
                      6:  * modification, are permitted provided that the following conditions
                      7:  * are met:
                      8:  * 1. Redistributions of source code must retain the above copyright
                      9:  *    notice, this list of conditions and the following disclaimer.
                     10:  * 2. Redistributions in binary form must reproduce the above copyright
                     11:  *    notice, this list of conditions and the following disclaimer in the
                     12:  *    documentation and/or other materials provided with the distribution.
                     13:  *
                     14:  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
                     15:  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
                     16:  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
                     17:  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
                     18:  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
                     19:  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
                     20:  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
                     21:  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
                     22:  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
                     23:  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
                     24:  */
                     25:
                     26: #include <sys/types.h>
                     27: #include <sys/uio.h>
                     28:
                     29: #include <stdio.h>
                     30: #include <string.h>
                     31: #include <unistd.h>
                     32: #include <fcntl.h>
                     33: #include <errno.h>
                     34:
                     35: #include "ssh2.h"
                     36: #include "ssherr.h"
                     37: #include "sshbuf.h"
                     38: #include "cipher.h"
                     39: #include "sshkey.h"
                     40: #include "sshkey-xmss.h"
                     41: #include "atomicio.h"
1.9       dtucker    42: #include "log.h"
1.1       markus     43:
                     44: #include "xmss_fast.h"
                     45:
                     46: /* opaque internal XMSS state */
                     47: #define XMSS_MAGIC             "xmss-state-v1"
                     48: #define XMSS_CIPHERNAME                "aes256-gcm@openssh.com"
                     49: struct ssh_xmss_state {
                     50:        xmss_params     params;
                     51:        u_int32_t       n, w, h, k;
                     52:
                     53:        bds_state       bds;
                     54:        u_char          *stack;
                     55:        u_int32_t       stackoffset;
                     56:        u_char          *stacklevels;
                     57:        u_char          *auth;
                     58:        u_char          *keep;
                     59:        u_char          *th_nodes;
                     60:        u_char          *retain;
                     61:        treehash_inst   *treehash;
                     62:
                     63:        u_int32_t       idx;            /* state read from file */
1.2       djm        64:        u_int32_t       maxidx;         /* restricted # of signatures */
1.1       markus     65:        int             have_state;     /* .state file exists */
                     66:        int             lockfd;         /* locked in sshkey_xmss_get_state() */
1.8       markus     67:        u_char          allow_update;   /* allow sshkey_xmss_update_state() */
1.1       markus     68:        char            *enc_ciphername;/* encrypt state with cipher */
                     69:        u_char          *enc_keyiv;     /* encrypt state with key */
                     70:        u_int32_t       enc_keyiv_len;  /* length of enc_keyiv */
                     71: };
                     72:
                     73: int     sshkey_xmss_init_bds_state(struct sshkey *);
                     74: int     sshkey_xmss_init_enc_key(struct sshkey *, const char *);
                     75: void    sshkey_xmss_free_bds(struct sshkey *);
                     76: int     sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
1.9       dtucker    77:            int *, int);
1.1       markus     78: int     sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
                     79:            struct sshbuf **);
                     80: int     sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
                     81:            struct sshbuf **);
                     82: int     sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
                     83: int     sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
                     84:
1.9       dtucker    85: #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
1.10      millert    86:     0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
1.1       markus     87:
                     88: int
                     89: sshkey_xmss_init(struct sshkey *key, const char *name)
                     90: {
                     91:        struct ssh_xmss_state *state;
                     92:
                     93:        if (key->xmss_state != NULL)
                     94:                return SSH_ERR_INVALID_FORMAT;
                     95:        if (name == NULL)
                     96:                return SSH_ERR_INVALID_FORMAT;
                     97:        state = calloc(sizeof(struct ssh_xmss_state), 1);
                     98:        if (state == NULL)
                     99:                return SSH_ERR_ALLOC_FAIL;
                    100:        if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
                    101:                state->n = 32;
                    102:                state->w = 16;
                    103:                state->h = 10;
                    104:        } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
                    105:                state->n = 32;
                    106:                state->w = 16;
                    107:                state->h = 16;
                    108:        } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
                    109:                state->n = 32;
                    110:                state->w = 16;
                    111:                state->h = 20;
                    112:        } else {
                    113:                free(state);
                    114:                return SSH_ERR_KEY_TYPE_UNKNOWN;
                    115:        }
                    116:        if ((key->xmss_name = strdup(name)) == NULL) {
                    117:                free(state);
                    118:                return SSH_ERR_ALLOC_FAIL;
                    119:        }
                    120:        state->k = 2;   /* XXX hardcoded */
                    121:        state->lockfd = -1;
                    122:        if (xmss_set_params(&state->params, state->n, state->h, state->w,
                    123:            state->k) != 0) {
                    124:                free(state);
                    125:                return SSH_ERR_INVALID_FORMAT;
                    126:        }
                    127:        key->xmss_state = state;
                    128:        return 0;
                    129: }
                    130:
                    131: void
                    132: sshkey_xmss_free_state(struct sshkey *key)
                    133: {
                    134:        struct ssh_xmss_state *state = key->xmss_state;
                    135:
                    136:        sshkey_xmss_free_bds(key);
                    137:        if (state) {
                    138:                if (state->enc_keyiv) {
                    139:                        explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
                    140:                        free(state->enc_keyiv);
                    141:                }
                    142:                free(state->enc_ciphername);
                    143:                free(state);
                    144:        }
                    145:        key->xmss_state = NULL;
                    146: }
                    147:
                    148: #define SSH_XMSS_K2_MAGIC      "k=2"
                    149: #define num_stack(x)           ((x->h+1)*(x->n))
                    150: #define num_stacklevels(x)     (x->h+1)
                    151: #define num_auth(x)            ((x->h)*(x->n))
                    152: #define num_keep(x)            ((x->h >> 1)*(x->n))
                    153: #define num_th_nodes(x)                ((x->h - x->k)*(x->n))
                    154: #define num_retain(x)          (((1ULL << x->k) - x->k - 1) * (x->n))
                    155: #define num_treehash(x)                ((x->h) - (x->k))
                    156:
                    157: int
                    158: sshkey_xmss_init_bds_state(struct sshkey *key)
                    159: {
                    160:        struct ssh_xmss_state *state = key->xmss_state;
                    161:        u_int32_t i;
                    162:
                    163:        state->stackoffset = 0;
                    164:        if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
                    165:            (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
                    166:            (state->auth = calloc(num_auth(state), 1)) == NULL ||
                    167:            (state->keep = calloc(num_keep(state), 1)) == NULL ||
                    168:            (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
                    169:            (state->retain = calloc(num_retain(state), 1)) == NULL ||
                    170:            (state->treehash = calloc(num_treehash(state),
                    171:            sizeof(treehash_inst))) == NULL) {
                    172:                sshkey_xmss_free_bds(key);
                    173:                return SSH_ERR_ALLOC_FAIL;
                    174:        }
                    175:        for (i = 0; i < state->h - state->k; i++)
                    176:                state->treehash[i].node = &state->th_nodes[state->n*i];
                    177:        xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
                    178:            state->stacklevels, state->auth, state->keep, state->treehash,
                    179:            state->retain, 0);
                    180:        return 0;
                    181: }
                    182:
                    183: void
                    184: sshkey_xmss_free_bds(struct sshkey *key)
                    185: {
                    186:        struct ssh_xmss_state *state = key->xmss_state;
                    187:
                    188:        if (state == NULL)
                    189:                return;
                    190:        free(state->stack);
                    191:        free(state->stacklevels);
                    192:        free(state->auth);
                    193:        free(state->keep);
                    194:        free(state->th_nodes);
                    195:        free(state->retain);
                    196:        free(state->treehash);
                    197:        state->stack = NULL;
                    198:        state->stacklevels = NULL;
                    199:        state->auth = NULL;
                    200:        state->keep = NULL;
                    201:        state->th_nodes = NULL;
                    202:        state->retain = NULL;
                    203:        state->treehash = NULL;
                    204: }
                    205:
                    206: void *
                    207: sshkey_xmss_params(const struct sshkey *key)
                    208: {
                    209:        struct ssh_xmss_state *state = key->xmss_state;
                    210:
                    211:        if (state == NULL)
                    212:                return NULL;
                    213:        return &state->params;
                    214: }
                    215:
                    216: void *
                    217: sshkey_xmss_bds_state(const struct sshkey *key)
                    218: {
                    219:        struct ssh_xmss_state *state = key->xmss_state;
                    220:
                    221:        if (state == NULL)
                    222:                return NULL;
                    223:        return &state->bds;
                    224: }
                    225:
                    226: int
                    227: sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
                    228: {
                    229:        struct ssh_xmss_state *state = key->xmss_state;
                    230:
                    231:        if (lenp == NULL)
                    232:                return SSH_ERR_INVALID_ARGUMENT;
                    233:        if (state == NULL)
                    234:                return SSH_ERR_INVALID_FORMAT;
                    235:        *lenp = 4 + state->n +
                    236:            state->params.wots_par.keysize +
                    237:            state->h * state->n;
                    238:        return 0;
                    239: }
                    240:
                    241: size_t
                    242: sshkey_xmss_pklen(const struct sshkey *key)
                    243: {
                    244:        struct ssh_xmss_state *state = key->xmss_state;
                    245:
                    246:        if (state == NULL)
                    247:                return 0;
                    248:        return state->n * 2;
                    249: }
                    250:
                    251: size_t
                    252: sshkey_xmss_sklen(const struct sshkey *key)
                    253: {
                    254:        struct ssh_xmss_state *state = key->xmss_state;
                    255:
                    256:        if (state == NULL)
                    257:                return 0;
                    258:        return state->n * 4 + 4;
                    259: }
                    260:
                    261: int
                    262: sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
                    263: {
                    264:        struct ssh_xmss_state *state = k->xmss_state;
                    265:        const struct sshcipher *cipher;
                    266:        size_t keylen = 0, ivlen = 0;
                    267:
                    268:        if (state == NULL)
                    269:                return SSH_ERR_INVALID_ARGUMENT;
                    270:        if ((cipher = cipher_by_name(ciphername)) == NULL)
                    271:                return SSH_ERR_INTERNAL_ERROR;
                    272:        if ((state->enc_ciphername = strdup(ciphername)) == NULL)
                    273:                return SSH_ERR_ALLOC_FAIL;
                    274:        keylen = cipher_keylen(cipher);
                    275:        ivlen = cipher_ivlen(cipher);
                    276:        state->enc_keyiv_len = keylen + ivlen;
                    277:        if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
                    278:                free(state->enc_ciphername);
                    279:                state->enc_ciphername = NULL;
                    280:                return SSH_ERR_ALLOC_FAIL;
                    281:        }
                    282:        arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
                    283:        return 0;
                    284: }
                    285:
                    286: int
                    287: sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
                    288: {
                    289:        struct ssh_xmss_state *state = k->xmss_state;
                    290:        int r;
                    291:
                    292:        if (state == NULL || state->enc_keyiv == NULL ||
                    293:            state->enc_ciphername == NULL)
                    294:                return SSH_ERR_INVALID_ARGUMENT;
                    295:        if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
                    296:            (r = sshbuf_put_string(b, state->enc_keyiv,
                    297:            state->enc_keyiv_len)) != 0)
                    298:                return r;
                    299:        return 0;
                    300: }
                    301:
                    302: int
                    303: sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
                    304: {
                    305:        struct ssh_xmss_state *state = k->xmss_state;
                    306:        size_t len;
                    307:        int r;
                    308:
                    309:        if (state == NULL)
                    310:                return SSH_ERR_INVALID_ARGUMENT;
                    311:        if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
                    312:            (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
                    313:                return r;
                    314:        state->enc_keyiv_len = len;
                    315:        return 0;
                    316: }
                    317:
                    318: int
                    319: sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
                    320:     enum sshkey_serialize_rep opts)
                    321: {
                    322:        struct ssh_xmss_state *state = k->xmss_state;
                    323:        u_char have_info = 1;
                    324:        u_int32_t idx;
                    325:        int r;
                    326:
                    327:        if (state == NULL)
                    328:                return SSH_ERR_INVALID_ARGUMENT;
                    329:        if (opts != SSHKEY_SERIALIZE_INFO)
                    330:                return 0;
                    331:        idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
                    332:        if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
                    333:            (r = sshbuf_put_u32(b, idx)) != 0 ||
                    334:            (r = sshbuf_put_u32(b, state->maxidx)) != 0)
                    335:                return r;
                    336:        return 0;
                    337: }
                    338:
                    339: int
                    340: sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
                    341: {
                    342:        struct ssh_xmss_state *state = k->xmss_state;
                    343:        u_char have_info;
                    344:        int r;
                    345:
                    346:        if (state == NULL)
                    347:                return SSH_ERR_INVALID_ARGUMENT;
                    348:        /* optional */
                    349:        if (sshbuf_len(b) == 0)
                    350:                return 0;
                    351:        if ((r = sshbuf_get_u8(b, &have_info)) != 0)
                    352:                return r;
                    353:        if (have_info != 1)
                    354:                return SSH_ERR_INVALID_ARGUMENT;
                    355:        if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
                    356:            (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
                    357:                return r;
                    358:        return 0;
                    359: }
                    360:
                    361: int
                    362: sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
                    363: {
                    364:        int r;
                    365:        const char *name;
                    366:
                    367:        if (bits == 10) {
                    368:                name = XMSS_SHA2_256_W16_H10_NAME;
                    369:        } else if (bits == 16) {
                    370:                name = XMSS_SHA2_256_W16_H16_NAME;
                    371:        } else if (bits == 20) {
                    372:                name = XMSS_SHA2_256_W16_H20_NAME;
                    373:        } else {
                    374:                name = XMSS_DEFAULT_NAME;
                    375:        }
                    376:        if ((r = sshkey_xmss_init(k, name)) != 0 ||
                    377:            (r = sshkey_xmss_init_bds_state(k)) != 0 ||
                    378:            (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
                    379:                return r;
                    380:        if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
                    381:            (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
                    382:                return SSH_ERR_ALLOC_FAIL;
                    383:        }
                    384:        xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
                    385:            sshkey_xmss_params(k));
                    386:        return 0;
                    387: }
                    388:
                    389: int
                    390: sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
1.9       dtucker   391:     int *have_file, int printerror)
1.1       markus    392: {
                    393:        struct sshbuf *b = NULL, *enc = NULL;
                    394:        int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
                    395:        u_int32_t len;
                    396:        unsigned char buf[4], *data = NULL;
                    397:
                    398:        *have_file = 0;
                    399:        if ((fd = open(filename, O_RDONLY)) >= 0) {
                    400:                *have_file = 1;
                    401:                if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
1.10      millert   402:                        PRINT("corrupt state file: %s", filename);
1.1       markus    403:                        goto done;
                    404:                }
                    405:                len = PEEK_U32(buf);
                    406:                if ((data = calloc(len, 1)) == NULL) {
                    407:                        ret = SSH_ERR_ALLOC_FAIL;
                    408:                        goto done;
                    409:                }
                    410:                if (atomicio(read, fd, data, len) != len) {
1.10      millert   411:                        PRINT("cannot read blob: %s", filename);
1.1       markus    412:                        goto done;
                    413:                }
                    414:                if ((enc = sshbuf_from(data, len)) == NULL) {
                    415:                        ret = SSH_ERR_ALLOC_FAIL;
                    416:                        goto done;
                    417:                }
                    418:                sshkey_xmss_free_bds(k);
                    419:                if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
                    420:                        ret = r;
                    421:                        goto done;
                    422:                }
                    423:                if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
                    424:                        ret = r;
                    425:                        goto done;
                    426:                }
                    427:                ret = 0;
                    428:        }
                    429: done:
                    430:        if (fd != -1)
                    431:                close(fd);
                    432:        free(data);
                    433:        sshbuf_free(enc);
                    434:        sshbuf_free(b);
                    435:        return ret;
                    436: }
                    437:
                    438: int
1.9       dtucker   439: sshkey_xmss_get_state(const struct sshkey *k, int printerror)
1.1       markus    440: {
                    441:        struct ssh_xmss_state *state = k->xmss_state;
                    442:        u_int32_t idx = 0;
                    443:        char *filename = NULL;
                    444:        char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
                    445:        int lockfd = -1, have_state = 0, have_ostate, tries = 0;
                    446:        int ret = SSH_ERR_INVALID_ARGUMENT, r;
                    447:
                    448:        if (state == NULL)
                    449:                goto done;
                    450:        /*
                    451:         * If maxidx is set, then we are allowed a limited number
                    452:         * of signatures, but don't need to access the disk.
                    453:         * Otherwise we need to deal with the on-disk state.
                    454:         */
                    455:        if (state->maxidx) {
                    456:                /* xmss_sk always contains the current state */
                    457:                idx = PEEK_U32(k->xmss_sk);
                    458:                if (idx < state->maxidx) {
                    459:                        state->allow_update = 1;
                    460:                        return 0;
                    461:                }
                    462:                return SSH_ERR_INVALID_ARGUMENT;
                    463:        }
                    464:        if ((filename = k->xmss_filename) == NULL)
                    465:                goto done;
1.4       deraadt   466:        if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
                    467:            asprintf(&statefile, "%s.state", filename) == -1 ||
                    468:            asprintf(&ostatefile, "%s.ostate", filename) == -1) {
1.1       markus    469:                ret = SSH_ERR_ALLOC_FAIL;
                    470:                goto done;
                    471:        }
1.5       deraadt   472:        if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
1.1       markus    473:                ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   474:                PRINT("cannot open/create: %s", lockfile);
1.1       markus    475:                goto done;
                    476:        }
1.5       deraadt   477:        while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
1.1       markus    478:                if (errno != EWOULDBLOCK) {
                    479:                        ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   480:                        PRINT("cannot lock: %s", lockfile);
1.1       markus    481:                        goto done;
                    482:                }
                    483:                if (++tries > 10) {
                    484:                        ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   485:                        PRINT("giving up on: %s", lockfile);
1.1       markus    486:                        goto done;
                    487:                }
                    488:                usleep(1000*100*tries);
                    489:        }
                    490:        /* XXX no longer const */
                    491:        if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
1.9       dtucker   492:            statefile, &have_state, printerror)) != 0) {
1.1       markus    493:                if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
1.9       dtucker   494:                    ostatefile, &have_ostate, printerror)) == 0) {
1.1       markus    495:                        state->allow_update = 1;
                    496:                        r = sshkey_xmss_forward_state(k, 1);
                    497:                        state->idx = PEEK_U32(k->xmss_sk);
                    498:                        state->allow_update = 0;
                    499:                }
                    500:        }
                    501:        if (!have_state && !have_ostate) {
                    502:                /* check that bds state is initialized */
                    503:                if (state->bds.auth == NULL)
                    504:                        goto done;
1.10      millert   505:                PRINT("start from scratch idx 0: %u", state->idx);
1.1       markus    506:        } else if (r != 0) {
                    507:                ret = r;
                    508:                goto done;
                    509:        }
                    510:        if (state->idx + 1 < state->idx) {
1.10      millert   511:                PRINT("state wrap: %u", state->idx);
1.1       markus    512:                goto done;
                    513:        }
                    514:        state->have_state = have_state;
                    515:        state->lockfd = lockfd;
                    516:        state->allow_update = 1;
                    517:        lockfd = -1;
                    518:        ret = 0;
                    519: done:
                    520:        if (lockfd != -1)
                    521:                close(lockfd);
                    522:        free(lockfile);
                    523:        free(statefile);
                    524:        free(ostatefile);
                    525:        return ret;
                    526: }
                    527:
                    528: int
                    529: sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
                    530: {
                    531:        struct ssh_xmss_state *state = k->xmss_state;
                    532:        u_char *sig = NULL;
                    533:        size_t required_siglen;
                    534:        unsigned long long smlen;
                    535:        u_char data;
                    536:        int ret, r;
                    537:
                    538:        if (state == NULL || !state->allow_update)
                    539:                return SSH_ERR_INVALID_ARGUMENT;
                    540:        if (reserve == 0)
                    541:                return SSH_ERR_INVALID_ARGUMENT;
                    542:        if (state->idx + reserve <= state->idx)
                    543:                return SSH_ERR_INVALID_ARGUMENT;
                    544:        if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
                    545:                return r;
                    546:        if ((sig = malloc(required_siglen)) == NULL)
                    547:                return SSH_ERR_ALLOC_FAIL;
                    548:        while (reserve-- > 0) {
                    549:                state->idx = PEEK_U32(k->xmss_sk);
                    550:                smlen = required_siglen;
                    551:                if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
                    552:                    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
                    553:                        r = SSH_ERR_INVALID_ARGUMENT;
                    554:                        break;
                    555:                }
                    556:        }
                    557:        free(sig);
                    558:        return r;
                    559: }
                    560:
                    561: int
1.9       dtucker   562: sshkey_xmss_update_state(const struct sshkey *k, int printerror)
1.1       markus    563: {
                    564:        struct ssh_xmss_state *state = k->xmss_state;
                    565:        struct sshbuf *b = NULL, *enc = NULL;
                    566:        u_int32_t idx = 0;
                    567:        unsigned char buf[4];
                    568:        char *filename = NULL;
                    569:        char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
                    570:        int fd = -1;
                    571:        int ret = SSH_ERR_INVALID_ARGUMENT;
                    572:
                    573:        if (state == NULL || !state->allow_update)
                    574:                return ret;
                    575:        if (state->maxidx) {
                    576:                /* no update since the number of signatures is limited */
                    577:                ret = 0;
                    578:                goto done;
                    579:        }
                    580:        idx = PEEK_U32(k->xmss_sk);
                    581:        if (idx == state->idx) {
1.2       djm       582:                /* no signature happened, no need to update */
1.1       markus    583:                ret = 0;
                    584:                goto done;
                    585:        } else if (idx != state->idx + 1) {
1.10      millert   586:                PRINT("more than one signature happened: idx %u state %u",
1.11    ! djm       587:                    idx, state->idx);
1.1       markus    588:                goto done;
                    589:        }
                    590:        state->idx = idx;
                    591:        if ((filename = k->xmss_filename) == NULL)
                    592:                goto done;
1.4       deraadt   593:        if (asprintf(&statefile, "%s.state", filename) == -1 ||
                    594:            asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
                    595:            asprintf(&nstatefile, "%s.nstate", filename) == -1) {
1.1       markus    596:                ret = SSH_ERR_ALLOC_FAIL;
                    597:                goto done;
                    598:        }
                    599:        unlink(nstatefile);
                    600:        if ((b = sshbuf_new()) == NULL) {
                    601:                ret = SSH_ERR_ALLOC_FAIL;
                    602:                goto done;
                    603:        }
                    604:        if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
1.10      millert   605:                PRINT("SERLIALIZE FAILED: %d", ret);
1.1       markus    606:                goto done;
                    607:        }
                    608:        if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
1.10      millert   609:                PRINT("ENCRYPT FAILED: %d", ret);
1.1       markus    610:                goto done;
                    611:        }
1.5       deraadt   612:        if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
1.1       markus    613:                ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   614:                PRINT("open new state file: %s", nstatefile);
1.1       markus    615:                goto done;
                    616:        }
                    617:        POKE_U32(buf, sshbuf_len(enc));
                    618:        if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
                    619:                ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   620:                PRINT("write new state file hdr: %s", nstatefile);
1.1       markus    621:                close(fd);
                    622:                goto done;
                    623:        }
1.3       markus    624:        if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
1.1       markus    625:            sshbuf_len(enc)) {
                    626:                ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   627:                PRINT("write new state file data: %s", nstatefile);
1.1       markus    628:                close(fd);
                    629:                goto done;
                    630:        }
1.5       deraadt   631:        if (fsync(fd) == -1) {
1.1       markus    632:                ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   633:                PRINT("sync new state file: %s", nstatefile);
1.1       markus    634:                close(fd);
                    635:                goto done;
                    636:        }
1.5       deraadt   637:        if (close(fd) == -1) {
1.1       markus    638:                ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   639:                PRINT("close new state file: %s", nstatefile);
1.1       markus    640:                goto done;
                    641:        }
                    642:        if (state->have_state) {
                    643:                unlink(ostatefile);
                    644:                if (link(statefile, ostatefile)) {
                    645:                        ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   646:                        PRINT("backup state %s to %s", statefile, ostatefile);
1.1       markus    647:                        goto done;
                    648:                }
                    649:        }
1.5       deraadt   650:        if (rename(nstatefile, statefile) == -1) {
1.1       markus    651:                ret = SSH_ERR_SYSTEM_ERROR;
1.10      millert   652:                PRINT("rename %s to %s", nstatefile, statefile);
1.1       markus    653:                goto done;
                    654:        }
                    655:        ret = 0;
                    656: done:
                    657:        if (state->lockfd != -1) {
                    658:                close(state->lockfd);
                    659:                state->lockfd = -1;
                    660:        }
                    661:        if (nstatefile)
                    662:                unlink(nstatefile);
                    663:        free(statefile);
                    664:        free(ostatefile);
                    665:        free(nstatefile);
                    666:        sshbuf_free(b);
                    667:        sshbuf_free(enc);
                    668:        return ret;
                    669: }
                    670:
                    671: int
                    672: sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
                    673: {
                    674:        struct ssh_xmss_state *state = k->xmss_state;
                    675:        treehash_inst *th;
                    676:        u_int32_t i, node;
                    677:        int r;
                    678:
                    679:        if (state == NULL)
                    680:                return SSH_ERR_INVALID_ARGUMENT;
                    681:        if (state->stack == NULL)
                    682:                return SSH_ERR_INVALID_ARGUMENT;
                    683:        state->stackoffset = state->bds.stackoffset;    /* copy back */
                    684:        if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
                    685:            (r = sshbuf_put_u32(b, state->idx)) != 0 ||
                    686:            (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
                    687:            (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
                    688:            (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
                    689:            (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
                    690:            (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
                    691:            (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
                    692:            (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
                    693:            (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
                    694:                return r;
                    695:        for (i = 0; i < num_treehash(state); i++) {
                    696:                th = &state->treehash[i];
                    697:                node = th->node - state->th_nodes;
                    698:                if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
                    699:                    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
                    700:                    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
                    701:                    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
                    702:                    (r = sshbuf_put_u32(b, node)) != 0)
                    703:                        return r;
                    704:        }
                    705:        return 0;
                    706: }
                    707:
                    708: int
                    709: sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
                    710:     enum sshkey_serialize_rep opts)
                    711: {
                    712:        struct ssh_xmss_state *state = k->xmss_state;
                    713:        int r = SSH_ERR_INVALID_ARGUMENT;
1.8       markus    714:        u_char have_stack, have_filename, have_enc;
1.1       markus    715:
                    716:        if (state == NULL)
                    717:                return SSH_ERR_INVALID_ARGUMENT;
                    718:        if ((r = sshbuf_put_u8(b, opts)) != 0)
                    719:                return r;
                    720:        switch (opts) {
                    721:        case SSHKEY_SERIALIZE_STATE:
                    722:                r = sshkey_xmss_serialize_state(k, b);
                    723:                break;
                    724:        case SSHKEY_SERIALIZE_FULL:
                    725:                if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
1.8       markus    726:                        return r;
1.1       markus    727:                r = sshkey_xmss_serialize_state(k, b);
                    728:                break;
1.8       markus    729:        case SSHKEY_SERIALIZE_SHIELD:
                    730:                /* all of stack/filename/enc are optional */
                    731:                have_stack = state->stack != NULL;
                    732:                if ((r = sshbuf_put_u8(b, have_stack)) != 0)
                    733:                        return r;
                    734:                if (have_stack) {
                    735:                        state->idx = PEEK_U32(k->xmss_sk);      /* update */
                    736:                        if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
                    737:                                return r;
                    738:                }
                    739:                have_filename = k->xmss_filename != NULL;
                    740:                if ((r = sshbuf_put_u8(b, have_filename)) != 0)
                    741:                        return r;
                    742:                if (have_filename &&
                    743:                    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
                    744:                        return r;
                    745:                have_enc = state->enc_keyiv != NULL;
                    746:                if ((r = sshbuf_put_u8(b, have_enc)) != 0)
                    747:                        return r;
                    748:                if (have_enc &&
                    749:                    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
                    750:                        return r;
                    751:                if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
                    752:                    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
                    753:                        return r;
                    754:                break;
1.1       markus    755:        case SSHKEY_SERIALIZE_DEFAULT:
                    756:                r = 0;
                    757:                break;
                    758:        default:
                    759:                r = SSH_ERR_INVALID_ARGUMENT;
                    760:                break;
                    761:        }
                    762:        return r;
                    763: }
                    764:
                    765: int
                    766: sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
                    767: {
                    768:        struct ssh_xmss_state *state = k->xmss_state;
                    769:        treehash_inst *th;
                    770:        u_int32_t i, lh, node;
                    771:        size_t ls, lsl, la, lk, ln, lr;
                    772:        char *magic;
1.7       djm       773:        int r = SSH_ERR_INTERNAL_ERROR;
1.1       markus    774:
                    775:        if (state == NULL)
                    776:                return SSH_ERR_INVALID_ARGUMENT;
                    777:        if (k->xmss_sk == NULL)
                    778:                return SSH_ERR_INVALID_ARGUMENT;
                    779:        if ((state->treehash = calloc(num_treehash(state),
                    780:            sizeof(treehash_inst))) == NULL)
                    781:                return SSH_ERR_ALLOC_FAIL;
                    782:        if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
                    783:            (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
                    784:            (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
                    785:            (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
                    786:            (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
                    787:            (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
                    788:            (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
                    789:            (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
                    790:            (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
                    791:            (r = sshbuf_get_u32(b, &lh)) != 0)
1.7       djm       792:                goto out;
                    793:        if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
                    794:                r = SSH_ERR_INVALID_ARGUMENT;
                    795:                goto out;
                    796:        }
1.1       markus    797:        /* XXX check stackoffset */
                    798:        if (ls != num_stack(state) ||
                    799:            lsl != num_stacklevels(state) ||
                    800:            la != num_auth(state) ||
                    801:            lk != num_keep(state) ||
                    802:            ln != num_th_nodes(state) ||
                    803:            lr != num_retain(state) ||
1.7       djm       804:            lh != num_treehash(state)) {
                    805:                r = SSH_ERR_INVALID_ARGUMENT;
                    806:                goto out;
                    807:        }
1.1       markus    808:        for (i = 0; i < num_treehash(state); i++) {
                    809:                th = &state->treehash[i];
                    810:                if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
                    811:                    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
                    812:                    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
                    813:                    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
                    814:                    (r = sshbuf_get_u32(b, &node)) != 0)
1.7       djm       815:                        goto out;
1.1       markus    816:                if (node < num_th_nodes(state))
                    817:                        th->node = &state->th_nodes[node];
                    818:        }
                    819:        POKE_U32(k->xmss_sk, state->idx);
                    820:        xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
                    821:            state->stacklevels, state->auth, state->keep, state->treehash,
                    822:            state->retain, 0);
1.7       djm       823:        /* success */
                    824:        r = 0;
                    825:  out:
                    826:        free(magic);
                    827:        return r;
1.1       markus    828: }
                    829:
                    830: int
                    831: sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
                    832: {
1.8       markus    833:        struct ssh_xmss_state *state = k->xmss_state;
1.1       markus    834:        enum sshkey_serialize_rep opts;
1.8       markus    835:        u_char have_state, have_stack, have_filename, have_enc;
1.1       markus    836:        int r;
                    837:
                    838:        if ((r = sshbuf_get_u8(b, &have_state)) != 0)
                    839:                return r;
                    840:
                    841:        opts = have_state;
                    842:        switch (opts) {
                    843:        case SSHKEY_SERIALIZE_DEFAULT:
                    844:                r = 0;
1.8       markus    845:                break;
                    846:        case SSHKEY_SERIALIZE_SHIELD:
                    847:                if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
                    848:                        return r;
                    849:                if (have_stack &&
                    850:                    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
                    851:                        return r;
                    852:                if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
                    853:                        return r;
                    854:                if (have_filename &&
                    855:                    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
                    856:                        return r;
                    857:                if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
                    858:                        return r;
                    859:                if (have_enc &&
                    860:                    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
                    861:                        return r;
                    862:                if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
                    863:                    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
                    864:                        return r;
1.1       markus    865:                break;
                    866:        case SSHKEY_SERIALIZE_STATE:
                    867:                if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
                    868:                        return r;
                    869:                break;
                    870:        case SSHKEY_SERIALIZE_FULL:
                    871:                if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
                    872:                    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
                    873:                        return r;
                    874:                break;
                    875:        default:
                    876:                r = SSH_ERR_INVALID_FORMAT;
                    877:                break;
                    878:        }
                    879:        return r;
                    880: }
                    881:
                    882: int
                    883: sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
                    884:    struct sshbuf **retp)
                    885: {
                    886:        struct ssh_xmss_state *state = k->xmss_state;
                    887:        struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
                    888:        struct sshcipher_ctx *ciphercontext = NULL;
                    889:        const struct sshcipher *cipher;
                    890:        u_char *cp, *key, *iv = NULL;
                    891:        size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
                    892:        int r = SSH_ERR_INTERNAL_ERROR;
                    893:
                    894:        if (retp != NULL)
                    895:                *retp = NULL;
                    896:        if (state == NULL ||
                    897:            state->enc_keyiv == NULL ||
                    898:            state->enc_ciphername == NULL)
                    899:                return SSH_ERR_INTERNAL_ERROR;
                    900:        if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
                    901:                r = SSH_ERR_INTERNAL_ERROR;
                    902:                goto out;
                    903:        }
                    904:        blocksize = cipher_blocksize(cipher);
                    905:        keylen = cipher_keylen(cipher);
                    906:        ivlen = cipher_ivlen(cipher);
                    907:        authlen = cipher_authlen(cipher);
                    908:        if (state->enc_keyiv_len != keylen + ivlen) {
                    909:                r = SSH_ERR_INVALID_FORMAT;
                    910:                goto out;
                    911:        }
                    912:        key = state->enc_keyiv;
                    913:        if ((encrypted = sshbuf_new()) == NULL ||
                    914:            (encoded = sshbuf_new()) == NULL ||
                    915:            (padded = sshbuf_new()) == NULL ||
                    916:            (iv = malloc(ivlen)) == NULL) {
                    917:                r = SSH_ERR_ALLOC_FAIL;
                    918:                goto out;
                    919:        }
                    920:
                    921:        /* replace first 4 bytes of IV with index to ensure uniqueness */
                    922:        memcpy(iv, key + keylen, ivlen);
                    923:        POKE_U32(iv, state->idx);
                    924:
                    925:        if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
                    926:            (r = sshbuf_put_u32(encoded, state->idx)) != 0)
                    927:                goto out;
                    928:
                    929:        /* padded state will be encrypted */
                    930:        if ((r = sshbuf_putb(padded, b)) != 0)
                    931:                goto out;
                    932:        i = 0;
                    933:        while (sshbuf_len(padded) % blocksize) {
                    934:                if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
                    935:                        goto out;
                    936:        }
                    937:        encrypted_len = sshbuf_len(padded);
                    938:
                    939:        /* header including the length of state is used as AAD */
                    940:        if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
                    941:                goto out;
                    942:        aadlen = sshbuf_len(encoded);
                    943:
                    944:        /* concat header and state */
                    945:        if ((r = sshbuf_putb(encoded, padded)) != 0)
                    946:                goto out;
                    947:
                    948:        /* reserve space for encryption of encoded data plus auth tag */
                    949:        /* encrypt at offset addlen */
                    950:        if ((r = sshbuf_reserve(encrypted,
                    951:            encrypted_len + aadlen + authlen, &cp)) != 0 ||
                    952:            (r = cipher_init(&ciphercontext, cipher, key, keylen,
                    953:            iv, ivlen, 1)) != 0 ||
                    954:            (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
                    955:            encrypted_len, aadlen, authlen)) != 0)
                    956:                goto out;
                    957:
                    958:        /* success */
                    959:        r = 0;
                    960:  out:
                    961:        if (retp != NULL) {
                    962:                *retp = encrypted;
                    963:                encrypted = NULL;
                    964:        }
                    965:        sshbuf_free(padded);
                    966:        sshbuf_free(encoded);
                    967:        sshbuf_free(encrypted);
                    968:        cipher_free(ciphercontext);
                    969:        free(iv);
                    970:        return r;
                    971: }
                    972:
                    973: int
                    974: sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
                    975:    struct sshbuf **retp)
                    976: {
                    977:        struct ssh_xmss_state *state = k->xmss_state;
                    978:        struct sshbuf *copy = NULL, *decrypted = NULL;
                    979:        struct sshcipher_ctx *ciphercontext = NULL;
                    980:        const struct sshcipher *cipher = NULL;
                    981:        u_char *key, *iv = NULL, *dp;
                    982:        size_t keylen, ivlen, authlen, aadlen;
                    983:        u_int blocksize, encrypted_len, index;
                    984:        int r = SSH_ERR_INTERNAL_ERROR;
                    985:
                    986:        if (retp != NULL)
                    987:                *retp = NULL;
                    988:        if (state == NULL ||
                    989:            state->enc_keyiv == NULL ||
                    990:            state->enc_ciphername == NULL)
                    991:                return SSH_ERR_INTERNAL_ERROR;
                    992:        if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
                    993:                r = SSH_ERR_INVALID_FORMAT;
                    994:                goto out;
                    995:        }
                    996:        blocksize = cipher_blocksize(cipher);
                    997:        keylen = cipher_keylen(cipher);
                    998:        ivlen = cipher_ivlen(cipher);
                    999:        authlen = cipher_authlen(cipher);
                   1000:        if (state->enc_keyiv_len != keylen + ivlen) {
                   1001:                r = SSH_ERR_INTERNAL_ERROR;
                   1002:                goto out;
                   1003:        }
                   1004:        key = state->enc_keyiv;
                   1005:
                   1006:        if ((copy = sshbuf_fromb(encoded)) == NULL ||
                   1007:            (decrypted = sshbuf_new()) == NULL ||
                   1008:            (iv = malloc(ivlen)) == NULL) {
                   1009:                r = SSH_ERR_ALLOC_FAIL;
                   1010:                goto out;
                   1011:        }
                   1012:
                   1013:        /* check magic */
                   1014:        if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
                   1015:            memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
                   1016:                r = SSH_ERR_INVALID_FORMAT;
                   1017:                goto out;
                   1018:        }
                   1019:        /* parse public portion */
                   1020:        if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
                   1021:            (r = sshbuf_get_u32(encoded, &index)) != 0 ||
                   1022:            (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
                   1023:                goto out;
                   1024:
                   1025:        /* check size of encrypted key blob */
                   1026:        if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
                   1027:                r = SSH_ERR_INVALID_FORMAT;
                   1028:                goto out;
                   1029:        }
                   1030:        /* check that an appropriate amount of auth data is present */
1.6       djm      1031:        if (sshbuf_len(encoded) < authlen ||
                   1032:            sshbuf_len(encoded) - authlen < encrypted_len) {
1.1       markus   1033:                r = SSH_ERR_INVALID_FORMAT;
                   1034:                goto out;
                   1035:        }
                   1036:
                   1037:        aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
                   1038:
                   1039:        /* replace first 4 bytes of IV with index to ensure uniqueness */
                   1040:        memcpy(iv, key + keylen, ivlen);
                   1041:        POKE_U32(iv, index);
                   1042:
                   1043:        /* decrypt private state of key */
                   1044:        if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
                   1045:            (r = cipher_init(&ciphercontext, cipher, key, keylen,
                   1046:            iv, ivlen, 0)) != 0 ||
                   1047:            (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
                   1048:            encrypted_len, aadlen, authlen)) != 0)
                   1049:                goto out;
                   1050:
                   1051:        /* there should be no trailing data */
                   1052:        if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
                   1053:                goto out;
                   1054:        if (sshbuf_len(encoded) != 0) {
                   1055:                r = SSH_ERR_INVALID_FORMAT;
                   1056:                goto out;
                   1057:        }
                   1058:
                   1059:        /* remove AAD */
                   1060:        if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
                   1061:                goto out;
                   1062:        /* XXX encrypted includes unchecked padding */
                   1063:
                   1064:        /* success */
                   1065:        r = 0;
                   1066:        if (retp != NULL) {
                   1067:                *retp = decrypted;
                   1068:                decrypted = NULL;
                   1069:        }
                   1070:  out:
                   1071:        cipher_free(ciphercontext);
                   1072:        sshbuf_free(copy);
                   1073:        sshbuf_free(decrypted);
                   1074:        free(iv);
                   1075:        return r;
                   1076: }
                   1077:
                   1078: u_int32_t
                   1079: sshkey_xmss_signatures_left(const struct sshkey *k)
                   1080: {
                   1081:        struct ssh_xmss_state *state = k->xmss_state;
                   1082:        u_int32_t idx;
                   1083:
                   1084:        if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
                   1085:            state->maxidx) {
                   1086:                idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
                   1087:                if (idx < state->maxidx)
                   1088:                        return state->maxidx - idx;
                   1089:        }
                   1090:        return 0;
                   1091: }
                   1092:
                   1093: int
                   1094: sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
                   1095: {
                   1096:        struct ssh_xmss_state *state = k->xmss_state;
                   1097:
                   1098:        if (sshkey_type_plain(k->type) != KEY_XMSS)
                   1099:                return SSH_ERR_INVALID_ARGUMENT;
                   1100:        if (maxsign == 0)
                   1101:                return 0;
                   1102:        if (state->idx + maxsign < state->idx)
                   1103:                return SSH_ERR_INVALID_ARGUMENT;
                   1104:        state->maxidx = state->idx + maxsign;
                   1105:        return 0;
                   1106: }