[BACK]Return to sshkey-xmss.c CVS log [TXT][DIR] Up to [local] / src / usr.bin / ssh

Annotation of src/usr.bin/ssh/sshkey-xmss.c, Revision 1.9

1.9     ! dtucker     1: /* $OpenBSD: sshkey-xmss.c,v 1.8 2019/11/13 07:53:10 markus Exp $ */
1.1       markus      2: /*
                      3:  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
                      4:  *
                      5:  * Redistribution and use in source and binary forms, with or without
                      6:  * modification, are permitted provided that the following conditions
                      7:  * are met:
                      8:  * 1. Redistributions of source code must retain the above copyright
                      9:  *    notice, this list of conditions and the following disclaimer.
                     10:  * 2. Redistributions in binary form must reproduce the above copyright
                     11:  *    notice, this list of conditions and the following disclaimer in the
                     12:  *    documentation and/or other materials provided with the distribution.
                     13:  *
                     14:  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
                     15:  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
                     16:  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
                     17:  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
                     18:  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
                     19:  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
                     20:  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
                     21:  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
                     22:  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
                     23:  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
                     24:  */
                     25:
                     26: #include <sys/types.h>
                     27: #include <sys/uio.h>
                     28:
                     29: #include <stdio.h>
                     30: #include <string.h>
                     31: #include <unistd.h>
                     32: #include <fcntl.h>
                     33: #include <errno.h>
                     34:
                     35: #include "ssh2.h"
                     36: #include "ssherr.h"
                     37: #include "sshbuf.h"
                     38: #include "cipher.h"
                     39: #include "sshkey.h"
                     40: #include "sshkey-xmss.h"
                     41: #include "atomicio.h"
1.9     ! dtucker    42: #include "log.h"
1.1       markus     43:
                     44: #include "xmss_fast.h"
                     45:
                     46: /* opaque internal XMSS state */
                     47: #define XMSS_MAGIC             "xmss-state-v1"
                     48: #define XMSS_CIPHERNAME                "aes256-gcm@openssh.com"
                     49: struct ssh_xmss_state {
                     50:        xmss_params     params;
                     51:        u_int32_t       n, w, h, k;
                     52:
                     53:        bds_state       bds;
                     54:        u_char          *stack;
                     55:        u_int32_t       stackoffset;
                     56:        u_char          *stacklevels;
                     57:        u_char          *auth;
                     58:        u_char          *keep;
                     59:        u_char          *th_nodes;
                     60:        u_char          *retain;
                     61:        treehash_inst   *treehash;
                     62:
                     63:        u_int32_t       idx;            /* state read from file */
1.2       djm        64:        u_int32_t       maxidx;         /* restricted # of signatures */
1.1       markus     65:        int             have_state;     /* .state file exists */
                     66:        int             lockfd;         /* locked in sshkey_xmss_get_state() */
1.8       markus     67:        u_char          allow_update;   /* allow sshkey_xmss_update_state() */
1.1       markus     68:        char            *enc_ciphername;/* encrypt state with cipher */
                     69:        u_char          *enc_keyiv;     /* encrypt state with key */
                     70:        u_int32_t       enc_keyiv_len;  /* length of enc_keyiv */
                     71: };
                     72:
                     73: int     sshkey_xmss_init_bds_state(struct sshkey *);
                     74: int     sshkey_xmss_init_enc_key(struct sshkey *, const char *);
                     75: void    sshkey_xmss_free_bds(struct sshkey *);
                     76: int     sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
1.9     ! dtucker    77:            int *, int);
1.1       markus     78: int     sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
                     79:            struct sshbuf **);
                     80: int     sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
                     81:            struct sshbuf **);
                     82: int     sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
                     83: int     sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
                     84:
1.9     ! dtucker    85: #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
        !            86:     0, SYSLOG_LEVEL_ERROR, __VA_ARGS__); } while (0)
