[BACK]Return to sshkey-xmss.c CVS log [TXT][DIR] Up to [local] / src / usr.bin / ssh

Annotation of src/usr.bin/ssh/sshkey-xmss.c, Revision 1.8

1.8     ! markus      1: /* $OpenBSD: sshkey-xmss.c,v 1.7 2019/10/14 06:00:02 djm Exp $ */
1.1       markus      2: /*
                      3:  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
                      4:  *
                      5:  * Redistribution and use in source and binary forms, with or without
                      6:  * modification, are permitted provided that the following conditions
                      7:  * are met:
                      8:  * 1. Redistributions of source code must retain the above copyright
                      9:  *    notice, this list of conditions and the following disclaimer.
                     10:  * 2. Redistributions in binary form must reproduce the above copyright
                     11:  *    notice, this list of conditions and the following disclaimer in the
                     12:  *    documentation and/or other materials provided with the distribution.
                     13:  *
                     14:  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
                     15:  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
                     16:  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
                     17:  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
                     18:  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
                     19:  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
                     20:  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
                     21:  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
                     22:  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
                     23:  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
                     24:  */
                     25:
                     26: #include <sys/types.h>
                     27: #include <sys/uio.h>
                     28:
                     29: #include <stdio.h>
                     30: #include <string.h>
                     31: #include <unistd.h>
                     32: #include <fcntl.h>
                     33: #include <errno.h>
                     34:
                     35: #include "ssh2.h"
                     36: #include "ssherr.h"
                     37: #include "sshbuf.h"
                     38: #include "cipher.h"
                     39: #include "sshkey.h"
                     40: #include "sshkey-xmss.h"
                     41: #include "atomicio.h"
                     42:
                     43: #include "xmss_fast.h"
                     44:
                     45: /* opaque internal XMSS state */
                     46: #define XMSS_MAGIC             "xmss-state-v1"
                     47: #define XMSS_CIPHERNAME                "aes256-gcm@openssh.com"
                     48: struct ssh_xmss_state {
                     49:        xmss_params     params;
                     50:        u_int32_t       n, w, h, k;
                     51:
                     52:        bds_state       bds;
                     53:        u_char          *stack;
                     54:        u_int32_t       stackoffset;
                     55:        u_char          *stacklevels;
                     56:        u_char          *auth;
                     57:        u_char          *keep;
                     58:        u_char          *th_nodes;
                     59:        u_char          *retain;
                     60:        treehash_inst   *treehash;
                     61:
                     62:        u_int32_t       idx;            /* state read from file */
1.2       djm        63:        u_int32_t       maxidx;         /* restricted # of signatures */
1.1       markus     64:        int             have_state;     /* .state file exists */
                     65:        int             lockfd;         /* locked in sshkey_xmss_get_state() */
1.8     ! markus     66:        u_char          allow_update;   /* allow sshkey_xmss_update_state() */
1.1       markus     67:        char            *enc_ciphername;/* encrypt state with cipher */
                     68:        u_char          *enc_keyiv;     /* encrypt state with key */
                     69:        u_int32_t       enc_keyiv_len;  /* length of enc_keyiv */
                     70: };
                     71:
                     72: int     sshkey_xmss_init_bds_state(struct sshkey *);
                     73: int     sshkey_xmss_init_enc_key(struct sshkey *, const char *);
                     74: void    sshkey_xmss_free_bds(struct sshkey *);
                     75: int     sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
                     76:            int *, sshkey_printfn *);
                     77: int     sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
                     78:            struct sshbuf **);
                     79: int     sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
                     80:            struct sshbuf **);
                     81: int     sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
                     82: int     sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
                     83:
                     84: #define PRINT(s...) do { if (pr) pr(s); } while (0)
                     85:
                     86: int
                     87: sshkey_xmss_init(struct sshkey *key, const char *name)
                     88: {
                     89:        struct ssh_xmss_state *state;
                     90:
                     91:        if (key->xmss_state != NULL)
                     92:                return SSH_ERR_INVALID_FORMAT;
                     93:        if (name == NULL)
                     94:                return SSH_ERR_INVALID_FORMAT;
                     95:        state = calloc(sizeof(struct ssh_xmss_state), 1);
                     96:        if (state == NULL)
                     97:                return SSH_ERR_ALLOC_FAIL;
                     98:        if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
                     99:                state->n = 32;
                    100:                state->w = 16;
                    101:                state->h = 10;
                    102:        } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
                    103:                state->n = 32;
                    104:                state->w = 16;
                    105:                state->h = 16;
                    106:        } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
                    107:                state->n = 32;
                    108:                state->w = 16;
                    109:                state->h = 20;
                    110:        } else {
                    111:                free(state);
                    112:                return SSH_ERR_KEY_TYPE_UNKNOWN;
                    113:        }
                    114:        if ((key->xmss_name = strdup(name)) == NULL) {
                    115:                free(state);
                    116:                return SSH_ERR_ALLOC_FAIL;
                    117:        }
                    118:        state->k = 2;   /* XXX hardcoded */
                    119:        state->lockfd = -1;
                    120:        if (xmss_set_params(&state->params, state->n, state->h, state->w,
                    121:            state->k) != 0) {
                    122:                free(state);
                    123:                return SSH_ERR_INVALID_FORMAT;
                    124:        }
                    125:        key->xmss_state = state;
                    126:        return 0;
                    127: }
                    128:
                    129: void
                    130: sshkey_xmss_free_state(struct sshkey *key)
                    131: {
                    132:        struct ssh_xmss_state *state = key->xmss_state;
                    133:
                    134:        sshkey_xmss_free_bds(key);
                    135:        if (state) {
                    136:                if (state->enc_keyiv) {
                    137:                        explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
                    138:                        free(state->enc_keyiv);
                    139:                }
                    140:                free(state->enc_ciphername);
                    141:                free(state);
                    142:        }
                    143:        key->xmss_state = NULL;
                    144: }
                    145:
                    146: #define SSH_XMSS_K2_MAGIC      "k=2"
                    147: #define num_stack(x)           ((x->h+1)*(x->n))
                    148: #define num_stacklevels(x)     (x->h+1)
                    149: #define num_auth(x)            ((x->h)*(x->n))
                    150: #define num_keep(x)            ((x->h >> 1)*(x->n))
                    151: #define num_th_nodes(x)                ((x->h - x->k)*(x->n))
                    152: #define num_retain(x)          (((1ULL << x->k) - x->k - 1) * (x->n))
                    153: #define num_treehash(x)                ((x->h) - (x->k))
                    154:
                    155: int
                    156: sshkey_xmss_init_bds_state(struct sshkey *key)
                    157: {
                    158:        struct ssh_xmss_state *state = key->xmss_state;
                    159:        u_int32_t i;
                    160:
                    161:        state->stackoffset = 0;
                    162:        if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
                    163:            (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
                    164:            (state->auth = calloc(num_auth(state), 1)) == NULL ||
                    165:            (state->keep = calloc(num_keep(state), 1)) == NULL ||
                    166:            (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
                    167:            (state->retain = calloc(num_retain(state), 1)) == NULL ||
                    168:            (state->treehash = calloc(num_treehash(state),
                    169:            sizeof(treehash_inst))) == NULL) {
                    170:                sshkey_xmss_free_bds(key);
                    171:                return SSH_ERR_ALLOC_FAIL;
                    172:        }
                    173:        for (i = 0; i < state->h - state->k; i++)
                    174:                state->treehash[i].node = &state->th_nodes[state->n*i];
                    175:        xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
                    176:            state->stacklevels, state->auth, state->keep, state->treehash,
                    177:            state->retain, 0);
                    178:        return 0;
                    179: }
                    180:
                    181: void
                    182: sshkey_xmss_free_bds(struct sshkey *key)
                    183: {
                    184:        struct ssh_xmss_state *state = key->xmss_state;
                    185:
                    186:        if (state == NULL)
                    187:                return;
                    188:        free(state->stack);
                    189:        free(state->stacklevels);
                    190:        free(state->auth);
                    191:        free(state->keep);
                    192:        free(state->th_nodes);
                    193:        free(state->retain);
                    194:        free(state->treehash);
                    195:        state->stack = NULL;
                    196:        state->stacklevels = NULL;
                    197:        state->auth = NULL;
                    198:        state->keep = NULL;
                    199:        state->th_nodes = NULL;
                    200:        state->retain = NULL;
                    201:        state->treehash = NULL;
                    202: }
                    203:
                    204: void *
                    205: sshkey_xmss_params(const struct sshkey *key)
                    206: {
                    207:        struct ssh_xmss_state *state = key->xmss_state;
                    208:
                    209:        if (state == NULL)
                    210:                return NULL;
                    211:        return &state->params;
                    212: }
                    213:
                    214: void *
                    215: sshkey_xmss_bds_state(const struct sshkey *key)
                    216: {
                    217:        struct ssh_xmss_state *state = key->xmss_state;
                    218:
                    219:        if (state == NULL)
                    220:                return NULL;
                    221:        return &state->bds;
                    222: }
                    223:
                    224: int
                    225: sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
                    226: {
                    227:        struct ssh_xmss_state *state = key->xmss_state;
                    228:
                    229:        if (lenp == NULL)
                    230:                return SSH_ERR_INVALID_ARGUMENT;
                    231:        if (state == NULL)
                    232:                return SSH_ERR_INVALID_FORMAT;
                    233:        *lenp = 4 + state->n +
                    234:            state->params.wots_par.keysize +
                    235:            state->h * state->n;
                    236:        return 0;
                    237: }
                    238:
                    239: size_t
                    240: sshkey_xmss_pklen(const struct sshkey *key)
                    241: {
                    242:        struct ssh_xmss_state *state = key->xmss_state;
                    243:
                    244:        if (state == NULL)
                    245:                return 0;
                    246:        return state->n * 2;
                    247: }
                    248:
                    249: size_t
                    250: sshkey_xmss_sklen(const struct sshkey *key)
                    251: {
                    252:        struct ssh_xmss_state *state = key->xmss_state;
                    253:
                    254:        if (state == NULL)
                    255:                return 0;
                    256:        return state->n * 4 + 4;
                    257: }
                    258:
                    259: int
                    260: sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
                    261: {
                    262:        struct ssh_xmss_state *state = k->xmss_state;
                    263:        const struct sshcipher *cipher;
                    264:        size_t keylen = 0, ivlen = 0;
                    265:
                    266:        if (state == NULL)
                    267:                return SSH_ERR_INVALID_ARGUMENT;
                    268:        if ((cipher = cipher_by_name(ciphername)) == NULL)
                    269:                return SSH_ERR_INTERNAL_ERROR;
                    270:        if ((state->enc_ciphername = strdup(ciphername)) == NULL)
                    271:                return SSH_ERR_ALLOC_FAIL;
                    272:        keylen = cipher_keylen(cipher);
                    273:        ivlen = cipher_ivlen(cipher);
                    274:        state->enc_keyiv_len = keylen + ivlen;
                    275:        if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
                    276:                free(state->enc_ciphername);
                    277:                state->enc_ciphername = NULL;
                    278:                return SSH_ERR_ALLOC_FAIL;
                    279:        }
                    280:        arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
                    281:        return 0;
                    282: }
                    283:
                    284: int
                    285: sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
                    286: {
                    287:        struct ssh_xmss_state *state = k->xmss_state;
                    288:        int r;
                    289:
                    290:        if (state == NULL || state->enc_keyiv == NULL ||
                    291:            state->enc_ciphername == NULL)
                    292:                return SSH_ERR_INVALID_ARGUMENT;
                    293:        if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
                    294:            (r = sshbuf_put_string(b, state->enc_keyiv,
                    295:            state->enc_keyiv_len)) != 0)
                    296:                return r;
                    297:        return 0;
                    298: }
                    299:
                    300: int
                    301: sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
                    302: {
                    303:        struct ssh_xmss_state *state = k->xmss_state;
                    304:        size_t len;
                    305:        int r;
                    306:
                    307:        if (state == NULL)
                    308:                return SSH_ERR_INVALID_ARGUMENT;
                    309:        if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
                    310:            (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
                    311:                return r;
                    312:        state->enc_keyiv_len = len;
                    313:        return 0;
                    314: }
                    315:
                    316: int
                    317: sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
                    318:     enum sshkey_serialize_rep opts)
                    319: {
                    320:        struct ssh_xmss_state *state = k->xmss_state;
                    321:        u_char have_info = 1;
                    322:        u_int32_t idx;
                    323:        int r;
                    324:
                    325:        if (state == NULL)
                    326:                return SSH_ERR_INVALID_ARGUMENT;
                    327:        if (opts != SSHKEY_SERIALIZE_INFO)
                    328:                return 0;
                    329:        idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
                    330:        if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
                    331:            (r = sshbuf_put_u32(b, idx)) != 0 ||
                    332:            (r = sshbuf_put_u32(b, state->maxidx)) != 0)
                    333:                return r;
                    334:        return 0;
                    335: }
                    336:
                    337: int
                    338: sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
                    339: {
                    340:        struct ssh_xmss_state *state = k->xmss_state;
                    341:        u_char have_info;
                    342:        int r;
                    343:
                    344:        if (state == NULL)
                    345:                return SSH_ERR_INVALID_ARGUMENT;
                    346:        /* optional */
                    347:        if (sshbuf_len(b) == 0)
                    348:                return 0;
                    349:        if ((r = sshbuf_get_u8(b, &have_info)) != 0)
                    350:                return r;
                    351:        if (have_info != 1)
                    352:                return SSH_ERR_INVALID_ARGUMENT;
                    353:        if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
                    354:            (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
                    355:                return r;
                    356:        return 0;
                    357: }
                    358:
                    359: int
                    360: sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
                    361: {
                    362:        int r;
                    363:        const char *name;
                    364:
                    365:        if (bits == 10) {
                    366:                name = XMSS_SHA2_256_W16_H10_NAME;
                    367:        } else if (bits == 16) {
                    368:                name = XMSS_SHA2_256_W16_H16_NAME;
                    369:        } else if (bits == 20) {
                    370:                name = XMSS_SHA2_256_W16_H20_NAME;
                    371:        } else {
                    372:                name = XMSS_DEFAULT_NAME;
                    373:        }
                    374:        if ((r = sshkey_xmss_init(k, name)) != 0 ||
                    375:            (r = sshkey_xmss_init_bds_state(k)) != 0 ||
                    376:            (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
                    377:                return r;
                    378:        if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
                    379:            (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
                    380:                return SSH_ERR_ALLOC_FAIL;
                    381:        }
                    382:        xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
                    383:            sshkey_xmss_params(k));
                    384:        return 0;
                    385: }
                    386:
                    387: int
                    388: sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
                    389:     int *have_file, sshkey_printfn *pr)
                    390: {
                    391:        struct sshbuf *b = NULL, *enc = NULL;
                    392:        int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
                    393:        u_int32_t len;
                    394:        unsigned char buf[4], *data = NULL;
                    395:
                    396:        *have_file = 0;
                    397:        if ((fd = open(filename, O_RDONLY)) >= 0) {
                    398:                *have_file = 1;
                    399:                if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
                    400:                        PRINT("%s: corrupt state file: %s", __func__, filename);
                    401:                        goto done;
                    402:                }
                    403:                len = PEEK_U32(buf);
                    404:                if ((data = calloc(len, 1)) == NULL) {
                    405:                        ret = SSH_ERR_ALLOC_FAIL;
                    406:                        goto done;
                    407:                }
                    408:                if (atomicio(read, fd, data, len) != len) {
                    409:                        PRINT("%s: cannot read blob: %s", __func__, filename);
                    410:                        goto done;
                    411:                }
                    412:                if ((enc = sshbuf_from(data, len)) == NULL) {
                    413:                        ret = SSH_ERR_ALLOC_FAIL;
                    414:                        goto done;
                    415:                }
                    416:                sshkey_xmss_free_bds(k);
                    417:                if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
                    418:                        ret = r;
                    419:                        goto done;
                    420:                }
                    421:                if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
                    422:                        ret = r;
                    423:                        goto done;
                    424:                }
                    425:                ret = 0;
                    426:        }
                    427: done:
                    428:        if (fd != -1)
                    429:                close(fd);
                    430:        free(data);
                    431:        sshbuf_free(enc);
                    432:        sshbuf_free(b);
                    433:        return ret;
                    434: }
                    435:
                    436: int
                    437: sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr)
                    438: {
                    439:        struct ssh_xmss_state *state = k->xmss_state;
                    440:        u_int32_t idx = 0;
                    441:        char *filename = NULL;
                    442:        char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
                    443:        int lockfd = -1, have_state = 0, have_ostate, tries = 0;
                    444:        int ret = SSH_ERR_INVALID_ARGUMENT, r;
                    445:
                    446:        if (state == NULL)
                    447:                goto done;
                    448:        /*
                    449:         * If maxidx is set, then we are allowed a limited number
                    450:         * of signatures, but don't need to access the disk.
                    451:         * Otherwise we need to deal with the on-disk state.
                    452:         */
                    453:        if (state->maxidx) {
                    454:                /* xmss_sk always contains the current state */
                    455:                idx = PEEK_U32(k->xmss_sk);
                    456:                if (idx < state->maxidx) {
                    457:                        state->allow_update = 1;
                    458:                        return 0;
                    459:                }
                    460:                return SSH_ERR_INVALID_ARGUMENT;
                    461:        }
                    462:        if ((filename = k->xmss_filename) == NULL)
                    463:                goto done;
1.4       deraadt   464:        if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
                    465:            asprintf(&statefile, "%s.state", filename) == -1 ||
                    466:            asprintf(&ostatefile, "%s.ostate", filename) == -1) {
1.1       markus    467:                ret = SSH_ERR_ALLOC_FAIL;
                    468:                goto done;
                    469:        }
