import {HalImageChannelLayout} from "@common/models/hal/hal-image/types"
import {HalPainterParametersType} from "@common/models/hal/hal-painter/types"
import {Matrix3x2, Matrix3x2Like, Matrix4} from "@cm/math"
import {assertNever} from "@cm/utils"
import {WebGl2Context} from "@common/models/webgl2/webgl2-context"
import * as WebGl2ShaderUtils from "@common/helpers/webgl2/webgl2-shader-utils"
import {
    ADDRESS_MODE_BORDER,
    ADDRESS_MODE_CLAMP_TO_EDGE,
    ADDRESS_MODE_MIRRORED_REPEAT,
    ADDRESS_MODE_REPEAT,
    MAX_TEXTURE_UNITS,
} from "@common/helpers/webgl2/constants"
import {HalPainterTarget} from "@common/models/hal/hal-painter-primitive"
import {checkForGlError} from "@common/helpers/webgl2/utils"
import {WebGl2Image} from "@common/models/webgl2/webgl2-image"
import {WebGl2ImageView} from "@common/models/webgl2/webgl2-image-view"
import {WebGl2ImagePhysical} from "@common/models/webgl2"
import {WebGl2ImageVirtual} from "@common/models/webgl2/webgl2-image-virtual"
import {WebGl2ImageAtlas} from "@common/models/webgl2/webgl2-image-atlas"

const TRACE = false

/**
 * This class is shared between multiple WebGlLayerGeometry instances that use the same WebGlCanvas and shading-function.
 * The shading-function is a GLSL function with the signature "vec4 computeColor(vec2 worldPosition, vec2 uv, vec4 color)" and
 * can call functions with the signature "vec4 texelFetchN(uvec2 texelIndex)" to sample from bound textures, where N
 * ranges from 0 to MAX_TEXTURE_UNITS - 1 and texelIndex is in pixels.
 */

const ATLAS_PAGE_INDEX_TEXTURE_UNIT_OFFSET = MAX_TEXTURE_UNITS

