import {DeclareMaterialNode, DeclareMaterialNodeType, MaterialNode, materialSlots} from "#material-nodes/declare-material-node"
import {threeConvert, threeVec3Node} from "#material-nodes/three-utils"
import {vec3} from "#material-nodes/types"
import {GetProperty} from "@cm/graph"
import * as THREENodes from "three/examples/jsm/nodes/Nodes.js"
import {z} from "zod"

const ReturnTypeSchema = z.object({
    camera: materialSlots,
    generated: materialSlots,
    normal: materialSlots,
    object: materialSlots,
    reflection: materialSlots,
    uv: materialSlots,
    window: materialSlots,
})
const InputTypeSchema = z.object({})
const ParametersTypeSchema = z.object({
    camera: vec3.optional(),
    generated: vec3.optional(),
    normal: vec3.optional(),
    object: vec3.optional(),
    reflection: vec3.optional(),
    uv: vec3.optional(),
    window: vec3.optional(),
})

const windowCoordinates = new THREENodes.FunctionNode(`
vec2 windowCoordinates(vec3 position) {
    vec4 clipPosition = projectionMatrix * modelViewMatrix * vec4(position, 1.0); // Transform to clip space
    vec2 ndc = clipPosition.xy / clipPosition.w; // Convert to NDC [-1,1]
    return ndc * 0.5 + 0.5; // Convert to [0,1] range (window space)
}
`)

class WindowCoordinatesNode extends THREENodes.TempNode {
    constructor() {
        super("vec2")
    }

    override generate(builder: THREENodes.NodeBuilder) {
        const type = this.getNodeType(builder)

        return THREENodes.call(windowCoordinates, {
            position: THREENodes.positionLocal,
        }).build(builder, type)
    }
}

export class TexCoord extends (DeclareMaterialNode(
    {
        returns: ReturnTypeSchema,
        inputs: InputTypeSchema,
        parameters: ParametersTypeSchema,
    },
    {
        toThree: async function (this: MaterialNode, {parameters, context}) {
            const camera = threeConvert(parameters.camera, threeVec3Node) ?? threeVec3Node({x: 0, y: 0, z: 0})
            const generated = THREENodes.uv(0)
            const normal = threeConvert(parameters.normal, threeVec3Node) ?? threeVec3Node({x: 0, y: 0, z: 0})
            const object = threeConvert(parameters.object, threeVec3Node) ?? threeVec3Node({x: 0, y: 0, z: 0})
            const reflection = threeConvert(parameters.reflection, threeVec3Node) ?? threeVec3Node({x: 0, y: 0, z: 0})
            const uv = THREENodes.uv(0)

            const window = new WindowCoordinatesNode()
            const {onThreeRequestShaderAdditions} = context
            if (onThreeRequestShaderAdditions) {
                for (const node of this.parents)
                    if (node instanceof GetProperty && node.parameters.key === "window") {
                        onThreeRequestShaderAdditions(
                            "fragment",
                            `uniform mat4 modelViewMatrix;
                        uniform mat4 projectionMatrix;`,
                        )
                        break
                    }
            }

            return {camera, generated, normal, object, reflection, uv, window}
        },
    },
) as DeclareMaterialNodeType<typeof ReturnTypeSchema, typeof InputTypeSchema, typeof ParametersTypeSchema>) {}
