import {ImageOpCommandQueueWebGL2} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-op-command-queue-webgl2"
import {ImageRef} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-ref"
import {
    createGaussianPyramid,
    ReturnType as GaussianImagePyramidReturnType,
} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/composite/create-gaussian-pyramid"
import {OddPixelStrategy} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/composite/down-sample"
import {normalizedCrossCorrelation} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/composite/normalized-cross-correlation"
import {createImage} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-create-image"
import {CachedGaussianImagePyramid} from "@app/textures/texture-editor/operator-stack/image-op-system/utils/caching/cached-gaussian-image-pyramid"
import {DebugImage} from "@app/textures/texture-editor/operator-stack/image-op-system/utils/debug-image"
import {findCorrelationPeak} from "@app/textures/texture-editor/operator-stack/operators/tiling/helpers/find-correlation-peak"
import {penalizeByDistance} from "@app/textures/texture-editor/operator-stack/operators/tiling/helpers/penalize-by-distance"
import {Box2Like, Vector2, Vector2Like} from "@cm/math"

const SCOPE_NAME = "HierarchicalCrossCorrelation"

export type ParameterType = {
    sourceImage: ImageRef
    sourceImageReferencePosition: Vector2Like // reference position in the source image (in pixels)
    templateImage: ImageRef
    templateImageReferencePosition: Vector2Like // reference position in the template image (in pixels)
    maxSearchRadius: number // maximum number of pixels to search around the source reference position
    correlationPenaltyPerPixel: number | Vector2Like
    penaltyDirectionX?: Vector2Like
    cacheData?: CacheData
    debugImage?: DebugImage
    debugRectFn?: (rect: Box2Like, color: string) => void
}

export type ReturnType = ImageRef // 3-channel image with the position (x, y) and the correlation value (z)

