import {ControlPoint, CurvePoints} from "#template-nodes/interfaces/object-data"
import {Inlet, NotReady, Outlet} from "#template-nodes/runtime-graph/slots"
import {TypeDescriptors} from "#template-nodes/runtime-graph/type-descriptors"
import {NodeClassImpl} from "#template-nodes/runtime-graph/types"
import {Vector3} from "@cm/math"
import * as THREE from "three"

const TD = TypeDescriptors

const sampleCurveDescriptor = {
    closed: TD.inlet(TD.Primitive<boolean | "repeat">()),
    allowScaling: TD.inlet(TD.Primitive<boolean>()),
    controlPoints: TD.inlet(TD.Identity<ControlPoint[]>()),
    segmentLength: TD.inlet(TD.Primitive<number>()),
    offsetLength: TD.inlet(TD.Primitive<number | undefined>()),
    curvePoints: TD.outlet(
        TD.Nullable(
            TD.Identity<{
                points: Float32Array
                normals: Float32Array
                tangents: Float32Array
                segments: Float32Array
                scales: Float32Array
            }>(),
        ),
    ),
}

function toThreeVector(vector: Vector3): THREE.Vector3 {
    return new THREE.Vector3(vector.x, vector.y, vector.z)
}

class IndexedCurvePath extends THREE.CurvePath<THREE.Vector3> {
    constructor() {
        super()
    }

    private getIndexedVector(t: number, type: "position" | "tangent", optionalTarget?: THREE.Vector3) {
        const d = t * this.getLength()
        const curveLengths = this.getCurveLengths()
        let i = 0

        while (i < curveLengths.length) {
            if (curveLengths[i] >= d) {
                const diff = curveLengths[i] - d
                const curve = this.curves[i]

                const segmentLength = curve.getLength()
                const u = segmentLength === 0 ? 0 : 1 - diff / segmentLength

                const t = curve.getUtoTmapping(u, undefined as unknown as number)

                return {vector: type === "position" ? curve.getPoint(t, optionalTarget) : curve.getTangent(t, optionalTarget), curveId: i, t}
            }

            i++
        }

        throw new Error("Invalid t value")
    }

    getSpacedIndexedVector(divisions: number, type: "position" | "tangent") {
        const points: {
            vector: THREE.Vector3
            curveId: number
            t: number
        }[] = []

        for (let i = 0; i <= divisions; i++) {
            points.push(this.getIndexedVector(i / divisions, type))
        }

        if (this.autoClose && points.length > 0) points.push(points[0])

        return points
    }

    sampleIndexedVector(tSamples: number[], type: "position" | "tangent") {
        const points = tSamples.map((t) => this.getIndexedVector(t, type))

        if (this.autoClose && points.length > 0) points.push(points[0])

        return points
    }
}

export class SampleCurve implements NodeClassImpl<typeof sampleCurveDescriptor, typeof SampleCurve> {
    static descriptor = sampleCurveDescriptor
    static uniqueName = "SampleCurve"
    closed!: Inlet<boolean | "repeat">
    allowScaling!: Inlet<boolean>
    controlPoints!: Inlet<ControlPoint[]>
    segmentLength!: Inlet<number>
    offsetLength!: Inlet<number | undefined>
    curvePoints!: Outlet<CurvePoints | null>

