// TODO: Select as first centroid the color with the highest pixel count
// TODO: Select largest groups of colors as centroids (or 50% of the k)

// K-means clustering algorithm
import {scan as d3Scan} from 'd3-array';
import {ColorDistance} from './colorDistance';

export const kmc = {
    quantize: function (pixels, k, centroidInitialization = 'random', colorComparisionMethod = 'Euclidean', numDistributedColors, colorSelectionPool, maxIterations = 100) {
        return doKMCQuantization(pixels, k, centroidInitialization, colorComparisionMethod, maxIterations, numDistributedColors, colorSelectionPool);
    },

    quantizeEuclidean: function (pixels, k, centroidInitialization = 'random', maxIterations = 100, numDistributedColors, colorSelectionPool) {
        return doKMCQuantization(pixels, k, centroidInitialization, 'Euclidean', maxIterations, numDistributedColors, colorSelectionPool);
    },

    quantizeCIEDE2000: function (pixels, k, centroidInitialization = 'random', maxIterations = 100, numDistributedColors, colorSelectionPool) {
        return doKMCQuantization(pixels, k, centroidInitialization, 'CIEDE2000', maxIterations, numDistributedColors, colorSelectionPool);
    }
}

function doKMCQuantization(pixels, k, centroidInitialization = 'random', colorComparisionMethod = 'Euclidean', maxIterations = 100, numDistributedColors, colorSelectionPool) {
    console.log("KMC Quantization. Centroid Initialization: " + centroidInitialization + ", Max Iterations: " + maxIterations + ", K: " + k + ", Pixels: " + pixels.length / 4 + ", Color Comparison Method: " + colorComparisionMethod
    + ", Num Distributed Colors: " + numDistributedColors + ", Color Selection Pool: " + colorSelectionPool);

    function distance(color1, color2) {
        return ColorDistance.distance(color1, color2, colorComparisionMethod);
    }

    function averageColor(colors) {
        const sum = colors.reduce((acc, color) => [acc[0] + color[0], acc[1] + color[1], acc[2] + color[2]], [0, 0, 0]);
        const count = colors.length;
        return [sum[0] / count, sum[1] / count, sum[2] / count];
    }

    function initCentroids(pixels, k, numDistributedColors, colorSelectionPool) {
        switch (centroidInitialization.toLowerCase()) {
            case 'random':
                return CentroidInitialization.initCentroidsRandom(pixels, k);
            case 'k-means++':
                return CentroidInitialization.initCentroidsPlusPlus(pixels, k, distance);
            case 'top n distributed':
                return CentroidInitialization.initCentroidsTopNDistributed(pixels, k, numDistributedColors, colorSelectionPool, distance);
            default:
                throw new Error('Invalid centroid initialization method');
        }
    }

    // Initialize centroids
    const centroids = initCentroids(pixels, k, numDistributedColors, colorSelectionPool);

    let iteration = 0;
    let centroidsChanged = true;

    while (centroidsChanged && iteration < maxIterations) {
        const clusters = Array.from({length: k}, () => []);

        // Assign each color to the nearest centroid
        for (let i = 0; i < pixels.length; i += 4) {
            const color = [pixels[i], pixels[i + 1], pixels[i + 2]];
            const nearestCentroidIndex = centroids
                .map((centroid, index) => [distance(color, centroid), index])
                .reduce((min, current) => (current[0] < min[0] ? current : min))[1];

            clusters[nearestCentroidIndex].push(color);
        }

        // Update centroids
        centroidsChanged = false;
        for (let i = 0; i < k; i++) {
            const newCentroid = averageColor(clusters[i]);
            if (distance(newCentroid, centroids[i]) > 1) {
                centroidsChanged = true;
            }
            centroids[i] = newCentroid;
        }

        iteration++;
    }

    const reducedPixels = new Uint8Array(pixels.length);

    // Assign the final centroid colors to the corresponding pixels
    for (let i = 0; i < pixels.length; i += 4) {
        const color = [pixels[i], pixels[i + 1], pixels[i + 2]];
        const nearestCentroidIndex = centroids
            .map((centroid, index) => [distance(color, centroid), index])
            .reduce((min, current) => (current[0] < min[0] ? current : min))[1];

        reducedPixels[i] = centroids[nearestCentroidIndex][0];
        reducedPixels[i + 1] = centroids[nearestCentroidIndex][1];
        reducedPixels[i + 2] = centroids[nearestCentroidIndex][2];
        reducedPixels[i + 3] = 255; // Assuming fully opaque pixels
    }

    return reducedPixels;
}