export const hierarchicalCrossCorrelation = (
    cmdQueue: ImageOpCommandQueueWebGL2,
    {
        sourceImage,
        sourceImageReferencePosition,
        templateImage,
        templateImageReferencePosition,
        maxSearchRadius,
        correlationPenaltyPerPixel,
        penaltyDirectionX,
        cacheData,
        debugImage,
    }: ParameterType,
): ReturnType => {
    cmdQueue.beginScope(SCOPE_NAME)

    const correlationNeighborhood = 8 // in each direction
    const searchOfs = -1
    const searchSize = 4

    sourceImageReferencePosition = Vector2.fromVector2Like(sourceImageReferencePosition).floorInPlace()
    templateImageReferencePosition = Vector2.fromVector2Like(templateImageReferencePosition).floorInPlace()
    const minSourceSize = Math.min(sourceImage.descriptor.width, sourceImage.descriptor.height)
    const minTemplateSize = Math.min(templateImage.descriptor.width, templateImage.descriptor.height)
    const minSize = Math.min(minSourceSize, minTemplateSize)
    const maxLevels = Math.ceil(Math.log2(minSize)) + 1
    const numLevels = Math.min(maxLevels, Math.ceil(Math.log2(maxSearchRadius)))
    const sigma = 0
    const oddPixelStrategy: OddPixelStrategy = "zero"

    // create source image pyramid
    const sourceImagePyramid =
        cacheData?.sourceImagePyramid ??
        createGaussianPyramid(cmdQueue, {
            sourceImage: sourceImage,
            sigma,
            oddPixelStrategy,
        })
    const sourceCoverageImagePyramid =
        cacheData?.sourceCoverageImagePyramid ??
        createGaussianPyramid(cmdQueue, {
            sourceImage: createImage(cmdQueue, {
                imageOrDescriptor: {
                    ...sourceImage.descriptor,
                    channelLayout: "R",
                    dataType: "uint8",
                    options: undefined,
                },
                fillColor: {r: 1, g: 1, b: 1},
            }),
            sigma,
            oddPixelStrategy,
        })
    if (sourceImagePyramid.descriptor.levels < numLevels || sourceCoverageImagePyramid.descriptor.levels < numLevels) {
        throw new Error("Invalid number of levels in source image pyramid")
    }

    // create template image pyramid
    const templateImagePyramid =
        (cacheData?.templateImagePyramid ?? templateImage === sourceImage)
            ? sourceImagePyramid
            : createGaussianPyramid(cmdQueue, {
                  sourceImage: templateImage,
                  sigma,
                  oddPixelStrategy,
              })
    const templateCoverageImagePyramid =
        (cacheData?.templateCoverageImagePyramid ?? templateImage === sourceImage)
            ? sourceCoverageImagePyramid
            : createGaussianPyramid(cmdQueue, {
                  sourceImage: createImage(cmdQueue, {
                      imageOrDescriptor: {
                          ...templateImage.descriptor,
                          channelLayout: "R",
                          dataType: "uint8",
                          options: undefined,
                      },
                      fillColor: {r: 1, g: 1, b: 1},
                  }),
                  sigma,
                  oddPixelStrategy,
              })
    if (templateImagePyramid.descriptor.levels < numLevels || templateCoverageImagePyramid.descriptor.levels < numLevels) {
        throw new Error("Invalid number of levels in template image pyramid")
    }

    // if (debugImage) {
    //     for (const image of sourceImagePyramid) {
    //         await debugImage.addImage(image)
    //         const desc = await imageOpContextWebGL2.getImageDescriptor(image)
    //         console.log("Source pyramid level", desc.width, desc.height)
    //     }
    // }

    if (cacheData) {
        cacheData.set(cmdQueue, {
            sourceImagePyramid,
            sourceCoverageImagePyramid,
            templateImagePyramid,
            templateCoverageImagePyramid,
        })
    }

    const painterUpdateSourceOffset = cmdQueue.createPainter(
        "compositor",
        "updateSourceOffset",
        `
        uniform int u_level;
        uniform int u_numLevels;

        vec4 computeColor(ivec2 targetPixel) {
            vec3 sourceOffsetAndCorrelation = texelFetch0(targetPixel).rgb;
            ivec2 prevSourceOffset = ivec2(sourceOffsetAndCorrelation.rg);
            float prevCorrelation = sourceOffsetAndCorrelation.b;
            vec3 peakCorrelationAndOffset = texelFetch1(targetPixel).rgb;
            float peakCorrelation = peakCorrelationAndOffset.r;
            ivec2 peakOffset = ivec2(peakCorrelationAndOffset.gb);
            float newCorrelation = prevCorrelation + peakCorrelation;
            ivec2 newSourceOffset = prevSourceOffset + peakOffset + ivec2(${searchOfs});
            if (u_level != 0) {
                newSourceOffset *= 2;   // prepare for next level
            } else {
                // divide newCorrelation by numLevels to get the final averaged correlation across levels
                newCorrelation /= float(u_numLevels);
            }
            return vec4(vec2(newSourceOffset), newCorrelation, 1);
        }
    `,
    )
    const initialSourceOffset = Vector2.fromVector2Like(sourceImageReferencePosition)
        .divInPlace(2 ** (numLevels - 1))
        .floorInPlace()
    let sourceOffsetAndCorrelation = createImage(cmdQueue, {
        imageOrDescriptor: {
            width: 1,
            height: 1,
            channelLayout: "RGB",
            dataType: "float32",
        },
        fillColor: {r: initialSourceOffset.x, g: initialSourceOffset.y, b: 0},
    })
    const vecCorrelationPenaltyPerPixel =
        typeof correlationPenaltyPerPixel === "number"
            ? new Vector2(correlationPenaltyPerPixel, correlationPenaltyPerPixel)
            : Vector2.fromVector2Like(correlationPenaltyPerPixel)
    for (let level = numLevels - 1; level >= 0; level--) {
        const levelScale = 2 ** level
        // cross-correlation
        // const templateOffset = new Vector2(
        //     templateRegion.width / levelScale / 2 - correlationWindowSize / 2,
        //     templateRegion.height / levelScale / 2 - correlationWindowSize / 2,
        // )
        // if (showDebugInfo && debugRectFn) {
        //     using sourceOffsetAndCorrelationWebGl2 = await imageOpContextWebGL2.getImage(sourceOffsetAndCorrelation)
        //     const sourceOffsetAndCorrelationData = await sourceOffsetAndCorrelationWebGl2.ref.halImage.readRawImageData("float32")
        //     const sourceOffset = new Vector2(sourceOffsetAndCorrelationData[0], sourceOffsetAndCorrelationData[1])
        //     debugRectFn(
        //         {
        //             x: sourceRegion.x + (sourceOffset.x << level),
        //             y: sourceRegion.y + (sourceOffset.y << level),
        //             width: (correlationWindowSize + searchSize - 1) << level,
        //             height: (correlationWindowSize + searchSize - 1) << level,
        //         },
        //         "blue",
        //     )
        // }
        const targetOffset = Vector2.fromVector2Like(templateImageReferencePosition).divInPlace(levelScale).floorInPlace()
        const correlation = normalizedCrossCorrelation(cmdQueue, {
            sourceImage: sourceImagePyramid.resultImages[level],
            sourceWeightImage: sourceCoverageImagePyramid?.resultImages[level],
            sourceRegion: {
                offsetImage: sourceOffsetAndCorrelation,
                offset: {x: -correlationNeighborhood + searchOfs, y: -correlationNeighborhood + searchOfs},
                width: 1 + 2 * correlationNeighborhood + searchSize - 1,
                height: 1 + 2 * correlationNeighborhood + searchSize - 1,
            },
            templateImage: templateImagePyramid.resultImages[level],
            templateWeightImage: templateCoverageImagePyramid?.resultImages[level],
            templateRegion: {
                x: targetOffset.x - correlationNeighborhood,
                y: targetOffset.y - correlationNeighborhood,
                width: 1 + 2 * correlationNeighborhood,
                height: 1 + 2 * correlationNeighborhood,
            },
            options: {premultipliedImages: true},
            debugImage: debugImage,
        })
        const distancePenalizedCorrelation = penalizeByDistance(cmdQueue, {
            correlation: correlation,
            offsetImage: sourceOffsetAndCorrelation,
            offset: {x: searchOfs, y: searchOfs},
            referencePosition: sourceImageReferencePosition,
            imageSize: sourceImage.descriptor,
            level: level,
            penaltyPerPixel: vecCorrelationPenaltyPerPixel,
            penaltyDirectionX: penaltyDirectionX,
        })
        debugImage?.addImage(distancePenalizedCorrelation, {scale: 0.5, offset: 0.5})
        const maxValueAndIndex = findCorrelationPeak(cmdQueue, {correlation: distancePenalizedCorrelation})
        if (debugImage) {
            cmdQueue.lambda({maxValueAndIndex}, async ({maxValueAndIndex}) => {
                const maxValueAndIndexImage = await cmdQueue.context.getImage(maxValueAndIndex)
                const data = await maxValueAndIndexImage.ref.halImage.readImageDataFloat()
                const maxValue = data[0]
                const maxIndex = new Vector2(data[1], data[2])
                console.log("Peak found at level", level, ":", maxIndex, maxValue)
            })
        }

        const nextSourceOffsetAndCorrelation = createImage(cmdQueue, {
            imageOrDescriptor: sourceOffsetAndCorrelation,
            fillColor: undefined,
        })
        cmdQueue.paint(painterUpdateSourceOffset, {
            parameters: {
                u_level: {type: "int", value: level},
                u_numLevels: {type: "int", value: numLevels},
            },
            sourceImages: [sourceOffsetAndCorrelation, maxValueAndIndex],
            resultImage: nextSourceOffsetAndCorrelation,
        })
        sourceOffsetAndCorrelation = nextSourceOffsetAndCorrelation

        // if (showDebugInfo) {
        //     const sourceOffsetData = await sourceOffsetWebGl2.ref.halImage.readRawImageData("float32")
        //     const sourceOffset = new Vector2(sourceOffsetData[0], sourceOffsetData[1])
        //     const maxValueAndIndexData = await maxValueAndIndexWebGL2.ref.halImage.readRawImageData("float32")
        //     const maxIndex = new Vector2(maxValueAndIndexData[1], maxValueAndIndexData[2])
        //     const maxCorrelation = maxValueAndIndexData[0]
        //     console.log(`Peak found at level ${level}: Index: ${maxIndex.x}, ${maxIndex.y}, Value: ${maxCorrelation}`)
        //     // const secondMaxValueAndIndexData = await secondMaxValueAndIndexWebGL2.ref.halImage.readRawImageData("float32")
        //     // const secondMaxIndex = new Vector2(secondMaxValueAndIndexData[1], secondMaxValueAndIndexData[2])
        //     // const secondMaxCorrelation = secondMaxValueAndIndexData[0]
        //     // console.log(`Second peak found at level ${level}: Index: ${secondMaxIndex.x}, ${secondMaxIndex.y}, Value: ${secondMaxCorrelation}`)
        //     if (debugRectFn) {
        //         // draw peak box
        //         debugRectFn(
        //             {
        //                 x: sourceRegion.x + (sourceOffset.x << level),
        //                 y: sourceRegion.y + (sourceOffset.y << level),
        //                 width: correlationWindowSize << level,
        //                 height: correlationWindowSize << level,
        //             },
        //             "lightblue",
        //         )
        //         debugRectFn(
        //             {
        //                 x: templateRegion.x + (templateOffset.x << level),
        //                 y: templateRegion.y + (templateOffset.y << level),
        //                 width: correlationWindowSize << level,
        //                 height: correlationWindowSize << level,
        //             },
        //             "lightgreen",
        //         )
        //     }
        // }

        if (level === 0) {
            break
        }
    }
    cmdQueue.endScope(SCOPE_NAME)
    return sourceOffsetAndCorrelation
}