    run() {
        if (
            this.closed === NotReady ||
            this.allowScaling === NotReady ||
            this.controlPoints === NotReady ||
            this.segmentLength === NotReady ||
            this.offsetLength === NotReady
        ) {
            this.curvePoints.emitIfChanged(NotReady)
            return
        }

        const closed = this.closed === "repeat" || this.closed

        const curvePath = new IndexedCurvePath()

        const indexedControlPoints = this.controlPoints.map((controlPoint, index) => ({...controlPoint, index}))
        const indexedControlPointsPerCurve: (typeof indexedControlPoints)[] = []

        if (indexedControlPoints.length >= 2) {
            const firstCornerIndex = indexedControlPoints.findIndex((point) => point.corner)

            if (firstCornerIndex !== -1) {
                const reshuffledControlPoints = closed
                    ? indexedControlPoints.slice(firstCornerIndex).concat(indexedControlPoints.slice(0, firstCornerIndex))
                    : indexedControlPoints

                let currentCurvePoints: typeof indexedControlPoints = []
                for (let i = 0; i < reshuffledControlPoints.length; i++) {
                    const currentCurvePoint = reshuffledControlPoints[i]
                    const {corner} = currentCurvePoint

                    currentCurvePoints.push(currentCurvePoint)

                    if ((corner && i !== 0) || i === reshuffledControlPoints.length - 1) {
                        if (i === reshuffledControlPoints.length - 1 && closed) {
                            if (currentCurvePoints.length >= 2) if (!corner) currentCurvePoints.push(reshuffledControlPoints[0])
                        }
                        if (currentCurvePoints.length >= 2) {
                            const curve = new THREE.CatmullRomCurve3(
                                currentCurvePoints.map(({position}) => toThreeVector(position)),
                                false,
                                "centripetal",
                            )
                            curvePath.add(curve)
                            indexedControlPointsPerCurve.push(currentCurvePoints)
                        }
                        currentCurvePoints = [currentCurvePoint]
                    }
                }

                if (closed && currentCurvePoints.length === 1) {
                    const currentCurvePoint = currentCurvePoints[0]
                    const {corner} = currentCurvePoint

                    if (corner) {
                        currentCurvePoints.push(reshuffledControlPoints[0])

                        const curve = new THREE.CatmullRomCurve3(
                            currentCurvePoints.map(({position}) => toThreeVector(position)),
                            false,
                            "centripetal",
                        )
                        curvePath.add(curve)
                        indexedControlPointsPerCurve.push(currentCurvePoints)
                    }
                }
            } else {
                const curve = new THREE.CatmullRomCurve3(
                    indexedControlPoints.map(({position}) => toThreeVector(position)),
                    closed && indexedControlPoints.length > 2,
                    "centripetal",
                )
                curvePath.add(curve)
                indexedControlPointsPerCurve.push(indexedControlPoints)
            }
        }

        if (curvePath.curves.length > 0) {
            curvePath.autoClose = this.closed === "repeat"

            const segmentLength = this.segmentLength
            const offsetLength = this.offsetLength ?? 0
            const allowScaling = this.allowScaling

            const accumulatedCurveLengths = curvePath.getCurveLengths()
            const totalCurveLength = curvePath.getLength()

            const tSamples: number[] = []
            const scales: number[] = []
            curvePath.curves.forEach((curve, i) => {
                const currentCurveLength = curve.getLength()
                if (currentCurveLength > segmentLength || allowScaling) {
                    const curveLengthOffset = i > 0 ? accumulatedCurveLengths[i - 1] : 0
                    const numSamples = allowScaling
                        ? Math.max(Math.round(currentCurveLength / segmentLength), 1)
                        : Math.max(Math.floor(currentCurveLength / segmentLength), 0)
                    const unusedSpace = currentCurveLength - numSamples * segmentLength
                    const addedStepSize = unusedSpace / (numSamples + 1)
                    const stepSize = segmentLength + addedStepSize
                    const currentScale = allowScaling ? currentCurveLength / (numSamples * segmentLength) : 1.0

                    for (let j = 0; j < numSamples; j++) {
                        const localT = offsetLength + segmentLength / 2 + addedStepSize + j * stepSize
                        if (localT < 0 || localT > currentCurveLength - offsetLength) continue
                        const t = (curveLengthOffset + localT) / totalCurveLength
                        tSamples.push(t)
                        scales.push(currentScale)
                    }
                }
            })

            const points = curvePath.sampleIndexedVector(tSamples, "position")

            if (curvePath.autoClose && scales.length > 0) scales.push(scales[0])

            const normals = points.map(({curveId, t}) => {
                const indexedControlPoints = indexedControlPointsPerCurve[curveId]
                const {closed} = curvePath.curves[curveId] as THREE.CatmullRomCurve3
                const index = t * (indexedControlPoints.length - (closed ? 0 : 1))

                const start = indexedControlPoints[Math.floor(index)]
                const end = indexedControlPoints[Math.ceil(index) % indexedControlPoints.length]

                const localT = index - Math.floor(index)

                const quatStart = new THREE.Quaternion().setFromUnitVectors(new THREE.Vector3(1, 0, 0), start.normal)
                const quatEnd = new THREE.Quaternion().setFromUnitVectors(new THREE.Vector3(1, 0, 0), end.normal)

                const interpolatedQuat = quatStart.slerp(quatEnd, localT)

                return new THREE.Vector3(1, 0, 0).applyQuaternion(interpolatedQuat).normalize()
            })

            const tangents = curvePath.sampleIndexedVector(tSamples, "tangent")

            const segments = points.map(({curveId, t}) => {
                const indexedControlPoints = indexedControlPointsPerCurve[curveId]
                const {closed} = curvePath.curves[curveId] as THREE.CatmullRomCurve3

                const index = t * (indexedControlPoints.length - (closed ? 0 : 1))

                const start = indexedControlPoints[Math.floor(index)]

                const localT = index - Math.floor(index)

                return start.index * (1 - localT) + (start.index + 1) * localT
            })

            this.curvePoints.emitIfChanged({
                points: new Float32Array(points.flatMap(({vector}) => [vector.x, vector.y, vector.z])),
                normals: new Float32Array(normals.flatMap((normal) => [normal.x, normal.y, normal.z])),
                tangents: new Float32Array(tangents.flatMap(({vector}) => [vector.x, vector.y, vector.z])),
                segments: new Float32Array(segments),
                scales: new Float32Array(scales),
            })
        } else {
            this.curvePoints.emitIfChanged(null)
        }
    }
}

type ControlPointData = {
    position: [number, number, number]
    normal: [number, number, number]
    corner: boolean
}

const getControlPointsDescriptor = {
    controlPointsData: TD.inlet<ControlPointData[]>(TD.JSON()),
    controlPoints: TD.outlet<
        {
            position: Vector3
            normal: Vector3
            corner: boolean
        }[]
    >(TD.Identity()),
}

export class GetControlPoints implements NodeClassImpl<typeof getControlPointsDescriptor, typeof GetControlPoints> {
    static descriptor = getControlPointsDescriptor
    static uniqueName = "GetControlPoints"
    controlPointsData!: Inlet<ControlPointData[]>
    controlPoints!: Outlet<ControlPoint[]>

    run() {
        if (this.controlPointsData === NotReady) return

        const transformedControlPoints = this.controlPointsData.map((controlPoint) => ({
            position: Vector3.fromArray(controlPoint.position),
            normal: Vector3.fromArray(controlPoint.normal),
            corner: controlPoint.corner,
        }))

        this.controlPoints.emitIfChanged(transformedControlPoints)
    }
}