1.5       deraadt   470:        if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
1.1       markus    471:                ret = SSH_ERR_SYSTEM_ERROR;
                    472:                PRINT("%s: cannot open/create: %s", __func__, lockfile);
                    473:                goto done;
                    474:        }
1.5       deraadt   475:        while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
1.1       markus    476:                if (errno != EWOULDBLOCK) {
                    477:                        ret = SSH_ERR_SYSTEM_ERROR;
                    478:                        PRINT("%s: cannot lock: %s", __func__, lockfile);
                    479:                        goto done;
                    480:                }
                    481:                if (++tries > 10) {
                    482:                        ret = SSH_ERR_SYSTEM_ERROR;
                    483:                        PRINT("%s: giving up on: %s", __func__, lockfile);
                    484:                        goto done;
                    485:                }
                    486:                usleep(1000*100*tries);
                    487:        }
                    488:        /* XXX no longer const */
                    489:        if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
                    490:            statefile, &have_state, pr)) != 0) {
                    491:                if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
                    492:                    ostatefile, &have_ostate, pr)) == 0) {
                    493:                        state->allow_update = 1;
                    494:                        r = sshkey_xmss_forward_state(k, 1);
                    495:                        state->idx = PEEK_U32(k->xmss_sk);
                    496:                        state->allow_update = 0;
                    497:                }
                    498:        }
                    499:        if (!have_state && !have_ostate) {
                    500:                /* check that bds state is initialized */
                    501:                if (state->bds.auth == NULL)
                    502:                        goto done;
                    503:                PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
                    504:        } else if (r != 0) {
                    505:                ret = r;
                    506:                goto done;
                    507:        }
                    508:        if (state->idx + 1 < state->idx) {
                    509:                PRINT("%s: state wrap: %u", __func__, state->idx);
                    510:                goto done;
                    511:        }
                    512:        state->have_state = have_state;
                    513:        state->lockfd = lockfd;
                    514:        state->allow_update = 1;
                    515:        lockfd = -1;
                    516:        ret = 0;
                    517: done:
                    518:        if (lockfd != -1)
                    519:                close(lockfd);
                    520:        free(lockfile);
                    521:        free(statefile);
                    522:        free(ostatefile);
                    523:        return ret;
                    524: }
                    525:
                    526: int
                    527: sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
                    528: {
                    529:        struct ssh_xmss_state *state = k->xmss_state;
                    530:        u_char *sig = NULL;
                    531:        size_t required_siglen;
                    532:        unsigned long long smlen;
                    533:        u_char data;
                    534:        int ret, r;
                    535:
                    536:        if (state == NULL || !state->allow_update)
                    537:                return SSH_ERR_INVALID_ARGUMENT;
                    538:        if (reserve == 0)
                    539:                return SSH_ERR_INVALID_ARGUMENT;
                    540:        if (state->idx + reserve <= state->idx)
                    541:                return SSH_ERR_INVALID_ARGUMENT;
                    542:        if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
                    543:                return r;
                    544:        if ((sig = malloc(required_siglen)) == NULL)
                    545:                return SSH_ERR_ALLOC_FAIL;
                    546:        while (reserve-- > 0) {
                    547:                state->idx = PEEK_U32(k->xmss_sk);
                    548:                smlen = required_siglen;
                    549:                if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
                    550:                    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
                    551:                        r = SSH_ERR_INVALID_ARGUMENT;
                    552:                        break;
                    553:                }
                    554:        }
                    555:        free(sig);
                    556:        return r;
                    557: }
                    558:
                    559: int
                    560: sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr)
                    561: {
                    562:        struct ssh_xmss_state *state = k->xmss_state;
                    563:        struct sshbuf *b = NULL, *enc = NULL;
                    564:        u_int32_t idx = 0;
                    565:        unsigned char buf[4];
                    566:        char *filename = NULL;
                    567:        char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
                    568:        int fd = -1;
                    569:        int ret = SSH_ERR_INVALID_ARGUMENT;
                    570:
                    571:        if (state == NULL || !state->allow_update)
                    572:                return ret;
                    573:        if (state->maxidx) {
                    574:                /* no update since the number of signatures is limited */
                    575:                ret = 0;
                    576:                goto done;
                    577:        }
                    578:        idx = PEEK_U32(k->xmss_sk);
                    579:        if (idx == state->idx) {
1.2       djm       580:                /* no signature happened, no need to update */
1.1       markus    581:                ret = 0;
                    582:                goto done;
                    583:        } else if (idx != state->idx + 1) {
                    584:                PRINT("%s: more than one signature happened: idx %u state %u",
                    585:                     __func__, idx, state->idx);
                    586:                goto done;
                    587:        }
                    588:        state->idx = idx;
                    589:        if ((filename = k->xmss_filename) == NULL)
                    590:                goto done;
1.4       deraadt   591:        if (asprintf(&statefile, "%s.state", filename) == -1 ||
                    592:            asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
                    593:            asprintf(&nstatefile, "%s.nstate", filename) == -1) {
1.1       markus    594:                ret = SSH_ERR_ALLOC_FAIL;
                    595:                goto done;
                    596:        }
                    597:        unlink(nstatefile);
                    598:        if ((b = sshbuf_new()) == NULL) {
                    599:                ret = SSH_ERR_ALLOC_FAIL;
                    600:                goto done;
                    601:        }
                    602:        if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
                    603:                PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
                    604:                goto done;
                    605:        }
                    606:        if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
                    607:                PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
                    608:                goto done;
                    609:        }
