import {cyclesNode, DeclareMaterialNode, DeclareMaterialNodeType} from "#material-nodes/declare-material-node"
import {getDefaultMaterial} from "#material-nodes/nodes/bsdf-principled"
import {getAll} from "@cm/graph"
import * as THREENodes from "three/examples/jsm/nodes/Nodes.js"
import {z} from "zod"

const ReturnsSchema = z.object({
    shader: z.instanceof(THREENodes.MeshPhysicalNodeMaterial).or(cyclesNode),
})
const InputSchema = z.object({
    shader: z.instanceof(THREENodes.MeshPhysicalNodeMaterial).or(cyclesNode).optional(),
    shader_001: z.instanceof(THREENodes.MeshPhysicalNodeMaterial).or(cyclesNode).optional(),
})
const ParametersSchema = z.object({})
const AddShaderBaseClass: DeclareMaterialNodeType<typeof ReturnsSchema, typeof InputSchema, typeof ParametersSchema> = DeclareMaterialNode(
    {
        returns: ReturnsSchema,
        inputs: InputSchema,
        parameters: ParametersSchema,
    },
    {
        toThree: async ({get, inputs, parameters}) => {
            const {shader, shader_001} = await getAll(inputs, get)

            const material = getDefaultMaterial()
            if (!shader || !shader_001) {
                if (shader) return {shader: shader}
                else if (shader_001) return {shader: shader_001}
                else return {shader: material}
            }

            const add = (node1: THREENodes.Node | null, node2: THREENodes.Node | null) => {
                if (!node1) return node2
                else if (!node2) return node1
                else return THREENodes.add(node1, node2)
            }

            material.colorNode = add(shader.colorNode, shader_001.colorNode) ?? material.colorNode
            material.roughnessNode = add(shader.roughnessNode, shader_001.roughnessNode) ?? material.roughnessNode
            material.metalnessNode = add(shader.metalnessNode, shader_001.metalnessNode) ?? material.metalnessNode
            material.iorNode = add(shader.iorNode ?? null, shader_001.iorNode ?? null) ?? material.iorNode
            material.specularColorNode = add(shader.specularColorNode, shader_001.specularColorNode) ?? material.specularColorNode
            material.sheenNode = add(shader.sheenNode, shader_001.sheenNode) ?? material.sheenNode
            material.clearcoatNode = add(shader.clearcoatNode, shader_001.clearcoatNode) ?? material.clearcoatNode
            material.clearcoatRoughnessNode = add(shader.clearcoatRoughnessNode, shader_001.clearcoatRoughnessNode) ?? material.clearcoatRoughnessNode
            material.clearcoatNormalNode = add(shader.clearcoatNormalNode, shader_001.clearcoatNormalNode) ?? material.clearcoatNormalNode
            material.normalNode = add(shader.normalNode, shader_001.normalNode) ?? material.normalNode
            material.emissiveNode = add(shader.emissiveNode, shader_001.emissiveNode) ?? material.emissiveNode
            const alphaValue = add(shader.opacityNode, shader_001.opacityNode) ?? material.opacityNode
            const transmissionValue = add(shader.transmissionNode, shader_001.transmissionNode) ?? material.transmissionNode

            if (alphaValue || transmissionValue) {
                material.transparent = shader.transparent || shader_001.transparent
                if (alphaValue) {
                    if (transmissionValue) console.warn("Material uses both alpha and transmission! Preferring alpha.")
                    material.opacityNode = alphaValue
                } else if (transmissionValue) {
                    material.transmissionNode = transmissionValue
                    //@ts-expect-error This is just to indicate to three that the material is translucent
                    material.transmission = 0.01
                }
            }

            return {shader: material}
        },
    },
)

export class AddShader extends AddShaderBaseClass {}
