import * as THREENodes from "three/examples/jsm/nodes/Nodes.js"
import {Color, Vec2, Vec3, Vec4} from "#material-nodes/types"
import * as THREE from "three"
import {z} from "zod"

export const threeRGBColorNode = (color: Color) => {
    return THREENodes.color(new THREE.Color(color.r, color.g, color.b))
}

export const threeValueNode = (value: number) => {
    return THREENodes.float(value)
}

export const threeVec2Node = (value: Vec2) => {
    return THREENodes.vec2(new THREE.Vector2(value.x, value.y))
}

export const threeVec3Node = (value: Vec3) => {
    return THREENodes.vec3(new THREE.Vector3(value.x, value.y, value.z))
}

export const threeVec4Node = (value: Vec4) => {
    return THREENodes.vec4(new THREE.Vector4(value.x, value.y, value.z, value.w))
}

export const threeConvert = <T, R>(value: T | undefined, converter: (value: T) => R, checker?: (value: T) => boolean): R | undefined => {
    if (value === undefined) return undefined
    if (checker !== undefined && !checker(value)) return undefined
    return converter(value)
}

const colorBurnNode = new THREENodes.FunctionNode(`
vec3 colorBurn(vec3 inputA, vec3 inputB, float f) {
    vec3 resultColor;
 
    float tmp = (1.0 - f) + f * inputB.x;
    if (tmp <= 0.0) {
        resultColor.x = 0.0;
    }
    else if ((tmp = (1.0 - (1.0 - inputA.x) / tmp)) < 0.0) {
        resultColor.x = 0.0;
    }
    else if (tmp > 1.0) {
        resultColor.x = 1.0;
    }
    else {
        resultColor.x = tmp;
    }
 
    tmp = (1.0 - f) + f * inputB.y;
    if (tmp <= 0.0) {
        resultColor.y = 0.0;
    }
    else if ((tmp = (1.0 - (1.0 - inputA.y) / tmp)) < 0.0) {
        resultColor.y = 0.0;
    }
    else if (tmp > 1.0) {
        resultColor.y = 1.0;
    }
    else {
        resultColor.y = tmp;
    }
 
    tmp = (1.0 - f) + f * inputB.z;
    if (tmp <= 0.0) {
        resultColor.z = 0.0;
    }
    else if ((tmp = (1.0 - (1.0 - inputA.z) / tmp)) < 0.0) {
        resultColor.z = 0.0;
    }
    else if (tmp > 1.0) {
        resultColor.z = 1.0;
    }
    else {
        resultColor.z = tmp;
    }
 
    return resultColor;
}
`)

export function threeColorBurnNode(inputA: THREENodes.Node, inputB: THREENodes.Node, fac: THREENodes.Node): THREENodes.Node {
    return THREENodes.call(colorBurnNode, [inputA, inputB, fac])
}

const colorDodgeNode = new THREENodes.FunctionNode(`
vec3 colorDodge(vec3 inputA, vec3 inputB, float f) {
    vec3 resultColor;
    float tmp;
 
    resultColor = inputA;
 
    if (inputA.x > 0.0) {
        tmp = 1.0 - f * inputB.x;
        if (tmp <= 0.0) {
            resultColor.x = 1.0;
        }
        else if ((tmp = (inputA.x / tmp)) > 1.0) {
            resultColor.x = 1.0;
        }
        else {
            resultColor.x = tmp;
        }
    }
 
    if (inputA.y > 0.0) {
        tmp = 1.0 - f * inputB.y;
        if (tmp <= 0.0) {
            resultColor.y = 1.0;
        }
        else if ((tmp = (inputA.y / tmp)) > 1.0) {
            resultColor.y = 1.0;
        }
        else {
            resultColor.y = tmp;
        }
    }
 
    if (inputA.z > 0.0) {
        tmp = 1.0 - f * inputB.z;
        if (tmp <= 0.0) {
            resultColor.z = 1.0;
        }
        else if ((tmp = (inputA.z / tmp)) > 1.0) {
            resultColor.z = 1.0;
        }
        else {
            resultColor.z = tmp;
        }
    }
 
    return resultColor;
}
`)

export function threeColorDodgeNode(inputA: THREENodes.Node, inputB: THREENodes.Node, fac: THREENodes.Node): THREENodes.Node {
    return THREENodes.call(colorDodgeNode, [inputA, inputB, fac])
}