1.5       deraadt   610:        if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
1.1       markus    611:                ret = SSH_ERR_SYSTEM_ERROR;
                    612:                PRINT("%s: open new state file: %s", __func__, nstatefile);
                    613:                goto done;
                    614:        }
                    615:        POKE_U32(buf, sshbuf_len(enc));
                    616:        if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
                    617:                ret = SSH_ERR_SYSTEM_ERROR;
                    618:                PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
                    619:                close(fd);
                    620:                goto done;
                    621:        }
1.3       markus    622:        if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
1.1       markus    623:            sshbuf_len(enc)) {
                    624:                ret = SSH_ERR_SYSTEM_ERROR;
                    625:                PRINT("%s: write new state file data: %s", __func__, nstatefile);
                    626:                close(fd);
                    627:                goto done;
                    628:        }
1.5       deraadt   629:        if (fsync(fd) == -1) {
1.1       markus    630:                ret = SSH_ERR_SYSTEM_ERROR;
                    631:                PRINT("%s: sync new state file: %s", __func__, nstatefile);
                    632:                close(fd);
                    633:                goto done;
                    634:        }
1.5       deraadt   635:        if (close(fd) == -1) {
1.1       markus    636:                ret = SSH_ERR_SYSTEM_ERROR;
                    637:                PRINT("%s: close new state file: %s", __func__, nstatefile);
                    638:                goto done;
                    639:        }
                    640:        if (state->have_state) {
                    641:                unlink(ostatefile);
                    642:                if (link(statefile, ostatefile)) {
                    643:                        ret = SSH_ERR_SYSTEM_ERROR;
                    644:                        PRINT("%s: backup state %s to %s", __func__, statefile,
                    645:                            ostatefile);
                    646:                        goto done;
                    647:                }
                    648:        }