export class WebGl2Shader {
    constructor(
        readonly context: WebGl2Context,
        shadingFunction: string,
    ) {
        if (TRACE) {
            console.log("Creating WebGl2Shader")
        }
        const vertexSrc = `#version 300 es
            precision highp float;
            precision highp int;

            layout(location = ${this.LOC_POSITION}) in vec2 a_position;
            layout(location = ${this.LOC_UV}) in vec2 a_uv;
            layout(location = ${this.LOC_COLOR}) in vec4 a_color;

            uniform ivec2 u_targetSize;
            uniform mat3x2 u_worldTransform;
            uniform mat3x2 u_uvTransform;
            uniform mat3x2 u_viewTransform;

            out vec2 v_worldPosition;
            out vec2 v_uv;
            out vec4 v_color;
                        
            void main() {
                v_worldPosition = u_worldTransform * vec3(a_position, 1);
                v_uv = u_uvTransform * vec3(a_uv, 1);
                v_color = a_color;
                gl_Position = vec4(u_viewTransform * vec3(v_worldPosition, 1), 0.0, 1.0);
            }
        `
        const textureFn = (index: number) => `
            vec4 texelFetchLodPhysical${index}(ivec2 texelIndex, int lod) {
                ivec2 shardTexel = shardTexelByTexelIndex(${index}, texelIndex);
                int shardIndex = shardIndexByTexelIndex(${index}, texelIndex);
                vec4 color = texelFetch(u_image[${index}], ivec3(shardTexel, shardIndex), lod);
                return fillMissingChannels(${index}, color);
            }
            
            float texelChannelFetchLodPhysical${index}(ivec2 texelIndex, int lod, int channel) {
                vec4 color = texelFetchLodPhysical${index}(texelIndex, lod);
                switch (channel) {
                    case 0: return color.r;
                    case 1: return color.g;
                    case 2: return color.b;
                    case 3: return color.a;
                    default: return 0.0;
                }
            }
            
            // TODO think about using layered channels for physical images as well
            // float texelChannelFetchLodPhysical${index}(ivec2 texelIndex, int lod, int channel) {
            //     ivec3 index = computePhysicalIndex(${index}, texelIndex, channel);
            //     float value = texelFetch(u_image[${index}], index, lod).r;
            //     return value;
            // }
            
            float texelChannelFetchLodVirtual${index}(ivec2 texelIndex, int lod, int channel) {
                const int VT_PAGE_SIZE_SHIFT = ${WebGl2ImageAtlas.pageSizeShift};
                const int VT_PAGE_MARGIN = ${WebGl2ImageAtlas.pageMargin};
                ivec2 lodTexelIndex = texelIndex >> lod; 
                ivec2 pageIndex = lodTexelIndex >> VT_PAGE_SIZE_SHIFT;  
                ivec2 pageTexelIndex = lodTexelIndex & ((1 << VT_PAGE_SIZE_SHIFT) - 1);
                uvec4 pageTableData = texelFetch(u_pageTableVt[${index}], ivec3(pageIndex, channel), lod);
                if (pageTableData.w != 0u) {
                    ivec2 atlasTexelIndex = ivec2(ivec2(pageTableData.xy) * ((1 << VT_PAGE_SIZE_SHIFT) + VT_PAGE_MARGIN) + pageTexelIndex);
                    return texelChannelFetchLodPhysical${index}(atlasTexelIndex, 0, 0);
                }
                return 0.0;
            }
            
            float texelChannelFetchLod${index}(ivec2 texelIndex, int lod, int channel, int addressMode, float borderColor) {
                if (isBorderTexel(${index}, texelIndex, addressMode)) {
                    return borderColor;
                }
                texelIndex = applyAddressMode(${index}, texelIndex, addressMode);
                texelIndex += u_imageOffset[${index}];
                int atlasIndex = u_atlasIndex[${index}];
                if (atlasIndex != -1) {
                    return texelChannelFetchLodVirtual${index}(texelIndex, lod, channel);
                }
                else {
                    return texelChannelFetchLodPhysical${index}(texelIndex, lod, channel);
                }
            }
            
            vec4 texelFetchLod${index}(ivec2 texelIndex, int lod, int addressMode, vec4 borderColor) {
                if (isBorderTexel(${index}, texelIndex, addressMode)) {
                    return borderColor;
                }
                texelIndex = applyAddressMode(${index}, texelIndex, addressMode);
                texelIndex += u_imageOffset[${index}];
                int atlasIndex = u_atlasIndex[${index}];
                if (atlasIndex != -1) {
                    int numChannels = u_numChannels[${index}]; 
                    float r = texelChannelFetchLodVirtual${index}(texelIndex, lod, 0);
                    float g = numChannels >= 2 ? texelChannelFetchLodVirtual${index}(texelIndex, lod, 1) : 0.0;
                    float b = numChannels >= 3 ? texelChannelFetchLodVirtual${index}(texelIndex, lod, 2) : 0.0;
                    float a = numChannels >= 4 ? texelChannelFetchLodVirtual${index}(texelIndex, lod, 3) : 1.0;
                    vec4 color = vec4(r, g, b, a);
                    return fillMissingChannels(${index}, color);
                }
                else {
                    return texelFetchLodPhysical${index}(texelIndex, lod);
                }
            }
            
            vec4 texelFetchLod${index}(ivec2 texelIndex, int lod, int addressMode) {
                return texelFetchLod${index}(texelIndex, lod, addressMode, vec4(0, 0, 0, 0));
            }

            vec4 texelFetch${index}(ivec2 texelIndex, int addressMode, vec4 borderColor) {
                return texelFetchLod${index}(texelIndex, 0, addressMode, borderColor);
            }

            vec4 texelFetch${index}(ivec2 texelIndex, int addressMode) {
                return texelFetch${index}(texelIndex, addressMode, vec4(0, 0, 0, 0));
            }

            vec4 texelFetch${index}(ivec2 texelIndex) {
                return texelFetch${index}(texelIndex, ADDRESS_MODE_REPEAT);
            }
            
            vec4 textureLodPx${index}(vec2 texelIndex, int lod, int addressMode, vec4 borderColor) {
                texelIndex -= 0.5;  // we offset by 0.5 to sample at the center of the texel instead of the corner
                ivec2 texelIndexI = ivec2(floor(texelIndex));
                vec2 texelIndexF = texelIndex - vec2(texelIndexI);
                vec4 texColor00 = texelFetchLod${index}(texelIndexI + ivec2(0, 0), lod, addressMode, borderColor);
                vec4 texColor10 = texelFetchLod${index}(texelIndexI + ivec2(1, 0), lod, addressMode, borderColor);
                vec4 texColor01 = texelFetchLod${index}(texelIndexI + ivec2(0, 1), lod, addressMode, borderColor);
                vec4 texColor11 = texelFetchLod${index}(texelIndexI + ivec2(1, 1), lod, addressMode, borderColor);
                vec4 texColor = mix(mix(texColor00, texColor10, texelIndexF.x), mix(texColor01, texColor11, texelIndexF.x), texelIndexF.y);
                return texColor;
            }
            
            vec4 textureLodPx${index}(vec2 texelIndex, int lod, int addressMode) {
                return textureLodPx${index}(texelIndex, lod, addressMode, vec4(0, 0, 0, 0));
            }
            
            vec4 texturePx${index}(vec2 texelIndex, int addressMode, vec4 borderColor) {
                return textureLodPx${index}(texelIndex, 0, addressMode, borderColor);
            }
            
            vec4 texturePx${index}(vec2 texelIndex, int addressMode) {
                return texturePx${index}(texelIndex, addressMode, vec4(0, 0, 0, 0));
            }
            
            vec4 texturePx${index}(vec2 texelIndex) {
                return texturePx${index}(texelIndex, ADDRESS_MODE_REPEAT);
            }
            
            vec4 textureUv${index}(vec2 uv, int addressMode) {
                vec2 texelIndex = uv * vec2(u_imageSize[${index}]);
                return texturePx${index}(texelIndex, addressMode);
            }

            vec4 textureUv${index}(vec2 uv) {
                return textureUv${index}(uv, ADDRESS_MODE_CLAMP_TO_EDGE);
            }
        `
        const fragmentSrc = `#version 300 es
            precision highp float;
            precision highp int;
            precision highp sampler2DArray;
            precision highp usampler2DArray;
            
            layout(location = 0) out vec4 color;
            
            uniform ivec2 u_targetSize;
            uniform sampler2DArray u_image[${MAX_TEXTURE_UNITS}];
            uniform ivec2 u_imageOffset[${MAX_TEXTURE_UNITS}];
            uniform ivec2 u_imageSize[${MAX_TEXTURE_UNITS}];
            uniform ivec2 u_shardSize[${MAX_TEXTURE_UNITS}];
            uniform ivec2 u_numShards[${MAX_TEXTURE_UNITS}];
            uniform int u_numChannels[${MAX_TEXTURE_UNITS}];
            uniform int u_atlasIndex[${MAX_TEXTURE_UNITS}];
            uniform usampler2DArray u_pageTableVt[${MAX_TEXTURE_UNITS}];
            uniform vec4 u_modulationColor;
            uniform int u_clearAlpha;

            in vec2 v_worldPosition;
            in vec2 v_uv;
            in vec4 v_color;

            const int ADDRESS_MODE_CLAMP_TO_EDGE = ${ADDRESS_MODE_CLAMP_TO_EDGE};
            const int ADDRESS_MODE_REPEAT = ${ADDRESS_MODE_REPEAT};
            const int ADDRESS_MODE_MIRRORED_REPEAT = ${ADDRESS_MODE_MIRRORED_REPEAT};
            const int ADDRESS_MODE_BORDER = ${ADDRESS_MODE_BORDER};
            
            float wrapFloat(float value, float maxValue) {
                value = mod(value, maxValue);
                return value >= 0.0 ? value : value + maxValue;
            }

            vec2 wrapTexelIndex(int index, vec2 texelIndex) {
                vec2 imageSize = vec2(u_imageSize[index]);
                texelIndex.x = wrapFloat(texelIndex.x, imageSize.x);
                texelIndex.y = wrapFloat(texelIndex.y, imageSize.y);
                return texelIndex;
            }

            int wrapInt(int value, int maxValue) {
                value = value % maxValue;
                return value >= 0 ? value : value + maxValue;
            }

            ivec2 wrapTexelIndex(int index, ivec2 texelIndex) {
                ivec2 imageSize = u_imageSize[index];
                texelIndex.x = wrapInt(texelIndex.x, imageSize.x);
                texelIndex.y = wrapInt(texelIndex.y, imageSize.y);
                return texelIndex;
            }
          
            // vec2 uvByTexelIndex(int index, ivec2 texelIndex) {
            //     ivec2 shardTexelIndex = texelIndex % ivec2(u_shardSize[index]);
            //     return (vec2(shardTexelIndex) + 0.5) / vec2(u_shardSize[index]);
            // }
            
            ivec2 shardTexelByTexelIndex(int index, ivec2 texelIndex) {
                ivec2 shardTexelIndex = texelIndex % u_shardSize[index];
                return shardTexelIndex;
            }

            int shardIndexByTexelIndex(int index, ivec2 texelIndex) {
                ivec2 shardIndex = texelIndex / u_shardSize[index];
                return shardIndex.y * u_numShards[index].x + shardIndex.x;
            }
            
           ivec3 computePhysicalIndex(int index, ivec2 texelIndex, int channel) {
                ivec2 shardTexel = shardTexelByTexelIndex(index, texelIndex);
                int shardIndex = shardIndexByTexelIndex(index, texelIndex);
                int layer = shardIndex * u_numChannels[index] + channel;
                return ivec3(shardTexel, layer);
            }
            
             bool isBorderTexel(int index, vec2 texelIndex, int addressMode) {
                switch (addressMode) {
                    case ADDRESS_MODE_BORDER:
                        return texelIndex.x < 0.0 
                            || texelIndex.y < 0.0
                            || texelIndex.x >= float(u_imageSize[index].x) 
                            || texelIndex.y >= float(u_imageSize[index].y);
                    default:
                        return false;
                }
            }

            bool isBorderTexel(int index, ivec2 texelIndex, int addressMode) {
                switch (addressMode) {
                    case ADDRESS_MODE_BORDER:
                        return texelIndex.x < 0 
                            || texelIndex.y < 0
                            || texelIndex.x >= u_imageSize[index].x 
                            || texelIndex.y >= u_imageSize[index].y;
                    default:
                        return false;
                }
            }
            
            bool isBorderTexel(int index, ivec2 texelIndex) {
                return isBorderTexel(index, texelIndex, ADDRESS_MODE_BORDER);
            }
            
            vec2 applyAddressMode(int index, vec2 texelIndex, int addressMode) {
                ivec2 imageSize = u_imageSize[index];
                switch (addressMode) {
                    case ADDRESS_MODE_CLAMP_TO_EDGE:
                    case ADDRESS_MODE_BORDER:
                        return clamp(texelIndex, vec2(0), vec2(imageSize - 1));
                    case ADDRESS_MODE_MIRRORED_REPEAT:
                        if (texelIndex.x < 0.0) texelIndex.x += 1.0;
                        if (texelIndex.y < 0.0) texelIndex.y += 1.0;
                        texelIndex = abs(texelIndex);
                        texelIndex = vec2(mod(vec2(texelIndex), vec2(imageSize) * 2.0));
                        if (texelIndex.x >= float(imageSize.x)) 
                            texelIndex.x = float(imageSize.x) * 2.0 - 1.0 - texelIndex.x;
                        if (texelIndex.y >= float(imageSize.y))
                            texelIndex.y = float(imageSize.y) * 2.0 - 1.0 - texelIndex.y;
                        return texelIndex;
                    case ADDRESS_MODE_REPEAT:
                    default:
                        return wrapTexelIndex(index, texelIndex);
                }
            }

            ivec2 applyAddressMode(int index, ivec2 texelIndex, int addressMode) {
                ivec2 imageSize = u_imageSize[index];
                switch (addressMode) {
                    case ADDRESS_MODE_CLAMP_TO_EDGE:
                    case ADDRESS_MODE_BORDER:
                        return clamp(texelIndex, ivec2(0), imageSize - 1);
                    case ADDRESS_MODE_MIRRORED_REPEAT:
                        if (texelIndex.x < 0) texelIndex.x++;
                        if (texelIndex.y < 0) texelIndex.y++;
                        texelIndex = abs(texelIndex);
                        texelIndex = ivec2(mod(vec2(texelIndex), vec2(imageSize) * 2.0));
                        if (texelIndex.x >= imageSize.x) 
                            texelIndex.x = imageSize.x * 2 - 1 - texelIndex.x;
                        if (texelIndex.y >= imageSize.y)
                            texelIndex.y = imageSize.y * 2 - 1 - texelIndex.y;
                        return texelIndex;
                    case ADDRESS_MODE_REPEAT:
                    default:
                        return wrapTexelIndex(index, texelIndex);
                }
            }

            vec4 fillMissingChannels(int index, vec4 color) {
                switch (u_numChannels[index]) {
                    case 1:
                        return vec4(color.r, color.r, color.r, 1);
                    case 2:
                        return vec4(color.r, color.g, 0, 1);
                    case 3:
                        return vec4(color.r, color.g, color.b, 1);
                    case 4:
                    default:
                        return color;
                }
            }

            // NOTE: We provide separate functions for each texture unit, because if we make the index a parameter we will get "ERROR: array index for samplers must be constant integral expressions" for "texture(u_image[index], ...)" even when it is a constant from the caller's point of view. :(
            ${textureFn(0)}
            ${textureFn(1)}
            ${textureFn(2)}
            ${textureFn(3)}

            ${shadingFunction}
            
            void main() {
                vec4 computedColor = computeColor(v_worldPosition, v_uv, v_color);
                color = computedColor * u_modulationColor;
                if (u_clearAlpha != 0) {
                    color.a = 1.0;
                }
            }
        `

        const gl = this.context.gl
        this.program = WebGl2ShaderUtils.compileAndLinkProgram(gl, vertexSrc, fragmentSrc)
        this.locWorldTransform = WebGl2ShaderUtils.getUniformLocation(gl, this.program, "u_worldTransform")
        this.locUVTransform = WebGl2ShaderUtils.getUniformLocation(gl, this.program, "u_uvTransform")
        this.locViewTransform = WebGl2ShaderUtils.getUniformLocation(gl, this.program, "u_viewTransform")
        this.locTargetSize = this.getOptionalUniformLocation("u_targetSize")
        this.locImage = this.getOptionalUniformLocation("u_image")
        this.locImageOffset = this.getOptionalUniformLocation("u_imageOffset")
        this.locImageSize = this.getOptionalUniformLocation("u_imageSize")
        this.locNumChannels = this.getOptionalUniformLocation("u_numChannels")
        this.locShardSize = this.getOptionalUniformLocation("u_shardSize")
        this.locNumShards = this.getOptionalUniformLocation("u_numShards")
        this.locAtlasIndex = this.getOptionalUniformLocation("u_atlasIndex")
        this.locPageTableVt = this.getOptionalUniformLocation("u_pageTableVt")
        this.locModulationColor = WebGl2ShaderUtils.getUniformLocation(gl, this.program, "u_modulationColor")
        this.locClearAlpha = WebGl2ShaderUtils.getUniformLocation(gl, this.program, "u_clearAlpha")
    }

