import {Pass, FullScreenQuad} from "@cm/material-nodes/three"
import {Three as THREE} from "@cm/material-nodes/three"
import {ToneMapping} from "@cm/template-nodes"
import {ToneMappingFunctions, buildLUTEntries} from "@cm/image-processing/tone-mapping"
import {DEFAULT_FLOAT_TEXTURE_TYPE, Float16ArrayBuilder, Float32ArrayBuilder, ThreeFloatTextureType} from "./three-utils"

export class ToneMappingRenderPass extends Pass {
    static ToneMappingShader = {
        uniforms: {
            tDiffuse: {value: null as THREE.Texture | null},
            toneMapExposure: {value: 1.0},
            toneMapLUT: {value: null as THREE.Texture | null},
            toneMapLUTSize: {value: 1.0},
            toneMapLUTRangeScale: {value: 1.0},
        },
        vertexShader: /* glsl */ `
            varying vec2 vUv;
            void main() {
                vUv = uv;
                gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
            }`,
        fragmentShader: /* glsl */ `
            uniform sampler2D tDiffuse;
            uniform float toneMapExposure;
            uniform sampler2D toneMapLUT;
            uniform float toneMapLUTSize;
            uniform float toneMapLUTRangeScale;
            varying vec2 vUv;
            float lutCoordConstrain(float x, float coord) {
                if(coord < 0.0) return 0.0;
                else if(coord > 1.0) return 1.0;
                return x;
            }
            vec4 sampleCube(sampler2D tex, vec3 coord, float sz) {
                float sz_1 = sz - 1.0;
                float rsz = 1.0 / sz;
                float rsz2 = rsz * rsz;
                float tx = (coord.x * sz_1 + 0.5) * rsz;
                float ty = (clamp(coord.y, 0.0, 1.0) * sz_1 + 0.5) * rsz2;
                float _iz = coord.z * sz_1;
                float iz = floor(_iz);
                float zFrac = _iz - iz;
                float tyz0 = ty + clamp(iz, 0.0, sz_1) * rsz;
                float tyz1 = ty + clamp(iz + 1.0, 0.0, sz_1) * rsz;
                vec4 z0c = texture2D(tex, vec2(tx, tyz0));
                vec4 z1c = texture2D(tex, vec2(tx, tyz1));
                vec4 result = mix(z0c, z1c, zFrac);
                result.r = lutCoordConstrain(result.r, coord.x);
                result.g = lutCoordConstrain(result.g, coord.y);
                result.b = lutCoordConstrain(result.b, coord.z);
                return result;
            }
            void main() {
                vec4 texel = texture2D(tDiffuse, vUv);
                texel.rgb /= max(texel.a, 1e-3); // un-premultiply alpha
                texel.rgb *= toneMapExposure;
                texel = LinearTosRGB(max(texel, 0.0)); // Sample in sRGB space, to utilize the range of the LUT better
                texel.rgb = sampleCube(toneMapLUT, texel.rgb * toneMapLUTRangeScale, toneMapLUTSize).rgb;
                texel.rgb *= texel.a; // re-premultiply alpha
                gl_FragColor = texel;
            }`,
    }

    private toneMapLUT = new CubeLUTTexture(32, 2.0, DEFAULT_FLOAT_TEXTURE_TYPE)
    private toneMappingUniforms: (typeof ToneMappingRenderPass.ToneMappingShader)["uniforms"]
    private toneMappingQuad: FullScreenQuad
    private toneMapping!: ToneMapping
    private exposure!: number

    constructor(toneMapping: ToneMapping, exposure: number) {
        super()

        this.toneMappingUniforms = THREE.UniformsUtils.clone(ToneMappingRenderPass.ToneMappingShader.uniforms)
        const material = new THREE.ShaderMaterial({
            uniforms: this.toneMappingUniforms,
            vertexShader: ToneMappingRenderPass.ToneMappingShader.vertexShader,
            fragmentShader: ToneMappingRenderPass.ToneMappingShader.fragmentShader,
            depthTest: false,
            depthWrite: false,
        })
        this.toneMappingQuad = new FullScreenQuad(material)

        this.setToneMapping(toneMapping)
        this.setExposure(exposure)
    }

    override dispose() {
        this.toneMappingQuad.material.dispose()
        this.toneMappingQuad.dispose()
    }

    override setSize(width: number, height: number) {}

    override render(
        renderer: THREE.WebGLRenderer,
        writeBuffer: THREE.WebGLRenderTarget,
        readBuffer: THREE.WebGLRenderTarget,
        deltaTime: number,
        maskActive: boolean,
    ) {
        const oldTarget = renderer.getRenderTarget()

        renderer.setRenderTarget(this.renderToScreen ? null : writeBuffer)
        this.toneMappingUniforms.tDiffuse.value = readBuffer.texture
        this.toneMappingUniforms.toneMapExposure.value = this.exposure
        this.toneMappingUniforms.toneMapLUT.value = this.toneMapLUT
        this.toneMappingUniforms.toneMapLUTSize.value = this.toneMapLUT.size
        this.toneMappingUniforms.toneMapLUTRangeScale.value = 1.0 / this.toneMapLUT.range
        this.toneMappingQuad.render(renderer)

        renderer.setRenderTarget(oldTarget)
    }

    setToneMapping(toneMapping: ToneMapping) {
        this.toneMapping = toneMapping
        this.toneMapLUT.updateWithTonemapping(toneMapping)
        this.autoEnable()
    }

    setExposure(exposure: number) {
        this.exposure = exposure
        this.autoEnable()
    }

    private autoEnable() {
        this.enabled = this.toneMapping.mode !== "linear" || this.exposure !== 1.0
    }
}

class CubeLUTTexture extends THREE.DataTexture {
    private data: Float16ArrayBuilder | Float32ArrayBuilder
    readonly size: number
    readonly range: number

    constructor(size: number, range: number, type: ThreeFloatTextureType) {
        const ArrayBuilder = type === THREE.HalfFloatType ? Float16ArrayBuilder : Float32ArrayBuilder
        const data = new ArrayBuilder(size * size * size * 4)
        super(data.array, size, size * size, THREE.RGBAFormat, type)
        this.data = data
        this.size = size
        this.range = range
        this.wrapS = THREE.ClampToEdgeWrapping
        this.wrapT = THREE.ClampToEdgeWrapping
        this.minFilter = THREE.LinearFilter
        this.magFilter = THREE.LinearFilter
    }

    updateWithArray(array: Float32Array) {
        const numElements = array.length / 3
        for (let n = 0; n < numElements; n++) {
            this.data.set(n * 4 + 0, array[n * 3 + 0])
            this.data.set(n * 4 + 1, array[n * 3 + 1])
            this.data.set(n * 4 + 2, array[n * 3 + 2])
            this.data.set(n * 4 + 3, 1.0)
        }
        this.needsUpdate = true
    }

    updateWithTonemapping(toneMapping: ToneMapping) {
        // Sample in sRGB space, to utilize the range of the LUT better
        this.updateWithArray(buildLUTEntries(this.size, this.range, ToneMappingFunctions.createForToneMappingData(toneMapping), true, false))
    }
}