1.5       deraadt   649:        if (rename(nstatefile, statefile) == -1) {
1.1       markus    650:                ret = SSH_ERR_SYSTEM_ERROR;
                    651:                PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
                    652:                goto done;
                    653:        }
                    654:        ret = 0;
                    655: done:
                    656:        if (state->lockfd != -1) {
                    657:                close(state->lockfd);
                    658:                state->lockfd = -1;
                    659:        }
                    660:        if (nstatefile)
                    661:                unlink(nstatefile);
                    662:        free(statefile);
                    663:        free(ostatefile);
                    664:        free(nstatefile);
                    665:        sshbuf_free(b);
                    666:        sshbuf_free(enc);
                    667:        return ret;
                    668: }
                    669:
                    670: int
                    671: sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
                    672: {
                    673:        struct ssh_xmss_state *state = k->xmss_state;
                    674:        treehash_inst *th;
                    675:        u_int32_t i, node;
                    676:        int r;
                    677:
                    678:        if (state == NULL)
                    679:                return SSH_ERR_INVALID_ARGUMENT;
                    680:        if (state->stack == NULL)
                    681:                return SSH_ERR_INVALID_ARGUMENT;
                    682:        state->stackoffset = state->bds.stackoffset;    /* copy back */
                    683:        if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
                    684:            (r = sshbuf_put_u32(b, state->idx)) != 0 ||
                    685:            (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
                    686:            (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
                    687:            (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
                    688:            (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
                    689:            (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
                    690:            (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
                    691:            (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
                    692:            (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
                    693:                return r;
                    694:        for (i = 0; i < num_treehash(state); i++) {
                    695:                th = &state->treehash[i];
                    696:                node = th->node - state->th_nodes;
                    697:                if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
                    698:                    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
                    699:                    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
                    700:                    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
                    701:                    (r = sshbuf_put_u32(b, node)) != 0)
                    702:                        return r;
                    703:        }
                    704:        return 0;
                    705: }
                    706:
                    707: int
                    708: sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
                    709:     enum sshkey_serialize_rep opts)
                    710: {
                    711:        struct ssh_xmss_state *state = k->xmss_state;
                    712:        int r = SSH_ERR_INVALID_ARGUMENT;
1.8     ! markus    713:        u_char have_stack, have_filename, have_enc;
1.1       markus    714:
                    715:        if (state == NULL)
                    716:                return SSH_ERR_INVALID_ARGUMENT;
                    717:        if ((r = sshbuf_put_u8(b, opts)) != 0)
                    718:                return r;
                    719:        switch (opts) {
                    720:        case SSHKEY_SERIALIZE_STATE:
                    721:                r = sshkey_xmss_serialize_state(k, b);
                    722:                break;
                    723:        case SSHKEY_SERIALIZE_FULL:
                    724:                if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
1.8     ! markus    725:                        return r;
1.1       markus    726:                r = sshkey_xmss_serialize_state(k, b);
                    727:                break;
1.8     ! markus    728:        case SSHKEY_SERIALIZE_SHIELD:
        !           729:                /* all of stack/filename/enc are optional */
        !           730:                have_stack = state->stack != NULL;
        !           731:                if ((r = sshbuf_put_u8(b, have_stack)) != 0)
        !           732:                        return r;
        !           733:                if (have_stack) {
        !           734:                        state->idx = PEEK_U32(k->xmss_sk);      /* update */
        !           735:                        if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
        !           736:                                return r;
        !           737:                }
        !           738:                have_filename = k->xmss_filename != NULL;
        !           739:                if ((r = sshbuf_put_u8(b, have_filename)) != 0)
        !           740:                        return r;
        !           741:                if (have_filename &&
        !           742:                    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
        !           743:                        return r;
        !           744:                have_enc = state->enc_keyiv != NULL;
        !           745:                if ((r = sshbuf_put_u8(b, have_enc)) != 0)
        !           746:                        return r;
        !           747:                if (have_enc &&
        !           748:                    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
        !           749:                        return r;
        !           750:                if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
        !           751:                    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
        !           752:                        return r;
        !           753:                break;
1.1       markus    754:        case SSHKEY_SERIALIZE_DEFAULT:
                    755:                r = 0;
                    756:                break;
                    757:        default:
                    758:                r = SSH_ERR_INVALID_ARGUMENT;
                    759:                break;
                    760:        }
                    761:        return r;
                    762: }
                    763:
                    764: int
                    765: sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
                    766: {
                    767:        struct ssh_xmss_state *state = k->xmss_state;
                    768:        treehash_inst *th;
                    769:        u_int32_t i, lh, node;
                    770:        size_t ls, lsl, la, lk, ln, lr;
                    771:        char *magic;
1.7       djm       772:        int r = SSH_ERR_INTERNAL_ERROR;
1.1       markus    773:
                    774:        if (state == NULL)
                    775:                return SSH_ERR_INVALID_ARGUMENT;
                    776:        if (k->xmss_sk == NULL)
                    777:                return SSH_ERR_INVALID_ARGUMENT;
                    778:        if ((state->treehash = calloc(num_treehash(state),
                    779:            sizeof(treehash_inst))) == NULL)
                    780:                return SSH_ERR_ALLOC_FAIL;
                    781:        if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
                    782:            (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
                    783:            (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
                    784:            (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
                    785:            (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
                    786:            (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
                    787:            (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
                    788:            (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
                    789:            (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
                    790:            (r = sshbuf_get_u32(b, &lh)) != 0)
1.7       djm       791:                goto out;
                    792:        if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
                    793:                r = SSH_ERR_INVALID_ARGUMENT;
                    794:                goto out;
                    795:        }
1.1       markus    796:        /* XXX check stackoffset */
                    797:        if (ls != num_stack(state) ||
                    798:            lsl != num_stacklevels(state) ||
                    799:            la != num_auth(state) ||
                    800:            lk != num_keep(state) ||
                    801:            ln != num_th_nodes(state) ||
                    802:            lr != num_retain(state) ||
1.7       djm       803:            lh != num_treehash(state)) {
                    804:                r = SSH_ERR_INVALID_ARGUMENT;
                    805:                goto out;
                    806:        }
1.1       markus    807:        for (i = 0; i < num_treehash(state); i++) {
                    808:                th = &state->treehash[i];
                    809:                if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
                    810:                    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
                    811:                    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
                    812:                    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
                    813:                    (r = sshbuf_get_u32(b, &node)) != 0)
1.7       djm       814:                        goto out;
1.1       markus    815:                if (node < num_th_nodes(state))
                    816:                        th->node = &state->th_nodes[node];
                    817:        }
                    818:        POKE_U32(k->xmss_sk, state->idx);
                    819:        xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
                    820:            state->stacklevels, state->auth, state->keep, state->treehash,
                    821:            state->retain, 0);
1.7       djm       822:        /* success */
                    823:        r = 0;
                    824:  out:
                    825:        free(magic);
                    826:        return r;
1.1       markus    827: }
                    828:
                    829: int
                    830: sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
                    831: {
1.8     ! markus    832:        struct ssh_xmss_state *state = k->xmss_state;
1.1       markus    833:        enum sshkey_serialize_rep opts;
1.8     ! markus    834:        u_char have_state, have_stack, have_filename, have_enc;
1.1       markus    835:        int r;
                    836:
                    837:        if ((r = sshbuf_get_u8(b, &have_state)) != 0)
                    838:                return r;
                    839:
                    840:        opts = have_state;
                    841:        switch (opts) {
                    842:        case SSHKEY_SERIALIZE_DEFAULT:
                    843:                r = 0;
1.8     ! markus    844:                break;
        !           845:        case SSHKEY_SERIALIZE_SHIELD:
        !           846:                if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
        !           847:                        return r;
        !           848:                if (have_stack &&
        !           849:                    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
        !           850:                        return r;
        !           851:                if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
        !           852:                        return r;
        !           853:                if (have_filename &&
        !           854:                    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
        !           855:                        return r;
        !           856:                if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
        !           857:                        return r;
        !           858:                if (have_enc &&
        !           859:                    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
        !           860:                        return r;
        !           861:                if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
        !           862:                    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
        !           863:                        return r;
1.1       markus    864:                break;
                    865:        case SSHKEY_SERIALIZE_STATE:
                    866:                if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
                    867:                        return r;
                    868:                break;
                    869:        case SSHKEY_SERIALIZE_FULL:
                    870:                if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
                    871:                    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
                    872:                        return r;
                    873:                break;
                    874:        default:
                    875:                r = SSH_ERR_INVALID_FORMAT;
                    876:                break;
                    877:        }
                    878:        return r;
                    879: }
                    880:
                    881: int
                    882: sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
                    883:    struct sshbuf **retp)
                    884: {
                    885:        struct ssh_xmss_state *state = k->xmss_state;
                    886:        struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
                    887:        struct sshcipher_ctx *ciphercontext = NULL;
                    888:        const struct sshcipher *cipher;
                    889:        u_char *cp, *key, *iv = NULL;
                    890:        size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
                    891:        int r = SSH_ERR_INTERNAL_ERROR;
                    892:
                    893:        if (retp != NULL)
                    894:                *retp = NULL;
                    895:        if (state == NULL ||
                    896:            state->enc_keyiv == NULL ||
                    897:            state->enc_ciphername == NULL)
                    898:                return SSH_ERR_INTERNAL_ERROR;
                    899:        if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
                    900:                r = SSH_ERR_INTERNAL_ERROR;
                    901:                goto out;
                    902:        }
                    903:        blocksize = cipher_blocksize(cipher);
                    904:        keylen = cipher_keylen(cipher);
                    905:        ivlen = cipher_ivlen(cipher);
                    906:        authlen = cipher_authlen(cipher);
                    907:        if (state->enc_keyiv_len != keylen + ivlen) {
                    908:                r = SSH_ERR_INVALID_FORMAT;
                    909:                goto out;
                    910:        }
                    911:        key = state->enc_keyiv;
                    912:        if ((encrypted = sshbuf_new()) == NULL ||
                    913:            (encoded = sshbuf_new()) == NULL ||
                    914:            (padded = sshbuf_new()) == NULL ||
                    915:            (iv = malloc(ivlen)) == NULL) {
                    916:                r = SSH_ERR_ALLOC_FAIL;
                    917:                goto out;
                    918:        }
                    919:
                    920:        /* replace first 4 bytes of IV with index to ensure uniqueness */
                    921:        memcpy(iv, key + keylen, ivlen);
                    922:        POKE_U32(iv, state->idx);
                    923:
                    924:        if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
                    925:            (r = sshbuf_put_u32(encoded, state->idx)) != 0)
                    926:                goto out;
                    927:
                    928:        /* padded state will be encrypted */
                    929:        if ((r = sshbuf_putb(padded, b)) != 0)
                    930:                goto out;
                    931:        i = 0;
                    932:        while (sshbuf_len(padded) % blocksize) {
                    933:                if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
                    934:                        goto out;
                    935:        }
                    936:        encrypted_len = sshbuf_len(padded);
                    937:
                    938:        /* header including the length of state is used as AAD */
                    939:        if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
                    940:                goto out;
                    941:        aadlen = sshbuf_len(encoded);
                    942:
                    943:        /* concat header and state */
                    944:        if ((r = sshbuf_putb(encoded, padded)) != 0)
                    945:                goto out;
                    946:
                    947:        /* reserve space for encryption of encoded data plus auth tag */
                    948:        /* encrypt at offset addlen */
                    949:        if ((r = sshbuf_reserve(encrypted,
                    950:            encrypted_len + aadlen + authlen, &cp)) != 0 ||
                    951:            (r = cipher_init(&ciphercontext, cipher, key, keylen,
                    952:            iv, ivlen, 1)) != 0 ||
                    953:            (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
                    954:            encrypted_len, aadlen, authlen)) != 0)
                    955:                goto out;
                    956:
                    957:        /* success */
                    958:        r = 0;
                    959:  out:
                    960:        if (retp != NULL) {
                    961:                *retp = encrypted;
                    962:                encrypted = NULL;
                    963:        }
                    964:        sshbuf_free(padded);
                    965:        sshbuf_free(encoded);
                    966:        sshbuf_free(encrypted);
                    967:        cipher_free(ciphercontext);
                    968:        free(iv);
                    969:        return r;
                    970: }
                    971:
                    972: int
                    973: sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
                    974:    struct sshbuf **retp)
                    975: {
                    976:        struct ssh_xmss_state *state = k->xmss_state;
                    977:        struct sshbuf *copy = NULL, *decrypted = NULL;
                    978:        struct sshcipher_ctx *ciphercontext = NULL;
                    979:        const struct sshcipher *cipher = NULL;
                    980:        u_char *key, *iv = NULL, *dp;
                    981:        size_t keylen, ivlen, authlen, aadlen;
                    982:        u_int blocksize, encrypted_len, index;
                    983:        int r = SSH_ERR_INTERNAL_ERROR;
                    984:
                    985:        if (retp != NULL)
                    986:                *retp = NULL;
                    987:        if (state == NULL ||
                    988:            state->enc_keyiv == NULL ||
                    989:            state->enc_ciphername == NULL)
                    990:                return SSH_ERR_INTERNAL_ERROR;
                    991:        if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
                    992:                r = SSH_ERR_INVALID_FORMAT;
                    993:                goto out;
                    994:        }
                    995:        blocksize = cipher_blocksize(cipher);
                    996:        keylen = cipher_keylen(cipher);
                    997:        ivlen = cipher_ivlen(cipher);
                    998:        authlen = cipher_authlen(cipher);
                    999:        if (state->enc_keyiv_len != keylen + ivlen) {
                   1000:                r = SSH_ERR_INTERNAL_ERROR;
                   1001:                goto out;
                   1002:        }
                   1003:        key = state->enc_keyiv;
                   1004:
                   1005:        if ((copy = sshbuf_fromb(encoded)) == NULL ||
                   1006:            (decrypted = sshbuf_new()) == NULL ||
                   1007:            (iv = malloc(ivlen)) == NULL) {
                   1008:                r = SSH_ERR_ALLOC_FAIL;
                   1009:                goto out;
                   1010:        }
                   1011:
                   1012:        /* check magic */
                   1013:        if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
                   1014:            memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
                   1015:                r = SSH_ERR_INVALID_FORMAT;
                   1016:                goto out;
                   1017:        }
                   1018:        /* parse public portion */
                   1019:        if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
                   1020:            (r = sshbuf_get_u32(encoded, &index)) != 0 ||
                   1021:            (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
                   1022:                goto out;
                   1023:
                   1024:        /* check size of encrypted key blob */
                   1025:        if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
                   1026:                r = SSH_ERR_INVALID_FORMAT;
                   1027:                goto out;
                   1028:        }
                   1029:        /* check that an appropriate amount of auth data is present */
1.6       djm      1030:        if (sshbuf_len(encoded) < authlen ||
                   1031:            sshbuf_len(encoded) - authlen < encrypted_len) {
1.1       markus   1032:                r = SSH_ERR_INVALID_FORMAT;
                   1033:                goto out;
                   1034:        }
                   1035:
                   1036:        aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
                   1037:
                   1038:        /* replace first 4 bytes of IV with index to ensure uniqueness */
                   1039:        memcpy(iv, key + keylen, ivlen);
                   1040:        POKE_U32(iv, index);
                   1041:
                   1042:        /* decrypt private state of key */
                   1043:        if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
                   1044:            (r = cipher_init(&ciphercontext, cipher, key, keylen,
                   1045:            iv, ivlen, 0)) != 0 ||
                   1046:            (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
                   1047:            encrypted_len, aadlen, authlen)) != 0)
                   1048:                goto out;
                   1049:
                   1050:        /* there should be no trailing data */
                   1051:        if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
                   1052:                goto out;
                   1053:        if (sshbuf_len(encoded) != 0) {
                   1054:                r = SSH_ERR_INVALID_FORMAT;
                   1055:                goto out;
                   1056:        }
                   1057:
                   1058:        /* remove AAD */
                   1059:        if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
                   1060:                goto out;
                   1061:        /* XXX encrypted includes unchecked padding */
                   1062:
                   1063:        /* success */
                   1064:        r = 0;
                   1065:        if (retp != NULL) {
                   1066:                *retp = decrypted;
                   1067:                decrypted = NULL;
                   1068:        }
                   1069:  out:
                   1070:        cipher_free(ciphercontext);
                   1071:        sshbuf_free(copy);
                   1072:        sshbuf_free(decrypted);
                   1073:        free(iv);
                   1074:        return r;
                   1075: }
                   1076:
                   1077: u_int32_t
                   1078: sshkey_xmss_signatures_left(const struct sshkey *k)
                   1079: {
                   1080:        struct ssh_xmss_state *state = k->xmss_state;
                   1081:        u_int32_t idx;
                   1082:
                   1083:        if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
                   1084:            state->maxidx) {
                   1085:                idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
                   1086:                if (idx < state->maxidx)
                   1087:                        return state->maxidx - idx;
                   1088:        }
                   1089:        return 0;
                   1090: }
                   1091:
                   1092: int
                   1093: sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
                   1094: {
                   1095:        struct ssh_xmss_state *state = k->xmss_state;
                   1096:
                   1097:        if (sshkey_type_plain(k->type) != KEY_XMSS)
                   1098:                return SSH_ERR_INVALID_ARGUMENT;
                   1099:        if (maxsign == 0)
                   1100:                return 0;
                   1101:        if (state->idx + maxsign < state->idx)
                   1102:                return SSH_ERR_INVALID_ARGUMENT;
                   1103:        state->maxidx = state->idx + maxsign;
                   1104:        return 0;
                   1105: }