package ir.classifiers;

import ir.utilities.Weight;
import ir.vsr.InvertedIndex;
import java.util.Hashtable;
import java.util.Vector;

/* loaded from: input_file:ir/classifiers/NaiveBayes.class */
public class NaiveBayes extends Classifier {
    protected Vector Categories;
    protected boolean isLaplace = true;
    protected double EPSILON = 1.0E-6d;
    protected BayesResult trainResult;
    public static final String name = "NaiveBayes";
    public int numCategories;
    public int numFeatures;
    public int numExamples;
    public boolean debug;

    @Override // ir.classifiers.Classifier
    public boolean usesInvertedIndex() {
        return false;
    }

    @Override // ir.classifiers.Classifier
    public void setInvertedIndex(InvertedIndex invertedIndex) {
    }

    public NaiveBayes(String[] strArr, boolean z) {
        this.debug = false;
        setCategories(strArr);
        this.debug = z;
    }

    public void setCategories(String[] strArr) {
        this.Categories = new Vector(strArr.length);
        for (String str : strArr) {
            this.Categories.add(str);
        }
        if (this.debug) {
            System.out.println(new StringBuffer().append("Categories: ").append(this.Categories).toString());
        }
        this.numCategories = strArr.length;
    }

    public void setDebug(boolean z) {
        this.debug = z;
    }

    public void setLaplace(boolean z) {
        this.isLaplace = z;
    }

    public void setEpsilon(double d) {
        this.EPSILON = d;
    }

    @Override // ir.classifiers.Classifier
    public String getName() {
        return name;
    }

    public double getEpsilon() {
        return this.EPSILON;
    }

    public BayesResult getTrainResult() {
        return this.trainResult;
    }

    public boolean getIsLaplace() {
        return this.isLaplace;
    }

    @Override // ir.classifiers.Classifier
    public void train(Vector vector) {
        this.trainResult = new BayesResult();
        this.numExamples = vector.size();
        this.trainResult.setClassPriors(calculatePriors(vector));
        this.trainResult.setFeatureTable(conditionalProbs(vector));
        if (this.debug) {
            displayProbs(this.trainResult.getClassPriors(), this.trainResult.getFeatureTable());
        }
    }

    @Override // ir.classifiers.Classifier
    public boolean test(Example example) {
        double[] calculateProbs = calculateProbs(example);
        int argMax = argMax(calculateProbs);
        if (this.debug) {
            System.out.print(new StringBuffer().append("Document: ").append(example.name).append("\nResults: ").toString());
            for (int i = 0; i < this.numCategories; i++) {
                System.out.print(new StringBuffer().append(this.Categories.get(i)).append("(").append(calculateProbs[i]).append(")\t").toString());
            }
            System.out.println(new StringBuffer().append("\nCorrect class: ").append(example.getCategory()).append(", Predicted class: ").append(argMax).append("\n").toString());
        }
        return argMax == example.getCategory();
    }

    protected double[] calculatePriors(Vector vector) {
        double[] dArr = new double[this.numCategories];
        for (int i = 0; i < this.numCategories; i++) {
            dArr[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.numExamples; i2++) {
            int category = ((Example) vector.get(i2)).getCategory();
            dArr[category] = dArr[category] + 1.0d;
        }
        for (int i3 = 0; i3 < this.numCategories; i3++) {
            if (this.isLaplace) {
                dArr[i3] = Math.log((dArr[i3] + 1.0d) / (this.numExamples + this.numCategories));
            } else {
                dArr[i3] = Math.log(dArr[i3] / this.numExamples);
            }
        }
        if (this.debug) {
            System.out.println("\nLog Class Priors:");
            for (int i4 = 0; i4 < this.numCategories; i4++) {
                System.out.print(new StringBuffer().append(dArr[i4]).append(" ").toString());
            }
            System.out.println();
        }
        return dArr;
    }

    protected Hashtable conditionalProbs(Vector vector) {
        double[] dArr;
        Hashtable hashtable = new Hashtable();
        double[] dArr2 = new double[this.numCategories];
        for (int i = 0; i < this.numCategories; i++) {
            dArr2[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.numExamples; i2++) {
            Example example = (Example) vector.get(i2);
            if (this.debug) {
                System.out.println(new StringBuffer().append("\nExample ").append(i2).append(": ").append(example).toString());
                System.out.println(new StringBuffer().append("Number of tokens: ").append(example.getHashMapVector().hashMap.size()).toString());
            }
            for (String str : example.getHashMapVector().hashMap.keySet()) {
                if (this.debug) {
                    System.out.println(new StringBuffer().append("Counts of token: ").append(str).toString());
                }
                if (hashtable.containsKey(str)) {
                    dArr = (double[]) hashtable.get(str);
                } else {
                    dArr = new double[this.numCategories];
                    for (int i3 = 0; i3 < this.numCategories; i3++) {
                        dArr[i3] = 0.0d;
                    }
                    hashtable.put(str, dArr);
                }
                double[] dArr3 = dArr;
                int category = example.getCategory();
                dArr3[category] = dArr3[category] + ((Weight) example.getHashMapVector().hashMap.get(str)).getValue();
                int category2 = example.getCategory();
                dArr2[category2] = dArr2[category2] + ((Weight) example.getHashMapVector().hashMap.get(str)).getValue();
                if (this.debug) {
                    for (double d : dArr) {
                        System.out.print(new StringBuffer().append(d).append(" ").toString());
                    }
                    System.out.println();
                }
            }
        }
        this.numFeatures = hashtable.size();
        if (this.debug) {
            System.out.println("\nLog Probs before multiplying priors...\n");
        }
        for (String str2 : hashtable.keySet()) {
            double[] dArr4 = (double[]) hashtable.get(str2);
            for (int i4 = 0; i4 < this.numCategories; i4++) {
                if (this.isLaplace) {
                    dArr4[i4] = (dArr4[i4] + 1.0d) / (dArr2[i4] + this.numFeatures);
                } else if (dArr4[i4] == 0.0d) {
                    dArr4[i4] = this.EPSILON;
                } else {
                    dArr4[i4] = dArr4[i4] / dArr2[i4];
                }
                dArr4[i4] = Math.log(dArr4[i4]);
            }
            if (this.debug) {
                System.out.println(new StringBuffer().append("Log probs of ").append(str2).toString());
                for (double d2 : dArr4) {
                    System.out.print(new StringBuffer().append(d2).append(" ").toString());
                }
                System.out.println();
            }
        }
        return hashtable;
    }

    protected double[] calculateProbs(Example example) {
        double[] dArr = (double[]) this.trainResult.getClassPriors().clone();
        Hashtable featureTable = this.trainResult.getFeatureTable();
        for (String str : example.getHashMapVector().hashMap.keySet()) {
            if (featureTable.containsKey(str)) {
                double[] dArr2 = (double[]) featureTable.get(str);
                for (int i = 0; i < this.numCategories; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + dArr2[i];
                }
            }
        }
        return dArr;
    }

    protected void displayProbs(double[] dArr, Hashtable hashtable) {
        System.out.println("\nAfter multiplying priors...");
        for (String str : hashtable.keySet()) {
            System.out.print(new StringBuffer().append("\nFeature: ").append(str).append(", Probs: ").toString());
            double[] dArr2 = (double[]) hashtable.get(str);
            for (int i = 0; i < dArr2.length; i++) {
                System.out.print(new StringBuffer().append(" ").append(Math.pow(2.718281828459045d, dArr[i] + dArr2[i])).toString());
            }
        }
        System.out.println();
    }
}
