package de.lmu.ifi.dbs.elki.application.greedyensemble;

import de.lmu.ifi.dbs.elki.application.AbstractApplication;
import de.lmu.ifi.dbs.elki.data.NumberVector;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.ArrayModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DBID;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDMIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDRef;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.HashSetModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDoubleDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.correlation.WeightedPearsonCorrelationDistanceFunction;
import de.lmu.ifi.dbs.elki.evaluation.roc.ROC;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.math.MeanVariance;
import de.lmu.ifi.dbs.elki.math.geometry.XYCurve;
import de.lmu.ifi.dbs.elki.utilities.DatabaseUtil;
import de.lmu.ifi.dbs.elki.utilities.datastructures.heap.TiedTopBoundedHeap;
import de.lmu.ifi.dbs.elki.utilities.datastructures.heap.TopBoundedHeap;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.pairs.DoubleIntPair;
import de.lmu.ifi.dbs.elki.utilities.pairs.DoubleObjPair;
import de.lmu.ifi.dbs.elki.workflow.InputStep;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import java.util.TreeSet;

@Reference(authors = "E. Schubert, R. Wojdanowski, A. Zimek, H.-P. Kriegel", title = "On Evaluation of Outlier Rankings and Outlier Scores", booktitle = "Proc. 12th SIAM International Conference on Data Mining (SDM), Anaheim, CA, 2012.")
/* loaded from: input_file:de/lmu/ifi/dbs/elki/application/greedyensemble/GreedyEnsembleExperiment.class */
public class GreedyEnsembleExperiment extends AbstractApplication {
    private static final Logging logger;
    private InputStep inputstep;
    boolean refine_truth;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:de/lmu/ifi/dbs/elki/application/greedyensemble/GreedyEnsembleExperiment$Parameterizer.class */
    public static class Parameterizer extends AbstractApplication.Parameterizer {
        InputStep inputstep;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.application.AbstractApplication.Parameterizer, de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public void makeOptions(Parameterization parameterization) {
            super.makeOptions(parameterization);
            this.inputstep = (InputStep) parameterization.tryInstantiate(InputStep.class);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.application.AbstractApplication.Parameterizer, de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public AbstractApplication makeInstance() {
            return new GreedyEnsembleExperiment(this.verbose, this.inputstep);
        }
    }

