package ir.classifiers;

import ir.vsr.DocumentIterator;
import ir.vsr.FileDocument;
import ir.vsr.InvertedIndex;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Random;
import java.util.Vector;

/* loaded from: input_file:ir/classifiers/CVLearningCurve.class */
public class CVLearningCurve {
    protected Vector[] totalExamples;
    protected Vector[][] foldBins;
    protected Classifier classifier;
    protected int numClasses;
    protected int numFolds;
    protected boolean debug;
    public double trainTime;
    public double testTime;
    public static final String[] CLASSES = {"bio", "chem", "phys"};
    public static final double[] POINTS = {0.1d, 0.2d, 0.3d, 0.4d, 0.5d, 0.6d, 0.7d, 0.8d, 0.9d, 1.0d};
    private static InvertedIndex invertedIndex = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    public CVLearningCurve(int i, Classifier classifier, int i2, String str, boolean z) {
        this.debug = false;
        this.numFolds = i;
        this.classifier = classifier;
        this.numClasses = i2;
        this.totalExamples = new Vector[this.numClasses];
        this.foldBins = new Vector[this.numClasses][this.numFolds];
        System.out.println(new StringBuffer().append("\nReading in documents from directory: ").append(str).toString());
        setTotalExamples(str);
        this.debug = z;
        this.testTime = 0.0d;
        this.trainTime = 0.0d;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public void setClassifier(Classifier classifier) {
        this.classifier = classifier;
    }

    public Vector[] getTotalExamples() {
        return this.totalExamples;
    }

    public void setTotalExamples(Vector[] vectorArr) {
        this.totalExamples = vectorArr;
    }

    public Vector[][] getFoldBins() {
        return this.foldBins;
    }

    public void setFoldBins(Vector[][] vectorArr) {
        this.foldBins = vectorArr;
    }

    public void setTotalExamples(String str) {
        DocumentIterator documentIterator = new DocumentIterator(new File(str), (short) 1, false);
        while (documentIterator.hasMoreDocuments()) {
            FileDocument nextDocument = documentIterator.nextDocument();
            int findClassID = findClassID(nextDocument.file.getName());
            Example example = new Example(nextDocument.hashMapVector(), findClassID, nextDocument.file.getName(), nextDocument);
            if (this.totalExamples[findClassID] == null) {
                this.totalExamples[findClassID] = new Vector();
            }
            this.totalExamples[findClassID].add(example);
        }
    }

    public Vector getCVPredictions() {
        Vector vector = new Vector();
        randomizeOrder();
        binExamples();
        for (int i = 0; i < POINTS.length; i++) {
            double d = POINTS[i];
            System.out.println(new StringBuffer().append("\nPercentage: ").append(d).toString());
            Vector vector2 = new Vector();
            for (int i2 = 0; i2 < this.numFolds; i2++) {
                System.out.println(new StringBuffer().append("\nCalculating results for fold: ").append(i2).toString());
                Vector trainCV = getTrainCV(i2, d);
                Vector testCV = getTestCV(i2);
                Vector testPrediction = getTestPrediction(trainCV, testCV);
                if (this.debug) {
                    System.out.println(new StringBuffer().append("Training on:\n").append(trainCV).toString());
                    System.out.println(new StringBuffer().append("Testing on:\n").append(testCV).toString());
                }
                vector2.add(testPrediction);
            }
            vector.add(vector2);
        }
        System.out.println("\nFinished calculating all prediction vectors...");
        return vector;
    }

    public Vector getTestPrediction(Vector vector, Vector vector2) {
        int i = 0;
        Vector vector3 = new Vector();
        long currentTimeMillis = System.currentTimeMillis();
        if (this.classifier.usesInvertedIndex()) {
            if (invertedIndex == null) {
                invertedIndex = new InvertedIndex(vector, (short) 1, false, false);
            } else {
                invertedIndex.docRefs.clear();
                invertedIndex.tokenHash.clear();
                invertedIndex.indexDocuments(vector);
            }
            this.classifier.setInvertedIndex(invertedIndex);
        }
        this.classifier.train(vector);
        this.trainTime += System.currentTimeMillis() - currentTimeMillis;
        long currentTimeMillis2 = System.currentTimeMillis();
        for (int i2 = 0; i2 < vector2.size(); i2++) {
            if (this.classifier.test((Example) vector2.get(i2))) {
                i++;
            }
        }
        this.testTime += System.currentTimeMillis() - currentTimeMillis2;
        vector3.add(new Integer(vector.size()));
        vector3.add(new Double((1.0d * i) / vector2.size()));
        return vector3;
    }

    public void binExamples() {
        for (int i = 0; i < this.numClasses; i++) {
            for (int i2 = 0; i2 < this.totalExamples[i].size(); i2++) {
                int i3 = this.numFolds > 1 ? i2 % this.numFolds : 0;
                if (this.foldBins[i][i3] == null) {
                    this.foldBins[i][i3] = new Vector();
                }
                this.foldBins[i][i3].add(this.totalExamples[i].get(i2));
            }
        }
    }

    public Vector getTrainCV(int i, double d) {
        Vector vector = new Vector();
        for (int i2 = 0; i2 < this.numClasses; i2++) {
            for (int i3 = 0; i3 < this.numFolds; i3++) {
                if (i3 != i) {
                    int size = (int) (d * this.foldBins[i2][i3].size());
                    for (int i4 = 0; i4 < size; i4++) {
                        vector.add(this.foldBins[i2][i3].get(i4));
                    }
                }
            }
        }
        return vector;
    }

    public Vector getTestCV(int i) {
        Vector vector = new Vector();
        for (int i2 = 0; i2 < this.numClasses; i2++) {
            vector.addAll(this.foldBins[i2][i]);
        }
        return vector;
    }

    public static int findClassID(String str) {
        int i = -1;
        if (str.indexOf(CLASSES[0]) != -1) {
            i = 0;
        }
        if (str.indexOf(CLASSES[1]) != -1) {
            i = 1;
        } else if (str.indexOf(CLASSES[2]) != -1) {
            i = 2;
        }
        return i;
    }

    private final void randomizeOrder() {
        Random random = new Random();
        for (int i = 0; i < this.numClasses; i++) {
            int size = this.totalExamples[i].size();
            for (int i2 = size - 1; i2 > 0; i2--) {
                int nextInt = random.nextInt(size);
                Example example = (Example) this.totalExamples[i].get(i2);
                this.totalExamples[i].set(i2, this.totalExamples[i].get(nextInt));
                this.totalExamples[i].set(nextInt, example);
            }
        }
    }

    void writeCurve(Vector vector, String str) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(new StringBuffer().append(str).append(".data").toString()));
        for (int i = 0; i < vector.size(); i++) {
            double d = 0.0d;
            Vector vector2 = (Vector) vector.get(i);
            int intValue = ((Integer) ((Vector) vector2.get(0)).get(0)).intValue();
            for (int i2 = 0; i2 < vector2.size(); i2++) {
                d += ((Double) ((Vector) vector2.get(i2)).get(1)).doubleValue();
            }
            printWriter.println(new StringBuffer().append(intValue).append("\t").append(d / vector2.size()).toString());
        }
        printWriter.close();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void makeGnuplotFile(Vector vector, String str) throws IOException {
        writeCurve(vector, str);
        PrintWriter printWriter = new PrintWriter(new FileWriter(new File(new StringBuffer().append(str).append(".gplot").toString())));
        printWriter.print(new StringBuffer().append("set xlabel \"Size of training set\"\nset ylabel \"Accuracy\"\n\nset terminal postscript color\nset size 0.75,0.75\n\nset data style linespoints\n\nplot '").append(str).append(".data' title \"").append(str).append(": 10-fold CV Learning Curve\"").toString());
        printWriter.close();
    }
}