1.1       markus     87:
                     88: int
                     89: sshkey_xmss_init(struct sshkey *key, const char *name)
                     90: {
                     91:        struct ssh_xmss_state *state;
                     92:
                     93:        if (key->xmss_state != NULL)
                     94:                return SSH_ERR_INVALID_FORMAT;
                     95:        if (name == NULL)
                     96:                return SSH_ERR_INVALID_FORMAT;
                     97:        state = calloc(sizeof(struct ssh_xmss_state), 1);
                     98:        if (state == NULL)
                     99:                return SSH_ERR_ALLOC_FAIL;
                    100:        if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
                    101:                state->n = 32;
                    102:                state->w = 16;
                    103:                state->h = 10;
                    104:        } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
                    105:                state->n = 32;
                    106:                state->w = 16;
                    107:                state->h = 16;
                    108:        } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
                    109:                state->n = 32;
                    110:                state->w = 16;
                    111:                state->h = 20;
                    112:        } else {
                    113:                free(state);
                    114:                return SSH_ERR_KEY_TYPE_UNKNOWN;
                    115:        }
                    116:        if ((key->xmss_name = strdup(name)) == NULL) {
                    117:                free(state);
                    118:                return SSH_ERR_ALLOC_FAIL;
                    119:        }
                    120:        state->k = 2;   /* XXX hardcoded */
                    121:        state->lockfd = -1;
                    122:        if (xmss_set_params(&state->params, state->n, state->h, state->w,
                    123:            state->k) != 0) {
                    124:                free(state);
                    125:                return SSH_ERR_INVALID_FORMAT;
                    126:        }
                    127:        key->xmss_state = state;
                    128:        return 0;
                    129: }
                    130:
                    131: void
                    132: sshkey_xmss_free_state(struct sshkey *key)
                    133: {
                    134:        struct ssh_xmss_state *state = key->xmss_state;
                    135:
                    136:        sshkey_xmss_free_bds(key);
                    137:        if (state) {
                    138:                if (state->enc_keyiv) {
                    139:                        explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
                    140:                        free(state->enc_keyiv);
                    141:                }
                    142:                free(state->enc_ciphername);
                    143:                free(state);
                    144:        }
                    145:        key->xmss_state = NULL;
                    146: }
                    147:
                    148: #define SSH_XMSS_K2_MAGIC      "k=2"
                    149: #define num_stack(x)           ((x->h+1)*(x->n))
                    150: #define num_stacklevels(x)     (x->h+1)
                    151: #define num_auth(x)            ((x->h)*(x->n))
                    152: #define num_keep(x)            ((x->h >> 1)*(x->n))
                    153: #define num_th_nodes(x)                ((x->h - x->k)*(x->n))
                    154: #define num_retain(x)          (((1ULL << x->k) - x->k - 1) * (x->n))
                    155: #define num_treehash(x)                ((x->h) - (x->k))
                    156:
                    157: int
                    158: sshkey_xmss_init_bds_state(struct sshkey *key)
                    159: {
                    160:        struct ssh_xmss_state *state = key->xmss_state;
                    161:        u_int32_t i;
                    162:
                    163:        state->stackoffset = 0;
                    164:        if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
                    165:            (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
                    166:            (state->auth = calloc(num_auth(state), 1)) == NULL ||
                    167:            (state->keep = calloc(num_keep(state), 1)) == NULL ||
                    168:            (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
                    169:            (state->retain = calloc(num_retain(state), 1)) == NULL ||
                    170:            (state->treehash = calloc(num_treehash(state),
                    171:            sizeof(treehash_inst))) == NULL) {
                    172:                sshkey_xmss_free_bds(key);
                    173:                return SSH_ERR_ALLOC_FAIL;
                    174:        }
                    175:        for (i = 0; i < state->h - state->k; i++)
                    176:                state->treehash[i].node = &state->th_nodes[state->n*i];
                    177:        xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
                    178:            state->stacklevels, state->auth, state->keep, state->treehash,
                    179:            state->retain, 0);
                    180:        return 0;
                    181: }
                    182:
                    183: void
                    184: sshkey_xmss_free_bds(struct sshkey *key)
                    185: {
                    186:        struct ssh_xmss_state *state = key->xmss_state;
                    187:
                    188:        if (state == NULL)
                    189:                return;
                    190:        free(state->stack);
                    191:        free(state->stacklevels);
                    192:        free(state->auth);
                    193:        free(state->keep);
                    194:        free(state->th_nodes);
                    195:        free(state->retain);
                    196:        free(state->treehash);
                    197:        state->stack = NULL;
                    198:        state->stacklevels = NULL;
                    199:        state->auth = NULL;
                    200:        state->keep = NULL;
                    201:        state->th_nodes = NULL;
                    202:        state->retain = NULL;
                    203:        state->treehash = NULL;
                    204: }
                    205:
                    206: void *
                    207: sshkey_xmss_params(const struct sshkey *key)
                    208: {
                    209:        struct ssh_xmss_state *state = key->xmss_state;
                    210:
                    211:        if (state == NULL)
                    212:                return NULL;
                    213:        return &state->params;
                    214: }
                    215:
                    216: void *
                    217: sshkey_xmss_bds_state(const struct sshkey *key)
                    218: {
                    219:        struct ssh_xmss_state *state = key->xmss_state;
                    220:
                    221:        if (state == NULL)
                    222:                return NULL;
                    223:        return &state->bds;
                    224: }
                    225:
                    226: int
                    227: sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
                    228: {
                    229:        struct ssh_xmss_state *state = key->xmss_state;
                    230:
                    231:        if (lenp == NULL)
                    232:                return SSH_ERR_INVALID_ARGUMENT;
                    233:        if (state == NULL)
                    234:                return SSH_ERR_INVALID_FORMAT;
                    235:        *lenp = 4 + state->n +
                    236:            state->params.wots_par.keysize +
                    237:            state->h * state->n;
                    238:        return 0;
                    239: }
                    240:
                    241: size_t
                    242: sshkey_xmss_pklen(const struct sshkey *key)
                    243: {
                    244:        struct ssh_xmss_state *state = key->xmss_state;
                    245:
                    246:        if (state == NULL)
                    247:                return 0;
                    248:        return state->n * 2;
                    249: }
                    250:
                    251: size_t
                    252: sshkey_xmss_sklen(const struct sshkey *key)
                    253: {
                    254:        struct ssh_xmss_state *state = key->xmss_state;
                    255:
                    256:        if (state == NULL)
                    257:                return 0;
                    258:        return state->n * 4 + 4;
                    259: }
                    260:
                    261: int
                    262: sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
                    263: {
                    264:        struct ssh_xmss_state *state = k->xmss_state;
                    265:        const struct sshcipher *cipher;
                    266:        size_t keylen = 0, ivlen = 0;
                    267:
                    268:        if (state == NULL)
                    269:                return SSH_ERR_INVALID_ARGUMENT;
                    270:        if ((cipher = cipher_by_name(ciphername)) == NULL)
                    271:                return SSH_ERR_INTERNAL_ERROR;
                    272:        if ((state->enc_ciphername = strdup(ciphername)) == NULL)
                    273:                return SSH_ERR_ALLOC_FAIL;
                    274:        keylen = cipher_keylen(cipher);
                    275:        ivlen = cipher_ivlen(cipher);
                    276:        state->enc_keyiv_len = keylen + ivlen;
                    277:        if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
                    278:                free(state->enc_ciphername);
                    279:                state->enc_ciphername = NULL;
                    280:                return SSH_ERR_ALLOC_FAIL;
                    281:        }
                    282:        arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
                    283:        return 0;
                    284: }
                    285:
                    286: int
                    287: sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
                    288: {
                    289:        struct ssh_xmss_state *state = k->xmss_state;
                    290:        int r;
                    291:
                    292:        if (state == NULL || state->enc_keyiv == NULL ||
                    293:            state->enc_ciphername == NULL)
                    294:                return SSH_ERR_INVALID_ARGUMENT;
                    295:        if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
                    296:            (r = sshbuf_put_string(b, state->enc_keyiv,
                    297:            state->enc_keyiv_len)) != 0)
                    298:                return r;
                    299:        return 0;
                    300: }
                    301:
                    302: int
                    303: sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
                    304: {
                    305:        struct ssh_xmss_state *state = k->xmss_state;
                    306:        size_t len;
                    307:        int r;
                    308:
                    309:        if (state == NULL)
                    310:                return SSH_ERR_INVALID_ARGUMENT;
                    311:        if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
                    312:            (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
                    313:                return r;
                    314:        state->enc_keyiv_len = len;
                    315:        return 0;
                    316: }
                    317:
                    318: int
                    319: sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
                    320:     enum sshkey_serialize_rep opts)
                    321: {
                    322:        struct ssh_xmss_state *state = k->xmss_state;
                    323:        u_char have_info = 1;
                    324:        u_int32_t idx;
                    325:        int r;
                    326:
                    327:        if (state == NULL)
                    328:                return SSH_ERR_INVALID_ARGUMENT;
                    329:        if (opts != SSHKEY_SERIALIZE_INFO)
                    330:                return 0;
                    331:        idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
                    332:        if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
                    333:            (r = sshbuf_put_u32(b, idx)) != 0 ||
                    334:            (r = sshbuf_put_u32(b, state->maxidx)) != 0)
                    335:                return r;
                    336:        return 0;
                    337: }
                    338:
                    339: int
                    340: sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
                    341: {
                    342:        struct ssh_xmss_state *state = k->xmss_state;
                    343:        u_char have_info;
                    344:        int r;
                    345:
                    346:        if (state == NULL)
                    347:                return SSH_ERR_INVALID_ARGUMENT;
                    348:        /* optional */
                    349:        if (sshbuf_len(b) == 0)
                    350:                return 0;
                    351:        if ((r = sshbuf_get_u8(b, &have_info)) != 0)
                    352:                return r;
                    353:        if (have_info != 1)
                    354:                return SSH_ERR_INVALID_ARGUMENT;
                    355:        if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
                    356:            (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
                    357:                return r;
                    358:        return 0;
                    359: }
                    360:
                    361: int
                    362: sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
                    363: {
                    364:        int r;
                    365:        const char *name;
                    366:
                    367:        if (bits == 10) {
                    368:                name = XMSS_SHA2_256_W16_H10_NAME;
                    369:        } else if (bits == 16) {
                    370:                name = XMSS_SHA2_256_W16_H16_NAME;
                    371:        } else if (bits == 20) {
                    372:                name = XMSS_SHA2_256_W16_H20_NAME;
                    373:        } else {
                    374:                name = XMSS_DEFAULT_NAME;
                    375:        }
                    376:        if ((r = sshkey_xmss_init(k, name)) != 0 ||
                    377:            (r = sshkey_xmss_init_bds_state(k)) != 0 ||
                    378:            (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
                    379:                return r;
                    380:        if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
                    381:            (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
                    382:                return SSH_ERR_ALLOC_FAIL;
                    383:        }
                    384:        xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
                    385:            sshkey_xmss_params(k));
                    386:        return 0;
                    387: }
                    388:
                    389: int
                    390: sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
1.9     ! dtucker   391:     int *have_file, int printerror)
1.1       markus    392: {
                    393:        struct sshbuf *b = NULL, *enc = NULL;
                    394:        int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
                    395:        u_int32_t len;
                    396:        unsigned char buf[4], *data = NULL;
                    397:
                    398:        *have_file = 0;
                    399:        if ((fd = open(filename, O_RDONLY)) >= 0) {
                    400:                *have_file = 1;
                    401:                if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
                    402:                        PRINT("%s: corrupt state file: %s", __func__, filename);
                    403:                        goto done;
                    404:                }
                    405:                len = PEEK_U32(buf);
                    406:                if ((data = calloc(len, 1)) == NULL) {
                    407:                        ret = SSH_ERR_ALLOC_FAIL;
                    408:                        goto done;
                    409:                }
                    410:                if (atomicio(read, fd, data, len) != len) {
                    411:                        PRINT("%s: cannot read blob: %s", __func__, filename);
                    412:                        goto done;
                    413:                }
                    414:                if ((enc = sshbuf_from(data, len)) == NULL) {
                    415:                        ret = SSH_ERR_ALLOC_FAIL;
                    416:                        goto done;
                    417:                }
                    418:                sshkey_xmss_free_bds(k);
                    419:                if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
                    420:                        ret = r;
                    421:                        goto done;
                    422:                }
                    423:                if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
                    424:                        ret = r;
                    425:                        goto done;
                    426:                }
                    427:                ret = 0;
                    428:        }
                    429: done:
                    430:        if (fd != -1)
                    431:                close(fd);
                    432:        free(data);
                    433:        sshbuf_free(enc);
                    434:        sshbuf_free(b);
                    435:        return ret;
                    436: }
                    437:
                    438: int