    // WebGlEntity
    dispose(): void {
        this.context.gl.deleteProgram(this.program)
    }

    getUniformLocation(uniformName: string): WebGLUniformLocation {
        return WebGl2ShaderUtils.getUniformLocation(this.context.gl, this.program, uniformName)
    }

    getOptionalUniformLocation(uniformName: string): WebGLUniformLocation | null {
        try {
            return WebGl2ShaderUtils.getUniformLocation(this.context.gl, this.program, uniformName)
            // eslint-disable-next-line unused-imports/no-unused-vars
        } catch (_error: unknown) {
            return null
        }
    }

    setUniforms(
        target: HalPainterTarget,
        worldTransform: Matrix3x2,
        uvTransform: Matrix3x2,
        images: (WebGl2Image | undefined)[],
        modulationColor: [r: number, g: number, b: number, a: number],
        parameters?: HalPainterParametersType,
    ) {
        const targetSize = [target.width, target.height]
        const targetPaintable = target
        const imageOffsets = images.flatMap((image) => image?.descriptor.textureOffset.toArray() ?? [0, 0])
        const imageSizes = images.flatMap((image) => (image ? [image.descriptor.width, image.descriptor.height] : [0, 0]))
        const imageShardSizes = images.flatMap((image) => image?.descriptor.shardSize.toArray() ?? [0, 0])
        const imageNumShards = images.flatMap((image) => image?.descriptor.numShards.toArray() ?? [0, 0])
        const imageAtlasIndex = images.flatMap((image) => (image?.descriptor.atlas ? 0 : -1))
        const getNumChannelByChannelLayout = (channelLayout: HalImageChannelLayout) => {
            switch (channelLayout) {
                case "RGBA":
                    return 4
                case "RGB":
                    return 3
                case "R":
                    return 1
                default:
                    assertNever(channelLayout)
            }
        }
        const gl = this.context.gl
        gl.uniformMatrix3x2fv(this.locWorldTransform, false, worldTransform.toArray())
        gl.uniformMatrix3x2fv(this.locUVTransform, false, uvTransform.toArray())
        if (this.locTargetSize) {
            gl.uniform2iv(this.locTargetSize, targetSize)
            checkForGlError(gl)
        }
        if (this.locImage) {
            gl.uniform1iv(this.locImage, [0, 1, 2, 3]) // this should be in sync with MAX_TEXTURE_UNITS above
            checkForGlError(gl)
        }
        if (this.locImageOffset) {
            gl.uniform2iv(this.locImageOffset, imageOffsets)
            checkForGlError(gl)
        }
        if (this.locImageSize) {
            gl.uniform2iv(this.locImageSize, imageSizes)
            checkForGlError(gl)
        }
        if (this.locNumChannels) {
            gl.uniform1iv(
                this.locNumChannels,
                images.map((image) => (image ? getNumChannelByChannelLayout(image.descriptor.channelLayout) : 0)),
            )
            checkForGlError(gl)
        }
        if (this.locShardSize) {
            gl.uniform2iv(this.locShardSize, imageShardSizes)
            checkForGlError(gl)
        }
        if (this.locNumShards) {
            gl.uniform2iv(this.locNumShards, imageNumShards)
            checkForGlError(gl)
        }
        if (this.locAtlasIndex) {
            gl.uniform1iv(this.locAtlasIndex, imageAtlasIndex)
            checkForGlError(gl)
        }
        if (this.locPageTableVt) {
            gl.uniform1iv(this.locPageTableVt, [4, 5, 6, 7]) // this should be in sync with MAX_TEXTURE_UNITS above
            checkForGlError(gl)
        }
        gl.uniform4fv(this.locModulationColor, modulationColor)
        checkForGlError(gl)
        gl.uniform1i(this.locClearAlpha, targetPaintable.forceAlphaToOne ? 1 : 0)
        checkForGlError(gl)
        if (parameters) {
            this.setParameters(parameters)
        }
    }

