package de.fau.cs.jstk.stat.hmm;

import de.fau.cs.jstk.io.IOUtil;
import de.fau.cs.jstk.stat.Mixture;
import de.fau.cs.jstk.util.Arithmetics;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteOrder;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.log4j.Logger;

/* loaded from: input_file:de/fau/cs/jstk/stat/hmm/Hmm.class */
public final class Hmm {
    private static Logger logger = Logger.getLogger(Hmm.class);
    public int id;
    public transient String textualId;
    public short ns;
    public State[] s;
    public float[] pi;
    public float[][] a;
    public transient Accumulator accumulator = null;
    private static /* synthetic */ int[] $SWITCH_TABLE$de$fau$cs$jstk$stat$hmm$Hmm$Topology;

    /* loaded from: input_file:de/fau/cs/jstk/stat/hmm/Hmm$Accumulator.class */
    public final class Accumulator {
        double[][] a;
        double[] pi;
        long segments = 0;
        long frames = 0;

        public Accumulator() {
            this.a = new double[Hmm.this.ns][Hmm.this.ns];
            this.pi = new double[Hmm.this.ns];
        }

        void propagate(Accumulator accumulator) {
            if (accumulator.frames == 0) {
                return;
            }
            this.segments += accumulator.segments;
            this.frames += accumulator.frames;
            for (int i = 0; i < this.a.length; i++) {
                double[] dArr = this.pi;
                int i2 = i;
                dArr[i2] = dArr[i2] + accumulator.pi[i];
                Arithmetics.vadd2(this.a[i], accumulator.a[i]);
            }
        }

        void interpolate(Accumulator accumulator, double d) {
            for (int i = 0; i < this.a.length; i++) {
                double gamma = d / (d + Hmm.this.s[i].gamma());
                this.pi[i] = (gamma * accumulator.pi[i]) + ((1.0d - gamma) * this.pi[i]);
                Arithmetics.interp1(this.a[i], accumulator.a[i], gamma);
            }
        }

        public String toString() {
            return "HMM.Accumulator nseq=" + this.segments + " nfrm=" + this.frames;
        }
    }

    /* loaded from: input_file:de/fau/cs/jstk/stat/hmm/Hmm$Topology.class */
    public enum Topology {
        LINEAR,
        BAKIS,
        LEFT_TO_RIGHT,
        ERGODIC;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Topology[] valuesCustom() {
            Topology[] valuesCustom = values();
            int length = valuesCustom.length;
            Topology[] topologyArr = new Topology[length];
            System.arraycopy(valuesCustom, 0, topologyArr, 0, length);
            return topologyArr;
        }
    }

    public Hmm(Hmm[] hmmArr) {
        this.pi = null;
        this.a = null;
        this.ns = (short) 0;
        LinkedList linkedList = new LinkedList();
        for (Hmm hmm : hmmArr) {
            this.ns = (short) (this.ns + hmm.ns);
            for (State state : hmm.s) {
                linkedList.add(state);
            }
        }
        this.s = (State[]) linkedList.toArray(new State[this.ns]);
        this.pi = new float[this.ns];
        this.pi[0] = 1.0f;
        this.a = new float[this.ns][this.ns];
        int i = 0;
        for (Hmm hmm2 : hmmArr) {
            for (int i2 = 0; i2 < hmm2.ns; i2++) {
                for (int i3 = 0; i3 < hmm2.ns; i3++) {
                    this.a[i2 + i][i3 + i] = hmm2.a[i2][i3];
                }
            }
            if (i > 0 && i < this.ns) {
                this.a[i - 1][i] = 0.5f;
                this.a[i - 1][i - 1] = 0.5f;
            }
            i += hmm2.ns;
        }
    }

    public Hmm(int i, short s, State state) {
        this.pi = null;
        this.a = null;
        this.id = i;
        this.ns = s;
        this.s = new State[s];
        this.pi = new float[s];
        this.a = new float[s][s];
        if (state instanceof DState) {
            DState dState = (DState) state;
            for (int i2 = 0; i2 < s; i2++) {
                this.s[i2] = new DState(dState);
            }
            return;
        }
        if (state instanceof CState) {
            CState cState = (CState) state;
            for (int i3 = 0; i3 < s; i3++) {
                this.s[i3] = new CState(cState);
            }
            return;
        }
        if (!(state instanceof SCState)) {
            throw new RuntimeException("Hmm(): Unsupported state type " + state.getClass().getCanonicalName());
        }
        SCState sCState = (SCState) state;
        for (int i4 = 0; i4 < s; i4++) {
            this.s[i4] = new SCState(sCState);
        }
    }

