package de.fau.cs.jstk.app;

import de.fau.cs.jstk.arch.Configuration;
import de.fau.cs.jstk.arch.Tokenization;
import de.fau.cs.jstk.io.FrameInputStream;
import de.fau.cs.jstk.io.FrameSource;
import de.fau.cs.jstk.stat.hmm.Alignment;
import de.fau.cs.jstk.stat.hmm.MetaAlignment;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

/* loaded from: input_file:de/fau/cs/jstk/app/Trainer.class */
public class Trainer {
    private static Logger logger = Logger.getLogger(Trainer.class);
    public static final String SYNOPSIS = "sikoried, 8/2/2010\nTrain the models of a TokenHierarchy using the given transcriptions or\nalignments. The required alignments can be read or computed on the fly.\n\nusage: app.Trainer config codebook codebook-out turn-list feat-dir [options]\n-a type\n  Use the given alignment strategy. Currently supported:\n  manual <directory>        : assume manual alignments in the given directory.\n  manual_linear <directory> : same as manual, but use linear alignment in absence of state alignment\n  forced [directory]        : compute forced Viterbi alignment; specify a directory to save\n                              the alignment result (default strategy).\n  linear [directory]        : estimate a linear alignment dependent on the number\n                              of states; specify a directory to save the alignment result.\n-t type\n  Use the given training strategy. Currently supported:\n  vt : Viterbi training aka EM* -- fast and efficient; great with linear alignment for\n       initialization (default).\n  bw : Baum-Welch training based on EM. Computationally more expensive but greater\n       accuracy.\n-p num\n  Number of threads to use for the training. If set to 0, the number of threads\n  is set to the number of CPU (default).\n--silent\n  Mute the DebugOutput.\n--prop\n  Propagate the sufficient statistics prior to reestimation.\n--interp <rho>\n  Perform suff.stat. propagation and interpolation prior to reestimation\n";

    /* loaded from: input_file:de/fau/cs/jstk/app/Trainer$AlignmentType.class */
    enum AlignmentType {
        MANUAL,
        MANUAL_LINEAR,
        FORCED,
        LINEAR;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static AlignmentType[] valuesCustom() {
            AlignmentType[] valuesCustom = values();
            int length = valuesCustom.length;
            AlignmentType[] alignmentTypeArr = new AlignmentType[length];
            System.arraycopy(valuesCustom, 0, alignmentTypeArr, 0, length);
            return alignmentTypeArr;
        }
    }

    /* loaded from: input_file:de/fau/cs/jstk/app/Trainer$Distributor.class */
    private static final class Distributor {
        List<Job> turns;
        Iterator<Job> it = null;

        Distributor(List<Job> list) {
            this.turns = null;
            this.turns = list;
            rewind();
        }

        void rewind() {
            this.it = this.turns.iterator();
        }