    setViewTransform(viewTransform: Matrix3x2Like) {
        const gl = this.context.gl
        gl.uniformMatrix3x2fv(this.locViewTransform, false, Matrix3x2.fromMatrix3x2Like(viewTransform).toArray())
    }

    setProgramAndData(images: (WebGl2Image | undefined)[]) {
        const gl = this.context.gl
        gl.useProgram(this.program)
        for (let i = 0; i < MAX_TEXTURE_UNITS; i++) {
            let image = images[i]
            if (image instanceof WebGl2ImageView) {
                image = image.image
            }
            if (image) {
                if (image instanceof WebGl2ImagePhysical) {
                    gl.activeTexture(gl.TEXTURE0 + i)
                    gl.bindTexture(gl.TEXTURE_2D_ARRAY, image.descriptor.texture)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_MAG_FILTER, gl.NEAREST)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_MIN_FILTER, gl.NEAREST_MIPMAP_NEAREST) // float textures require this filtering to be NEAREST
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)

                    gl.activeTexture(gl.TEXTURE0 + ATLAS_PAGE_INDEX_TEXTURE_UNIT_OFFSET + i)
                    gl.bindTexture(gl.TEXTURE_2D_ARRAY, null)
                } else if (image instanceof WebGl2ImageVirtual) {
                    if (!image.descriptor.atlas) {
                        throw new Error("Atlas missing in virtual image")
                    }

                    image.registerCurrentPageUsage() // mark all current pages as used TODO not all pages might actually be read so this is a very conservative approach - revise !

                    gl.activeTexture(gl.TEXTURE0 + i)
                    gl.bindTexture(gl.TEXTURE_2D_ARRAY, image.descriptor.atlas.image.descriptor.texture)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_MAG_FILTER, gl.NEAREST)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_MIN_FILTER, gl.NEAREST_MIPMAP_NEAREST) // float textures require this filtering to be NEAREST
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)

                    gl.activeTexture(gl.TEXTURE0 + ATLAS_PAGE_INDEX_TEXTURE_UNIT_OFFSET + i)
                    gl.bindTexture(gl.TEXTURE_2D_ARRAY, image.descriptor.texture)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_MAG_FILTER, gl.NEAREST)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_MIN_FILTER, gl.NEAREST_MIPMAP_NEAREST)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)
                    gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)
                } else {
                    throw new Error("Invalid image type")
                }
            } else {
                gl.activeTexture(gl.TEXTURE0 + i)
                gl.bindTexture(gl.TEXTURE_2D_ARRAY, null)

                gl.activeTexture(gl.TEXTURE0 + ATLAS_PAGE_INDEX_TEXTURE_UNIT_OFFSET + i)
                gl.bindTexture(gl.TEXTURE_2D_ARRAY, null)
            }
        }
    }

    unsetProgramAndData() {
        const gl = this.context.gl
        for (let i = 0; i < MAX_TEXTURE_UNITS; i++) {
            gl.activeTexture(gl.TEXTURE0 + i)
            gl.bindTexture(gl.TEXTURE_2D_ARRAY, null)
        }
    }

    private setParameters(parameters: HalPainterParametersType) {
        const gl = this.context.gl
        Object.entries(parameters).forEach(([name, parameterValue]) => {
            const location = this.getOptionalUniformLocation(name)
            if (!location) {
                if (!parameterValue.isOptional) {
                    console.warn("Parameter not found: " + name)
                }
            } else {
                switch (parameterValue.type) {
                    case "float": {
                        gl.uniform1f(location, parameterValue.value)
                        break
                    }
                    case "float2": {
                        gl.uniform2f(location, parameterValue.value.x, parameterValue.value.y)
                        break
                    }
                    case "float3": {
                        gl.uniform3f(location, parameterValue.value.x, parameterValue.value.y, parameterValue.value.z)
                        break
                    }
                    case "float4": {
                        gl.uniform4f(location, parameterValue.value.x, parameterValue.value.y, parameterValue.value.z, parameterValue.value.w)
                        break
                    }
                    case "int": {
                        gl.uniform1i(location, parameterValue.value)
                        break
                    }
                    case "int2": {
                        gl.uniform2i(location, parameterValue.value.x, parameterValue.value.y)
                        break
                    }
                    case "int3": {
                        gl.uniform3i(location, parameterValue.value.x, parameterValue.value.y, parameterValue.value.z)
                        break
                    }
                    case "int4": {
                        gl.uniform4i(location, parameterValue.value.x, parameterValue.value.y, parameterValue.value.z, parameterValue.value.w)
                        break
                    }
                    case "uint": {
                        gl.uniform1ui(location, parameterValue.value)
                        break
                    }
                    case "uint2": {
                        gl.uniform2ui(location, parameterValue.value.x, parameterValue.value.y)
                        break
                    }
                    case "uint3": {
                        gl.uniform3ui(location, parameterValue.value.x, parameterValue.value.y, parameterValue.value.z)
                        break
                    }
                    case "uint4": {
                        gl.uniform4ui(location, parameterValue.value.x, parameterValue.value.y, parameterValue.value.z, parameterValue.value.w)
                        break
                    }
                    case "float[]": {
                        gl.uniform1fv(location, parameterValue.value)
                        break
                    }
                    case "int[]": {
                        gl.uniform1iv(location, parameterValue.value)
                        break
                    }
                    case "uint[]": {
                        gl.uniform1uiv(location, parameterValue.value)
                        break
                    }
                    case "float2[]": {
                        gl.uniform2fv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y]),
                        )
                        break
                    }
                    case "int2[]": {
                        gl.uniform2iv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y]),
                        )
                        break
                    }
                    case "uint2[]": {
                        gl.uniform2uiv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y]),
                        )
                        break
                    }
                    case "float3[]": {
                        gl.uniform3fv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y, v.z]),
                        )
                        break
                    }
                    case "int3[]": {
                        gl.uniform3iv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y, v.z]),
                        )
                        break
                    }
                    case "uint3[]": {
                        gl.uniform3uiv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y, v.z]),
                        )
                        break
                    }
                    case "float4[]": {
                        gl.uniform4fv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y, v.z, v.w]),
                        )
                        break
                    }
                    case "int4[]": {
                        gl.uniform4iv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y, v.z, v.w]),
                        )
                        break
                    }
                    case "uint4[]": {
                        gl.uniform4uiv(
                            location,
                            parameterValue.value.flatMap((v) => [v.x, v.y, v.z, v.w]),
                        )
                        break
                    }
                    case "float3x2": {
                        const matrixValues = Matrix3x2.fromMatrix3x2Like(parameterValue.value).toArray()
                        gl.uniformMatrix3x2fv(location, false, matrixValues)
                        break
                    }
                    case "float4x4": {
                        const matrixValues = Matrix4.fromMatrix4Like(parameterValue.value).toArray()
                        gl.uniformMatrix4fv(location, false, matrixValues)
                        break
                    }
                    default:
                        assertNever(parameterValue)
                        break
                }
                checkForGlError(gl, `Setting parameter ${name}`)
            }
        })
    }

    readonly LOC_POSITION = 0
    readonly LOC_UV = 1
    readonly LOC_COLOR = 2

    private program: WebGLProgram
    private locWorldTransform: WebGLUniformLocation
    private locUVTransform: WebGLUniformLocation
    private locViewTransform: WebGLUniformLocation
    private locTargetSize: WebGLUniformLocation | null
    private locImage: WebGLUniformLocation | null
    private locImageOffset: WebGLUniformLocation | null
    private locImageSize: WebGLUniformLocation | null
    private locShardSize: WebGLUniformLocation | null
    private locNumShards: WebGLUniformLocation | null
    private locNumChannels: WebGLUniformLocation | null
    private locAtlasIndex: WebGLUniformLocation | null
    private locPageTableVt: WebGLUniformLocation | null
    private locModulationColor: WebGLUniformLocation
    private locClearAlpha: WebGLUniformLocation
}
