import {HalImageChannelLayout, HalImageDataType, HalImageDescriptor, HalImageSourceData, TypedArrayImage} from "@common/models/hal/hal-image/types"
import {WebGl2Context} from "@common/models/webgl2/webgl2-context"
import {WebGl2ImageDescriptor} from "@common/models/webgl2/webgl2-image"
import {Box2, Box2Like, ColorLike, Matrix3x2, Size2, Size2Like, Vector2, Vector2Like} from "@cm/math"
import {
    completeHalImageOptions,
    extractSubregionImageData,
    extractSubregionTypedArray,
    getGlFormat,
    getGlType,
    getInternalFormat,
    getMipLevelSize,
} from "@app/common/models/webgl2/webgl2-image-utils"
import {checkForGlError} from "@common/helpers/webgl2/utils"
import {assertNever} from "@cm/utils"
import {createTypedArrayImage, getDataTypeFromTypedArray, getNumChannels, isNativeImageData, isTypedArrayImage} from "@common/models/hal/hal-image/utils"
import {DrawArgs} from "@common/models/hal/hal-paintable"
import {TextureEditorSettings} from "@app/textures/texture-editor/texture-editor-settings"
import {HalImagePhysical} from "@common/models/hal/hal-image"
import {clearPaintable} from "@common/helpers/webgl2/webgl2-painter-utils"

const TRACE = TextureEditorSettings.EnableExperimental

export class WebGl2ImagePhysical implements HalImagePhysical {
    readonly isHalImage = true
    readonly isHalImagePhysical = true

    readonly forceAlphaToOne: boolean

    constructor(
        readonly context: WebGl2Context,
        descriptor: HalImageDescriptor,
    ) {
        this.forceAlphaToOne = descriptor.channelLayout === "RGB"
        const {descriptor: webGl2Descriptor, internalChannelLayout, internalDataType} = this.createTexture(descriptor)
        this._descriptor = webGl2Descriptor
        this._internalChannelLayout = internalChannelLayout
        this._internalDataType = internalDataType
        this._frameBuffers = this.createFrameBuffers(this._descriptor)
    }

    // HalEntity
    dispose(): void {
        if (TRACE) {
            console.log(
                `Disposing physical WebGL texture with ${this._descriptor.numShards.x}x${this._descriptor.numShards.y} shards of size ${this._descriptor.shardSize.width}x${this._descriptor.shardSize.height} containing ${this._descriptor.numMipLevels} mipmap levels for image of size ${this._descriptor.width}x${this._descriptor.height}.`,
            )
        }
        const gl = this.context.gl
        this._frameBuffers.forEach((frameBuffer) => gl.deleteFramebuffer(frameBuffer))
        gl.deleteTexture(this._descriptor.texture)
    }

    // HalImage
    get descriptor(): WebGl2ImageDescriptor {
        return this._descriptor
    }

    // HalPaintable
    beginDraw(_args: DrawArgs): number {
        return this._descriptor.numShards.x * this._descriptor.numShards.y
    }

    // HalPaintable
    beginDrawPass(args: DrawArgs, pass: number): Matrix3x2 {
        if (!this._frameBuffers) {
            throw Error("Frame buffers not initialized")
        }
        const gl = this.context.gl
        const renderTargetShardIndex = pass
        // this._isModified = true // we assume that the image will be modified

        const shardRegion = new Box2(
            (renderTargetShardIndex % this._descriptor.numShards.x) * this._descriptor.shardSize.width,
            Math.floor(renderTargetShardIndex / this._descriptor.numShards.x) * this._descriptor.shardSize.height,
            this._descriptor.shardSize.width,
            this._descriptor.shardSize.height,
        )
        const regionIntersection = Box2.intersect(shardRegion, args.region)
        gl.bindFramebuffer(gl.FRAMEBUFFER, this._frameBuffers[renderTargetShardIndex])
        gl.viewport(regionIntersection.x - shardRegion.x, regionIntersection.y - shardRegion.y, regionIntersection.width, regionIntersection.height)
        const transform = new Matrix3x2([1, 0, 0, 1, -1, -1]) // origin at top left
        transform.append(new Matrix3x2([2 / regionIntersection.width, 0, 0, 2 / regionIntersection.height, 0, 0])) // normalize to [-1, 1]
        transform.append(new Matrix3x2([1, 0, 0, 1, -Math.max(0, shardRegion.x - args.region.x), -Math.max(0, shardRegion.y - args.region.y)])) // offset to shard
        return transform
    }

