import {HalImageDescriptor, HalImageSourceData} from "@common/models/hal/hal-image/types"
import {WebGl2Context} from "@common/models/webgl2/webgl2-context"
import {PageHandle, WebGl2ImageAtlas} from "@common/models/webgl2/webgl2-image-atlas"
import {Box2, Box2Like, ColorLike, Matrix3x2, Size2Like, Vector2, Vector2Like} from "@cm/math"
import {EventEmitter} from "@angular/core"
import {WebGl2ImageDescriptor} from "@common/models/webgl2/webgl2-image"
import {completeHalImageOptions, getMipLevelSize} from "@common/models/webgl2/webgl2-image-utils"
import {getChannelLayoutFromImageSourceData, getDataTypeFromImageSourceData, getNumChannels, isHalImagePhysical} from "@common/models/hal/hal-image/utils"
import {WebGl2PainterPrimitive} from "@common/models/webgl2/webgl2-painter-primitive"
import {WebGl2Geometry} from "@common/models/webgl2/webgl2-geometry"
import {DrawArgs} from "@common/models/hal/hal-paintable"
import {HalImagePhysical, HalImageVirtual} from "@common/models/hal/hal-image"
import {checkForGlError} from "@common/helpers/webgl2/utils"
import {clearPaintable} from "@common/helpers/webgl2/webgl2-painter-utils"

const TRACE = false

export class WebGl2ImageVirtual implements HalImageVirtual {
    readonly isHalImage = true
    readonly isHalImageVirtual = true

    readonly pageAdded = new EventEmitter<PageDescriptor>()
    readonly pageRemoved = new EventEmitter<PageDescriptor>()

    readonly forceAlphaToOne = false

    constructor(
        readonly context: WebGl2Context,
        descriptor: HalImageDescriptor,
    ) {
        const {descriptor: webGl2Descriptor, atlasDescriptor} = this.createTexture(descriptor)
        this._descriptor = webGl2Descriptor
        this._atlasDescriptor = atlasDescriptor
        this._copyToAtlasInfo = this.createCopyToAtlasInfo()
    }

    // HalEntity
    dispose(): void {
        const gl = this.context.gl
        this._copyToAtlasInfo.geometry.dispose()
        this._copyToAtlasInfo.painter.dispose()
        this._atlasDescriptor.pageHandles.forEach((pageHandle) => this._atlasDescriptor.atlas.freePage(pageHandle, false)) // remove all pages from atlas
        gl.deleteTexture(this._descriptor.texture)
    }

    // HalImage
    get descriptor(): WebGl2ImageDescriptor {
        return this._descriptor
    }

    // HalPaintable
    beginDraw(args: DrawArgs): number {
        if (this._currentDrawInfo) {
            throw new Error("Drawing already in progress.")
        }
        const bufferImage = this.context.requestSynchronousBufferImage({
            width: args.region.width,
            height: args.region.height,
            channelLayout: "RGBA", // fix this to maximum channel layout to facilitate reusing the buffer image
            dataType: "float32", // fix this to maximum data type to facilitate reusing the buffer image
        })
        this._currentDrawInfo = {
            bufferImage,
        }
        if (!args.hints.fullWrite) {
            // move atlas pages buffer
            this.context.blit({
                sourceImage: this,
                sourceRegion: args.region,
                targetImage: bufferImage,
            })
        }
        return this._currentDrawInfo.bufferImage.beginDraw({
            region: {x: 0, y: 0, width: args.region.width, height: args.region.height},
            mipLevel: 0,
            hints: args.hints,
        })
    }

    // HalPaintable
    beginDrawPass(args: DrawArgs, pass: number): Matrix3x2 {
        if (!this._currentDrawInfo) {
            throw new Error("Drawing not started.")
        }
        return this._currentDrawInfo.bufferImage.beginDrawPass(
            {region: {x: 0, y: 0, width: args.region.width, height: args.region.height}, mipLevel: 0, hints: args.hints},
            pass,
        )
    }

    // HalPaintable
    endDrawPass(args: DrawArgs, pass: number): void {
        if (!this._currentDrawInfo) {
            throw new Error("Drawing not started.")
        }
        this._currentDrawInfo.bufferImage.endDrawPass(
            {region: {x: 0, y: 0, width: args.region.width, height: args.region.height}, mipLevel: 0, hints: args.hints},
            pass,
        )
    }

