import {CmLogger} from "#utils/log"

export type FunctionInfo = {fn: (...args: any[]) => any; inPlaceArgs?: number[]; memCost?: number; cpuCost?: number}
export type Node = {fn: FunctionInfo; args: Node[]}

type Entry = {fn: FunctionInfo["fn"]; argSlots: number[]; outSlot: number} | {clearSlot: number}

function countReferences(root: Node) {
    const refCounts = new Map<Node, number>()
    const traverse = (node: Node): void => {
        const refCount = refCounts.get(node)
        if (refCount == null) {
            refCounts.set(node, 1)
            node.args.forEach(traverse)
        } else {
            refCounts.set(node, refCount + 1)
        }
    }
    traverse(root)
    return refCounts
}

export function compileFunctionGraph(root: Node, logger: CmLogger): () => Promise<any> {
    // const compileBegin = performance.now();

    const refCounts = countReferences(root)
    const slotForNode = new Map<Node, number>()
    const entries: Entry[] = []
    const slotActive: boolean[] = []
    let totalMemCost = 0
    let totalCpuCost = 0

    const traverse = (node: Node): number => {
        const existingSlot = slotForNode.get(node)
        if (existingSlot != null) {
            return existingSlot
        }

        const argSlots = node.args.map(traverse)
        let outSlot = -1
        if (node.fn.inPlaceArgs) {
            for (const idx of node.fn.inPlaceArgs) {
                if (refCounts.get(node.args[idx]) == 1) {
                    // last ref
                    outSlot = argSlots[idx]
                }
            }
        }
        if (outSlot === -1) {
            // allocate new slot
            totalMemCost += node.fn.memCost ?? 0
            outSlot = slotActive.indexOf(false)
            if (outSlot === -1) {
                outSlot = slotActive.length
                slotActive.push(true)
            } else {
                slotActive[outSlot] = true
            }
        }
        totalCpuCost += node.fn.cpuCost ?? 0
        entries.push({fn: node.fn.fn, argSlots, outSlot})
        slotForNode.set(node, outSlot)
        for (let idx = 0; idx < argSlots.length; idx++) {
            const arg = node.args[idx]
            const argSlot = argSlots[idx]
            const refCount = refCounts.get(arg)! - 1
            refCounts.set(arg, refCount)
            if (refCount === 0 && argSlot !== outSlot) {
                // release slot
                entries.push({clearSlot: argSlot})
                slotActive[argSlot] = false
            }
        }

        return outSlot
    }
    traverse(root)
    // console.log(`Compilation time: ${Math.round(performance.now() - compileBegin)} ms`);
    // console.log(`Estimated cost factors: mem = ${Math.round(totalMemCost*100)/100} / cpu = ${Math.round(totalCpuCost*100)/100}`);
    const numSlots = slotActive.length
    return async () => {
        const slots = new Array(numSlots).fill(undefined)
        let lastOut: any
        let idx = 0
        // const begin = performance.now();
        try {
            for (const entry of entries) {
                if ("fn" in entry) {
                    // const fnBegin = performance.now();
                    lastOut = await entry.fn(...entry.argSlots.map((x) => slots[x]))
                    // console.log(`Function ${entry.fn.name} took ${Math.round(performance.now() - fnBegin)} ms`);
                    slots[entry.outSlot] = lastOut
                } else {
                    slots[entry.clearSlot] = undefined
                }
                ++idx
            }
        } catch (e) {
            logger.error(`At index ${idx}:`)
            entries.forEach((x, i) => logger.log(i.toString(), (x as any).fn?.name, x))
            throw e
        }
        // console.log(`Total execution time: ${Math.round(performance.now() - begin)} ms`);
        return lastOut
    }
}