    // HalPaintable
    endDrawPass(_args: DrawArgs, _pass: number): void {}

    // HalPaintable
    endDraw(_args: DrawArgs): void {
        const gl = this.context.gl
        gl.bindFramebuffer(gl.FRAMEBUFFER, null) // unbind framebuffer to avoid potential subsequent feedback framebuffer operation
    }

    // HalPaintable
    get width(): number {
        return this._descriptor.width
    }

    // HalPaintable
    get height(): number {
        return this._descriptor.height
    }

    // HalPaintable
    clear(color?: ColorLike, mipLevel?: number) {
        clearPaintable(this, color, mipLevel)
    }

    // HalImage
    get numMipLevels(): number {
        return this._descriptor.numMipLevels
    }

    // HalImage
    getMipLevelSize(mipLevel: number): Size2Like {
        if (mipLevel < 0 || mipLevel >= this._descriptor.numMipLevels) {
            throw Error("Invalid mip level.")
        }
        return getMipLevelSize(this._descriptor, mipLevel)
    }

    // HalImage
    readImageDataFloat(region?: Box2Like): Float32Array {
        region ??= {x: 0, y: 0, width: this._descriptor.width, height: this._descriptor.height}
        if (
            region.x < 0 ||
            region.y < 0 ||
            region.width < 0 ||
            region.height < 0 ||
            region.x + region.width > this._descriptor.width ||
            region.y + region.height > this._descriptor.height
        ) {
            throw Error("Invalid region.")
        }
        const numChannels = getNumChannels(this._descriptor.channelLayout)
        const numElements = region.width * region.height * numChannels
        const imageData = new Float32Array(numElements)
        const gl = this.context.gl
        const sourceShardImage =
            this._descriptor.dataType !== "float32" || this._descriptor.channelLayout !== "RGBA"
                ? this.context.requestSynchronousBufferImage({
                      width: Math.min(region.width, this._descriptor.shardSize.width),
                      height: Math.min(region.height, this._descriptor.shardSize.height),
                      channelLayout: "RGBA",
                      dataType: "float32",
                  })
                : null
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._descriptor.texture)
        const readFbo = gl.createFramebuffer()
        gl.pixelStorei(gl.PACK_ALIGNMENT, 1) // make sure to tightly pack the data