    // HalPaintable
    endDraw(args: DrawArgs): void {
        if (!this._currentDrawInfo) {
            throw new Error("Drawing not started.")
        }
        this._currentDrawInfo.bufferImage.endDraw({region: {x: 0, y: 0, width: args.region.width, height: args.region.height}, mipLevel: 0, hints: args.hints})
        this.copyToAtlas(args.region, args.mipLevel, this._currentDrawInfo.bufferImage, {x: 0, y: 0})
        this._currentDrawInfo = null
    }

    // HalPaintable
    get width(): number {
        return this._descriptor.width
    }

    // HalPaintable
    get height(): number {
        return this._descriptor.height
    }

    // HalPaintable
    clear(color?: ColorLike, mipLevel?: number) {
        clearPaintable(this, color, mipLevel)
    }

    private copyToAtlas(targetRegion: Box2Like, mipLevel: number, sourceImage: HalImagePhysical, sourceOffset: Vector2Like) {
        const geometry = this._copyToAtlasInfo.geometry
        geometry.clear()

        const pageDescriptors = this.requestRegion(targetRegion, mipLevel)
        for (const pageDescriptor of pageDescriptors) {
            const pageSize = WebGl2ImageAtlas.pageSize
            const pageRegion = new Box2(pageDescriptor.x * pageSize, pageDescriptor.y * pageSize, pageSize, pageSize)
            const regionIntersection = Box2.intersect(pageRegion, targetRegion)
            if (!regionIntersection.isEmpty()) {
                geometry.addRect(
                    {
                        x: pageDescriptor.pageHandle.value.pageIndex.x * WebGl2ImageAtlas.pageSizeWithMargin + regionIntersection.x - pageRegion.x,
                        y: pageDescriptor.pageHandle.value.pageIndex.y * WebGl2ImageAtlas.pageSizeWithMargin + regionIntersection.y - pageRegion.y,
                        width: regionIntersection.width,
                        height: regionIntersection.height,
                    },
                    {
                        x: sourceOffset.x + regionIntersection.x - targetRegion.x,
                        y: sourceOffset.y + regionIntersection.y - targetRegion.y,
                        width: regionIntersection.width,
                        height: regionIntersection.height,
                    },
                    {
                        r: pageDescriptor.channel,
                        g: 0,
                        b: 0,
                        a: 1,
                    },
                )
            }
        }

        this._copyToAtlasInfo.painter.paint({
            target: this._atlasDescriptor.atlas.image,
            sourceImages: [sourceImage],
            geometry: geometry,
        })
    }

    // HalImage
    get numMipLevels(): number {
        return this._descriptor.numMipLevels
    }

    // HalImage
    getMipLevelSize(mipLevel: number): Size2Like {
        if (mipLevel < 0 || mipLevel >= this._descriptor.numMipLevels) {
            throw Error("Invalid mip level.")
        }
        return getMipLevelSize(this._descriptor, mipLevel)
    }

    // HalImage
    readImageDataFloat(_region?: Box2Like): Float32Array {
        throw new Error("Method not implemented.")
    }

    // HalImage
    writeImageData(sourceData: HalImageSourceData, sourceRegion?: Box2Like, targetOffset?: Vector2Like) {
        sourceRegion ??= {x: 0, y: 0, width: this._descriptor.width, height: this._descriptor.height}
        targetOffset ??= {x: 0, y: 0}
        let sourceImage: HalImagePhysical
        let sourceOffset: Vector2Like
        if (isHalImagePhysical(sourceData)) {
            sourceImage = sourceData
            sourceOffset = {x: sourceRegion.x, y: sourceRegion.y}
        } else {
            const channelLayout = getChannelLayoutFromImageSourceData(sourceData)
            const dataType = getDataTypeFromImageSourceData(sourceData)
            sourceImage = this.context.requestSynchronousBufferImage({
                width: sourceRegion.width,
                height: sourceRegion.height,
                channelLayout: channelLayout,
                dataType: dataType,
            })
            sourceImage.writeImageData(sourceData, sourceRegion)
            sourceOffset = {x: 0, y: 0}
        }
        this.copyToAtlas(
            {
                x: targetOffset.x,
                y: targetOffset.y,
                width: sourceRegion.width,
                height: sourceRegion.height,
            },
            0,
            sourceImage,
            sourceOffset,
        )
    }

    // HalImageVirtual
    writePageImageData(pageDescriptor: PageDescriptor, sourceData: HalImageSourceData, sourceRegion?: Box2Like) {
        this.writeImageData(sourceData, sourceRegion, pageDescriptor.region)
    }

