package de.fau.cs.jstk.trans;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;
import de.fau.cs.jstk.io.FrameInputStream;
import de.fau.cs.jstk.io.SampleInputStream;
import de.fau.cs.jstk.stat.Sample;
import de.fau.cs.jstk.trans.Projection;
import de.fau.cs.jstk.util.Arithmetics;
import de.fau.cs.jstk.util.Pair;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.log4j.BasicConfigurator;

/* loaded from: input_file:de/fau/cs/jstk/trans/LDA.class */
public class LDA extends Projection {
    private Projection.Accumulator global;
    private HashMap<Short, Projection.Accumulator> stats;
    private double[][] Swi;
    private double[] evals;
    public static final String SYNOPSIS = "sikoried, 2/2/2011\nCompute LDA using (regularized) pseudo-inverse (SVD) and save the resulting\ntransformation y = A * (x-m) to the given projection file.\nusage: transformations.LDA proj list1 [list2 ...] indir\n  proj  : output file for projection (Frame format)\n  list  : file list(s); in case of single list expecting binary sample format instead of frame.\n  indir : directory where the input files are located (use . for current dir)\n";

    public LDA(int i) {
        super(i);
        this.global = null;
        this.stats = new HashMap<>();
        this.Swi = null;
        this.evals = null;
        this.global = new Projection.Accumulator(i);
    }

    public void accumulate(List<Sample> list) {
        for (Sample sample : list) {
            accumulate(sample.c, sample.x);
        }
    }

    public void accumulate(Sample sample) {
        accumulate(sample.c, sample.x);
    }

    public void accumulate(short s, double[] dArr) {
        this.global.accumulate(dArr);
        if (!this.stats.containsKey(Short.valueOf(s))) {
            this.stats.put(Short.valueOf(s), new Projection.Accumulator(dArr.length));
        }
        this.stats.get(Short.valueOf(s)).accumulate(dArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v28, types: [double[], double[][]] */
    public void estimate(HashMap<Short, Double> hashMap) {
        if (hashMap == null) {
            hashMap = new HashMap<>();
            Iterator<Map.Entry<Short, Projection.Accumulator>> it = this.stats.entrySet().iterator();
            while (it.hasNext()) {
                hashMap.put(it.next().getKey(), Double.valueOf(r0.getValue().getCount() / this.global.getCount()));
            }
        }
        this.fd = this.global.getFd();
        double[] mean = this.global.getMean();
        double[] dArr = new double[(this.fd * (this.fd + 1)) / 2];
        double[] dArr2 = new double[(this.fd * (this.fd + 1)) / 2];
        for (Map.Entry<Short, Projection.Accumulator> entry : this.stats.entrySet()) {
            double doubleValue = hashMap.get(entry.getKey()).doubleValue();
            double[] mean2 = entry.getValue().getMean();
            double[] covariance = entry.getValue().getCovariance();
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + (doubleValue * covariance[i]);
            }
            int i3 = 0;
            for (int i4 = 0; i4 < mean2.length; i4++) {
                for (int i5 = 0; i5 <= i4; i5++) {
                    int i6 = i3;
                    i3++;
                    dArr2[i6] = dArr2[i6] + (doubleValue * (mean2[i4] - mean[i4]) * (mean2[i5] - mean[i5]));
                }
            }
        }
        Matrix matrix = new Matrix(this.fd, this.fd);
        Matrix matrix2 = new Matrix(this.fd, this.fd);
        int i7 = 0;
        for (int i8 = 0; i8 < this.fd; i8++) {
            for (int i9 = 0; i9 <= i8; i9++) {
                matrix.set(i8, i9, dArr[i7]);
                matrix.set(i9, i8, dArr[i7]);
                matrix2.set(i8, i9, dArr2[i7]);
                matrix2.set(i9, i8, dArr2[i7]);
                i7++;
            }
        }
        Matrix matrix3 = new Matrix(Arithmetics.pinv(matrix.getArray(), 1.0E-12d));
        this.Swi = matrix3.getArray();
        EigenvalueDecomposition eigenvalueDecomposition = new EigenvalueDecomposition(matrix3.times(matrix2));
        double[][] array = eigenvalueDecomposition.getV().transpose().getArray();
        LinkedList linkedList = new LinkedList();
        for (int i10 = 0; i10 < this.fd; i10++) {
            linkedList.add(new Pair(array[i10], Double.valueOf(eigenvalueDecomposition.getD().get(i10, i10))));
        }
        Collections.sort(linkedList, new Comparator<Pair<double[], Double>>() { // from class: de.fau.cs.jstk.trans.LDA.1
            @Override // java.util.Comparator
            public int compare(Pair<double[], Double> pair, Pair<double[], Double> pair2) {
                return (int) Math.signum(pair2.b.doubleValue() - pair.b.doubleValue());
            }
        });
        this.mean = this.global.getMean();
        int size = this.stats.size() - 1;
        this.proj = new double[size];
        this.evals = new double[size];
        Iterator it2 = linkedList.iterator();
        for (int i11 = 0; i11 < size; i11++) {
            Pair pair = (Pair) it2.next();
            this.proj[i11] = (double[]) pair.a;
            this.evals[i11] = ((Double) pair.b).doubleValue();
        }
    }

    public double[] getEigenvalues() {
        return this.evals;
    }

    @Override // de.fau.cs.jstk.trans.Projection
    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Projection = \n");
        stringBuffer.append(super.toString());
        stringBuffer.append("LDA = \n");
        stringBuffer.append("Swi = \n");
        for (double[] dArr : this.Swi) {
            stringBuffer.append(String.valueOf(Arrays.toString(dArr)) + "\n");
        }
        stringBuffer.append("evals = " + Arrays.toString(this.evals));
        return stringBuffer.toString();
    }

    public static void main(String[] strArr) throws IOException {
        BasicConfigurator.configure();
        if (strArr.length < 2) {
            System.err.println(SYNOPSIS);
            System.exit(1);
        }
        String str = strArr[0];
        String str2 = String.valueOf(strArr[strArr.length - 1]) + System.getProperty("file.separator");
        String[] strArr2 = new String[strArr.length - 2];
        System.arraycopy(strArr, 1, strArr2, 0, strArr2.length);
        LDA lda = null;
        if (strArr2.length == 1) {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(strArr2[0]));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                SampleInputStream sampleInputStream = new SampleInputStream(new FileInputStream(String.valueOf(str2) + readLine));
                while (true) {
                    Sample read = sampleInputStream.read();
                    if (read == null) {
                        break;
                    }
                    if (lda == null) {
                        lda = new LDA(read.x.length);
                    }
                    lda.accumulate(read);
                }
            }
            bufferedReader.close();
        } else {
            for (int i = 0; i < strArr2.length; i++) {
                BufferedReader bufferedReader2 = new BufferedReader(new FileReader(strArr2[i]));
                while (true) {
                    String readLine2 = bufferedReader2.readLine();
                    if (readLine2 == null) {
                        break;
                    }
                    FrameInputStream frameInputStream = new FrameInputStream(new File(String.valueOf(str2) + readLine2));
                    double[] dArr = new double[frameInputStream.getFrameSize()];
                    if (lda == null) {
                        lda = new LDA(dArr.length);
                    }
                    while (frameInputStream.read(dArr)) {
                        lda.accumulate((short) i, dArr);
                    }
                }
                bufferedReader2.close();
            }
        }
        lda.estimate(null);
        lda.save(new File(str));
    }
}