        // move the data from the gpu
        const shardSize = Vector2.fromSize2Like(this._descriptor.shardSize)
        const minShard = Vector2.fromVector2Like(region).divInPlace(shardSize).floorInPlace()
        const maxShard = Vector2.fromVector2Like(region).addInPlace(Vector2.fromSize2Like(region)).divInPlace(shardSize).ceilInPlace()
        const isMultiShard = maxShard.x - minShard.x > 1 || maxShard.y - minShard.y > 1
        const shardRegionSize = new Size2(Math.min(shardSize.x, region.width), Math.min(shardSize.y, region.height))
        const isRGBA = numChannels === 4
        const needsShardRegionData = isMultiShard || !!sourceShardImage
        let shardRegionData = sourceShardImage
            ? undefined
            : needsShardRegionData
              ? new Float32Array(shardRegionSize.width * shardRegionSize.height * 4)
              : imageData
        for (let sy = minShard.y; sy < maxShard.y; sy++) {
            for (let sx = minShard.x; sx < maxShard.x; sx++) {
                const shardIndex = sy * this._descriptor.numShards.x + sx
                const shardRegion = {
                    x: sx * shardSize.x,
                    y: sy * shardSize.y,
                    width: shardSize.x,
                    height: shardSize.y,
                }
                const intersectRegion = Box2.intersect(shardRegion, region)
                if (sourceShardImage) {
                    this.context.blit({
                        sourceImage: this,
                        sourceRegion: intersectRegion,
                        targetImage: sourceShardImage,
                    })
                    shardRegionData = sourceShardImage.readImageDataFloat({x: 0, y: 0, width: intersectRegion.width, height: intersectRegion.height})
                } else {
                    if (!shardRegionData) {
                        throw Error("Internal error: shardRegionData is not initialized.")
                    }
                    gl.bindFramebuffer(gl.READ_FRAMEBUFFER, readFbo)
                    gl.framebufferTextureLayer(gl.READ_FRAMEBUFFER, gl.COLOR_ATTACHMENT0, this._descriptor.texture, 0, shardIndex)
                    gl.readPixels(intersectRegion.x, intersectRegion.y, intersectRegion.width, intersectRegion.height, gl.RGBA, gl.FLOAT, shardRegionData) // readPixels only supports RGBA
                    checkForGlError(gl, "Reading texture data")
                }
                if (needsShardRegionData) {
                    if (!shardRegionData) {
                        throw Error("Internal error: shardRegionData is not initialized.")
                    }
                    // move region data to image data
                    if (isRGBA) {
                        for (let row = 0; row < intersectRegion.height; row++) {
                            const srcStart = row * shardRegionSize.width * numChannels
                            const srcEnd = srcStart + intersectRegion.width * numChannels
                            const dstStart = ((region.y + row) * region.width + region.x) * numChannels
                            imageData.set(shardRegionData.subarray(srcStart, srcEnd), dstStart)
                        }
                    } else {
                        for (let row = 0; row < intersectRegion.height; row++) {
                            const srcStart = row * shardRegionSize.width * 4
                            const dstStart = ((region.y + row) * region.width + region.x) * numChannels
                            for (let x = 0; x < shardRegionSize.width; x++) {
                                for (let c = 0; c < numChannels; c++) {
                                    imageData[dstStart + x * numChannels + c] = shardRegionData[srcStart + x * 4 + c]
                                }
                            }
                        }
                    }
                }
            }
        }