        synchronized Job next() {
            if (this.it.hasNext()) {
                return this.it.next();
            }
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/fau/cs/jstk/app/Trainer$Job.class */
    public static final class Job {
        MetaAlignment.Turn turn;
        AlignmentType align;
        TrainingType train;

        Job(MetaAlignment.Turn turn, AlignmentType alignmentType, TrainingType trainingType) {
            this.turn = turn;
            this.align = alignmentType;
            this.train = trainingType;
        }
    }

    /* loaded from: input_file:de/fau/cs/jstk/app/Trainer$TrainingType.class */
    enum TrainingType {
        VITERBI,
        BAUM_WELCH;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static TrainingType[] valuesCustom() {
            TrainingType[] valuesCustom = values();
            int length = valuesCustom.length;
            TrainingType[] trainingTypeArr = new TrainingType[length];
            System.arraycopy(valuesCustom, 0, trainingTypeArr, 0, length);
            return trainingTypeArr;
        }
    }

    /* loaded from: input_file:de/fau/cs/jstk/app/Trainer$Worker.class */
    private static final class Worker implements Runnable {
        Distributor dist;
        Configuration conf;
        CountDownLatch latch;
        long jobs = 0;

        Worker(Configuration configuration, Distributor distributor, CountDownLatch countDownLatch) {
            this.conf = configuration;
            this.dist = distributor;
            this.latch = countDownLatch;
        }

        @Override // java.lang.Runnable
        public void run() {
            MetaAlignment metaAlignment;
            try {
                while (true) {
                    Job next = this.dist.next();
                    if (next == null) {
                        Trainer.logger.info("Trainer.Worker#" + Thread.currentThread().getId() + ".run(): finished; computed " + this.jobs + " jobs");
                        return;
                    }
                    FrameInputStream frameInputStream = new FrameInputStream(new File(next.turn.canonicalInputName()));
                    if (next.align == AlignmentType.MANUAL || next.align == AlignmentType.MANUAL_LINEAR) {
                        BufferedReader bufferedReader = new BufferedReader(new FileReader(next.turn.canonicalOutputName()));
                        metaAlignment = new MetaAlignment(frameInputStream, bufferedReader, this.conf.th, next.align == AlignmentType.MANUAL);
                        bufferedReader.close();
                    } else if (next.align == AlignmentType.FORCED) {
                        metaAlignment = new MetaAlignment((FrameSource) frameInputStream, (Iterable<Tokenization>) this.conf.tok.getSentenceTokenization(next.turn.transcription), this.conf.th, true);
                    } else {
                        if (next.align != AlignmentType.LINEAR) {
                            throw new Exception("Trainer.Worker#" + Thread.currentThread().getId() + ".run(): invalid alignment strategy!");
                        }
                        metaAlignment = new MetaAlignment((FrameSource) frameInputStream, (Iterable<Tokenization>) this.conf.tok.getSentenceTokenization(next.turn.transcription), this.conf.th, false);
                    }
                    if ((next.align == AlignmentType.LINEAR || next.align == AlignmentType.FORCED) && next.turn.outDir != null) {
                        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(next.turn.canonicalOutputName()));
                        metaAlignment.write(bufferedWriter);
                        bufferedWriter.close();
                    }
                    if (next.train == TrainingType.VITERBI) {
                        for (Alignment alignment : metaAlignment.alignments) {
                            alignment.model.incrementVT(alignment);
                        }
                    } else {
                        if (next.train != TrainingType.BAUM_WELCH) {
                            throw new Exception("Trainer.Worker#" + Thread.currentThread().getId() + ".run(): invalid training strategy!");
                        }
                        for (Alignment alignment2 : metaAlignment.alignments) {
                            alignment2.model.incrementBW(alignment2.observation);
                        }
                    }
                    this.jobs++;
                }
            } catch (IOException e) {
                e.printStackTrace();
                System.err.println("Trainer.Worker#" + Thread.currentThread().getId() + ".run(): " + e);
            } catch (Exception e2) {
                e2.printStackTrace();
                System.err.println("Trainer.Worker#" + Thread.currentThread().getId() + ".run(): " + e2);
            } finally {
                this.latch.countDown();
            }
        }
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length < 5) {
            System.err.println(SYNOPSIS);
            System.exit(1);
        }
        BasicConfigurator.configure();
        for (String str : strArr) {
            if (str.equals("--silent")) {
                logger.setLevel(Level.FATAL);
            }
        }
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        AlignmentType alignmentType = AlignmentType.FORCED;
        TrainingType trainingType = TrainingType.VITERBI;
        String str2 = null;
        int i = 0 + 1;
        String str3 = strArr[0];
        int i2 = i + 1;
        String str4 = strArr[i];
        int i3 = i2 + 1;
        String str5 = strArr[i2];
        int i4 = i3 + 1;
        String str6 = strArr[i3];
        int i5 = i4 + 1;
        String str7 = strArr[i4];
        double d = -1.0d;
        while (i5 < strArr.length) {
            if (strArr[i5].equals("--silent")) {
                logger.setLevel(Level.FATAL);
            } else if (strArr[i5].equals("-t")) {
                i5++;
                String str8 = strArr[i5];
                if (str8.equals("bw")) {
                    trainingType = TrainingType.BAUM_WELCH;
                } else {
                    if (!str8.equals("vt")) {
                        throw new Exception("Trainer.main(): unknown training strategy");
                    }
                    trainingType = TrainingType.VITERBI;
                }
            } else if (strArr[i5].equals("-a")) {
                i5++;
                String str9 = strArr[i5];
                if (str9.equals("manual")) {
                    alignmentType = AlignmentType.MANUAL;
                    i5++;
                    str2 = strArr[i5];
                } else if (str9.equals("manual_linear")) {
                    alignmentType = AlignmentType.MANUAL_LINEAR;
                    i5++;
                    str2 = strArr[i5];
                } else if (str9.equals("forced")) {
                    alignmentType = AlignmentType.FORCED;
                } else {
                    if (!str9.equals("linear")) {
                        throw new Exception("Trainer.main(): unknown alignment strategy");
                    }
                    alignmentType = AlignmentType.LINEAR;
                }
                if ((alignmentType == AlignmentType.FORCED || alignmentType == AlignmentType.LINEAR) && i5 + 1 < strArr.length && !strArr[i5 + 1].startsWith("-")) {
                    i5++;
                    str2 = strArr[i5];
                }
            } else if (strArr[i5].equals("-p")) {
                i5++;
                int parseInt = Integer.parseInt(strArr[i5]);
                if (parseInt < 0) {
                    throw new Exception("Trainer.main(): invalid number of threads");
                }
                if (parseInt == 0) {
                    availableProcessors = Runtime.getRuntime().availableProcessors();
                } else if (parseInt <= availableProcessors) {
                    availableProcessors = parseInt;
                } else {
                    logger.info("Trainer.main(): warning -- using more threads thann CPUs!");
                    availableProcessors = parseInt;
                }
            } else if (strArr[i5].equals("--prop")) {
                d = 0.0d;
            } else if (strArr[i5].equals("--interp")) {
                i5++;
                d = Double.parseDouble(strArr[i5]);
            } else {
                logger.info("Trainer.main(): warning -- ignoring unknown argument \"" + strArr[i5] + "\"");
            }
            i5++;
        }
        List<MetaAlignment.Turn> readTurnList = MetaAlignment.Turn.readTurnList(str6, str7, str2);
        logger.info("Trainer.main(): read " + readTurnList.size() + " turns");
        LinkedList linkedList = new LinkedList();
        Iterator<MetaAlignment.Turn> it = readTurnList.iterator();
        while (it.hasNext()) {
            linkedList.add(new Job(it.next(), alignmentType, trainingType));
        }
        CountDownLatch countDownLatch = new CountDownLatch(availableProcessors);
        Distributor distributor = new Distributor(linkedList);
        Worker[] workerArr = new Worker[availableProcessors];
        for (int i6 = 0; i6 < availableProcessors; i6++) {
            logger.info("Trainer.main(): preparing thread #" + i6);
            Configuration configuration = new Configuration(new File(str3));
            configuration.loadCodebook(new File(str4));
            configuration.cb.init();
            workerArr[i6] = new Worker(configuration, distributor, countDownLatch);
        }
        logger.info("Trainer.main(): begin training using " + availableProcessors + " threads");
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(availableProcessors);
        for (int i7 = 0; i7 < availableProcessors; i7++) {
            newFixedThreadPool.execute(workerArr[i7]);
        }
        countDownLatch.await();
        newFixedThreadPool.shutdownNow();
        logger.info("Trainer.main(): re-estimating...");
        Configuration configuration2 = workerArr[0].conf;
        for (int i8 = 1; i8 < availableProcessors; i8++) {
            configuration2.cb.consume(workerArr[i8].conf.cb);
        }
        if (d >= 0.0d) {
            logger.info("propagating statistics...");
            configuration2.th.propagate();
        }
        if (d > 0.0d) {
            logger.info("interpolating statistics (rho = " + d + ")...");
            configuration2.th.interpolate(d);
        }
        configuration2.cb.reestimate();
        logger.info("Trainer.main(): writing out " + str5);
        configuration2.cb.write(new File(str5));
    }
}