1.9     ! dtucker   439: sshkey_xmss_get_state(const struct sshkey *k, int printerror)
1.1       markus    440: {
                    441:        struct ssh_xmss_state *state = k->xmss_state;
                    442:        u_int32_t idx = 0;
                    443:        char *filename = NULL;
                    444:        char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
                    445:        int lockfd = -1, have_state = 0, have_ostate, tries = 0;
                    446:        int ret = SSH_ERR_INVALID_ARGUMENT, r;
                    447:
                    448:        if (state == NULL)
                    449:                goto done;
                    450:        /*
                    451:         * If maxidx is set, then we are allowed a limited number
                    452:         * of signatures, but don't need to access the disk.
                    453:         * Otherwise we need to deal with the on-disk state.
                    454:         */
                    455:        if (state->maxidx) {
                    456:                /* xmss_sk always contains the current state */
                    457:                idx = PEEK_U32(k->xmss_sk);
                    458:                if (idx < state->maxidx) {
                    459:                        state->allow_update = 1;
                    460:                        return 0;
                    461:                }
                    462:                return SSH_ERR_INVALID_ARGUMENT;
                    463:        }
                    464:        if ((filename = k->xmss_filename) == NULL)
                    465:                goto done;
1.4       deraadt   466:        if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
                    467:            asprintf(&statefile, "%s.state", filename) == -1 ||
                    468:            asprintf(&ostatefile, "%s.ostate", filename) == -1) {
1.1       markus    469:                ret = SSH_ERR_ALLOC_FAIL;
                    470:                goto done;
                    471:        }