export const CentroidInitialization = {
    initCentroidsRandom: function (pixels, k) {
        const centroids = [];
        const pixelCount = pixels.length / 4;

        for (let i = 0; i < k; i++) {
            const randomPixelIndex = Math.floor(Math.random() * pixelCount);
            centroids.push([
                pixels[randomPixelIndex * 4],
                pixels[randomPixelIndex * 4 + 1],
                pixels[randomPixelIndex * 4 + 2],
            ]);
        }

        return centroids;
    },

    initCentroidsPlusPlus: function (pixels, k, colorDistanceFunction) {
        const centroids = [];
        const pixelCount = pixels.length / 4;

        let randomPixelIndex = Math.floor(Math.random() * pixelCount);
        centroids.push([
            pixels[randomPixelIndex * 4],
            pixels[randomPixelIndex * 4 + 1],
            pixels[randomPixelIndex * 4 + 2],
        ]);

        for (let i = 1; i < k; i++) {
            const distances = [];

            for (let j = 0; j < pixelCount; j++) {
                const pixel = [
                    pixels[j * 4],
                    pixels[j * 4 + 1],
                    pixels[j * 4 + 2],
                ];
                const nearestCentroidDist = centroids.reduce((minDist, centroid) => {
                    const dist = colorDistanceFunction(centroid, pixel);
                    return Math.min(minDist, dist);
                }, Infinity);

                distances.push(nearestCentroidDist);
            }

            const totalDist = distances.reduce((sum, dist) => sum + dist, 0);
            const probabilities = distances.map((dist) => dist / totalDist);

            randomPixelIndex = d3Scan(probabilities, (a, b) => b - a);
            centroids.push([
                pixels[randomPixelIndex * 4],
                pixels[randomPixelIndex * 4 + 1],
                pixels[randomPixelIndex * 4 + 2],
            ]);
        }

        return centroids;
    },

    /**
     * Initializes centroids by selecting the top numDistributedColors with the most pixels and distributing them
     * across the color space. The remaining centroids are then initialized using the traditional
     * k-means++ approach.
     *
     * @param {Uint8Array} pixels - The input pixel data in RGBA format
     * @param {number} k - The number of centroids to initialize
     * @param {number} numDistributedColors - The number of top colors with the most pixels to consider for distribution
     * @param {number} colorSelectionPool - The number of colors with the highest pixel count to choose from
     * @param {function} colorDistanceFunction - A function that calculates the distance between two colors
     * @returns {Array<Array<number>>} An array of k centroids represented as [r, g, b] arrays
     */
    initCentroidsTopNDistributed: function (pixels, k, numDistributedColors, colorSelectionPool, colorDistanceFunction) {
        const centroids = [];
        const pixelCount = pixels.length / 4;

        // Calculate the histogram of colors
        const colorHistogram = new Map();

        for (let i = 0; i < pixelCount; i++) {
            const color = [
                pixels[i * 4],
                pixels[i * 4 + 1],
                pixels[i * 4 + 2],
            ];
            const colorKey = color.join(',');

            if (colorHistogram.has(colorKey)) {
                colorHistogram.set(colorKey, colorHistogram.get(colorKey) + 1);
            } else {
                colorHistogram.set(colorKey, 1);
            }
        }

        // Select the top M colors with the highest pixel counts
        const topMColors = Array.from(colorHistogram.entries())
            .sort((a, b) => b[1] - a[1])
            .slice(0, colorSelectionPool)
            .map(([colorKey]) => colorKey.split(',').map(Number));

        // Calculate the distances between colors
        const colorDistances = new Map();

        for (let i = 0; i < topMColors.length; i++) {
            for (let j = i + 1; j < topMColors.length; j++) {
                const dist = colorDistanceFunction(topMColors[i], topMColors[j]);
                const key = `${i},${j}`;
                colorDistances.set(key, dist);
            }
        }

        // Select the first N centroids uniformly distributed based on color similarity
        const selectedCentroids = new Set();

        while (selectedCentroids.size < numDistributedColors) {
            let maxDist = -Infinity;
            let maxKey;

            for (const [key, dist] of colorDistances.entries()) {
                if (dist > maxDist) {
                    const [i, j] = key.split(',').map(Number);

                    if (selectedCentroids.has(i) || selectedCentroids.has(j)) {
                        continue;
                    }

                    maxDist = dist;
                    maxKey = key;
                }
            }

            if (maxKey) {
                const [i, j] = maxKey.split(',').map(Number);
                selectedCentroids.add(i);
                selectedCentroids.add(j);
                colorDistances.delete(maxKey);
            }
        }

        // Add the selected centroids
        for (const index of selectedCentroids) {
            centroids.push(topMColors[index]);
        }

        // Complete the centroids using the original method
        for (let i = centroids.length; i < k; i++) {
            const distances = [];

            for (let j = 0; j < pixelCount; j++) {
                const pixel = [
                    pixels[j * 4],
                    pixels[j * 4 + 1],
                    pixels[j * 4 + 2],
                ];
                const nearestCentroidDist = centroids.reduce((minDist, centroid) => {
                    const dist = colorDistanceFunction(centroid, pixel);
                    return Math.min(minDist, dist);
                }, Infinity);

                distances.push(nearestCentroidDist);
            }

            const totalDist = distances.reduce((sum, dist) => sum + dist, 0);
            const probabilities = distances.map((dist) => dist / totalDist);

            let randomPixelIndex = d3Scan(probabilities, (a, b) => b - a);
            centroids.push([
                pixels[randomPixelIndex * 4],
                pixels[randomPixelIndex * 4 + 1],
                pixels[randomPixelIndex * 4 + 2],
            ]);
        }

        return centroids;
    },
}