import {DeclareObjectNode, ObjectNode, TemplateObjectNode} from "#template-nodes/declare-object-node"
import {SceneNodes} from "#template-nodes/interfaces/scene-object"
import {CurveObjectData, CurvePoints, isCurveObjectData, ObjectData} from "#template-nodes/interfaces/object-data"
import {CurveLike, curveLike, isCurveLike, ObjectLike, objectLike} from "#template-nodes/node-types"
import {MeshCurve} from "#template-nodes/nodes/mesh-curve"
import {NamedNodeParameters, namedNodeParameters} from "#template-nodes/nodes/named-node"
import {BuilderInlet} from "#template-nodes/runtime-graph/graph-builder"
import {SampleBridgeCurve} from "#template-nodes/runtime-graph/nodes/sample-bridge-curve"
import {SampleCurve} from "#template-nodes/runtime-graph/nodes/sample-curve"
import {mergeBounds, transformBounds} from "#template-nodes/utils/scene-geometry-utils"
import {skipped, visitNone, VisitorNodeVersion} from "@cm/graph/declare-visitor-node"
import {CircularRefNode, versionChain} from "@cm/graph/node-graph"
import {registerNode} from "@cm/graph/register-node"
import {Matrix4} from "@cm/math"
import {z} from "zod"
import {nodeInstance} from "@cm/graph"
import {NodesInput} from "#template-nodes/nodes/input"
import {GraphBuilderScope} from "#template-nodes/runtime-graph/graph-builder-scope"
import {BoundsData} from "#template-nodes/geometry-processing/mesh-data"

const seamParameters = namedNodeParameters.merge(
    z.object({
        item: objectLike.nullable(),
        bridge: objectLike.optional(),
        repeatSize: z.number().or(objectLike).optional(),
        repeatSizeFactor: z.number().optional(),
        offsetSize: z.number().or(objectLike).optional(),
        offsetFactor: z.number().optional(),
        curves: z.array(curveLike.or(nodeInstance(NodesInput))),
        allowScaling: z.boolean(),
    }),
)
export type SeamParameters = {
    item: ObjectLike | null
    bridge: ObjectLike | undefined
    repeatSize: number | ObjectLike | undefined
    repeatSizeFactor: number | undefined
    offsetSize: number | ObjectLike | undefined
    offsetFactor: number | undefined
    curves: (CurveLike | NodesInput)[]
    allowScaling: boolean
}

type V0 = ObjectNode &
    NamedNodeParameters & {
        item: ObjectLike | null
        repeatSize: number | ObjectLike | undefined
        curve: MeshCurve | null
        allowScaling: boolean
    }
type V1 = Omit<V0, "curve"> & {curves: CurveLike[]}
export const v0: VisitorNodeVersion<V0, V1> = {
    toNextVersion: (parameters) => {
        const {curve, ...rest} = parameters
        if (curve instanceof CircularRefNode) throw new Error("Cannot resolve circular references when going from v0 to v1 in seam node")
        return {
            curves: curve ? [curve] : [],
            ...parameters,
        }
    },
}

type SeamCurve = {id: string; transform: Matrix4; meshCurveControl: SceneNodes.MeshCurveControl; curvePoints: CurvePoints | null}