1.5       deraadt   472:        if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
1.1       markus    473:                ret = SSH_ERR_SYSTEM_ERROR;
                    474:                PRINT("%s: cannot open/create: %s", __func__, lockfile);
                    475:                goto done;
                    476:        }
1.5       deraadt   477:        while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
1.1       markus    478:                if (errno != EWOULDBLOCK) {
                    479:                        ret = SSH_ERR_SYSTEM_ERROR;
                    480:                        PRINT("%s: cannot lock: %s", __func__, lockfile);
                    481:                        goto done;
                    482:                }
                    483:                if (++tries > 10) {
                    484:                        ret = SSH_ERR_SYSTEM_ERROR;
                    485:                        PRINT("%s: giving up on: %s", __func__, lockfile);
                    486:                        goto done;
                    487:                }
                    488:                usleep(1000*100*tries);
                    489:        }
                    490:        /* XXX no longer const */
                    491:        if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
1.9     ! dtucker   492:            statefile, &have_state, printerror)) != 0) {
1.1       markus    493:                if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
1.9     ! dtucker   494:                    ostatefile, &have_ostate, printerror)) == 0) {
1.1       markus    495:                        state->allow_update = 1;
                    496:                        r = sshkey_xmss_forward_state(k, 1);
                    497:                        state->idx = PEEK_U32(k->xmss_sk);
                    498:                        state->allow_update = 0;
                    499:                }
                    500:        }
                    501:        if (!have_state && !have_ostate) {
                    502:                /* check that bds state is initialized */
                    503:                if (state->bds.auth == NULL)
                    504:                        goto done;
                    505:                PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
                    506:        } else if (r != 0) {
                    507:                ret = r;
                    508:                goto done;
                    509:        }
                    510:        if (state->idx + 1 < state->idx) {
                    511:                PRINT("%s: state wrap: %u", __func__, state->idx);
                    512:                goto done;
                    513:        }
                    514:        state->have_state = have_state;
                    515:        state->lockfd = lockfd;
                    516:        state->allow_update = 1;
                    517:        lockfd = -1;
                    518:        ret = 0;
                    519: done:
                    520:        if (lockfd != -1)
                    521:                close(lockfd);
                    522:        free(lockfile);
                    523:        free(statefile);
                    524:        free(ostatefile);
                    525:        return ret;
                    526: }
                    527:
                    528: int
                    529: sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
                    530: {
                    531:        struct ssh_xmss_state *state = k->xmss_state;
                    532:        u_char *sig = NULL;
                    533:        size_t required_siglen;
                    534:        unsigned long long smlen;
                    535:        u_char data;
                    536:        int ret, r;
                    537:
                    538:        if (state == NULL || !state->allow_update)
                    539:                return SSH_ERR_INVALID_ARGUMENT;
                    540:        if (reserve == 0)
                    541:                return SSH_ERR_INVALID_ARGUMENT;
                    542:        if (state->idx + reserve <= state->idx)
                    543:                return SSH_ERR_INVALID_ARGUMENT;
                    544:        if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
                    545:                return r;
                    546:        if ((sig = malloc(required_siglen)) == NULL)
                    547:                return SSH_ERR_ALLOC_FAIL;
                    548:        while (reserve-- > 0) {
                    549:                state->idx = PEEK_U32(k->xmss_sk);
                    550:                smlen = required_siglen;
                    551:                if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
                    552:                    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
                    553:                        r = SSH_ERR_INVALID_ARGUMENT;
                    554:                        break;
                    555:                }
                    556:        }
                    557:        free(sig);
                    558:        return r;
                    559: }
                    560:
                    561: int