const overlayNode = new THREENodes.FunctionNode(`
vec3 overlay(vec3 inputA, vec3 inputB, float f) {
    vec3 resultColor;
    float invF = 1.0 - f;
 
    if (inputA.x < 0.5) {
        resultColor.x = inputA.x * (invF + 2.0 * f * inputB.x);
    }
    else {
        resultColor.x = 1.0 - (invF + 2.0 * f * (1.0 - inputB.x)) * (1.0 - inputA.x);
    }
 
    if (inputA.y < 0.5) {
        resultColor.y = inputA.y * (invF + 2.0 * f * inputB.y);
    }
    else {
        resultColor.y = 1.0 - (invF + 2.0 * f * (1.0 - inputB.y)) * (1.0 - inputA.y);
    }
 
    if (inputA.z < 0.5) {
        resultColor.z = inputA.z * (invF + 2.0 * f * inputB.z);
    }
    else {
        resultColor.z = 1.0 - (invF + 2.0 * f * (1.0 - inputB.z)) * (1.0 - inputA.z);
    }
 
    return resultColor;
}
`)

export function threeOverlayNode(inputA: THREENodes.Node, inputB: THREENodes.Node, fac: THREENodes.Node): THREENodes.Node {
    return THREENodes.call(overlayNode, [inputA, inputB, fac])
}

const rgbToHsvNode = new THREENodes.FunctionNode(`
vec3 rgb2hsv(vec3 rgb) {
    vec3 hsv;
    float rgbmax = max(rgb.x, max(rgb.y, rgb.z));
    float rgbmin = min(rgb.x, min(rgb.y, rgb.z));
    float delta = rgbmax - rgbmin;
    hsv.z = rgbmax;

    if (rgbmax == 0.0)
        hsv = vec3(0.0, 0.0, 0.0);
    else
        hsv.y = delta / rgbmax;

    vec3 c = (vec3(rgbmax, rgbmax, rgbmax) - rgb) / delta;
    if (hsv.y == 0.0)
        hsv.x = 0.0;
    else
        if (rgb.x == rgbmax)
            hsv.x = c.z - c.y;
        else if (rgb.y == rgbmax)
            hsv.x = 2.0 + c.x - c.z;
        else
            hsv.x = 4.0 + c.y - c.x;

        hsv.x /= 6.0;
        if (hsv.x < 0.0)
            hsv.x += 1.0;

    return hsv;
}
`)

const hsvToRgbNode = new THREENodes.FunctionNode(`
vec3 hsv2rgb(vec3 hsv) {
    vec3 rgb;
    float h = hsv.x;
    float s = hsv.y;
    float v = hsv.z;

    if (s == 0.0)
        rgb = vec3(v, v, v);
    else
        if (h == 1.0)
            h = 0.0;
        h *= 6.0;
        float i = floor(h);
        float f = h - i;
        rgb = vec3(f, f, f);
        float p = v * (1.0 - s);
        float q = v * (1.0 - s * f);
        float t = v * (1.0 - s * (1.0 - f));

        if (i == 0.0)
            rgb = vec3(v, t, p);
        else if (i == 1.0)
            rgb = vec3(q, v, p);
        else if (i == 2.0)
            rgb = vec3(p, v, t);
        else if (i == 3.0)
            rgb = vec3(p, q, v);
        else if (i == 4.0)
            rgb = vec3(t, p, v);
        else
            rgb = vec3(v, p, q);

    return rgb;
}
`)

export function threeRgbToHsvNode(rgb: THREENodes.Node) {
    return THREENodes.call(rgbToHsvNode, {rgb})
}

export function threeHsvToRgbNode(hsv: THREENodes.Node) {
    return THREENodes.call(hsvToRgbNode, {hsv})
}

export const lutSize = 256
//+ 0.5 / lutSize due to THREE.NearestFilter
const applyLut = new THREENodes.FunctionNode(`
vec4 applyLut(vec4 rgbaIn, sampler2D lut, float fac) {
    vec4 rgbaOut = vec4(
        texture(lut, vec2(rgbaIn.r + ${0.5 / lutSize}, 0.5)).r,
        texture(lut, vec2(rgbaIn.g + ${0.5 / lutSize}, 0.5)).g,
        texture(lut, vec2(rgbaIn.b + ${0.5 / lutSize}, 0.5)).b,
        texture(lut, vec2(rgbaIn.a + ${0.5 / lutSize}, 0.5)).a
    );
    return mix(rgbaIn, rgbaOut, fac);
}
`)

export class ApplyLUTNode extends THREENodes.TempNode {
    constructor(
        public rgbaInput: THREENodes.Node,
        public lutTexture: THREE.Texture,
        public fac: THREENodes.Node,
    ) {
        super("vec4")
    }