    public Hmm(InputStream inputStream, HashMap<Integer, Mixture> hashMap) throws IOException {
        this.pi = null;
        this.a = null;
        this.id = IOUtil.readInt(inputStream, ByteOrder.LITTLE_ENDIAN);
        this.ns = IOUtil.readShort(inputStream, ByteOrder.LITTLE_ENDIAN);
        this.pi = new float[this.ns];
        if (!IOUtil.readFloat(inputStream, this.pi, ByteOrder.LITTLE_ENDIAN)) {
            throw new IOException("could not read entry probs");
        }
        this.a = new float[this.ns][this.ns];
        for (int i = 0; i < this.ns; i++) {
            if (!IOUtil.readFloat(inputStream, this.a[i], ByteOrder.LITTLE_ENDIAN)) {
                throw new IOException("could not read transition probs");
            }
        }
        this.s = new State[this.ns];
        for (int i2 = 0; i2 < this.ns; i2++) {
            this.s[i2] = State.read(inputStream, hashMap);
        }
    }

    public void write(OutputStream outputStream) throws IOException {
        IOUtil.writeInt(outputStream, this.id, ByteOrder.LITTLE_ENDIAN);
        IOUtil.writeShort(outputStream, this.ns, ByteOrder.LITTLE_ENDIAN);
        IOUtil.writeFloat(outputStream, this.pi, ByteOrder.LITTLE_ENDIAN);
        for (int i = 0; i < this.ns; i++) {
            IOUtil.writeFloat(outputStream, this.a[i], ByteOrder.LITTLE_ENDIAN);
        }
        for (State state : this.s) {
            state.write(outputStream);
        }
    }

    public void setSharedCodebook(Mixture mixture) {
        for (State state : this.s) {
            ((SCState) state).cb = mixture;
        }
    }

    public Mixture getSharedCodebook() {
        return ((SCState) this.s[0]).cb;
    }

    public boolean equals(Object obj) {
        return (obj instanceof Hmm) && ((Hmm) obj).id == this.id;
    }

    public void init() {
        if (this.accumulator != null) {
            logger.warn("replacing existing Accumulator!");
        }
        this.accumulator = new Accumulator();
        for (State state : this.s) {
            state.init();
        }
    }

    public void discard() {
        this.accumulator = null;
        for (State state : this.s) {
            state.discard();
        }
    }

