import {BoundsData} from "#template-nodes/geometry-processing/mesh-data"
import {CurvePoints} from "#template-nodes/interfaces/object-data"
import {Inlet, NotReady, Outlet} from "#template-nodes/runtime-graph/slots"
import {TypeDescriptors} from "#template-nodes/runtime-graph/type-descriptors"
import {NodeClassImpl} from "#template-nodes/runtime-graph/types"
import * as THREE from "three"

const TD = TypeDescriptors

const sampleBridgeCurveDescriptor = {
    inputCurvePoints: TD.inlet(
        TD.Nullable(
            TD.Identity<{
                points: Float32Array
                normals: Float32Array
                tangents: Float32Array
                segments: Float32Array
                scales: Float32Array
            }>(),
        ),
    ),
    holeBounds: TD.inlet(TD.Identity<BoundsData[]>()),
    bridgeBounds: TD.inlet(TD.Identity<BoundsData>()),
    closed: TD.inlet(TD.Primitive<boolean>()),
    outputCurvePoints: TD.outlet(
        TD.Nullable(
            TD.Identity<{
                points: Float32Array
                normals: Float32Array
                tangents: Float32Array
                segments: Float32Array
                scales: Float32Array
            }>(),
        ),
    ),
}

export class SampleBridgeCurve implements NodeClassImpl<typeof sampleBridgeCurveDescriptor, typeof SampleBridgeCurve> {
    static descriptor = sampleBridgeCurveDescriptor
    static uniqueName = "SampleBridgeCurve"

    inputCurvePoints!: Inlet<CurvePoints | null>
    holeBounds!: Inlet<BoundsData[]>
    bridgeBounds!: Inlet<BoundsData>
    closed!: Inlet<boolean>
    outputCurvePoints!: Outlet<CurvePoints | null>

    run() {
        if (this.inputCurvePoints === NotReady || this.holeBounds === NotReady || this.bridgeBounds === NotReady || this.closed === NotReady) {
            this.outputCurvePoints.emitIfChanged(NotReady)
            return
        }

        if (this.inputCurvePoints === null || this.holeBounds.length !== 2) {
            this.outputCurvePoints.emitIfChanged(null)
            return
        }

        const leftHole = this.holeBounds[0].centroid[2] < this.holeBounds[1].centroid[2] ? this.holeBounds[0] : this.holeBounds[1]
        const rightHole = this.holeBounds[0].centroid[2] < this.holeBounds[1].centroid[2] ? this.holeBounds[1] : this.holeBounds[0]
        const leftHoleDiameter = leftHole.aabb[1][2] - leftHole.aabb[0][2]
        const rightHoleDiameter = rightHole.aabb[1][2] - rightHole.aabb[0][2]
        const bridgeLength = Math.max(Math.abs(this.bridgeBounds.aabb[0][2]), Math.abs(this.bridgeBounds.aabb[1][2])) * 2

        const {points, normals, tangents, segments, scales} = this.inputCurvePoints

        const threePoints: THREE.Vector3[] = []
        const threeNormals: THREE.Vector3[] = []
        const threeLeftHole: THREE.Vector3[] = []
        const threeRightHole: THREE.Vector3[] = []
        for (let i = 0; i < segments.length; i++) {
            const position = new THREE.Vector3(points[i * 3], points[i * 3 + 1], points[i * 3 + 2])
            const normal = new THREE.Vector3(normals[i * 3], normals[i * 3 + 1], normals[i * 3 + 2])
            const tangent = new THREE.Vector3(tangents[i * 3], tangents[i * 3 + 1], tangents[i * 3 + 2]).projectOnPlane(normal).normalize()
            const bitangent = new THREE.Vector3().crossVectors(normal, tangent).normalize()

            const matrix = new THREE.Matrix4()
                .makeBasis(bitangent, normal, tangent)
                .setPosition(position)
                .scale(new THREE.Vector3(1, 1, scales[i]))

            threePoints.push(position)
            threeNormals.push(normal)

            threeLeftHole.push(new THREE.Vector3(...leftHole.centroid).applyMatrix4(matrix))
            threeRightHole.push(new THREE.Vector3(...rightHole.centroid).applyMatrix4(matrix))
        }

        const newPoints: THREE.Vector3[] = []
        const newNormals: THREE.Vector3[] = []
        const newTangents: THREE.Vector3[] = []
        const newSegments: number[] = []
        const newScales: number[] = []

        const interpolateNormal = (n0: THREE.Vector3, n1: THREE.Vector3) => {
            const quatStart = new THREE.Quaternion().setFromUnitVectors(new THREE.Vector3(1, 0, 0), n0)
            const quatEnd = new THREE.Quaternion().setFromUnitVectors(new THREE.Vector3(1, 0, 0), n1)
            const interpolatedQuat = quatStart.slerp(quatEnd, 0.5)
            return new THREE.Vector3(1, 0, 0).applyQuaternion(interpolatedQuat).normalize()
        }

        for (let i = 1; i < (this.closed ? segments.length + 1 : segments.length); i++) {
            const leftIndex = i - 1
            const rightIndex = i % segments.length

            const p = threeRightHole[leftIndex].clone().lerp(threeLeftHole[rightIndex], 0.5)
            newPoints.push(p)

            const n = interpolateNormal(threeNormals[leftIndex], threeNormals[rightIndex])
            newNormals.push(n)

            const t = threeLeftHole[rightIndex].clone().sub(threeRightHole[leftIndex]).normalize()
            newTangents.push(t)

            const segment = segments[leftIndex]
            newSegments.push(segment)

            const distance = threeRightHole[leftIndex].distanceTo(threeLeftHole[rightIndex])

            const scale = (distance + 0.5 * (leftHoleDiameter + rightHoleDiameter)) / bridgeLength
            newScales.push(scale)
        }

        this.outputCurvePoints.emitIfChanged({
            points: new Float32Array(newPoints.flatMap((vector) => [vector.x, vector.y, vector.z])),
            normals: new Float32Array(newNormals.flatMap((vector) => [vector.x, vector.y, vector.z])),
            tangents: new Float32Array(newTangents.flatMap((vector) => [vector.x, vector.y, vector.z])),
            segments: new Float32Array(newSegments),
            scales: new Float32Array(newScales),
        })
    }
}
