package de.fau.cs.jstk.stat;

import de.fau.cs.jstk.io.ChunkedDataSet;
import de.fau.cs.jstk.io.FrameInputStream;
import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.log4j.Logger;

/* loaded from: input_file:de/fau/cs/jstk/stat/ParallelML.class */
public class ParallelML {
    private static Logger logger = Logger.getLogger(ParallelML.class);
    private int numThreads;
    private int fd;
    private ChunkedDataSet data;
    private boolean diagonalCovariance;
    private Density estimate = null;
    private long samples = 0;

    /* loaded from: input_file:de/fau/cs/jstk/stat/ParallelML$Estimator.class */
    private class Estimator implements Runnable {
        CountDownLatch latch;
        double[] buf;
        double[] mue;
        double[] est;
        long cnt_chunk;
        long cnt_frame;
        LinkedList<double[]> partial;

        Estimator(CountDownLatch countDownLatch) {
            this.buf = new double[ParallelML.this.fd];
            this.mue = null;
            this.est = null;
            this.cnt_chunk = 0L;
            this.cnt_frame = 0L;
            this.partial = new LinkedList<>();
            this.latch = countDownLatch;
            this.est = new double[ParallelML.this.fd];
        }

        Estimator(CountDownLatch countDownLatch, double[] dArr) {
            this.buf = new double[ParallelML.this.fd];
            this.mue = null;
            this.est = null;
            this.cnt_chunk = 0L;
            this.cnt_frame = 0L;
            this.partial = new LinkedList<>();
            this.latch = countDownLatch;
            this.mue = dArr;
            this.est = ParallelML.this.diagonalCovariance ? new double[ParallelML.this.fd] : new double[(ParallelML.this.fd * (ParallelML.this.fd + 1)) / 2];
        }

        @Override // java.lang.Runnable
        public void run() {
            try {
                while (true) {
                    ChunkedDataSet.Chunk nextChunk = ParallelML.this.data.nextChunk();
                    if (nextChunk == null) {
                        break;
                    }
                    FrameInputStream frameReader = nextChunk.getFrameReader();
                    for (int i = 0; i < this.est.length; i++) {
                        this.est[i] = 0.0d;
                    }
                    while (frameReader.read(this.buf)) {
                        if (this.mue == null) {
                            for (int i2 = 0; i2 < ParallelML.this.fd; i2++) {
                                double[] dArr = this.est;
                                int i3 = i2;
                                dArr[i3] = dArr[i3] + this.buf[i2];
                            }
                        } else {
                            for (int i4 = 0; i4 < ParallelML.this.fd; i4++) {
                                double[] dArr2 = this.buf;
                                int i5 = i4;
                                dArr2[i5] = dArr2[i5] - this.mue[i4];
                            }
                            if (ParallelML.this.diagonalCovariance) {
                                for (int i6 = 0; i6 < ParallelML.this.fd; i6++) {
                                    double[] dArr3 = this.est;
                                    int i7 = i6;
                                    dArr3[i7] = dArr3[i7] + ((this.buf[i6] - this.mue[i6]) * (this.buf[i6] - this.mue[i6]));
                                }
                            } else {
                                int i8 = 0;
                                for (int i9 = 0; i9 < ParallelML.this.fd; i9++) {
                                    for (int i10 = 0; i10 <= i9; i10++) {
                                        double[] dArr4 = this.est;
                                        int i11 = i8;
                                        i8++;
                                        dArr4[i11] = dArr4[i11] + (this.buf[i9] * this.buf[i10]);
                                    }
                                }
                            }
                        }
                        this.cnt_frame++;
                    }
                    this.partial.add((double[]) this.est.clone());
                    this.cnt_chunk++;
                }
                if (this.mue == null) {
                    ParallelML.this.processedSamples(this.cnt_frame);
                }
                ParallelML.logger.info("ParallelML.Estimator#" + Thread.currentThread().getId() + ".run(): processed " + this.cnt_frame + " in " + this.cnt_chunk + " chunks");
            } catch (IOException e) {
                ParallelML.logger.info("Exception in Estimator Thread #" + Thread.currentThread().getId() + ": " + e.toString());
            } finally {
                this.latch.countDown();
            }
        }
    }

    public ParallelML(int i, ChunkedDataSet chunkedDataSet, int i2, boolean z) {
        this.numThreads = i2;
        this.data = chunkedDataSet;
        this.fd = i;
        this.diagonalCovariance = z;
    }

    public Density mlEstimate() throws IOException, InterruptedException {
        logger.info("ParallelML.estimate(): BEGIN");
        if (this.estimate != null) {
            return this.estimate;
        }
        CountDownLatch countDownLatch = new CountDownLatch(this.numThreads);
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
        Estimator[] estimatorArr = new Estimator[this.numThreads];
        logger.info("ParallelML.estimate(): starting thread pool (1)");
        for (int i = 0; i < this.numThreads; i++) {
            Estimator estimator = new Estimator(countDownLatch);
            estimatorArr[i] = estimator;
            newFixedThreadPool.execute(estimator);
        }
        countDownLatch.await();
        this.data.rewind();
        logger.info("ParallelML.estimate(): normalizing mean value");
        double[] dArr = new double[this.fd];
        for (Estimator estimator2 : estimatorArr) {
            Iterator<double[]> it = estimator2.partial.iterator();
            while (it.hasNext()) {
                double[] next = it.next();
                for (int i2 = 0; i2 < this.fd; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + (next[i2] / this.samples);
                }
            }
        }
        logger.info("ParallelML.estimate(): starting thread pool (2)");
        CountDownLatch countDownLatch2 = new CountDownLatch(this.numThreads);
        for (int i4 = 0; i4 < this.numThreads; i4++) {
            Estimator estimator3 = new Estimator(countDownLatch2, dArr);
            estimatorArr[i4] = estimator3;
            newFixedThreadPool.execute(estimator3);
        }
        countDownLatch2.await();
        this.data.rewind();
        logger.info("ParallelML.estimate(): normalizing covariance");
        double[] dArr2 = this.diagonalCovariance ? new double[this.fd] : new double[(this.fd * (this.fd + 1)) / 2];
        for (Estimator estimator4 : estimatorArr) {
            Iterator<double[]> it2 = estimator4.partial.iterator();
            while (it2.hasNext()) {
                double[] next2 = it2.next();
                for (int i5 = 0; i5 < dArr2.length; i5++) {
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] + (next2[i5] / this.samples);
                }
            }
        }
        this.estimate = this.diagonalCovariance ? new DensityDiagonal(1.0d, dArr, dArr2) : new DensityFull(1.0d, dArr, dArr2);
        this.estimate.update();
        newFixedThreadPool.shutdownNow();
        return this.estimate;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public synchronized void processedSamples(long j) {
        this.samples += j;
    }
}