1.9     ! dtucker   562: sshkey_xmss_update_state(const struct sshkey *k, int printerror)
1.1       markus    563: {
                    564:        struct ssh_xmss_state *state = k->xmss_state;
                    565:        struct sshbuf *b = NULL, *enc = NULL;
                    566:        u_int32_t idx = 0;
                    567:        unsigned char buf[4];
                    568:        char *filename = NULL;
                    569:        char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
                    570:        int fd = -1;
                    571:        int ret = SSH_ERR_INVALID_ARGUMENT;
                    572:
                    573:        if (state == NULL || !state->allow_update)
                    574:                return ret;
                    575:        if (state->maxidx) {
                    576:                /* no update since the number of signatures is limited */
                    577:                ret = 0;
                    578:                goto done;
                    579:        }
                    580:        idx = PEEK_U32(k->xmss_sk);
                    581:        if (idx == state->idx) {
1.2       djm       582:                /* no signature happened, no need to update */
1.1       markus    583:                ret = 0;
                    584:                goto done;
                    585:        } else if (idx != state->idx + 1) {
                    586:                PRINT("%s: more than one signature happened: idx %u state %u",
                    587:                     __func__, idx, state->idx);
                    588:                goto done;
                    589:        }
                    590:        state->idx = idx;
                    591:        if ((filename = k->xmss_filename) == NULL)
                    592:                goto done;
1.4       deraadt   593:        if (asprintf(&statefile, "%s.state", filename) == -1 ||
                    594:            asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
                    595:            asprintf(&nstatefile, "%s.nstate", filename) == -1) {
1.1       markus    596:                ret = SSH_ERR_ALLOC_FAIL;
                    597:                goto done;
                    598:        }
                    599:        unlink(nstatefile);
                    600:        if ((b = sshbuf_new()) == NULL) {
                    601:                ret = SSH_ERR_ALLOC_FAIL;
                    602:                goto done;
                    603:        }
                    604:        if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
                    605:                PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
                    606:                goto done;
                    607:        }
                    608:        if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
                    609:                PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
                    610:                goto done;
                    611:        }
1.5       deraadt   612:        if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
1.1       markus    613:                ret = SSH_ERR_SYSTEM_ERROR;
                    614:                PRINT("%s: open new state file: %s", __func__, nstatefile);
                    615:                goto done;
                    616:        }
                    617:        POKE_U32(buf, sshbuf_len(enc));
                    618:        if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
                    619:                ret = SSH_ERR_SYSTEM_ERROR;
                    620:                PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
                    621:                close(fd);
                    622:                goto done;
                    623:        }
1.3       markus    624:        if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
1.1       markus    625:            sshbuf_len(enc)) {
                    626:                ret = SSH_ERR_SYSTEM_ERROR;
                    627:                PRINT("%s: write new state file data: %s", __func__, nstatefile);
                    628:                close(fd);
                    629:                goto done;
                    630:        }
1.5       deraadt   631:        if (fsync(fd) == -1) {
1.1       markus    632:                ret = SSH_ERR_SYSTEM_ERROR;
                    633:                PRINT("%s: sync new state file: %s", __func__, nstatefile);
                    634:                close(fd);
                    635:                goto done;
                    636:        }
1.5       deraadt   637:        if (close(fd) == -1) {
1.1       markus    638:                ret = SSH_ERR_SYSTEM_ERROR;
                    639:                PRINT("%s: close new state file: %s", __func__, nstatefile);
                    640:                goto done;
                    641:        }
                    642:        if (state->have_state) {
                    643:                unlink(ostatefile);
                    644:                if (link(statefile, ostatefile)) {
                    645:                        ret = SSH_ERR_SYSTEM_ERROR;
                    646:                        PRINT("%s: backup state %s to %s", __func__, statefile,
                    647:                            ostatefile);
                    648:                        goto done;
                    649:                }
                    650:        }