        gl.deleteFramebuffer(readFbo)
        return imageData
    }

    // HalImage
    writeImageData(sourceData: HalImageSourceData, sourceRegion?: Box2Like, targetOffset?: Vector2Like) {
        sourceRegion ??= {x: 0, y: 0, width: this._descriptor.width, height: this._descriptor.height}
        targetOffset ??= {x: 0, y: 0}
        // const sourceDataSize = getSourceDataSize(sourceData)
        // if (
        //     sourceRegion.x < 0 ||
        //     sourceRegion.y < 0 ||
        //     sourceRegion.width < 0 ||
        //     sourceRegion.height < 0 ||
        //     sourceRegion.x + sourceRegion.width > sourceDataSize.width ||
        //     sourceRegion.y + sourceRegion.height > sourceDataSize.height
        // ) {
        //     throw Error("Invalid source region.")
        // }
        if (targetOffset.x + sourceRegion.width > this._descriptor.width || targetOffset.y + sourceRegion.height > this._descriptor.height) {
            throw Error("Invalid region.")
        }

        const start = performance.now()

        if (isNativeImageData(sourceData) || isTypedArrayImage(sourceData)) {
            const gl = this.context.gl
            gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._descriptor.texture)
            gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1) // make sure to tightly pack the data

            const glFormat = getGlFormat(gl, this._internalChannelLayout)
            const glType = getGlType(gl, this._internalDataType)

            let extractSourceRegion: (region: Box2Like) => TypedArrayImage
            let sourceChannelLayout: HalImageChannelLayout
            let sourceDataType: HalImageDataType
            if (isNativeImageData(sourceData)) {
                sourceDataType = sourceData.isSrgb ? "uint8srgb" : "uint8"
                const data = sourceData.data
                if (data instanceof HTMLImageElement || data instanceof HTMLCanvasElement) {
                    sourceChannelLayout = "RGBA"
                    extractSourceRegion = (region: Box2Like) => {
                        const ctx = this.context.requestSynchronousRenderingContext2D(region.width, region.height)
                        ctx.drawImage(data, region.x, region.y, region.width, region.height, 0, 0, region.width, region.height)
                        return {
                            width: region.width,
                            height: region.height,
                            channelLayout: sourceChannelLayout,
                            data: ctx.getImageData(0, 0, region.width, region.height).data,
                        }
                    }
                } else if (data instanceof ImageData) {
                    sourceChannelLayout = "RGBA"
                    extractSourceRegion = (region: Box2Like) => {
                        return {
                            width: region.width,
                            height: region.height,
                            channelLayout: sourceChannelLayout,
                            data: extractSubregionImageData(data, region).data,
                        }
                    }
                } else {
                    assertNever(data)
                }
            } else if (isTypedArrayImage(sourceData)) {
                const numChannels = getNumChannels(sourceData.channelLayout)
                if (sourceData.data.length !== sourceData.width * sourceData.height * numChannels) {
                    throw Error("Invalid image data")
                }
                sourceDataType = getDataTypeFromTypedArray(sourceData.data)
                let arrayType: new (size: number) => Uint8ClampedArray | Uint8Array | Uint16Array | Float32Array
                if (sourceData.data instanceof Uint8ClampedArray) {
                    arrayType = Uint8ClampedArray
                } else if (sourceData.data instanceof Uint8Array) {
                    arrayType = Uint8Array
                } else if (sourceData.data instanceof Uint16Array) {
                    arrayType = Uint16Array
                } else if (sourceData.data instanceof Float32Array) {
                    arrayType = Float32Array
                } else {
                    assertNever(sourceData.data)
                }
                sourceChannelLayout = sourceData.channelLayout
                extractSourceRegion = (region: Box2Like): TypedArrayImage => {
                    const regionData = extractSubregionTypedArray(sourceData.data, arrayType, sourceData, numChannels, region)
                    return createTypedArrayImage(region.width, region.height, sourceData.channelLayout, regionData)
                }
            } else {
                assertNever(sourceData)
            }

            const glSourceFormat = getGlFormat(gl, sourceChannelLayout)
            const glSourceType = getGlType(gl, sourceDataType)
            // if the format does not match, use intermediate image and blit to the texture for conversion
            const sourceShardImage =
                glFormat !== glSourceFormat || glType !== glSourceType
                    ? this.context.requestSynchronousBufferImage({
                          width: Math.min(sourceRegion.width, this._descriptor.shardSize.width),
                          height: Math.min(sourceRegion.height, this._descriptor.shardSize.height),
                          channelLayout: sourceChannelLayout,
                          dataType: sourceDataType,
                      })
                    : null

            // move the data to the gpu
            const shardSize = Vector2.fromSize2Like(this._descriptor.shardSize)
            const minShard = Vector2.fromVector2Like(targetOffset).divInPlace(shardSize).floorInPlace()
            const maxShard = Vector2.fromVector2Like(targetOffset).addInPlace(Vector2.fromSize2Like(sourceRegion)).divInPlace(shardSize).ceilInPlace()
            const numShards = maxShard.sub(minShard)
            const targetRegion = new Box2(targetOffset.x, targetOffset.y, sourceRegion.width, sourceRegion.height)
            for (let sy = 0; sy < numShards.y; sy++) {
                for (let sx = 0; sx < numShards.x; sx++) {
                    const shardIndex = (minShard.y + sy) * this._descriptor.numShards.x + minShard.x + sx
                    const shardRegion = {
                        x: (minShard.x + sx) * shardSize.x,
                        y: (minShard.y + sy) * shardSize.y,
                        width: shardSize.x,
                        height: shardSize.y,
                    }
                    const intersectRegion = Box2.intersect(shardRegion, targetRegion)
                    const thisSourceRegion = {
                        x: sx * shardSize.x + sourceRegion.x,
                        y: sy * shardSize.y + sourceRegion.y,
                        width: intersectRegion.width,
                        height: intersectRegion.height,
                    }
                    const regionData = extractSourceRegion(thisSourceRegion)
                    if (sourceShardImage) {
                        sourceShardImage.writeImageData(regionData, {
                            x: 0,
                            y: 0,
                            width: intersectRegion.width,
                            height: intersectRegion.height,
                        })
                        this.context.blit({
                            sourceImage: sourceShardImage,
                            sourceRegion: {
                                x: 0,
                                y: 0,
                                width: intersectRegion.width,
                                height: intersectRegion.height,
                            },
                            targetImage: this,
                            targetOffset: {
                                x: intersectRegion.x - shardRegion.x,
                                y: intersectRegion.y - shardRegion.y,
                            },
                        })
                    } else {
                        gl.texSubImage3D(
                            gl.TEXTURE_2D_ARRAY,
                            0,
                            intersectRegion.x - shardRegion.x,
                            intersectRegion.y - shardRegion.y,
                            shardIndex,
                            intersectRegion.width,
                            intersectRegion.height,
                            1,
                            glFormat,
                            glType,
                            regionData.data,
                        )
                        checkForGlError(gl)
                    }
                }
            }
        } else {
            this.context.blit({
                sourceImage: sourceData,
                sourceRegion: sourceRegion,
                targetImage: this,
                targetOffset: targetOffset,
            })
        }

        this.generateMipmaps()

        const end = performance.now()
        if (TRACE) {
            console.log("Uploaded image to GPU in " + (end - start) + "ms")
        }
    }

    // the texture needs to be bound to TEXTURE_2D_ARRAY before calling this
    generateMipmaps() {
        // if (this._descriptor.numMipLevels > 1) {
        //     const gl = this.context.gl
        //     gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._descriptor.texture)
        //     gl.generateMipmap(gl.TEXTURE_2D_ARRAY)
        // }
    }

    private createTexture(descriptor: HalImageDescriptor): {
        descriptor: WebGl2ImageDescriptor
        internalChannelLayout: HalImageChannelLayout
        internalDataType: HalImageDataType
    } {
        const descriptorOptions = completeHalImageOptions(descriptor.options)

        if (descriptor.width < 0 || descriptor.height < 0) {
            throw Error("Image dimensions must be positive.")
        }
        if (!Number.isInteger(descriptor.width) || !Number.isInteger(descriptor.height)) {
            throw Error("Image dimensions must be integers.")
        }

        let internalChannelLayout = descriptor.channelLayout
        if (internalChannelLayout === "RGB") {
            // RGB format is not supported to write to in WebGL2, so we convert it to RGBA
            internalChannelLayout = "RGBA"
        }

        let internalDataType = descriptor.dataType
        switch (internalDataType) {
            case "float16":
                if (!this.context.EXT_color_buffer_half_float) {
                    if (this.context.EXT_color_buffer_float) {
                        console.warn("Device does not support float16 format. Falling back to float32.")
                        internalDataType = "float32"
                    } else {
                        throw new Error("Float format not supported by device")
                    }
                }
                break
            case "float32":
                if (!this.context.EXT_color_buffer_float) {
                    if (this.context.EXT_color_buffer_half_float) {
                        console.warn("Device does not support float32 format. Falling back to float16.")
                        internalDataType = "float16"
                    } else {
                        throw new Error("Float format not supported by device")
                    }
                }
                break
        }

        const start = performance.now()

        const shardSize = new Size2(this.computeOptimalShardSize(descriptor.width), this.computeOptimalShardSize(descriptor.height))
        const numShards = new Vector2(Math.ceil(descriptor.width / shardSize.width), Math.ceil(descriptor.height / shardSize.height))
        const numMipLevels = descriptorOptions.useMipMaps ? Math.floor(Math.log2(Math.max(shardSize.width, shardSize.height))) + 1 : 1
        const numChannels = getNumChannels(descriptor.channelLayout)

        if (TRACE) {
            console.log(
                `Creating physical WebGL texture with ${numShards.x}x${numShards.y} shards of size ${shardSize.width}x${shardSize.height} containing ${numMipLevels} mipmap levels for image of size ${descriptor.width}x${descriptor.height}.`,
            )
        }

        const gl = this.context.gl
        const texture = gl.createTexture()
        if (!texture) {
            throw new Error("Failed to create texture")
        }
        const currentBinding = gl.getParameter(gl.TEXTURE_BINDING_2D_ARRAY)
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, texture)
        gl.texStorage3D(
            gl.TEXTURE_2D_ARRAY,
            numMipLevels,
            getInternalFormat(gl, internalChannelLayout, internalDataType),
            shardSize.width,
            shardSize.height,
            numShards.x * numShards.y,
        )
        const lastError = gl.getError()
        if (lastError !== gl.NO_ERROR) {
            if (lastError === gl.OUT_OF_MEMORY) {
                throw new OutOfMemoryError("Failed to create texture: Out of memory.")
            } else {
                throw Error(`Failed to create texture (${lastError}).`)
            }
        }
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, currentBinding) // restore previous binding

        const end = performance.now()
        if (TRACE) {
            console.log("Created physical GPU image in " + (end - start) + "ms")
        }

        return {
            descriptor: {
                width: descriptor.width,
                height: descriptor.height,
                channelLayout: descriptor.channelLayout,
                dataType: descriptor.dataType,
                textureOffset: new Vector2(0, 0),
                shardSize: shardSize,
                numShards: numShards,
                numMipLevels: numMipLevels,
                numChannels: numChannels,
                texture: texture,
                atlas: null,
            },
            internalChannelLayout: internalChannelLayout,
            internalDataType: internalDataType,
        }
    }

    private createFrameBuffers(descriptor: WebGl2ImageDescriptor) {
        const frameBuffers: WebGLFramebuffer[] = []
        const numShards = descriptor.numShards.x * descriptor.numShards.y
        if (numShards > 0) {
            const gl = this.context.gl
            const currentBinding = gl.getParameter(gl.FRAMEBUFFER_BINDING)
            for (let i = 0; i < numShards; i++) {
                const frameBuffer = gl.createFramebuffer()
                if (!frameBuffer) {
                    throw new Error("Failed to create framebuffer")
                }
                frameBuffers.push(frameBuffer)
                gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer)
                gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, descriptor.texture, 0, i)
            }
            gl.bindFramebuffer(gl.FRAMEBUFFER, currentBinding) // restore previous binding
        }
        return frameBuffers
    }

    private computeOptimalShardSize(length: number): number {
        // we want as little shards as possible, but we also don't want to waste too much unused space in the last shard
        const maxShardSize = this.context.maxTextureSize
        const maxNumShards = this.context.maxTextureLayers
        const numRequiredShards = Math.ceil(length / maxShardSize)
        if (numRequiredShards > maxNumShards) {
            throw Error(`Image would require more shards (${numRequiredShards}) than the GPU allows (${maxNumShards}).`)
        }
        return Math.ceil(length / numRequiredShards)
    }

    private _descriptor: WebGl2ImageDescriptor
    private _internalChannelLayout: HalImageChannelLayout
    private _internalDataType: HalImageDataType
    private _frameBuffers: WebGLFramebuffer[]
}

export class OutOfMemoryError extends Error {
    constructor(message: string) {
        super(message)
        this.name = "OutOfMemoryError"
    }
}
