diff --git a/reftable/basics.c b/reftable/basics.c index 70b1091d14..fe2b83ff83 100644 --- a/reftable/basics.c +++ b/reftable/basics.c @@ -124,11 +124,8 @@ int reftable_buf_add(struct reftable_buf *buf, const void *data, size_t len) size_t newlen = buf->len + len; if (newlen + 1 > buf->alloc) { - char *reallocated = buf->buf; - REFTABLE_ALLOC_GROW(reallocated, newlen + 1, buf->alloc); - if (!reallocated) + if (REFTABLE_ALLOC_GROW(buf->buf, newlen + 1, buf->alloc)) return REFTABLE_OUT_OF_MEMORY_ERROR; - buf->buf = reallocated; } memcpy(buf->buf + buf->len, data, len); @@ -233,11 +230,9 @@ char **parse_names(char *buf, int size) next = end; } if (p < next) { - char **names_grown = names; - REFTABLE_ALLOC_GROW(names_grown, names_len + 1, names_cap); - if (!names_grown) + if (REFTABLE_ALLOC_GROW(names, names_len + 1, + names_cap)) goto err; - names = names_grown; names[names_len] = reftable_strdup(p); if (!names[names_len++]) @@ -246,7 +241,8 @@ char **parse_names(char *buf, int size) p = next + 1; } - REFTABLE_REALLOC_ARRAY(names, names_len + 1); + if (REFTABLE_ALLOC_GROW(names, names_len + 1, names_cap)) + goto err; names[names_len] = NULL; return names; diff --git a/reftable/basics.h b/reftable/basics.h index 36beda2c25..4bf71b0954 100644 --- a/reftable/basics.h +++ b/reftable/basics.h @@ -120,15 +120,38 @@ char *reftable_strdup(const char *str); #define REFTABLE_ALLOC_ARRAY(x, alloc) (x) = reftable_malloc(st_mult(sizeof(*(x)), (alloc))) #define REFTABLE_CALLOC_ARRAY(x, alloc) (x) = reftable_calloc((alloc), sizeof(*(x))) #define REFTABLE_REALLOC_ARRAY(x, alloc) (x) = reftable_realloc((x), st_mult(sizeof(*(x)), (alloc))) -#define REFTABLE_ALLOC_GROW(x, nr, alloc) \ - do { \ - if ((nr) > alloc) { \ - alloc = 2 * (alloc) + 1; \ - if (alloc < (nr)) \ - alloc = (nr); \ - REFTABLE_REALLOC_ARRAY(x, alloc); \ - } \ - } while (0) + +static inline void *reftable_alloc_grow(void *p, size_t nelem, size_t elsize, + size_t *allocp) +{ + void *new_p; + size_t alloc = *allocp * 2 + 1; + if (alloc < nelem) + alloc = nelem; + new_p = reftable_realloc(p, st_mult(elsize, alloc)); + if (!new_p) + return p; + *allocp = alloc; + return new_p; +} + +#define REFTABLE_ALLOC_GROW(x, nr, alloc) ( \ + (nr) > (alloc) && ( \ + (x) = reftable_alloc_grow((x), (nr), sizeof(*(x)), &(alloc)), \ + (nr) > (alloc) \ + ) \ +) + +#define REFTABLE_ALLOC_GROW_OR_NULL(x, nr, alloc) do { \ + size_t reftable_alloc_grow_or_null_alloc = alloc; \ + if (REFTABLE_ALLOC_GROW((x), (nr), reftable_alloc_grow_or_null_alloc)) { \ + REFTABLE_FREE_AND_NULL(x); \ + alloc = 0; \ + } else { \ + alloc = reftable_alloc_grow_or_null_alloc; \ + } \ +} while (0) + #define REFTABLE_FREE_AND_NULL(p) do { reftable_free(p); (p) = NULL; } while (0) #ifndef REFTABLE_ALLOW_BANNED_ALLOCATORS diff --git a/reftable/block.c b/reftable/block.c index 0198078485..9858bbc7c5 100644 --- a/reftable/block.c +++ b/reftable/block.c @@ -53,7 +53,8 @@ static int block_writer_register_restart(struct block_writer *w, int n, if (2 + 3 * rlen + n > w->block_size - w->next) return -1; if (is_restart) { - REFTABLE_ALLOC_GROW(w->restarts, w->restart_len + 1, w->restart_cap); + REFTABLE_ALLOC_GROW_OR_NULL(w->restarts, w->restart_len + 1, + w->restart_cap); if (!w->restarts) return REFTABLE_OUT_OF_MEMORY_ERROR; w->restarts[w->restart_len++] = w->next; @@ -176,7 +177,8 @@ int block_writer_finish(struct block_writer *w) * is guaranteed to return `Z_STREAM_END`. */ compressed_len = deflateBound(w->zstream, src_len); - REFTABLE_ALLOC_GROW(w->compressed, compressed_len, w->compressed_cap); + REFTABLE_ALLOC_GROW_OR_NULL(w->compressed, compressed_len, + w->compressed_cap); if (!w->compressed) { ret = REFTABLE_OUT_OF_MEMORY_ERROR; return ret; @@ -235,8 +237,8 @@ int block_reader_init(struct block_reader *br, struct reftable_block *block, uLong src_len = block->len - block_header_skip; /* Log blocks specify the *uncompressed* size in their header. */ - REFTABLE_ALLOC_GROW(br->uncompressed_data, sz, - br->uncompressed_cap); + REFTABLE_ALLOC_GROW_OR_NULL(br->uncompressed_data, sz, + br->uncompressed_cap); if (!br->uncompressed_data) { err = REFTABLE_OUT_OF_MEMORY_ERROR; goto done; diff --git a/reftable/pq.c b/reftable/pq.c index 6ee1164dd3..5591e875e1 100644 --- a/reftable/pq.c +++ b/reftable/pq.c @@ -49,7 +49,7 @@ int merged_iter_pqueue_add(struct merged_iter_pqueue *pq, const struct pq_entry { size_t i = 0; - REFTABLE_ALLOC_GROW(pq->heap, pq->len + 1, pq->cap); + REFTABLE_ALLOC_GROW_OR_NULL(pq->heap, pq->len + 1, pq->cap); if (!pq->heap) return REFTABLE_OUT_OF_MEMORY_ERROR; pq->heap[pq->len++] = *e; diff --git a/reftable/record.c b/reftable/record.c index fb5652ed57..04429d23fe 100644 --- a/reftable/record.c +++ b/reftable/record.c @@ -246,8 +246,8 @@ static int reftable_ref_record_copy_from(void *rec, const void *src_rec, if (src->refname) { size_t refname_len = strlen(src->refname); - REFTABLE_ALLOC_GROW(ref->refname, refname_len + 1, - ref->refname_cap); + REFTABLE_ALLOC_GROW_OR_NULL(ref->refname, refname_len + 1, + ref->refname_cap); if (!ref->refname) { err = REFTABLE_OUT_OF_MEMORY_ERROR; goto out; @@ -385,7 +385,7 @@ static int reftable_ref_record_decode(void *rec, struct reftable_buf key, SWAP(r->refname, refname); SWAP(r->refname_cap, refname_cap); - REFTABLE_ALLOC_GROW(r->refname, key.len + 1, r->refname_cap); + REFTABLE_ALLOC_GROW_OR_NULL(r->refname, key.len + 1, r->refname_cap); if (!r->refname) { err = REFTABLE_OUT_OF_MEMORY_ERROR; goto done; @@ -839,7 +839,7 @@ static int reftable_log_record_decode(void *rec, struct reftable_buf key, if (key.len <= 9 || key.buf[key.len - 9] != 0) return REFTABLE_FORMAT_ERROR; - REFTABLE_ALLOC_GROW(r->refname, key.len - 8, r->refname_cap); + REFTABLE_ALLOC_GROW_OR_NULL(r->refname, key.len - 8, r->refname_cap); if (!r->refname) { err = REFTABLE_OUT_OF_MEMORY_ERROR; goto done; @@ -947,8 +947,8 @@ static int reftable_log_record_decode(void *rec, struct reftable_buf key, } string_view_consume(&in, n); - REFTABLE_ALLOC_GROW(r->value.update.message, scratch->len + 1, - r->value.update.message_cap); + REFTABLE_ALLOC_GROW_OR_NULL(r->value.update.message, scratch->len + 1, + r->value.update.message_cap); if (!r->value.update.message) { err = REFTABLE_OUT_OF_MEMORY_ERROR; goto done; diff --git a/reftable/stack.c b/reftable/stack.c index 634f0c5425..531660a49f 100644 --- a/reftable/stack.c +++ b/reftable/stack.c @@ -317,7 +317,9 @@ static int reftable_stack_reload_once(struct reftable_stack *st, * thus need to keep them alive here, which we * do by bumping their refcount. */ - REFTABLE_ALLOC_GROW(reused, reused_len + 1, reused_alloc); + REFTABLE_ALLOC_GROW_OR_NULL(reused, + reused_len + 1, + reused_alloc); if (!reused) { err = REFTABLE_OUT_OF_MEMORY_ERROR; goto done; @@ -949,8 +951,8 @@ int reftable_addition_add(struct reftable_addition *add, if (err < 0) goto done; - REFTABLE_ALLOC_GROW(add->new_tables, add->new_tables_len + 1, - add->new_tables_cap); + REFTABLE_ALLOC_GROW_OR_NULL(add->new_tables, add->new_tables_len + 1, + add->new_tables_cap); if (!add->new_tables) { err = REFTABLE_OUT_OF_MEMORY_ERROR; goto done; diff --git a/reftable/writer.c b/reftable/writer.c index 624e90fb53..740c98038e 100644 --- a/reftable/writer.c +++ b/reftable/writer.c @@ -254,7 +254,8 @@ static int writer_index_hash(struct reftable_writer *w, struct reftable_buf *has if (key->offset_len > 0 && key->offsets[key->offset_len - 1] == off) return 0; - REFTABLE_ALLOC_GROW(key->offsets, key->offset_len + 1, key->offset_cap); + REFTABLE_ALLOC_GROW_OR_NULL(key->offsets, key->offset_len + 1, + key->offset_cap); if (!key->offsets) return REFTABLE_OUT_OF_MEMORY_ERROR; key->offsets[key->offset_len++] = off; @@ -820,7 +821,7 @@ static int writer_flush_nonempty_block(struct reftable_writer *w) * Note that this also applies when flushing index blocks, in which * case we will end up with a multi-level index. */ - REFTABLE_ALLOC_GROW(w->index, w->index_len + 1, w->index_cap); + REFTABLE_ALLOC_GROW_OR_NULL(w->index, w->index_len + 1, w->index_cap); if (!w->index) return REFTABLE_OUT_OF_MEMORY_ERROR; diff --git a/t/unit-tests/t-reftable-basics.c b/t/unit-tests/t-reftable-basics.c index 65d50df091..990dc1a244 100644 --- a/t/unit-tests/t-reftable-basics.c +++ b/t/unit-tests/t-reftable-basics.c @@ -20,6 +20,11 @@ static int integer_needle_lesseq(size_t i, void *_args) return args->needle <= args->haystack[i]; } +static void *realloc_stub(void *p UNUSED, size_t size UNUSED) +{ + return NULL; +} + int cmd_main(int argc UNUSED, const char *argv[] UNUSED) { if_test ("binary search with binsearch works") { @@ -141,5 +146,56 @@ int cmd_main(int argc UNUSED, const char *argv[] UNUSED) check_int(in, ==, out); } + if_test ("REFTABLE_ALLOC_GROW works") { + int *arr = NULL, *old_arr; + size_t alloc = 0, old_alloc; + + check(!REFTABLE_ALLOC_GROW(arr, 1, alloc)); + check(arr != NULL); + check_uint(alloc, >=, 1); + arr[0] = 42; + + old_alloc = alloc; + old_arr = arr; + reftable_set_alloc(malloc, realloc_stub, free); + check(REFTABLE_ALLOC_GROW(arr, old_alloc + 1, alloc)); + check(arr == old_arr); + check_uint(alloc, ==, old_alloc); + + old_alloc = alloc; + reftable_set_alloc(malloc, realloc, free); + check(!REFTABLE_ALLOC_GROW(arr, old_alloc + 1, alloc)); + check(arr != NULL); + check_uint(alloc, >, old_alloc); + arr[alloc - 1] = 42; + + reftable_free(arr); + } + + if_test ("REFTABLE_ALLOC_GROW_OR_NULL works") { + int *arr = NULL; + size_t alloc = 0, old_alloc; + + REFTABLE_ALLOC_GROW_OR_NULL(arr, 1, alloc); + check(arr != NULL); + check_uint(alloc, >=, 1); + arr[0] = 42; + + old_alloc = alloc; + REFTABLE_ALLOC_GROW_OR_NULL(arr, old_alloc + 1, alloc); + check(arr != NULL); + check_uint(alloc, >, old_alloc); + arr[alloc - 1] = 42; + + old_alloc = alloc; + reftable_set_alloc(malloc, realloc_stub, free); + REFTABLE_ALLOC_GROW_OR_NULL(arr, old_alloc + 1, alloc); + check(arr == NULL); + check_uint(alloc, ==, 0); + reftable_set_alloc(malloc, realloc, free); + + reftable_free(arr); + } + return test_done(); } diff --git a/t/unit-tests/t-reftable-merged.c b/t/unit-tests/t-reftable-merged.c index a12bd0e1a3..60836f80d6 100644 --- a/t/unit-tests/t-reftable-merged.c +++ b/t/unit-tests/t-reftable-merged.c @@ -178,7 +178,7 @@ static void t_merged_refs(void) if (err > 0) break; - REFTABLE_ALLOC_GROW(out, len + 1, cap); + check(!REFTABLE_ALLOC_GROW(out, len + 1, cap)); out[len++] = ref; } reftable_iterator_destroy(&it); @@ -459,7 +459,7 @@ static void t_merged_logs(void) if (err > 0) break; - REFTABLE_ALLOC_GROW(out, len + 1, cap); + check(!REFTABLE_ALLOC_GROW(out, len + 1, cap)); out[len++] = log; } reftable_iterator_destroy(&it);