    private requestPage(args: {x: number; y: number; channel: number; mipLevel: number}) {
        if (!this._atlasDescriptor) {
            throw Error("Atlas descriptor is not set.")
        }
        const pageHash = this.computePageHash(args)
        let pageDescriptor = this._atlasDescriptor.pageDescriptorsByPageHash.get(pageHash)
        if (!pageDescriptor) {
            // this page is still missing
            const pageHandle = this._atlasDescriptor.atlas.addPage((pageHandle) => this.onPageRemoved(pageDescriptor!, pageHandle))
            pageDescriptor = {
                ...args,
                region: {
                    x: args.x * WebGl2ImageAtlas.pageSize,
                    y: args.y * WebGl2ImageAtlas.pageSize,
                    width: WebGl2ImageAtlas.pageSize,
                    height: WebGl2ImageAtlas.pageSize,
                },
                hash: pageHash,
                pageHandle: pageHandle,
            }
            this.onPageAdded(pageDescriptor)
        } else {
            // this page has been requested before; update its usage
            this._atlasDescriptor.atlas.registerPageUsage(pageDescriptor.pageHandle)
        }
        return pageDescriptor
    }

    private requestPages(args: {x: number; y: number; mipLevel: number}) {
        const pageDescriptors: PageDescriptor[] = []
        for (let channel = 0; channel < this._descriptor.numChannels; channel++) {
            const pageDescriptor = this.requestPage({...args, channel: channel})
            pageDescriptors.push(pageDescriptor)
        }
        return pageDescriptors
    }

    // HalImageVirtual
    requestRegion(region: Box2Like, mipLevel: number): PageDescriptor[] {
        if (mipLevel < 0 || mipLevel >= this._descriptor.numMipLevels) {
            throw new Error(`Invalid mip level: ${mipLevel}`)
        }
        if (TRACE) {
            console.log("requestRegion", region, mipLevel)
        }
        // determine which pages are still missing
        const mipLevelSize = this.getMipLevelSize(mipLevel)
        region = Box2.intersect(region, {x: 0, y: 0, width: mipLevelSize.width, height: mipLevelSize.height})
        const pageSize = WebGl2ImageAtlas.pageSize
        const thisMipLevel = mipLevel
        const p00 = new Vector2(region.x, region.y).mul(1 / (pageSize * (1 << thisMipLevel))).floor()
        const p11 = new Vector2(region.x + region.width, region.y + region.height).mul(1 / (pageSize * (1 << thisMipLevel))).ceil()
        const thisPageRegion = {
            x: p00.x,
            y: p00.y,
            width: p11.x - p00.x,
            height: p11.y - p00.y,
        }
        const allPageDescriptors: PageDescriptor[] = []
        for (let y = 0; y < thisPageRegion.height; y++) {
            for (let x = 0; x < thisPageRegion.width; x++) {
                const pageDescriptors = this.requestPages({x: thisPageRegion.x + x, y: thisPageRegion.y + y, mipLevel: thisMipLevel})
                allPageDescriptors.push(...pageDescriptors)
            }
        }
        return allPageDescriptors
    }

    registerCurrentPageUsage() {
        this._atlasDescriptor.pageHandles.forEach((pageHandle) => this._atlasDescriptor.atlas.registerPageUsage(pageHandle))
    }