    public void reestimate() {
        double d = 0.0d;
        for (int i = 0; i < this.ns; i++) {
            d += this.accumulator.pi[i];
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.ns; i2++) {
                d2 += this.accumulator.a[i][i2];
            }
            if (d2 > 0.0d) {
                for (int i3 = 0; i3 < this.ns; i3++) {
                    this.a[i][i3] = (float) (this.accumulator.a[i][i3] / d2);
                }
                this.s[i].reestimate();
            } else {
                logger.warn("hmm(" + this.id + ")[" + i + "] no transition weight, no re-estimation of a[" + i + "][] and s[" + i + "]");
            }
        }
        if (d <= 0.0d) {
            logger.warn("hmm(" + this.id + ") no entries logged => no re-estimation of pi");
            return;
        }
        for (int i4 = 0; i4 < this.ns; i4++) {
            this.pi[i4] = (float) (this.accumulator.pi[i4] / d);
        }
    }

    public void propagate(Hmm hmm) {
        if (hmm.ns != this.ns) {
            throw new RuntimeException("HMM.propagate(): Source HMM has different number of states!");
        }
        for (int i = 0; i < this.ns; i++) {
            this.s[i].propagate(hmm.s[i]);
        }
        this.accumulator.propagate(hmm.accumulator);
    }

    public void interpolate(Hmm hmm, double d) {
        if (hmm.ns != this.ns) {
            throw new RuntimeException("HMM.interpolate(): Source HMM has different number of states!");
        }
        this.accumulator.interpolate(hmm.accumulator, d);
        for (int i = 0; i < this.ns; i++) {
            this.s[i].interpolate(hmm.s[i], d);
        }
    }

    public void pinterpolate(double d, Hmm hmm) {
        if (this.ns != hmm.ns) {
            throw new RuntimeException("Hmm.pinterpolate(): different numbers of states");
        }
        for (int i = 0; i < this.ns; i++) {
            this.pi[i] = (float) ((d * hmm.pi[i]) + ((1.0d - d) * this.pi[i]));
            Arithmetics.interp1(this.a[i], hmm.a[i], (float) d);
            Arithmetics.makesumto1(this.a[i]);
            this.s[i].pinterpolate(d, hmm.s[i]);
        }
        Arithmetics.makesumto1(this.pi);
    }

    public void incrementBW(List<double[]> list) {
        int size = list.size();
        if (size < this.ns) {
            logger.info("HMM.incrementBW(): WARNING -- observation sequence (" + size + ") shorter than model length (" + ((int) this.ns) + ")!");
        }
        if (size == 0) {
            return;
        }
        this.accumulator.segments++;
        this.accumulator.frames += size;
        Iterator<double[]> it = list.iterator();
        double[][] dArr = new double[size][this.ns];
        double[][] dArr2 = new double[size][this.ns];
        double[] dArr3 = new double[size];
        double[][] dArr4 = new double[size][this.ns];
        double[] next = it.next();
        for (int i = 0; i < this.ns; i++) {
            dArr4[0][i] = this.s[i].emits(next);
            dArr[0][i] = this.pi[i] * dArr4[0][i];
            dArr3[0] = dArr3[0] + dArr[0][i];
        }
        for (int i2 = 0; i2 < this.ns; i2++) {
            double[] dArr5 = dArr[0];
            int i3 = i2;
            dArr5[i3] = dArr5[i3] / dArr3[0];
        }
        for (int i4 = 1; i4 < size; i4++) {
            double[] next2 = it.next();
            for (int i5 = 0; i5 < this.ns; i5++) {
                dArr4[i4][i5] = this.s[i5].emits(next2);
                for (int i6 = 0; i6 < this.ns; i6++) {
                    double[] dArr6 = dArr[i4];
                    int i7 = i5;
                    dArr6[i7] = dArr6[i7] + (dArr[i4 - 1][i6] * this.a[i6][i5]);
                }
                double[] dArr7 = dArr[i4];
                int i8 = i5;
                dArr7[i8] = dArr7[i8] * dArr4[i4][i5];
                int i9 = i4;
                dArr3[i9] = dArr3[i9] + dArr[i4][i5];
            }
            for (int i10 = 0; i10 < this.ns; i10++) {
                double[] dArr8 = dArr[i4];
                int i11 = i10;
                dArr8[i11] = dArr8[i11] / dArr3[i4];
            }
        }
        for (int i12 = 0; i12 < this.ns; i12++) {
            dArr2[size - 1][i12] = 1.0d / dArr3[size - 1];
        }
        for (int i13 = size - 2; i13 >= 0; i13--) {
            for (int i14 = 0; i14 < this.ns; i14++) {
                for (int i15 = 0; i15 < this.ns; i15++) {
                    double[] dArr9 = dArr2[i13];
                    int i16 = i14;
                    dArr9[i16] = dArr9[i16] + (this.a[i14][i15] * dArr4[i13 + 1][i15] * dArr2[i13 + 1][i15]);
                }
                double[] dArr10 = dArr2[i13];
                int i17 = i14;
                dArr10[i17] = dArr10[i17] / dArr3[i13];
            }
        }
        Iterator<double[]> it2 = list.iterator();
        double[] next3 = it2.next();
        double[] dArr11 = new double[this.ns];
        for (int i18 = 0; i18 < size; i18++) {
            double d = 0.0d;
            for (int i19 = 0; i19 < this.ns; i19++) {
                dArr11[i19] = dArr[i18][i19] * dArr2[i18][i19];
                d += dArr11[i19];
            }
            if (i18 == 0) {
                for (int i20 = 0; i20 < this.ns; i20++) {
                    double[] dArr12 = this.accumulator.pi;
                    int i21 = i20;
                    dArr12[i21] = dArr12[i21] + (dArr11[i20] / d);
                }
            }
            for (int i22 = 0; i22 < this.ns; i22++) {
                this.s[i22].accumulate(dArr11[i22] / d, next3);
            }
            if (i18 == size - 1) {
                return;
            }
            for (int i23 = 0; i23 < this.ns; i23++) {
                for (int i24 = 0; i24 < this.ns; i24++) {
                    double[] dArr13 = this.accumulator.a[i23];
                    int i25 = i24;
                    dArr13[i25] = dArr13[i25] + (dArr[i18][i23] * this.a[i23][i24] * dArr4[i18 + 1][i24] * dArr2[i18 + 1][i24]);
                }
            }
            next3 = it2.next();
        }
    }

    public void incrementVT(Alignment alignment) {
        if (!alignment.model.equals(this)) {
            throw new RuntimeException("HMM[" + this.id + "].incrementVT(): Alignment.model and this model do not match.");
        }
        if (alignment.q == null) {
            throw new RuntimeException("HMM[" + this.id + "].incrementVT(): No state alignment present.");
        }
        incrementVT(alignment.observation, alignment.q);
    }

    public void incrementVT(List<double[]> list, int[] iArr) {
        if (list.size() != iArr.length) {
            logger.fatal("HMM.incrementVT(): observation.size() != q.length");
            throw new RuntimeException("HMM.incrementVT(): observation.size() != q.length");
        }
        if (list.size() < this.ns) {
            logger.info("HMM.incrementVT(): WARNING -- observation sequence (" + list.size() + ") shorter than model length (" + ((int) this.ns) + ")!");
        }
        if (iArr.length == 0) {
            return;
        }
        this.accumulator.segments++;
        this.accumulator.frames += iArr.length;
        Iterator<double[]> it = list.iterator();
        double[] dArr = this.accumulator.pi;
        int i = iArr[0];
        dArr[i] = dArr[i] + 1.0d;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            this.s[iArr[i2]].accumulate(1.0d, it.next());
            if (i2 == iArr.length - 1) {
                return;
            }
            double[] dArr2 = this.accumulator.a[iArr[i2]];
            int i3 = iArr[i2 + 1];
            dArr2[i3] = dArr2[i3] + 1.0d;
        }
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("hmm.HMM id=" + this.id + " ns=" + ((int) this.ns) + (this.textualId != null ? this.textualId : "") + "\n");
        stringBuffer.append("pi = [");
        for (int i = 0; i < this.ns; i++) {
            stringBuffer.append(" " + this.pi[i]);
        }
        stringBuffer.append(" ]\na = [\n");
        for (int i2 = 0; i2 < this.ns; i2++) {
            for (int i3 = 0; i3 < this.ns; i3++) {
                stringBuffer.append(" " + this.a[i2][i3]);
            }
            stringBuffer.append("\n");
        }
        stringBuffer.append("]\n");
        for (int i4 = 0; i4 < this.ns; i4++) {
            stringBuffer.append("s[" + i4 + "] " + this.s[i4] + "\n");
        }
        return stringBuffer.toString();
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:12:0x0038. Please report as an issue. */
    public void setTransitions(Topology topology) {
        for (int i = 0; i < this.ns; i++) {
            this.pi[i] = 0.0f;
            for (int i2 = 0; i2 < this.ns; i2++) {
                this.a[i][i2] = 0.0f;
            }
        }
        switch ($SWITCH_TABLE$de$fau$cs$jstk$stat$hmm$Hmm$Topology()[topology.ordinal()]) {
            case 1:
                this.pi[0] = 1.0f;
                for (int i3 = 0; i3 < this.ns - 1; i3++) {
                    this.a[i3][i3] = 0.5f;
                    this.a[i3][i3 + 1] = 0.5f;
                }
                this.a[this.ns - 1][this.ns - 1] = 1.0f;
                return;
            case 4:
                for (int i4 = 0; i4 < this.ns; i4++) {
                    this.pi[i4] = (float) (1.0d / this.ns);
                    for (int i5 = 0; i5 < this.ns; i5++) {
                        this.a[i4][i5] = (float) (1.0d / this.ns);
                    }
                }
            case 2:
            case 3:
            default:
                logger.info("HMM.setTransitions(): requested topology not implemented (yet)");
                return;
        }
    }

    static /* synthetic */ int[] $SWITCH_TABLE$de$fau$cs$jstk$stat$hmm$Hmm$Topology() {
        int[] iArr = $SWITCH_TABLE$de$fau$cs$jstk$stat$hmm$Hmm$Topology;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Topology.valuesCustom().length];
        try {
            iArr2[Topology.BAKIS.ordinal()] = 2;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Topology.ERGODIC.ordinal()] = 4;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[Topology.LEFT_TO_RIGHT.ordinal()] = 3;
        } catch (NoSuchFieldError unused3) {
        }
        try {
            iArr2[Topology.LINEAR.ordinal()] = 1;
        } catch (NoSuchFieldError unused4) {
        }
        $SWITCH_TABLE$de$fau$cs$jstk$stat$hmm$Hmm$Topology = iArr2;
        return iArr2;
    }
}
