#include "list.h"
#include "stdint.h"
#include "stdlib.h"
#include "string.h"
#include "debug.h"

#define LIST_DEFAULT_RESERVE 4

List list_init(size_t element_size) {
    return list_with_len(element_size, 0);
}

List list_with_len(size_t element_size, size_t len) {
    List self = {
        .element_size = element_size,
        .cap = 0,
        .len = 0,
        .data = NULL,
    };
    if(len != 0) {
        list_set_len(&self, len);
        if (self.data == NULL) {
            LOG_ERROR("Failed to allocate list with starting capacity of %d", LIST_DEFAULT_RESERVE);
            self.cap = 0;
        }
    }
    return self;
}

List list_copy(const List* source) {
    List self = list_init(source->element_size);
    list_set_len(&self, source->len);
    if(self.cap > 0) {
        memcpy(self.data, source->data, source->element_size * source->len);
    } else {
        LOG_ERROR("Failed to reserve space");
    }
    return self;
}

void list_empty(List* self) {
    if(self->data != NULL && self->cap != 0)
        free(self->data);
    self->data = NULL;
    self->cap = 0;
    self->len = 0;
}

void list_reserve(List* self, size_t at_least) {
    if(at_least < self->cap)
        return;

    size_t new_cap = self->cap > 0 ? self->cap : LIST_DEFAULT_RESERVE;
    while(at_least >= new_cap) {
        new_cap *= 2;
    }

    void* new;
    if(self->data == NULL)
        new = malloc(new_cap * self->element_size);
    else
        new = realloc(self->data, new_cap * self->element_size);
    ASSERT_RETURN(new != NULL,, "Failed to reserve space for %zu extra elements in list", new_cap);

    self->data = new;
    self->cap = new_cap;
}

void list_set_len(List* self, size_t len) {
    list_reserve(self, len);
    self->len = len;
}

void* list_at_unchecked(List* self, size_t at) {
    union {
        uint8_t* as_byte;
        void* as_void;
    } data = {
        .as_void = self->data
    };

    return data.as_byte + self->element_size * at;
}

void* list_at(List* self, size_t at) {
    ASSERT_RETURN(at < self->len, NULL, "Index %zu out of bounds", at);
    return list_at_unchecked(self, at);
}

size_t list_add(List* self, void* item) {
    list_reserve(self, self->len + 1);
    union {
        uint8_t* as_byte;
        void* as_void;
    } data = {
        .as_void = self->data
    };

    uint8_t* into = data.as_byte + (self->element_size * self->len);

    memcpy(into, item, self->element_size);
    ++self->len;

    return self->len - 1;
}

void list_insert(List* self, void* item, size_t at) {
    list_reserve(self, self->len + 1);

    if(at >= self->len) {
        list_add(self, item);
        return;
    }
    
    union {
        uint8_t* as_byte;
        void* as_void;
    } data = {
        .as_void = self->data
    };
    uint8_t* from = data.as_byte + (self->element_size * at);
    uint8_t* into = data.as_byte + (self->element_size * (at + 1));
    uint8_t* end = data.as_byte + (self->element_size * self->len);
    memmove(into, from, end - from);
    memcpy(from, item, self->element_size);
    ++self->len;
}

void list_erase(List* self, size_t at) {
    ASSERT_RETURN(at < self->len,, "Index %zu out of bounds", at);

    union {
        uint8_t* as_byte;
        void* as_void;
    } data = {
        .as_void = self->data
    };

    uint8_t* into = data.as_byte + at * self->element_size;
    uint8_t* from = data.as_byte + (at + 1) * self->element_size;

    if(at < self->len - 1)
        memmove(into, from, (self->len - at) * self->element_size);
    --self->len;

    size_t new_cap = self->cap;
    while(new_cap > self->len) {
        new_cap /= 2;
    }
    new_cap *= 2;


    if(new_cap == self->cap)
        return;

    void* shrunk = realloc(self->data, new_cap * self->element_size);
    ASSERT_RETURN(shrunk != NULL || new_cap == 0,, "Failed to shrink List to %zu", new_cap);

    self->data = shrunk;
    self->cap = new_cap;
}

void* list_iterator_begin(List* self) {
    return list_at_unchecked(self, 0);
}

void* list_iterator_end(List* self) {
    return list_at_unchecked(self, self->len);
}

size_t list_find(List* self, void* query) {
    union {
        uint8_t* as_byte;
        void* as_void;
    } data = {
        .as_void = self->data
    };
    for(size_t i = 0; i < self->len; ++i) {
        if(memcmp(data.as_byte + (i * self->element_size), query, self->element_size) == 0)
            return i;
    }
    return self->len;
}