    public GreedyEnsembleExperiment(boolean z, InputStep inputStep) {
        super(z);
        this.refine_truth = false;
        this.inputstep = inputStep;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // de.lmu.ifi.dbs.elki.application.AbstractApplication
    public void run() {
        Database database = this.inputstep.getDatabase();
        Relation relation = database.getRelation(TypeUtil.NUMBER_VECTOR_FIELD, new Object[0]);
        Relation<String> guessLabelRepresentation = DatabaseUtil.guessLabelRepresentation(database);
        DBIDRef dbid = guessLabelRepresentation.iterDBIDs().getDBID();
        if (!guessLabelRepresentation.get(dbid).matches("bylabel")) {
            throw new AbortException("No 'by label' reference outlier found, which is needed for weighting!");
        }
        int dimensionality = DatabaseUtil.dimensionality(relation);
        NumberVector numberVector = (NumberVector) relation.get(dbid);
        TreeSet treeSet = new TreeSet();
        for (int i = 0; i < dimensionality; i++) {
            if (numberVector.doubleValue(i + 1) > 0.0d) {
                treeSet.add(Integer.valueOf(i));
            }
        }
        int i2 = (int) (0.005d * dimensionality);
        int i3 = 0;
        int[] iArr = new int[dimensionality];
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            DBID dbid2 = iterDBIDs.getDBID();
            if (!dbid.sameDBID(dbid2)) {
                NumberVector numberVector2 = (NumberVector) relation.get(dbid2);
                TiedTopBoundedHeap tiedTopBoundedHeap = new TiedTopBoundedHeap(i2, Collections.reverseOrder());
                for (int i4 = 0; i4 < dimensionality; i4++) {
                    tiedTopBoundedHeap.add(new DoubleIntPair(numberVector2.doubleValue(i4 + 1), i4));
                }
                if (tiedTopBoundedHeap.size() >= 2 * i2) {
                    logger.warning("Too many ties. Expected: " + i2 + " got: " + tiedTopBoundedHeap.size());
                }
                Iterator it = tiedTopBoundedHeap.iterator();
                while (it.hasNext()) {
                    DoubleIntPair doubleIntPair = (DoubleIntPair) it.next();
                    if (iArr[doubleIntPair.second] == 0) {
                        iArr[doubleIntPair.second] = 1;
                        i3++;
                    } else {
                        int i5 = doubleIntPair.second;
                        iArr[i5] = iArr[i5] + 1;
                    }
                }
            }
            iterDBIDs.advance();
        }
        logger.verbose("Merged top " + i2 + " outliers to: " + i3 + " outliers");
        double[] dArr = new double[dimensionality];
        double[] dArr2 = new double[dimensionality];
        updateEstimations(iArr, i3, dArr, dArr2);
        NumberVector newNumberVector = numberVector.newNumberVector(dArr2);
        PrimitiveDoubleDistanceFunction<NumberVector<?, ?>> distanceFunction = getDistanceFunction(dArr);
        double[] dArr3 = new double[dimensionality];
        DBIDIter iterDBIDs2 = relation.iterDBIDs();
        while (iterDBIDs2.valid()) {
            DBID dbid3 = iterDBIDs2.getDBID();
            if (!dbid.equals(dbid3)) {
                NumberVector numberVector3 = (NumberVector) relation.get(dbid3);
                for (int i6 = 0; i6 < dimensionality; i6++) {
                    int i7 = i6;
                    dArr3[i7] = dArr3[i7] + numberVector3.doubleValue(i6 + 1);
                }
            }
            iterDBIDs2.advance();
        }
        for (int i8 = 0; i8 < dimensionality; i8++) {
            int i9 = i8;
            dArr3[i9] = dArr3[i9] / (relation.size() - 1);
        }
        NumberVector<?, ?> newNumberVector2 = numberVector.newNumberVector(dArr3);
        double d = 0.0d;
        String str = "";
        double d2 = Double.POSITIVE_INFINITY;
        String str2 = "";
        DBID dbid4 = null;
        double d3 = Double.POSITIVE_INFINITY;
        DBIDIter iterDBIDs3 = relation.iterDBIDs();
        while (iterDBIDs3.valid()) {
            DBID dbid5 = iterDBIDs3.getDBID();
            if (!dbid.equals(dbid5)) {
                NumberVector<?, ?> numberVector4 = (NumberVector) relation.get(dbid5);
                double computeROCAUC = computeROCAUC(numberVector4, treeSet, dimensionality);
                double doubleDistance = distanceFunction.doubleDistance(numberVector4, newNumberVector);
                double doubleDistance2 = distanceFunction.doubleDistance(numberVector4, numberVector);
                logger.verbose("ROC AUC: " + computeROCAUC + " estimated " + doubleDistance + " cost " + doubleDistance2 + " " + guessLabelRepresentation.get(dbid5));
                if (computeROCAUC > d) {
                    d = computeROCAUC;
                    str = guessLabelRepresentation.get(dbid5);
                }
                if (doubleDistance2 < d2) {
                    d2 = doubleDistance2;
                    str2 = guessLabelRepresentation.get(dbid5);
                }
                if (doubleDistance < d3) {
                    d3 = doubleDistance;
                    dbid4 = dbid5;
                }
            }
            iterDBIDs3.advance();
        }
        logger.verbose("Distance function: " + distanceFunction);
        logger.verbose("Initial estimation of outliers: " + i3);
        logger.verbose("Initializing ensemble with: " + guessLabelRepresentation.get(dbid4));
        ArrayModifiableDBIDs newArray = DBIDUtil.newArray(dbid4);
        HashSetModifiableDBIDs newHashSet = DBIDUtil.newHashSet(relation.getDBIDs());
        newHashSet.remove(dbid4);
        newHashSet.remove(dbid);
        double[] dArr4 = new double[dimensionality];
        NumberVector numberVector5 = (NumberVector) relation.get(dbid4);
        for (int i10 = 0; i10 < dimensionality; i10++) {
            dArr4[i10] = numberVector5.doubleValue(i10 + 1);
        }
        double[] dArr5 = new double[dimensionality];
        while (newHashSet.size() > 0) {
            NumberVector newNumberVector3 = numberVector.newNumberVector(dArr4);
            double size = newArray.size() / (newArray.size() + 1.0d);
            double size2 = 1.0d / (newArray.size() + 1.0d);
            TopBoundedHeap topBoundedHeap = new TopBoundedHeap(newHashSet.size(), Collections.reverseOrder());
            DBIDMIter iter = newHashSet.iter();
            while (iter.valid()) {
                DBID dbid6 = iter.getDBID();
                topBoundedHeap.add(new DoubleObjPair(distanceFunction.doubleDistance((NumberVector) relation.get(dbid6), newNumberVector3), dbid6));
                iter.advance();
            }
            while (true) {
                if (topBoundedHeap.size() > 0) {
                    DBID dbid7 = (DBID) ((DoubleObjPair) topBoundedHeap.poll()).second;
                    newHashSet.remove(dbid7);
                    NumberVector numberVector6 = (NumberVector) relation.get(dbid7);
                    for (int i11 = 0; i11 < dimensionality; i11++) {
                        dArr5[i11] = (dArr4[i11] * size) + (numberVector6.doubleValue(i11 + 1) * size2);
                    }
                    if (distanceFunction.doubleDistance(newNumberVector, numberVector.newNumberVector(dArr5)) < distanceFunction.doubleDistance(newNumberVector, newNumberVector3)) {
                        System.arraycopy(dArr5, 0, dArr4, 0, dimensionality);
                        newArray.add(dbid7);
                        break;
                    }
                    if (this.refine_truth) {
                        boolean z = false;
                        TiedTopBoundedHeap tiedTopBoundedHeap2 = new TiedTopBoundedHeap(i2, Collections.reverseOrder());
                        for (int i12 = 0; i12 < dimensionality; i12++) {
                            tiedTopBoundedHeap2.add(new DoubleIntPair(numberVector6.doubleValue(i12 + 1), i12));
                        }
                        Iterator it2 = tiedTopBoundedHeap2.iterator();
                        while (it2.hasNext()) {
                            DoubleIntPair doubleIntPair2 = (DoubleIntPair) it2.next();
                            if (!$assertionsDisabled && iArr[doubleIntPair2.second] <= 0) {
                                throw new AssertionError();
                            }
                            int i13 = doubleIntPair2.second;
                            iArr[i13] = iArr[i13] - 1;
                            if (iArr[doubleIntPair2.second] == 0) {
                                i3--;
                                z = true;
                            }
                        }
                        if (z) {
                            updateEstimations(iArr, i3, dArr, dArr2);
                            newNumberVector = numberVector.newNumberVector(dArr2);
                        }
                    }
                }
            }
        }
        StringBuffer stringBuffer = new StringBuffer();
        DBIDMIter iter2 = newArray.iter();
        while (iter2.valid()) {
            if (stringBuffer.length() > 0) {
                stringBuffer.append(" ");
            }
            stringBuffer.append(guessLabelRepresentation.get(iter2));
            iter2.advance();
        }
        NumberVector<?, ?> newNumberVector4 = numberVector.newNumberVector(dArr4);
        logger.verbose("Estimated outliers remaining: " + i3);
        logger.verbose("Greedy ensemble: " + stringBuffer.toString());
        logger.verbose("Best single ROC AUC: " + d + " (" + str + ")");
        logger.verbose("Best single cost:    " + d2 + " (" + str2 + ")");
        double computeROCAUC2 = computeROCAUC(newNumberVector2, treeSet, dimensionality);
        double doubleDistance3 = distanceFunction.doubleDistance(newNumberVector2, numberVector);
        logger.verbose("Naive ensemble AUC:   " + computeROCAUC2 + " cost: " + doubleDistance3);
        logger.verbose("Naive ensemble Gain:  " + gain(computeROCAUC2, d, 1.0d) + " cost gain: " + gain(doubleDistance3, d2, 0.0d));
        double computeROCAUC3 = computeROCAUC(newNumberVector4, treeSet, dimensionality);
        double doubleDistance4 = distanceFunction.doubleDistance(newNumberVector4, numberVector);
        logger.verbose("Greedy ensemble AUC:  " + computeROCAUC3 + " cost: " + doubleDistance4);
        logger.verbose("Greedy ensemble Gain to best:  " + gain(computeROCAUC3, d, 1.0d) + " cost gain: " + gain(doubleDistance4, d2, 0.0d));
        logger.verbose("Greedy ensemble Gain to naive: " + gain(computeROCAUC3, computeROCAUC2, 1.0d) + " cost gain: " + gain(doubleDistance4, doubleDistance3, 0.0d));
        MeanVariance meanVariance = new MeanVariance();
        MeanVariance meanVariance2 = new MeanVariance();
        HashSetModifiableDBIDs newHashSet2 = DBIDUtil.newHashSet(relation.getDBIDs());
        newHashSet2.remove(dbid);
        for (int i14 = 0; i14 < 5000; i14++) {
            double[] dArr6 = new double[dimensionality];
            DBIDIter iter3 = DBIDUtil.randomSample(newHashSet2, newArray.size(), Long.valueOf(i14)).iter();
            while (iter3.valid()) {
                DBID dbid8 = iter3.getDBID();
                if (!$assertionsDisabled && dbid.equals(dbid8)) {
                    throw new AssertionError();
                }
                NumberVector numberVector7 = (NumberVector) relation.get(dbid8);
                for (int i15 = 0; i15 < dimensionality; i15++) {
                    int i16 = i15;
                    dArr6[i16] = dArr6[i16] + numberVector7.doubleValue(i15 + 1);
                }
                iter3.advance();
            }
            for (int i17 = 0; i17 < dimensionality; i17++) {
                int i18 = i17;
                dArr6[i18] = dArr6[i18] / newArray.size();
            }
            NumberVector<?, ?> newNumberVector5 = numberVector.newNumberVector(dArr6);
            meanVariance.put(computeROCAUC(newNumberVector5, treeSet, dimensionality));
            meanVariance2.put(distanceFunction.doubleDistance(newNumberVector5, numberVector));
        }
        logger.verbose("Random ensemble AUC:  " + meanVariance.getMean() + " + stddev: " + meanVariance.getSampleStddev() + " = " + (meanVariance.getMean() + meanVariance.getSampleStddev()));
        logger.verbose("Random ensemble Gain: " + gain(meanVariance.getMean(), d, 1.0d));
        logger.verbose("Greedy improvement:   " + ((computeROCAUC3 - meanVariance.getMean()) / meanVariance.getSampleStddev()) + " standard deviations.");
        logger.verbose("Random ensemble Cost: " + meanVariance2.getMean() + " + stddev: " + meanVariance2.getSampleStddev() + " = " + (meanVariance2.getMean() + meanVariance.getSampleStddev()));
        logger.verbose("Random ensemble Gain: " + gain(meanVariance2.getMean(), d2, 0.0d));
        logger.verbose("Greedy improvement:   " + ((meanVariance2.getMean() - doubleDistance4) / meanVariance2.getSampleStddev()) + " standard deviations.");
        logger.verbose("Naive ensemble Gain to random: " + gain(computeROCAUC2, meanVariance.getMean(), 1.0d) + " cost gain: " + gain(doubleDistance3, meanVariance2.getMean(), 0.0d));
        logger.verbose("Random ensemble Gain to naive: " + gain(meanVariance.getMean(), computeROCAUC2, 1.0d) + " cost gain: " + gain(meanVariance2.getMean(), doubleDistance3, 0.0d));
        logger.verbose("Greedy ensemble Gain to random: " + gain(computeROCAUC3, meanVariance.getMean(), 1.0d) + " cost gain: " + gain(doubleDistance4, meanVariance2.getMean(), 0.0d));
    }

