package de.fau.cs.jstk.stat;

import de.fau.cs.jstk.io.ChunkedDataSet;
import de.fau.cs.jstk.io.FrameInputStream;
import de.fau.cs.jstk.stat.Density;
import de.fau.cs.jstk.stat.MleDensityAccumulator;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.log4j.Logger;

/* loaded from: input_file:de/fau/cs/jstk/stat/ParallelEM.class */
public final class ParallelEM {
    private static Logger logger = Logger.getLogger(ParallelEM.class);
    private int numThreads;
    private ChunkedDataSet data;
    public Mixture previous;
    public Mixture current;
    private int nd;
    private int fd;
    public int ni;
    private MleDensityAccumulator.MleOptions opts;
    private Density.Flags flags;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/fau/cs/jstk/stat/ParallelEM$Worker.class */
    public class Worker implements Runnable {
        Mixture m;
        MleMixtureAccumulator a;
        CountDownLatch latch;
        double[] f;
        double[] p;
        int cnt_chunk = 0;
        int cnt_frame = 0;

        Worker(Mixture mixture, MleMixtureAccumulator mleMixtureAccumulator, CountDownLatch countDownLatch) {
            this.latch = countDownLatch;
            this.m = mixture;
            this.a = mleMixtureAccumulator;
            this.f = new double[ParallelEM.this.fd];
            this.p = new double[ParallelEM.this.nd];
            mleMixtureAccumulator.flush();
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    ChunkedDataSet.Chunk nextChunk = ParallelEM.this.data.nextChunk();
                    if (nextChunk == null) {
                        ParallelEM.logger.info("ParallelEM.Worker#" + Thread.currentThread().getId() + ".run(): processed " + this.cnt_frame + " in " + this.cnt_chunk + " chunks");
                        return;
                    }
                    FrameInputStream frameReader = nextChunk.getFrameReader();
                    while (frameReader.read(this.f)) {
                        this.m.evaluate(this.f);
                        this.m.posteriors(this.p);
                        this.a.accumulate(this.p, this.f);
                        this.cnt_frame++;
                    }
                    this.cnt_chunk++;
                } catch (IOException e) {
                    ParallelEM.logger.info("ParallelEM.Worker#" + Thread.currentThread().getId() + ".run(): IOException: " + e.toString());
                    return;
                } finally {
                    this.latch.countDown();
                }
            }
        }
    }

    public ParallelEM(Mixture mixture, ChunkedDataSet chunkedDataSet, int i) throws IOException {
        this(mixture, chunkedDataSet, MleDensityAccumulator.MleOptions.pDefaultOptions, Density.Flags.fAllParams, i);
    }

    public ParallelEM(Mixture mixture, ChunkedDataSet chunkedDataSet, MleDensityAccumulator.MleOptions mleOptions, Density.Flags flags, int i) throws IOException {
        this.numThreads = 0;
        this.data = null;
        this.previous = null;
        this.current = null;
        this.ni = 0;
        this.data = chunkedDataSet;
        this.numThreads = i;
        this.current = mixture;
        this.fd = mixture.fd;
        this.nd = mixture.nd;
        this.opts = mleOptions;
        this.flags = flags;
    }

    public void setChunkedDataSet(ChunkedDataSet chunkedDataSet) {
        this.data = chunkedDataSet;
    }

    public void setNumberOfThreads(int i) {
        this.numThreads = i;
    }

    public void iterate(int i) throws ClassNotFoundException, IOException, InterruptedException {
        while (true) {
            int i2 = i;
            i--;
            if (i2 <= 0) {
                return;
            } else {
                iterate();
            }
        }
    }

    public void iterate() throws ClassNotFoundException, IOException, InterruptedException {
        Logger logger2 = logger;
        StringBuilder sb = new StringBuilder("ParallelEM.iterate(): BEGIN iteration ");
        int i = this.ni + 1;
        this.ni = i;
        logger2.info(sb.append(i).toString());
        Mixture[] mixtureArr = new Mixture[this.numThreads];
        MleMixtureAccumulator[] mleMixtureAccumulatorArr = new MleMixtureAccumulator[this.numThreads];
        this.previous = this.current.m63clone();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
        CountDownLatch countDownLatch = new CountDownLatch(this.numThreads);
        MleMixtureAccumulator mleMixtureAccumulator = new MleMixtureAccumulator(this.current.fd, this.current.nd, this.current.diagonal() ? DensityDiagonal.class : DensityFull.class);
        for (int i2 = 0; i2 < this.numThreads; i2++) {
            Mixture m63clone = this.current.m63clone();
            mixtureArr[i2] = m63clone;
            MleMixtureAccumulator mleMixtureAccumulator2 = new MleMixtureAccumulator(mleMixtureAccumulator);
            mleMixtureAccumulatorArr[i2] = mleMixtureAccumulator2;
            newFixedThreadPool.execute(new Worker(m63clone, mleMixtureAccumulator2, countDownLatch));
        }
        countDownLatch.await();
        newFixedThreadPool.shutdownNow();
        this.data.rewind();
        for (int i3 = 0; i3 < this.numThreads; i3++) {
            mleMixtureAccumulator.propagate(mleMixtureAccumulatorArr[i3]);
        }
        MleMixtureAccumulator.MleUpdate(this.previous, this.opts, this.flags, mleMixtureAccumulator, this.current);
        logger.info("ParallelEM.iterate(): END");
    }
}
