import {ImageRef} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-ref"
import {ImageOpCommandQueueWebGL2} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-op-command-queue-webgl2"
import {assertSameChannelLayout, assertSameSize} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/utils"
import {convertRgbToLab} from "@app/textures/texture-editor/operator-stack/operators/tiling/helpers/convert-rgb-to-lab"
import {math} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-math"
import {Vector2, Vector2Like} from "@cm/math"
import {AStarPathFinding} from "@app/textures/texture-editor/operator-stack/operators/tiling/helpers/a-star-path-finding"
import {createImage} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-create-image"
import {BoundaryDirection} from "@app/textures/texture-editor/operator-stack/operators/tiling/toolbox/tiling-area/boundary-item"
import {FloodFill} from "@app/textures/texture-editor/operator-stack/operators/tiling/helpers/flood-fill"
import {gaussianBlur} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/composite/gaussian-blur"
import {TextureEditorSettings} from "@app/textures/texture-editor/texture-editor-settings"
import {createTypedArrayImage} from "@common/models/hal/hal-image/utils"

// Computes a mask image from a foreground and background image by finding the path of least perceptual difference between the two images.

const TRACE = TextureEditorSettings.EnableFullTrace

const SCOPE_NAME = "ComputeMaskImage"

export type ParameterType = {
    backgroundImage: ImageRef
    foregroundImage: ImageRef
    direction: BoundaryDirection
}

export type ReturnType = ImageRef

