#include "physics_entity.h"
#include "camera.h"
#include "rigidbody.h"
#include "shape.h"
#include "render.h"
#include "debug.h"
#include "program.h"

void physics_entity_debug_draw(PhysicsEntity self) {
    RigidBody* body = self.tc->get_rigidbody(self.data);
    Shape* shape = self.tc->get_shape(self.data);
    Transform* transform = self.transformable->get_transform(self.data);

    shape_draw(shape, *transform);
    rigidbody_debug_draw_contacts(body);

    Vector lhs = transform->position;
    Vector rhs = vaddf(lhs, rigidbody_get_velocity(body));
    lhs = camera_world_to_pixel_point(&g_camera, lhs);
    rhs = camera_world_to_pixel_point(&g_camera, rhs);
    SDL_SetRenderDrawColor(g_renderer, 0, 255, 0, 255);
    SDL_RenderDrawLine(g_renderer, lhs.x, lhs.y, rhs.x, rhs.y);

    rhs = camera_world_to_pixel_point(&g_camera, vaddf(transform->position, rigidbody_get_force(body)));
    SDL_SetRenderDrawColor(g_renderer, 0, 255, 255, 255);
    SDL_RenderDrawLine(g_renderer, lhs.x, lhs.y, rhs.x, rhs.y);
}
static  inline
Vector _internal_calculate_contact_force(RigidBody* self, Contact* contact) {
    Collision hit = contact->hit;

    const Vector velocity = hit.velocity;
    const Vector normal = hit.normal;

    const float elasticity = 0.0;
    const float damping = 0.0;
    const float push = vdotf(normal, velocity);

    return ZeroVector; //vsubf(vmulff(normal, elasticity * fmaxf(1.0, -push)), vmulff(normal, damping * fminf(0.0, push)));
}

static inline
int _internal_default_contact_solver(RigidBody* body, Contact* contact, Transform pre_solve) {
    Collision hit = contact->hit;
    Transform* trans = rigidbody_get_transform(body);
    const Vector world_collision_point = vaddf(transform_point(&pre_solve, hit.point), hit.penetration_vector);
    const float current_dot = vdotf(hit.normal, vsubf(transform_point(trans, hit.point), world_collision_point));
    if(current_dot >= -0.0001)
        return 1;
    // the desired position is anywhere the overlapping vertex is further along the normal than the contact point
    const Vector target = vaddf(trans->position, vmulff(hit.normal, -current_dot));
    trans->position = vmovetowardsf(trans->position, target, 1.f);
    return 0;
}

void physics_entity_apply_collision_forces(PhysicsEntity self, List* contacts) {
    RigidBody* body = self.tc->get_rigidbody(self.data);
    // apply collision impulse
    list_foreach(Contact, contact, contacts)
        rigidbody_add_impulse(body, _internal_calculate_contact_force(body, contact), 1);
}

void physics_entity_solve_contacts(PhysicsEntity self, List* contacts) {
    physics_entity_apply_collision_forces(self, contacts);
    RigidBody* body = self.tc->get_rigidbody(self.data);
    const Transform pre_solve = *rigidbody_get_transform(body);
    // attempt to solve constraints
    int done;
    for(size_t iteration = 100; iteration != 0; --iteration) {
        done = 1;
        list_foreach(Contact, contact, contacts) {
            if(!_internal_default_contact_solver(body, contact, pre_solve))
                done = 0;
        }
        if(done)
            break;
        if(iteration == 1)
            LOG_WARNING("gave up on solving %zu contacts", contacts->len);
    }
    Vector dir = vnormalizedf(vsubf(rigidbody_get_transform(body)->position, pre_solve.position));
    Vector vel = rigidbody_get_velocity(body);
    float dot = vdotf(dir, vel);
    if(dot < 0)
        vel = vsubf(vel, vmulff(dir, dot));
    rigidbody_set_velocity(body, vel);
}

void physics_entity_update(PhysicsEntity self) {
    RigidBody* body = self.tc->get_rigidbody(self.data);

    ASSERT_RETURN(!visnanf(rigidbody_get_velocity(body)),, "Velocity is NaN (0)");

    List* contacts = rigidbody_get_contacts(body);
    if(contacts->len > 0) {
        self.tc->collision_solver(self.data, contacts);
        list_foreach(Contact, contact, contacts)
            self.tc->on_collision(self.data, contact->hit);
    }
    rigidbody_collect_contacts(body);

    ASSERT_RETURN(!visnanf(rigidbody_get_velocity(body)),, "Velocity is NaN (1)");
}