1.5       deraadt   651:        if (rename(nstatefile, statefile) == -1) {
1.1       markus    652:                ret = SSH_ERR_SYSTEM_ERROR;
                    653:                PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
                    654:                goto done;
                    655:        }
                    656:        ret = 0;
                    657: done:
                    658:        if (state->lockfd != -1) {
                    659:                close(state->lockfd);
                    660:                state->lockfd = -1;
                    661:        }
                    662:        if (nstatefile)
                    663:                unlink(nstatefile);
                    664:        free(statefile);
                    665:        free(ostatefile);
                    666:        free(nstatefile);
                    667:        sshbuf_free(b);
                    668:        sshbuf_free(enc);
                    669:        return ret;
                    670: }
                    671:
                    672: int
                    673: sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
                    674: {
                    675:        struct ssh_xmss_state *state = k->xmss_state;
                    676:        treehash_inst *th;
                    677:        u_int32_t i, node;
                    678:        int r;
                    679:
                    680:        if (state == NULL)
                    681:                return SSH_ERR_INVALID_ARGUMENT;
                    682:        if (state->stack == NULL)
                    683:                return SSH_ERR_INVALID_ARGUMENT;
                    684:        state->stackoffset = state->bds.stackoffset;    /* copy back */
                    685:        if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
                    686:            (r = sshbuf_put_u32(b, state->idx)) != 0 ||
                    687:            (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
                    688:            (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
                    689:            (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
                    690:            (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
                    691:            (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
                    692:            (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
                    693:            (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
                    694:            (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
                    695:                return r;
                    696:        for (i = 0; i < num_treehash(state); i++) {
                    697:                th = &state->treehash[i];
                    698:                node = th->node - state->th_nodes;
                    699:                if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
                    700:                    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
                    701:                    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
                    702:                    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
                    703:                    (r = sshbuf_put_u32(b, node)) != 0)
                    704:                        return r;
                    705:        }
                    706:        return 0;
                    707: }
                    708:
                    709: int
                    710: sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
                    711:     enum sshkey_serialize_rep opts)
                    712: {
                    713:        struct ssh_xmss_state *state = k->xmss_state;
                    714:        int r = SSH_ERR_INVALID_ARGUMENT;
1.8       markus    715:        u_char have_stack, have_filename, have_enc;
1.1       markus    716:
                    717:        if (state == NULL)
                    718:                return SSH_ERR_INVALID_ARGUMENT;
                    719:        if ((r = sshbuf_put_u8(b, opts)) != 0)
                    720:                return r;
                    721:        switch (opts) {
                    722:        case SSHKEY_SERIALIZE_STATE:
                    723:                r = sshkey_xmss_serialize_state(k, b);
                    724:                break;
                    725:        case SSHKEY_SERIALIZE_FULL:
                    726:                if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
1.8       markus    727:                        return r;
1.1       markus    728:                r = sshkey_xmss_serialize_state(k, b);
                    729:                break;
1.8       markus    730:        case SSHKEY_SERIALIZE_SHIELD:
                    731:                /* all of stack/filename/enc are optional */
                    732:                have_stack = state->stack != NULL;
                    733:                if ((r = sshbuf_put_u8(b, have_stack)) != 0)
                    734:                        return r;
                    735:                if (have_stack) {
                    736:                        state->idx = PEEK_U32(k->xmss_sk);      /* update */
                    737:                        if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
                    738:                                return r;
                    739:                }
                    740:                have_filename = k->xmss_filename != NULL;
                    741:                if ((r = sshbuf_put_u8(b, have_filename)) != 0)
                    742:                        return r;
                    743:                if (have_filename &&
                    744:                    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
                    745:                        return r;
                    746:                have_enc = state->enc_keyiv != NULL;
                    747:                if ((r = sshbuf_put_u8(b, have_enc)) != 0)
                    748:                        return r;
                    749:                if (have_enc &&
                    750:                    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
                    751:                        return r;
                    752:                if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
                    753:                    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
                    754:                        return r;
                    755:                break;
1.1       markus    756:        case SSHKEY_SERIALIZE_DEFAULT:
                    757:                r = 0;
                    758:                break;
                    759:        default:
                    760:                r = SSH_ERR_INVALID_ARGUMENT;
                    761:                break;
                    762:        }
                    763:        return r;
                    764: }
                    765:
                    766: int
                    767: sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
                    768: {
                    769:        struct ssh_xmss_state *state = k->xmss_state;
                    770:        treehash_inst *th;
                    771:        u_int32_t i, lh, node;
                    772:        size_t ls, lsl, la, lk, ln, lr;
                    773:        char *magic;
1.7       djm       774:        int r = SSH_ERR_INTERNAL_ERROR;
1.1       markus    775:
                    776:        if (state == NULL)
                    777:                return SSH_ERR_INVALID_ARGUMENT;
                    778:        if (k->xmss_sk == NULL)
                    779:                return SSH_ERR_INVALID_ARGUMENT;
                    780:        if ((state->treehash = calloc(num_treehash(state),
                    781:            sizeof(treehash_inst))) == NULL)
                    782:                return SSH_ERR_ALLOC_FAIL;
                    783:        if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
                    784:            (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
                    785:            (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
                    786:            (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
                    787:            (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
                    788:            (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
                    789:            (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
                    790:            (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
                    791:            (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
                    792:            (r = sshbuf_get_u32(b, &lh)) != 0)
1.7       djm       793:                goto out;
                    794:        if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
                    795:                r = SSH_ERR_INVALID_ARGUMENT;
                    796:                goto out;
                    797:        }
1.1       markus    798:        /* XXX check stackoffset */
                    799:        if (ls != num_stack(state) ||
                    800:            lsl != num_stacklevels(state) ||
                    801:            la != num_auth(state) ||
                    802:            lk != num_keep(state) ||
                    803:            ln != num_th_nodes(state) ||
                    804:            lr != num_retain(state) ||
1.7       djm       805:            lh != num_treehash(state)) {
                    806:                r = SSH_ERR_INVALID_ARGUMENT;
                    807:                goto out;
                    808:        }
1.1       markus    809:        for (i = 0; i < num_treehash(state); i++) {
                    810:                th = &state->treehash[i];
                    811:                if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
                    812:                    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
                    813:                    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
                    814:                    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
                    815:                    (r = sshbuf_get_u32(b, &node)) != 0)
1.7       djm       816:                        goto out;
1.1       markus    817:                if (node < num_th_nodes(state))
                    818:                        th->node = &state->th_nodes[node];
                    819:        }
                    820:        POKE_U32(k->xmss_sk, state->idx);
                    821:        xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
                    822:            state->stacklevels, state->auth, state->keep, state->treehash,
                    823:            state->retain, 0);
1.7       djm       824:        /* success */
                    825:        r = 0;
                    826:  out:
                    827:        free(magic);
                    828:        return r;
1.1       markus    829: }
                    830:
                    831: int
                    832: sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
                    833: {
1.8       markus    834:        struct ssh_xmss_state *state = k->xmss_state;
1.1       markus    835:        enum sshkey_serialize_rep opts;
1.8       markus    836:        u_char have_state, have_stack, have_filename, have_enc;
1.1       markus    837:        int r;
                    838:
                    839:        if ((r = sshbuf_get_u8(b, &have_state)) != 0)
                    840:                return r;
                    841:
                    842:        opts = have_state;
                    843:        switch (opts) {
                    844:        case SSHKEY_SERIALIZE_DEFAULT:
                    845:                r = 0;
1.8       markus    846:                break;
                    847:        case SSHKEY_SERIALIZE_SHIELD:
                    848:                if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
                    849:                        return r;
                    850:                if (have_stack &&
                    851:                    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
                    852:                        return r;
                    853:                if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
                    854:                        return r;
                    855:                if (have_filename &&
                    856:                    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
                    857:                        return r;
                    858:                if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
                    859:                        return r;
                    860:                if (have_enc &&
                    861:                    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
                    862:                        return r;
                    863:                if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
                    864:                    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
                    865:                        return r;
1.1       markus    866:                break;
                    867:        case SSHKEY_SERIALIZE_STATE:
                    868:                if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
                    869:                        return r;
                    870:                break;
                    871:        case SSHKEY_SERIALIZE_FULL:
                    872:                if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
                    873:                    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
                    874:                        return r;
                    875:                break;
                    876:        default:
                    877:                r = SSH_ERR_INVALID_FORMAT;
                    878:                break;
                    879:        }
                    880:        return r;
                    881: }
                    882:
                    883: int
                    884: sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
                    885:    struct sshbuf **retp)
                    886: {
                    887:        struct ssh_xmss_state *state = k->xmss_state;
                    888:        struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
                    889:        struct sshcipher_ctx *ciphercontext = NULL;
                    890:        const struct sshcipher *cipher;
                    891:        u_char *cp, *key, *iv = NULL;
                    892:        size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
                    893:        int r = SSH_ERR_INTERNAL_ERROR;
                    894:
                    895:        if (retp != NULL)
                    896:                *retp = NULL;
                    897:        if (state == NULL ||
                    898:            state->enc_keyiv == NULL ||
                    899:            state->enc_ciphername == NULL)
                    900:                return SSH_ERR_INTERNAL_ERROR;
                    901:        if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
                    902:                r = SSH_ERR_INTERNAL_ERROR;
                    903:                goto out;
                    904:        }
                    905:        blocksize = cipher_blocksize(cipher);
                    906:        keylen = cipher_keylen(cipher);
                    907:        ivlen = cipher_ivlen(cipher);
                    908:        authlen = cipher_authlen(cipher);
                    909:        if (state->enc_keyiv_len != keylen + ivlen) {
                    910:                r = SSH_ERR_INVALID_FORMAT;
                    911:                goto out;
                    912:        }
                    913:        key = state->enc_keyiv;
                    914:        if ((encrypted = sshbuf_new()) == NULL ||
                    915:            (encoded = sshbuf_new()) == NULL ||
                    916:            (padded = sshbuf_new()) == NULL ||
                    917:            (iv = malloc(ivlen)) == NULL) {
                    918:                r = SSH_ERR_ALLOC_FAIL;
                    919:                goto out;
                    920:        }
                    921:
                    922:        /* replace first 4 bytes of IV with index to ensure uniqueness */
                    923:        memcpy(iv, key + keylen, ivlen);
                    924:        POKE_U32(iv, state->idx);
                    925:
                    926:        if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
                    927:            (r = sshbuf_put_u32(encoded, state->idx)) != 0)
                    928:                goto out;
                    929:
                    930:        /* padded state will be encrypted */
                    931:        if ((r = sshbuf_putb(padded, b)) != 0)
                    932:                goto out;
                    933:        i = 0;
                    934:        while (sshbuf_len(padded) % blocksize) {
                    935:                if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
                    936:                        goto out;
                    937:        }
                    938:        encrypted_len = sshbuf_len(padded);
                    939:
                    940:        /* header including the length of state is used as AAD */
                    941:        if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
                    942:                goto out;
                    943:        aadlen = sshbuf_len(encoded);
                    944:
                    945:        /* concat header and state */
                    946:        if ((r = sshbuf_putb(encoded, padded)) != 0)
                    947:                goto out;
                    948:
                    949:        /* reserve space for encryption of encoded data plus auth tag */
                    950:        /* encrypt at offset addlen */
                    951:        if ((r = sshbuf_reserve(encrypted,
                    952:            encrypted_len + aadlen + authlen, &cp)) != 0 ||
                    953:            (r = cipher_init(&ciphercontext, cipher, key, keylen,
                    954:            iv, ivlen, 1)) != 0 ||
                    955:            (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
                    956:            encrypted_len, aadlen, authlen)) != 0)
                    957:                goto out;
                    958:
                    959:        /* success */
                    960:        r = 0;
                    961:  out:
                    962:        if (retp != NULL) {
                    963:                *retp = encrypted;
                    964:                encrypted = NULL;
                    965:        }
                    966:        sshbuf_free(padded);
                    967:        sshbuf_free(encoded);
                    968:        sshbuf_free(encrypted);
                    969:        cipher_free(ciphercontext);
                    970:        free(iv);
                    971:        return r;
                    972: }
                    973:
                    974: int
                    975: sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
                    976:    struct sshbuf **retp)
                    977: {
                    978:        struct ssh_xmss_state *state = k->xmss_state;
                    979:        struct sshbuf *copy = NULL, *decrypted = NULL;
                    980:        struct sshcipher_ctx *ciphercontext = NULL;
                    981:        const struct sshcipher *cipher = NULL;
                    982:        u_char *key, *iv = NULL, *dp;
                    983:        size_t keylen, ivlen, authlen, aadlen;
                    984:        u_int blocksize, encrypted_len, index;
                    985:        int r = SSH_ERR_INTERNAL_ERROR;
                    986:
                    987:        if (retp != NULL)
                    988:                *retp = NULL;
                    989:        if (state == NULL ||
                    990:            state->enc_keyiv == NULL ||
                    991:            state->enc_ciphername == NULL)
                    992:                return SSH_ERR_INTERNAL_ERROR;
                    993:        if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
                    994:                r = SSH_ERR_INVALID_FORMAT;
                    995:                goto out;
                    996:        }
                    997:        blocksize = cipher_blocksize(cipher);
                    998:        keylen = cipher_keylen(cipher);
                    999:        ivlen = cipher_ivlen(cipher);
                   1000:        authlen = cipher_authlen(cipher);
                   1001:        if (state->enc_keyiv_len != keylen + ivlen) {
                   1002:                r = SSH_ERR_INTERNAL_ERROR;
                   1003:                goto out;
                   1004:        }
                   1005:        key = state->enc_keyiv;
                   1006:
                   1007:        if ((copy = sshbuf_fromb(encoded)) == NULL ||
                   1008:            (decrypted = sshbuf_new()) == NULL ||
                   1009:            (iv = malloc(ivlen)) == NULL) {
                   1010:                r = SSH_ERR_ALLOC_FAIL;
                   1011:                goto out;
                   1012:        }
                   1013:
                   1014:        /* check magic */
                   1015:        if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
                   1016:            memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
                   1017:                r = SSH_ERR_INVALID_FORMAT;
                   1018:                goto out;
                   1019:        }
                   1020:        /* parse public portion */
                   1021:        if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
                   1022:            (r = sshbuf_get_u32(encoded, &index)) != 0 ||
                   1023:            (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
                   1024:                goto out;
                   1025:
                   1026:        /* check size of encrypted key blob */
                   1027:        if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
                   1028:                r = SSH_ERR_INVALID_FORMAT;
                   1029:                goto out;
                   1030:        }
                   1031:        /* check that an appropriate amount of auth data is present */
1.6       djm      1032:        if (sshbuf_len(encoded) < authlen ||
                   1033:            sshbuf_len(encoded) - authlen < encrypted_len) {
1.1       markus   1034:                r = SSH_ERR_INVALID_FORMAT;
                   1035:                goto out;
                   1036:        }
                   1037:
                   1038:        aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
                   1039:
                   1040:        /* replace first 4 bytes of IV with index to ensure uniqueness */
                   1041:        memcpy(iv, key + keylen, ivlen);
                   1042:        POKE_U32(iv, index);
                   1043:
                   1044:        /* decrypt private state of key */
                   1045:        if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
                   1046:            (r = cipher_init(&ciphercontext, cipher, key, keylen,
                   1047:            iv, ivlen, 0)) != 0 ||
                   1048:            (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
                   1049:            encrypted_len, aadlen, authlen)) != 0)
                   1050:                goto out;
                   1051:
                   1052:        /* there should be no trailing data */
                   1053:        if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
                   1054:                goto out;
                   1055:        if (sshbuf_len(encoded) != 0) {
                   1056:                r = SSH_ERR_INVALID_FORMAT;
                   1057:                goto out;
                   1058:        }
                   1059:
                   1060:        /* remove AAD */
                   1061:        if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
                   1062:                goto out;
                   1063:        /* XXX encrypted includes unchecked padding */
                   1064:
                   1065:        /* success */
                   1066:        r = 0;
                   1067:        if (retp != NULL) {
                   1068:                *retp = decrypted;
                   1069:                decrypted = NULL;
                   1070:        }
                   1071:  out:
                   1072:        cipher_free(ciphercontext);
                   1073:        sshbuf_free(copy);
                   1074:        sshbuf_free(decrypted);
                   1075:        free(iv);
                   1076:        return r;
                   1077: }
                   1078:
                   1079: u_int32_t
                   1080: sshkey_xmss_signatures_left(const struct sshkey *k)
                   1081: {
                   1082:        struct ssh_xmss_state *state = k->xmss_state;
                   1083:        u_int32_t idx;
                   1084:
                   1085:        if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
                   1086:            state->maxidx) {
                   1087:                idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
                   1088:                if (idx < state->maxidx)
                   1089:                        return state->maxidx - idx;
                   1090:        }
                   1091:        return 0;
                   1092: }
                   1093:
                   1094: int
                   1095: sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
                   1096: {
                   1097:        struct ssh_xmss_state *state = k->xmss_state;
                   1098:
                   1099:        if (sshkey_type_plain(k->type) != KEY_XMSS)
                   1100:                return SSH_ERR_INVALID_ARGUMENT;
                   1101:        if (maxsign == 0)
                   1102:                return 0;
                   1103:        if (state->idx + maxsign < state->idx)
                   1104:                return SSH_ERR_INVALID_ARGUMENT;
                   1105:        state->maxidx = state->idx + maxsign;
                   1106:        return 0;
                   1107: }