import {nodeInstance, NodeParameters, registerNode, skipped, visitNone} from "@cm/graph"
import {NamedNodeParameters, namedNodeParameters} from "#template-nodes/nodes/named-node"
import {CurveLike, curveLike, isCurveLike, meshLike} from "#template-nodes/node-types"
import {DeclareMeshNodeTS, TemplateMeshNode} from "#template-nodes/declare-mesh-node"
import {z} from "zod"
import {
    geomToMesh,
    meshToGeom,
    Operators,
    AttributeRef,
    continuousCurveAttr,
    Primitives,
    GeomBuilderContext,
} from "#template-nodes/geometry-processing/geometry-graph"
import {MeshNodes} from "@cm/render-nodes"
import {hashObject} from "@cm/utils/hashing"
import {GenerateMesh} from "#template-nodes/runtime-graph/nodes/generate-mesh"
import {StoredMesh} from "#template-nodes/nodes/stored-mesh"
import {ProceduralMesh} from "#template-nodes/nodes/procedural-mesh"
import {MeshInputFwd, NodesInput} from "#template-nodes/nodes/input"
import {MeshSwitchFwd} from "#template-nodes/nodes/switch"
import {MeshOutputFwd} from "#template-nodes/nodes/output"
import {ObjectData, isCurveObjectData} from "#template-nodes/interfaces/object-data"

function applyCutToGeometry(
    baseMeshGraph: MeshNodes.Mesh,
    curvePoints: (Float32Array | undefined)[],
    curveSpace: "position" | "uv",
    numUVChannels: number,
    cutUVChannel: number,
): MeshNodes.Mesh {
    const numGroups = curvePoints.length

    const ctx = new GeomBuilderContext()

    if (cutUVChannel >= numUVChannels) {
        throw new Error("Invalid UV channel")
    }

    if (!curvePoints) {
        return baseMeshGraph
    }

    const baseMesh = meshToGeom(ctx, baseMeshGraph, numUVChannels)

    let curveRegions: [AttributeRef, AttributeRef][] = []
    const groupIDs: number[] = []
    for (let groupIdx = 0; groupIdx < numGroups; groupIdx++) {
        const points = curvePoints[groupIdx]
        if (points) {
            let curveUV: AttributeRef
            if (curveSpace === "uv") {
                curveUV = Primitives.curve(ctx, {
                    uv: continuousCurveAttr(points, 2),
                }).uv
            } else {
                const baseCurve = Primitives.curve(ctx, {
                    position: continuousCurveAttr(points, 3),
                })
                curveUV = Operators.project(baseCurve.position, baseCurve, baseMesh.position, baseMesh).uvs[cutUVChannel].weld(1e-4)
            }
            const curveGroup = curveUV.constInt(groupIdx)
            curveRegions.push([curveUV, curveGroup])
        }
        groupIDs.push(groupIdx)
    }

    const [curveUV, curveGroup] = Operators.merge(...curveRegions)

    const [cutMesh, windingAttrs] = Operators.cut(baseMesh.uvs[cutUVChannel], baseMesh, curveUV, curveGroup, groupIDs)

    let materialID = windingAttrs[0].constInt(0)

    for (let groupIdx = 0; groupIdx < numGroups; groupIdx++) {
        const winding = windingAttrs[groupIdx]
        const cond = winding.gt(Operators.getCutThreshold(winding))
        materialID = cond.select(groupIdx + 1, materialID)
    }
    materialID = materialID.add(0.5).castInt()

    return geomToMesh({...cutMesh, materialID})
}

const cutMeshParameters = namedNodeParameters.merge(
    z.object({
        mesh: meshLike.nullable(),
        curves: z.array(curveLike.or(nodeInstance(NodesInput))),
        uvChannel: z.number().int().nonnegative().default(0),
    }),
) as z.ZodType<NodeParameters>

type MeshLikeNoCutMesh = StoredMesh | ProceduralMesh | MeshInputFwd | MeshSwitchFwd | MeshOutputFwd

export type CutMeshParameters = NamedNodeParameters & {
    mesh: CutMeshFwd | MeshLikeNoCutMesh | null // workaround for recursive type
    curves: (CurveLike | NodesInput)[]
    uvChannel: number
}

@registerNode
export class CutMesh extends DeclareMeshNodeTS<CutMeshParameters>(
    {
        validation: {paramsSchema: cutMeshParameters},
        onVisited: {
            onFilterActive: ({parameters}) => {
                if (parameters.mesh === null || parameters.curves.length === 0) return skipped
                return visitNone(parameters)
            },
            onCompile: function (this: CutMeshFwd, {context, parameters}) {
                const {evaluator} = context
                const {templateContext} = evaluator
                const {sceneManager} = templateContext
                const {mesh, curves, uvChannel} = parameters

                if (mesh === null || curves.length === 0) return skipped

                const scope = evaluator.getScope(this)
                const [baseMeshObjectData, baseMeshObjectDataInvalid] = scope.branch(evaluator.evaluateMesh(scope, mesh))
                const curveData = scope.merge(
                    curves.map((curveNode, idx) => {
                        const curveScope = scope.scope(`evalCurve${idx}`)
                        if (isCurveLike(curveNode)) {
                            return curveScope.list([evaluator.evaluateCurve(curveScope, curveNode)])
                        } else {
                            return curveScope.pureLambda(
                                evaluator.evaluateNodes(curveScope, [curveNode]),
                                (evaluatedNodes) => {
                                    return evaluatedNodes.map((evaluatedNode) => {
                                        if (evaluatedNode.type === "object") {
                                            const objectData = evaluatedNode.value as ObjectData | null
                                            if (objectData && isCurveObjectData(objectData)) return objectData
                                        }

                                        return null
                                    })
                                },
                                "curveData",
                            )
                        }
                    }),
                )

                const transformed = scope.pureLambda(
                    scope.tuple(curveData, baseMeshObjectData, uvChannel),
                    ([curveData, baseMeshObjectData, uvChannel]) => {
                        const abstract = baseMeshObjectData.completeMeshData.abstract
                        const numUVs = baseMeshObjectData.completeMeshData.reified.uvs.length
                        const curvePoints = curveData.map((curveData) => curveData?.curvePoints?.points)
                        return {
                            contentHash: hashObject(["cutMesh", abstract.contentHash, numUVs, curvePoints]),
                            displayGeometryGraph: applyCutToGeometry(abstract.displayGeometryGraph, curvePoints, "position", numUVs, uvChannel),
                            renderGeometryGraph: applyCutToGeometry(abstract.renderGeometryGraph, curvePoints, "position", numUVs, uvChannel),
                            displayGeometryGraphResources: abstract.displayGeometryGraphResources,
                        }
                    },
                    "transformedGraphs",
                )

                const {completeMeshData} = scope.node(GenerateMesh, {
                    sceneManager,
                    inputMeshData: transformed,
                })

                this.setupMesh(scope, context, completeMeshData, true, baseMeshObjectDataInvalid)

                return visitNone(parameters)
            },
        },
    },
    {nodeClass: "CutMesh"},
) {}

export type CutMeshFwd = TemplateMeshNode<CutMeshParameters>