    private readPageIndices(): Uint8Array {
        if (!this._atlasDescriptor) {
            throw Error("Atlas descriptor is not set.")
        }
        const numChannels = 4
        const numElements = this._atlasDescriptor.numPages.x * this._atlasDescriptor.numPages.y * numChannels
        const rawImageData = new Uint8Array(numElements)
        const gl = this.context.gl
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._descriptor.texture)
        const readFbo = gl.createFramebuffer()
        gl.bindFramebuffer(gl.READ_FRAMEBUFFER, readFbo)
        gl.pixelStorei(gl.PACK_ALIGNMENT, 1) // make sure to tightly pack the data
        gl.framebufferTextureLayer(gl.READ_FRAMEBUFFER, gl.COLOR_ATTACHMENT0, this._descriptor.texture, 0, 0)
        gl.readPixels(0, 0, this._atlasDescriptor.numPages.x, this._atlasDescriptor.numPages.y, gl.RGBA_INTEGER, gl.UNSIGNED_BYTE, rawImageData)
        checkForGlError(gl, "Reading pixels")
        gl.deleteFramebuffer(readFbo)
        return rawImageData
    }

    private computePageHash(args: {x: number; y: number; channel: number; mipLevel: number}): number {
        if (!this._atlasDescriptor) {
            throw Error("Atlas descriptor is not set.")
        }
        return (
            ((args.mipLevel * this._atlasDescriptor.numPages.y + args.y) * this._atlasDescriptor.numPages.x + args.x) * this._descriptor.numChannels +
            args.channel
        )
    }

    private onPageAdded(pageDescriptor: PageDescriptor): void {
        if (!this._atlasDescriptor) {
            throw Error("Atlas descriptor is not set.")
        }
        if (TRACE) {
            console.log("onPageAdded", pageDescriptor.pageHandle)
        }
        // update pageHandle in pageLookupTexture
        const gl = this.context.gl
        const currentBinding = gl.getParameter(gl.TEXTURE_BINDING_2D_ARRAY)
        gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1) // make sure to tightly pack the data
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._descriptor.texture)
        gl.texSubImage3D(
            gl.TEXTURE_2D_ARRAY,
            pageDescriptor.mipLevel,
            pageDescriptor.x,
            pageDescriptor.y,
            pageDescriptor.channel,
            1,
            1,
            1,
            gl.RGBA_INTEGER,
            gl.UNSIGNED_BYTE,
            new Uint8Array([...pageDescriptor.pageHandle.value.pageIndex.toArray(), 0, 1]),
        )
        checkForGlError(gl)
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, currentBinding) // restore previous binding
        this._atlasDescriptor.pageDescriptorsByPageHash.set(pageDescriptor.hash, pageDescriptor)
        this._atlasDescriptor.pageHandles.add(pageDescriptor.pageHandle)
        this.pageAdded.emit(pageDescriptor)
    }

    private onPageRemoved(pageDescriptor: PageDescriptor, pageHandle: PageHandle): void {
        if (!this._atlasDescriptor) {
            throw Error("Atlas descriptor is not set.")
        }
        if (TRACE) {
            console.log("onPageRemoved", pageHandle)
        }
        // update pageHandle in pageLookupTexture
        const gl = this.context.gl
        const currentBinding = gl.getParameter(gl.TEXTURE_BINDING_2D_ARRAY)
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._descriptor.texture)
        gl.texSubImage3D(
            gl.TEXTURE_2D_ARRAY,
            pageDescriptor.mipLevel,
            pageDescriptor.x,
            pageDescriptor.y,
            pageDescriptor.channel,
            1,
            1,
            1,
            gl.RGBA_INTEGER,
            gl.UNSIGNED_BYTE,
            new Uint8Array([0, 0, 0, 0]),
        )
        checkForGlError(gl)
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, currentBinding) // restore previous binding
        this._atlasDescriptor.pageDescriptorsByPageHash.delete(pageDescriptor.hash)
        this._atlasDescriptor.pageHandles.delete(pageHandle)
        this.pageRemoved.emit(pageDescriptor)
    }

    private createTexture(descriptor: HalImageDescriptor): {descriptor: WebGl2ImageDescriptor; atlasDescriptor: AtlasDescriptor} {
        const descriptorOptions = completeHalImageOptions(descriptor.options)

        if (descriptor.width < 0 || descriptor.height < 0) {
            throw Error("Image dimensions must be positive.")
        }
        if (!Number.isInteger(descriptor.width) || !Number.isInteger(descriptor.height)) {
            throw Error("Image dimensions must be integers.")
        }
        const gl = this.context.gl

        // if the texture size exceeds the atlas page size in either dimension we use the atlas, else we use the texture directly
        const atlasDescriptor: AtlasDescriptor = {
            atlas: this.context.atlas,
            numPages: {
                x: Math.ceil(descriptor.width / WebGl2ImageAtlas.pageSize),
                y: Math.ceil(descriptor.height / WebGl2ImageAtlas.pageSize),
            },
            pageDescriptorsByPageHash: new Map(),
            pageHandles: new Set(),
        }

        const start = performance.now()

        const numPages = new Vector2(atlasDescriptor.numPages.x, atlasDescriptor.numPages.y)
        const numMipLevels = descriptorOptions.useMipMaps ? Math.floor(Math.log2(Math.max(numPages.x, numPages.y))) + 1 : 1
        const numChannels = getNumChannels(descriptor.channelLayout)

        if (TRACE) {
            console.log(`Creating virtual WebGL texture with of size ${descriptor.width}x${descriptor.height} containing ${numMipLevels} mipmap levels.`)
        }

        const texture = gl.createTexture()
        if (!texture) {
            throw Error("Failed to create texture.")
        }
        const currentBinding = gl.getParameter(gl.TEXTURE_BINDING_2D_ARRAY)
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, texture)

        //gl.texStorage3D(gl.TEXTURE_2D_ARRAY, numMipLevels, gl.RGBA8UI, numPages.x, numPages.y, numChannels)

        // if (numMipLevels > 1) {
        //     let level = 0
        //     let width = numPages.x
        //     let height = numPages.y
        //     while (width > 1 || height > 1) {
        //         gl.texImage3D(gl.TEXTURE_2D_ARRAY, level, gl.RGBA8UI, width, height, numChannels, 0, gl.RGBA_INTEGER, gl.UNSIGNED_BYTE, null)
        //         width = Math.ceil(width / 2)
        //         height = Math.ceil(height / 2)
        //         level++
        //     }
        //     if (level !== numMipLevels) {
        //         throw Error("Internal error: Incorrect number of mip levels.")
        //     }
        // } else {
        //     gl.texStorage3D(gl.TEXTURE_2D_ARRAY, numMipLevels, gl.RGBA8UI, numPages.x, numPages.y, numChannels)
        // }

        const nextPowerOfTwo = (value: number) => 2 ** Math.ceil(Math.log2(value))
        const texSize = nextPowerOfTwo(Math.max(numPages.x, numPages.y))
        gl.texStorage3D(gl.TEXTURE_2D_ARRAY, numMipLevels, gl.RGBA8UI, texSize, texSize, numChannels)

        const lastError = gl.getError()
        if (lastError !== gl.NO_ERROR) {
            if (lastError === gl.OUT_OF_MEMORY) {
                throw new OutOfMemoryError("Failed to create texture: Out of memory.")
            } else {
                throw Error(`Failed to create texture (${lastError}).`)
            }
        }
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, currentBinding) // restore previous binding

        const end = performance.now()
        if (TRACE) {
            console.log("Created virtual GPU image in " + (end - start) + "ms")
        }

        const atlas = this.context.atlas
        // TODO check atlas channellayout/datatype compatibility

        return {
            descriptor: {
                width: descriptor.width,
                height: descriptor.height,
                channelLayout: descriptor.channelLayout,
                dataType: descriptor.dataType,
                textureOffset: new Vector2(0, 0),
                shardSize: atlas.image.descriptor.shardSize,
                numShards: atlas.image.descriptor.numShards,
                numMipLevels: numMipLevels,
                numChannels: numChannels,
                texture: texture,
                atlas: atlas,
            },
            atlasDescriptor: atlasDescriptor,
        }
    }

    private createCopyToAtlasInfo(): CopyToAtlasInfo {
        const painter = new WebGl2PainterPrimitive(
            this.context,
            `
            vec4 computeColor(vec2 position, vec2 uv, vec4 color) {
                int channel = int(round(color.r));
                float value = texelChannelFetchLod0(ivec2(uv), 0, channel, ADDRESS_MODE_CLAMP_TO_EDGE, 0.0);
                return vec4(value, 0, 0, 1);
            }
            `,
        )
        const geometry = new WebGl2Geometry(this.context)
        return {
            painter: painter,
            geometry: geometry,
        }
    }

    private _descriptor: WebGl2ImageDescriptor
    private _atlasDescriptor: AtlasDescriptor
    private _currentDrawInfo: DrawInfo | null = null
    private _copyToAtlasInfo: CopyToAtlasInfo
}

type AtlasDescriptor = {
    atlas: WebGl2ImageAtlas
    numPages: Vector2Like
    pageDescriptorsByPageHash: Map<number, PageDescriptor>
    pageHandles: Set<PageHandle>
}

type DrawInfo = {
    bufferImage: HalImagePhysical
}

type CopyToAtlasInfo = {
    painter: WebGl2PainterPrimitive
    geometry: WebGl2Geometry
}

export type PageDescriptor = {
    x: number
    y: number
    mipLevel: number
    channel: number
    region: Box2Like
    pageHandle: PageHandle
    hash: number
}

export class OutOfMemoryError extends Error {
    constructor(message: string) {
        super(message)
        this.name = "OutOfMemoryError"
    }
}