    protected void updateEstimations(int[] iArr, int i, double[] dArr, double[] dArr2) {
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] > 0) {
                dArr[i2] = 0.5d / i;
                dArr2[i2] = 1.0d;
            } else {
                dArr[i2] = 0.5d / (iArr.length - i);
                dArr2[i2] = 0.0d;
            }
        }
    }

    private PrimitiveDoubleDistanceFunction<NumberVector<?, ?>> getDistanceFunction(double[] dArr) {
        return new WeightedPearsonCorrelationDistanceFunction(dArr);
    }

    private double computeROCAUC(NumberVector<?, ?> numberVector, Set<Integer> set, int i) {
        DoubleIntPair[] doubleIntPairArr = new DoubleIntPair[i];
        for (int i2 = 0; i2 < i; i2++) {
            doubleIntPairArr[i2] = new DoubleIntPair(numberVector.doubleValue(i2 + 1), i2);
        }
        Arrays.sort(doubleIntPairArr, Collections.reverseOrder(DoubleIntPair.BYFIRST_COMPARATOR));
        return XYCurve.areaUnderCurve(ROC.materializeROC(i, set, Arrays.asList(doubleIntPairArr).iterator()));
    }

    double gain(double d, double d2, double d3) {
        return 1.0d - ((d3 - d) / (d3 - d2));
    }

    public static void main(String[] strArr) {
        runCLIApplication(GreedyEnsembleExperiment.class, strArr);
    }

    static {
        $assertionsDisabled = !GreedyEnsembleExperiment.class.desiredAssertionStatus();
        logger = Logging.getLogger((Class<?>) GreedyEnsembleExperiment.class);
    }
}