    override generate(builder: THREENodes.NodeBuilder) {
        const type = this.getNodeType(builder)

        const lutTextureNode = THREENodes.convert(THREENodes.texture(this.lutTexture), "texture")
        const rgbOutput = THREENodes.call(applyLut, {rgbaIn: this.rgbaInput, lut: lutTextureNode, fac: this.fac})
        return rgbOutput.build(builder, type)
    }
}
export const rgbBlendType = z.enum([
    "MIX",
    "ADD",
    "SUBTRACT",
    "MULTIPLY",
    "SCREEN",
    "DIVIDE",
    "DIFFERENCE",
    "EXCLUSION",
    "DARKEN",
    "LIGHTEN",
    "OVERLAY",
    "COLOR_DODGE",
    "COLOR_BURN",
    "HUE",
    "SATURATION",
    "VALUE",
    "COLOR",
    "SOFT_LIGHT",
    "LINEAR_LIGHT",
])

export type RgbBlendType = z.infer<typeof rgbBlendType>

export const getColor = (operation: RgbBlendType, inputA: THREENodes.Node, inputB: THREENodes.Node, facNode: THREENodes.Node) => {
    if (operation === "MIX") {
        return THREENodes.mix(inputA, inputB, facNode)
    } else if (operation === "LINEAR_LIGHT") {
        const mulBNode = THREENodes.mul(inputB, threeValueNode(2.0))
        const subBNode = THREENodes.sub(mulBNode, threeValueNode(1))
        const facBNode = THREENodes.mul(facNode, subBNode)
        return THREENodes.add(inputA, facBNode)
    } else if (operation === "SOFT_LIGHT") {
        const invANode = THREENodes.sub(threeValueNode(1), inputA) // (1-A)
        const invBNode = THREENodes.sub(threeValueNode(1), inputB) // (1-B)
        const mulInvAInvBNode = THREENodes.mul(invANode, invBNode) // (1-A)(1-B)
        const screenNode = THREENodes.sub(threeValueNode(1), mulInvAInvBNode) // 1-(1-A)(1-B)
        const mulInvABNode = THREENodes.mul(invANode, inputB) // (1-A)B
        const mulInvABANode = THREENodes.mul(mulInvABNode, inputA) // (1-A)BA
        const mulAScreenNode = THREENodes.mul(inputA, screenNode) // A*scr
        const additionNode = THREENodes.add(mulInvABANode, mulAScreenNode) // (1-A)BA+A*scr
        return THREENodes.mix(inputA, additionNode, facNode) // (1-f)A+f[(1-A)BA+A*scr]
    } else if (operation === "MULTIPLY") {
        const mulNode = THREENodes.mul(inputA, inputB)
        return THREENodes.mix(inputA, mulNode, facNode)
    } else if (operation === "ADD") {
        const mulNode = THREENodes.add(inputA, inputB)
        return THREENodes.mix(inputA, mulNode, facNode)
    } else if (operation === "SUBTRACT") {
        const mulNode = THREENodes.sub(inputA, inputB)
        return THREENodes.mix(inputA, mulNode, facNode)
    } else if (operation === "DIFFERENCE") {
        const subNode = THREENodes.sub(inputA, inputB) // A-B
        const diffNode = THREENodes.abs(subNode) // |A-B|
        return THREENodes.mix(inputA, diffNode, facNode) // (1-f)A+f|A-B|
    } else if (operation === "EXCLUSION") {
        const addNode = THREENodes.add(inputA, inputB) // A+B
        const mulNode = THREENodes.mul(threeValueNode(2), THREENodes.mul(inputA, inputB)) // 2*A*B
        const exclusionNode = THREENodes.sub(addNode, mulNode) // A+B-2*A*B
        return THREENodes.mix(inputA, exclusionNode, facNode) // (1-f)A+f*[A+B-2*A*B]
    } else if (operation === "LIGHTEN") {
        const mulNode = THREENodes.mul(facNode, inputB)
        return THREENodes.max(inputA, mulNode)
    } else if (operation === "DARKEN") {
        const minNode = THREENodes.min(inputA, inputB) // min(A, B)
        return THREENodes.mix(inputA, minNode, facNode) // (1-f)A+f*min(A,B)
    } else if (operation === "DIVIDE") {
        const rNodeInpA = new THREENodes.SplitNode(inputA, "r")
        const gNodeInpA = new THREENodes.SplitNode(inputA, "g")
        const bNodeInpA = new THREENodes.SplitNode(inputA, "b")
        const rNodeInpB = new THREENodes.SplitNode(inputB, "r")
        const gNodeInpB = new THREENodes.SplitNode(inputB, "g")
        const bNodeInpB = new THREENodes.SplitNode(inputB, "b")
        const divNodeR = THREENodes.div(rNodeInpA, rNodeInpB)
        const divNodeG = THREENodes.div(gNodeInpA, gNodeInpB)
        const divNodeB = THREENodes.div(bNodeInpA, bNodeInpB)
        const divResultR = THREENodes.cond(THREENodes.greaterThan(rNodeInpB, threeValueNode(0)), divNodeR, rNodeInpA)
        const divResultG = THREENodes.cond(THREENodes.greaterThan(gNodeInpB, threeValueNode(0)), divNodeG, gNodeInpA)
        const divResultB = THREENodes.cond(THREENodes.greaterThan(bNodeInpB, threeValueNode(0)), divNodeB, bNodeInpA)
        const divNode = new THREENodes.JoinNode([divResultR, divResultG, divResultB])

        return THREENodes.mix(inputA, divNode, facNode)
    } else if (operation === "SCREEN") {
        const invANode = THREENodes.sub(threeValueNode(1), inputA) // (1-A)
        const invBNode = THREENodes.sub(threeValueNode(1), inputB) // (1-B)
        const invFacNode = THREENodes.sub(threeValueNode(1), facNode) // (1-f)
        const mulInvFacInvANode = THREENodes.mul(invFacNode, invANode) // (1-f)(1-A)
        const mulFacInvANode = THREENodes.mul(facNode, invANode) // f(1-A)
        const mulInvAInvBNode = THREENodes.mul(mulFacInvANode, invBNode) // f(1-A)(1-B)
        const addNode = THREENodes.add(mulInvFacInvANode, mulInvAInvBNode) // (1-f)(1-A)+f(1-A)(1-B)
        return THREENodes.sub(threeValueNode(1), addNode) // 1-[(1-f)(1-A)+f(1-A)(1-B)]
    } else if (operation === "COLOR") {
        const hsvNodeInpB = threeRgbToHsvNode(inputB)
        const sNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "y")

        // If part
        const hsvNodeInpA = threeRgbToHsvNode(inputA)
        const hNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "x")
        const vNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "z")
        const newHsvNode = new THREENodes.JoinNode([hNodeHsvInpB, sNodeHsvInpB, vNodeHsvInpA])
        const newRgbNode = threeHsvToRgbNode(newHsvNode)
        const ifNode = THREENodes.mix(inputA, newRgbNode, facNode)

        return THREENodes.cond(THREENodes.greaterThan(sNodeHsvInpB, threeValueNode(0)), ifNode, inputA)
    } else if (operation === "HUE") {
        const hsvNodeInpB = threeRgbToHsvNode(inputB)
        const sNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "y")

        // If part
        const hsvNodeInpA = threeRgbToHsvNode(inputA)
        const hNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "x")
        const sNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "y")
        const vNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "z")
        const newHsvNode = new THREENodes.JoinNode([hNodeHsvInpB, sNodeHsvInpA, vNodeHsvInpA])
        const newNodeRgb = threeHsvToRgbNode(newHsvNode)
        const ifNode = THREENodes.mix(inputA, newNodeRgb, facNode)

        return THREENodes.cond(THREENodes.greaterThan(sNodeHsvInpB, threeValueNode(0)), ifNode, inputA)
    } else if (operation === "SATURATION") {
        const hsvNodeInpA = threeRgbToHsvNode(inputA)
        const sNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "y")

        // If part
        const hsvNodeInpB = threeRgbToHsvNode(inputB)
        const hNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "x")
        const vNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "z")
        const sNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "y")
        const mixNode = THREENodes.mix(sNodeHsvInpA, sNodeHsvInpB, facNode)
        const ifNodeHsv = new THREENodes.JoinNode([hNodeHsvInpA, mixNode, vNodeHsvInpA])
        const ifNodeRgb = threeHsvToRgbNode(ifNodeHsv)

        return THREENodes.cond(THREENodes.greaterThan(sNodeHsvInpA, threeValueNode(0)), ifNodeRgb, inputA)
    } else if (operation === "VALUE") {
        const hsvNodeInpA = threeRgbToHsvNode(inputA)
        const hsvNodeInpB = threeRgbToHsvNode(inputB)
        const hNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "x")
        const sNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "y")
        const vNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "z")
        const vNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "z")
        const mixNode = THREENodes.mix(vNodeHsvInpA, vNodeHsvInpB, facNode)
        const newHsvNode = new THREENodes.JoinNode([hNodeHsvInpA, sNodeHsvInpA, mixNode])

        return threeHsvToRgbNode(newHsvNode)
    } else if (operation === "COLOR_BURN") {
        return threeColorBurnNode(inputA, inputB, facNode)
    } else if (operation === "COLOR_DODGE") {
        return threeColorDodgeNode(inputA, inputB, facNode)
    } else if (operation === "OVERLAY") {
        return threeOverlayNode(inputA, inputB, facNode)
    } else throw new Error(`Unsupported operation: ${operation}`)
}