function mapper(
    curveScope: GraphBuilderScope,
    curveData: CurveObjectData,
    ctx: BuilderInlet<{
        allowScaling: boolean
        segmentLength: number
        offsetLength: number | undefined
        bridgeMeshes: SceneNodes.Mesh[]
        bounds: BoundsData[]
    }>,
    curveIndex: number,
): BuilderInlet<{
    seamCurve: SeamCurve
    bridgeCurve: SeamCurve | null
} | null> {
    if (curveData.visible === false) return null

    const {curvePoints} = curveScope.node(SampleCurve, {
        closed: curveScope.get(curveData, "closed"),
        allowScaling: curveScope.get(ctx, "allowScaling"),
        controlPoints: curveScope.get(curveData, "controlPoints"),
        segmentLength: curveScope.get(ctx, "segmentLength"),
        offsetLength: curveScope.get(ctx, "offsetLength"),
    })

    const meshCurveControl = curveScope.pureLambda(
        curveData,
        (curveData) => {
            const meshCurveControl = [...curveData.preDisplayList, ...curveData.displayList].filter(SceneNodes.MeshCurveControl.is)
            if (meshCurveControl.length !== 1) throw new Error("MeshCurveControl not found")
            return meshCurveControl[0]
        },
        "meshCurveControl",
    )

    const seamCurve = curveScope.struct<SeamCurve>("SeamCurve", {
        id: `${curveIndex}`,
        transform: curveScope.get(curveData, "matrix"),
        meshCurveControl,
        curvePoints,
    })

    const bridgeMeshes = curveScope.get(ctx, "bridgeMeshes")
    const bounds = curveScope.get(ctx, "bounds")

    const {outputCurvePoints} = curveScope.node(SampleBridgeCurve, {
        inputCurvePoints: curveScope.pureLambda(
            curveScope.tuple(curveData, bridgeMeshes, curvePoints),
            ([curveData, bridgeMeshes, curvePoints]) => {
                if (bridgeMeshes.length === 0 || curveData.visible === false) return null
                return curvePoints
            },
            "bridgeCurvePoints",
        ),
        holeBounds: bounds,
        bridgeBounds: curveScope.pureLambda(
            bridgeMeshes,
            (bridgeMeshes) => {
                const bounds = mergeBounds(bridgeMeshes.map((mesh) => transformBounds(mesh.completeMeshData.reified.bounds, mesh.transform)))
                return bounds
            },
            "bridgeBounds",
        ),
        closed: curveScope.pureLambda(
            curveData,
            (curveData) => {
                return curveData.closed && curveData.controlPoints.length > 2
            },
            "closed",
        ),
    })

    const bridgeCurve = curveScope.pureLambda<[CurvePoints | null, SeamCurve], SeamCurve | null>(
        curveScope.tuple(outputCurvePoints, seamCurve),
        ([outputCurvePoints, seamCurve]) => {
            if (!outputCurvePoints) return null
            return {...seamCurve, curvePoints: outputCurvePoints}
        },
        "bridgeCurve",
    )

    return curveScope.struct("Result", {
        seamCurve,
        bridgeCurve,
    })
}