export class CacheData {
    set(
        cmdQueue: ImageOpCommandQueueWebGL2,
        data: {
            sourceImagePyramid?: GaussianImagePyramidReturnType
            sourceCoverageImagePyramid?: GaussianImagePyramidReturnType
            templateImagePyramid?: GaussianImagePyramidReturnType
            templateCoverageImagePyramid?: GaussianImagePyramidReturnType
        },
    ) {
        this.dispose()
        this._sourceImagePyramid.set(cmdQueue, data.sourceImagePyramid)
        this._sourceCoverageImagePyramid.set(cmdQueue, data.sourceCoverageImagePyramid)
        this._templateImagePyramid.set(cmdQueue, data.templateImagePyramid)
        this._templateCoverageImagePyramid.set(cmdQueue, data.templateCoverageImagePyramid)
    }

    dispose() {
        this._sourceImagePyramid.dispose()
        this._sourceCoverageImagePyramid.dispose()
        this._templateImagePyramid.dispose()
        this._templateCoverageImagePyramid.dispose()
    }

    get sourceImagePyramid() {
        return this._sourceImagePyramid.getIfExists()
    }

    get sourceCoverageImagePyramid() {
        return this._sourceCoverageImagePyramid.getIfExists()
    }

    get templateImagePyramid() {
        return this._templateImagePyramid.getIfExists()
    }

    get templateCoverageImagePyramid() {
        return this._templateCoverageImagePyramid.getIfExists()
    }

    private _sourceImagePyramid = new CachedGaussianImagePyramid()
    private _sourceCoverageImagePyramid = new CachedGaussianImagePyramid()
    private _templateImagePyramid = new CachedGaussianImagePyramid()
    private _templateCoverageImagePyramid = new CachedGaussianImagePyramid()
}