export const computeBorderMaskImage = (cmdQueue: ImageOpCommandQueueWebGL2, {backgroundImage, foregroundImage, direction}: ParameterType): ReturnType => {
    cmdQueue.beginScope(SCOPE_NAME)

    assertSameChannelLayout(backgroundImage.descriptor, foregroundImage.descriptor)
    assertSameSize(backgroundImage.descriptor, foregroundImage.descriptor)

    // if RGB convert to CIE-LAB
    if (backgroundImage.descriptor.channelLayout !== "R") {
        backgroundImage = convertRgbToLab(cmdQueue, {sourceImage: backgroundImage})
        foregroundImage = convertRgbToLab(cmdQueue, {sourceImage: foregroundImage})
    }

    // use perceptual difference between foreground and background image as cost
    const costImage = math(cmdQueue, {
        operator: "distance",
        operandA: backgroundImage,
        operandB: foregroundImage,
    })

    // construct mask image by finding the path of the least cost
    let maskImage = createImage(cmdQueue, {
        imageOrDescriptor: {
            width: backgroundImage.descriptor.width,
            height: backgroundImage.descriptor.height,
            channelLayout: "R",
            dataType: "uint8",
        },
        fillColor: undefined,
    })
    // TODO remove this dummy operation (this is currently necessary to avoid a bug in the image op system: when the lambda function below is executed it gets-writes-releases the WebGl2 mask image,
    // TODO but since nobody else has written or in any way accessed the WebGl2 image, the release call will dispose it. the image op system should keep a reference and avoid release in this case)
    maskImage = math(cmdQueue, {
        operator: "+",
        operandA: maskImage,
        operandB: 0,
    })
    cmdQueue.lambda({maskImage, costImage, direction}, async ({maskImage, costImage, direction}) => {
        // we perform A* search to find the path of least perceptual difference using a CPU implementation for now
        const costImageWebGl2 = await cmdQueue.context.getImage(costImage)
        const costData = await costImageWebGl2.ref.halImage.readImageDataFloat()
        costImageWebGl2.release()
        // fill border boundary with high cost to avoid getting too near it
        const borderCost = 1000000
        const borderWidthRatio = 0.2
        let offsetPixel, offsetLine, numPixels, numLines, offsetOtherBoundary: number
        if (direction === BoundaryDirection.Horizontal) {
            offsetPixel = 1
            offsetLine = costImage.descriptor.width
            numPixels = costImage.descriptor.width
            numLines = Math.ceil(costImage.descriptor.height * borderWidthRatio)
            offsetOtherBoundary = (costImage.descriptor.height - numLines) * costImage.descriptor.width
        } else {
            offsetPixel = costImage.descriptor.width
            offsetLine = 1
            numPixels = costImage.descriptor.height
            numLines = Math.ceil(costImage.descriptor.width * borderWidthRatio)
            offsetOtherBoundary = costImage.descriptor.width - numLines
        }
        let jPos = 0
        for (let j = 0; j < numLines; j++) {
            let iPos = jPos
            for (let i = 0; i < numPixels; i++) {
                costData[iPos] = borderCost
                costData[iPos + offsetOtherBoundary] = borderCost
                iPos += offsetPixel
            }
            jPos += offsetLine
        }
        // determine start and end pixel
        let startPixel, endPixel: Vector2
        if (direction === BoundaryDirection.Horizontal) {
            startPixel = new Vector2(0, Math.floor(costImage.descriptor.height / 2))
            endPixel = new Vector2(costImage.descriptor.width - 1, Math.floor(costImage.descriptor.height / 2))
        } else {
            startPixel = new Vector2(Math.floor(costImage.descriptor.width / 2), 0)
            endPixel = new Vector2(Math.floor(costImage.descriptor.width / 2), costImage.descriptor.height - 1)
        }
        // find path
        const pathFindingStartTime = performance.now()
        const tupleToIndex = (p: Vector2Like) => p.y * costImage.descriptor.width + p.x
        const indexToTuple = (index: number) => new Vector2(Math.floor(index % costImage.descriptor.width), Math.floor(index / costImage.descriptor.width))
        const distance = (from: Vector2Like, to: Vector2Like) => Math.max(Math.abs(from.x - to.x), Math.abs(from.y - to.y)) // chebyshev distance due to 8-connectivity
        const path = AStarPathFinding.findPath(
            costData.length,
            tupleToIndex(startPixel),
            tupleToIndex(endPixel),
            (from: number, to: number) => distance(indexToTuple(from), indexToTuple(to)),
            (index: number) => {
                const tuple = indexToTuple(index)
                const neighbors: number[] = []
                for (let y = tuple.y - 1; y <= tuple.y + 1; y++) {
                    for (let x = tuple.x - 1; x <= tuple.x + 1; x++) {
                        if (x === tuple.x && y === tuple.y) {
                            continue
                        }
                        if (x < 0 || x >= costImage.descriptor.width || y < 0 || y >= costImage.descriptor.height) {
                            continue
                        }
                        neighbors.push(tupleToIndex({x, y}))
                    }
                }
                return neighbors
            },
            (from: number, to: number) => {
                const costWeight = 5 // weight of cost in path finding; this value seems to work well
                return distance(indexToTuple(from), indexToTuple(to)) + (costData[to] + costData[from]) * costWeight
            },
        )
        const pathFindingEndTime = performance.now()
        if (TRACE) {
            console.log("path finding took", pathFindingEndTime - pathFindingStartTime, "ms")
            console.log("path", path)
        }

        const fillStartTime = performance.now()
        const maskData = new Float32Array(maskImage.descriptor.width * maskImage.descriptor.height).fill(1)
        path.forEach((index) => {
            maskData[index] = 0
        })
        let fillStartPos: Vector2
        if (direction === BoundaryDirection.Horizontal) {
            fillStartPos = new Vector2(Math.floor(costImage.descriptor.width / 2), 0)
        } else {
            fillStartPos = new Vector2(0, Math.floor(costImage.descriptor.height / 2))
        }
        FloodFill.fill(maskData, maskImage.descriptor, fillStartPos, 0)
        const fillEndTime = performance.now()
        if (TRACE) {
            console.log("filling took", fillEndTime - fillStartTime, "ms")
        }

        const copyToGpuStartTime = performance.now()
        const maskImageWebGl2 = await cmdQueue.context.getImage(maskImage)
        maskImageWebGl2.ref.halImage.writeImageData(createTypedArrayImage(maskImage.descriptor.width, maskImage.descriptor.height, "R", maskData))
        maskImageWebGl2.release()
        const copyToGpuEndTime = performance.now()
        if (TRACE) {
            console.log("copy to GPU took", copyToGpuEndTime - copyToGpuStartTime, "ms")
        }
    })
    maskImage = gaussianBlur(cmdQueue, {
        sourceImage: maskImage,
        sigma: 1,
        borderMode: "renormalize",
    })

    cmdQueue.endScope(SCOPE_NAME)
    return maskImage
}