@registerNode
export class Seam extends DeclareObjectNode(
    {parameters: seamParameters},
    {
        onVisited: {
            onFilterActive: ({parameters}) => {
                if (parameters.item === null || parameters.curves.length === 0) return skipped
                return visitNone(parameters)
            },
            onCompile: function (this: SeamFwd, {context, parameters}) {
                const {item, bridge, curves, repeatSize, allowScaling} = parameters

                if (item === null || curves.length === 0) return skipped

                const {evaluator} = context

                const scope = evaluator.getScope(this)

                const [objectData, objectDataInvalid] = scope.branch(evaluator.evaluateObject(scope, item))
                const meshesAndBounds = scope.pureLambda(
                    objectData,
                    ({preDisplayList, displayList, matrix}) => {
                        const inverseTransform = matrix.inverse()
                        const meshes = [...preDisplayList, ...displayList].filter(SceneNodes.Mesh.is).map<SceneNodes.Mesh>(({id, transform, ...mesh}) => ({
                            ...mesh,
                            id: `${id}-seamItem`,
                            transform: inverseTransform.multiply(transform),
                            receiveRealtimeShadows: false,
                        }))
                        const bounds = meshes.map((mesh) => transformBounds(mesh.completeMeshData.reified.bounds, mesh.transform))
                        return {meshes, bounds}
                    },
                    "item",
                )
                const meshes = scope.get(meshesAndBounds, "meshes")
                const bounds = scope.get(meshesAndBounds, "bounds")

                const bridgeMeshes = scope.pureLambda(
                    evaluator.evaluateObject(scope, bridge ?? null),
                    (bridge) => {
                        if (!bridge) return []
                        const {preDisplayList, displayList, matrix} = bridge
                        const inverseTransform = matrix.inverse()
                        const meshes = [...preDisplayList, ...displayList].filter(SceneNodes.Mesh.is).map<SceneNodes.Mesh>(({id, transform, ...mesh}) => ({
                            ...mesh,
                            id: `${id}-origin`,
                            transform: inverseTransform.multiply(transform),
                        }))
                        return meshes
                    },
                    "bridge",
                )

                const repeat = typeof repeatSize === "number" ? repeatSize : evaluator.evaluateObject(scope, repeatSize ?? null)
                const segmentLength = scope.pureLambda(
                    scope.tuple(meshes, repeat, parameters.repeatSizeFactor),
                    ([meshes, repeat, repeatSizeFactor]) => {
                        const factor = repeatSizeFactor ?? 1
                        if (typeof repeat === "number") return repeat * factor

                        const repeatMeshes = (() => {
                            if (repeat) {
                                const inverseTransform = repeat.matrix.inverse()
                                return [...repeat.preDisplayList, ...repeat.displayList]
                                    .filter(SceneNodes.Mesh.is)
                                    .map<SceneNodes.Mesh>((mesh) => ({...mesh, transform: inverseTransform.multiply(mesh.transform)}))
                            } else return meshes
                        })()

                        const bounds = mergeBounds(repeatMeshes.map((mesh) => transformBounds(mesh.completeMeshData.reified.bounds, mesh.transform)))
                        const size = Math.max(Math.abs(bounds.aabb[0][2]), Math.abs(bounds.aabb[1][2])) * 2
                        return size * factor
                    },
                    "segmentLength",
                )

                const offset =
                    typeof parameters.offsetSize === "number" ? parameters.offsetSize : evaluator.evaluateObject(scope, parameters.offsetSize ?? null)
                const offsetLength = scope.pureLambda(
                    scope.tuple(offset, parameters.offsetFactor),
                    ([offset, offsetFactor]) => {
                        const factor = offsetFactor ?? 1
                        if (typeof offset === "number") return offset * factor
                        if (!offset) return undefined

                        const offsetMeshes = (() => {
                            const inverseTransform = offset.matrix.inverse()
                            return [...offset.preDisplayList, ...offset.displayList]
                                .filter(SceneNodes.Mesh.is)
                                .map<SceneNodes.Mesh>((mesh) => ({...mesh, transform: inverseTransform.multiply(mesh.transform)}))
                        })()

                        const bounds = mergeBounds(offsetMeshes.map((mesh) => transformBounds(mesh.completeMeshData.reified.bounds, mesh.transform)))
                        const size = Math.max(Math.abs(bounds.aabb[0][2]), Math.abs(bounds.aabb[1][2])) * 2
                        return size * factor
                    },
                    "offsetLength",
                )

                const curveData = scope.filterInvalid(
                    scope.merge(
                        curves.map((curveNode, idx) => {
                            const curveScope = scope.scope(`evalCurve${idx}`)
                            if (isCurveLike(curveNode)) {
                                return curveScope.list([evaluator.evaluateCurve(curveScope, curveNode)])
                            } else {
                                return curveScope.pureLambda(
                                    evaluator.evaluateNodes(curveScope, [curveNode]),
                                    (evaluatedNodes) => {
                                        return evaluatedNodes.map((evaluatedNode) => {
                                            if (evaluatedNode.type === "object") {
                                                const objectData = evaluatedNode.value as ObjectData | null
                                                if (objectData && isCurveObjectData(objectData)) return objectData
                                            }

                                            return null
                                        })
                                    },
                                    "curveData",
                                )
                            }
                        }),
                    ),
                )

                const seamCurveData = scope.filterInvalid(
                    scope.map(curveData, mapper, scope.struct("Context", {allowScaling, segmentLength, offsetLength, bridgeMeshes, bounds})),
                )

                this.setupObject(
                    scope,
                    context,
                    "Seam",
                    undefined,
                    undefined,
                    (objectProps) => {
                        return scope.pureLambda(
                            scope.tuple(seamCurveData, objectProps, meshes, bridgeMeshes),
                            ([seamCurveData, {transform, id, $id, ...rest}, meshes, bridgeMeshes]) => {
                                const seamCurves = seamCurveData.map((curve) => curve.seamCurve)
                                const bridgeCurves = seamCurveData.map((curve) => curve.bridgeCurve).filter((curve): curve is SeamCurve => curve !== null)

                                return [
                                    ...seamCurves.map<SceneNodes.Seam>(({id: curveId, curvePoints, meshCurveControl, transform}) => ({
                                        type: "Seam",
                                        id: `${id}_${curveId}`,
                                        $id: `${$id}_${curveId}`,
                                        ...rest,
                                        item: meshes,
                                        curvePoints,
                                        meshCurveControl,
                                        transform,
                                    })),
                                    ...bridgeCurves.map<SceneNodes.Seam>(({id: curveId, curvePoints, meshCurveControl, transform}) => ({
                                        type: "Seam",
                                        id: `${id}_bridge_${curveId}`,
                                        $id: `${$id}_bridge_${curveId}`,
                                        ...rest,
                                        item: bridgeMeshes,
                                        curvePoints,
                                        meshCurveControl,
                                        transform,
                                    })),
                                ]
                            },
                            "Seam",
                        )
                    },
                    objectDataInvalid,
                )

                return visitNone(parameters)
            },
        },
    },
    {nodeClass: "Seam", versionChain: versionChain([v0])},
) {}

export type SeamFwd = TemplateObjectNode<SeamParameters>
