/*
 * Decompiled with CFR 0.152.
 */
package tsg;

import java.io.File;
import java.io.PrintWriter;
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;
import java.util.TreeSet;
import kernels.NodeSetCollector;
import kernels.NodeSetCollectorSimple;
import kernels.NodeSetCollectorStandard;
import settings.Parameters;
import tsg.Label;
import tsg.TSNodeLabel;
import tsg.TSNodeLabelIndex;
import tsg.TSNodeLabelStructure;
import tsg.corpora.Wsj;
import util.FileUtil;
import util.PrintProgress;
import util.Utility;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class DOP_EM_BigDecimal {
    public static Hashtable<TSNodeLabel, BigDecimal[]> fragmentTableFreq;
    public static Hashtable<Label, BigDecimal[]> rootTableFreq;
    public static ArrayList<TSNodeLabelIndex> trainingCorpus;
    public static int endCycle;
    public static String workingDir;

    static {
        endCycle = 10;
    }

    public static void readFragmentsFile(File fragmentFile) throws Exception {
        fragmentTableFreq = new Hashtable();
        Scanner scan = FileUtil.getScanner(fragmentFile);
        int countFragments = 0;
        while (scan.hasNextLine()) {
            String line = scan.nextLine();
            if (line.equals("")) continue;
            ++countFragments;
            String[] fragmentFreq = line.split("\t");
            String fragmentString = fragmentFreq[0];
            BigDecimal freq = new BigDecimal(fragmentFreq[1]);
            TSNodeLabel fragment = new TSNodeLabel(fragmentString, false);
            fragmentTableFreq.put(fragment, new BigDecimal[]{freq});
        }
        System.out.println("Read " + countFragments + " fragments");
        scan.close();
    }

    private static void readTreeBank(ArrayList<TSNodeLabel> treebank) throws Exception {
        trainingCorpus = new ArrayList();
        for (TSNodeLabel t : treebank) {
            trainingCorpus.add(new TSNodeLabelIndex(t));
        }
    }

    public static void printFragmentFreq(File outputFile) {
        PrintWriter pw = FileUtil.getPrintWriter(outputFile);
        for (Map.Entry<TSNodeLabel, BigDecimal[]> e : fragmentTableFreq.entrySet()) {
            String fragmentString = e.getKey().toString(false, true);
            double freq = e.getValue()[0].doubleValue();
            pw.println(String.valueOf(fragmentString) + "\t" + freq);
        }
        pw.close();
    }

    /*
     * WARNING - void declaration
     */
    public static void addCFGfragments() throws Exception {
        void var1_4;
        Hashtable ruleTable = new Hashtable();
        for (TSNodeLabel tSNodeLabel : trainingCorpus) {
            ArrayList<TSNodeLabel> nodes = tSNodeLabel.collectAllNodes();
            for (TSNodeLabel n : nodes) {
                if (n.isLexical) continue;
                String rule = n.cfgRule();
                Utility.increaseInTableInt(ruleTable, rule);
            }
        }
        System.out.println("Read " + ruleTable.size() + " CFG fragments");
        boolean bl = false;
        for (Map.Entry e : ruleTable.entrySet()) {
            TSNodeLabel ruleFragment = new TSNodeLabel("( " + (String)e.getKey() + ")", false);
            if (fragmentTableFreq.containsKey(ruleFragment)) continue;
            BigDecimal freq = new BigDecimal(((int[])e.getValue())[0]);
            fragmentTableFreq.put(ruleFragment, new BigDecimal[]{freq});
            ++var1_4;
        }
        System.out.println("Added " + (int)var1_4 + " CFG fragments");
    }

    public static void getRootFreq() {
        rootTableFreq = new Hashtable();
        for (Map.Entry<TSNodeLabel, BigDecimal[]> e : fragmentTableFreq.entrySet()) {
            Label rootLabel = e.getKey().label;
            BigDecimal freq = e.getValue()[0];
            Utility.increaseInTableBigDecimalArray(rootTableFreq, rootLabel, freq);
        }
        System.out.println("Built root freq. table: " + rootTableFreq.size() + " entries.");
    }

    public static void runEM() {
        int cycle = 0;
        BigDecimal previousLikelihood = BigDecimal.ZERO;
        do {
            Hashtable<TSNodeLabel, BigDecimal[]> newFragmentTableFreq = new Hashtable<TSNodeLabel, BigDecimal[]>();
            BigDecimal currentLikelihood = BigDecimal.ONE;
            PrintProgress.start("Iterating Training Corpus:");
            for (TSNodeLabelIndex t : trainingCorpus) {
                PrintProgress.next();
                BigDecimal prob = DOP_EM_BigDecimal.updateNewFragmentTableFreq(t, newFragmentTableFreq);
                if (prob.compareTo(BigDecimal.ZERO) == 0) {
                    System.err.println("Zero prob. + " + PrintProgress.currentIndex());
                    return;
                }
                currentLikelihood = currentLikelihood.multiply(prob);
            }
            PrintProgress.end();
            System.out.println("EM cyle " + cycle++ + ". Likelihood: " + currentLikelihood);
            if (currentLikelihood.compareTo(previousLikelihood) < 0) break;
            previousLikelihood = currentLikelihood;
            fragmentTableFreq = newFragmentTableFreq;
            DOP_EM_BigDecimal.getRootFreq();
            DOP_EM_BigDecimal.printFragmentFreq(new File(String.valueOf(workingDir) + "kernelsMUB_CFG_freq_EM_cycle" + cycle + ".txt"));
        } while (cycle != endCycle);
    }

    private static BigDecimal updateNewFragmentTableFreq(TSNodeLabelIndex t, Hashtable<TSNodeLabel, BigDecimal[]> newFragmentTableFreq) {
        NodeSetCollectorSimple setCollector = new NodeSetCollectorSimple();
        HashMap<BitSet, TSNodeLabel_BigDecimal> bitSetFreqTable = new HashMap<BitSet, TSNodeLabel_BigDecimal>();
        for (Map.Entry<TSNodeLabel, BigDecimal[]> e : fragmentTableFreq.entrySet()) {
            DOP_EM_BigDecimal.getCFGSetCoveringFragment(t, e.getKey(), e.getValue()[0], (NodeSetCollector)setCollector, bitSetFreqTable);
        }
        TSNodeLabelStructure tStructure = new TSNodeLabelStructure(t);
        ProbChart pc = new ProbChart(setCollector, tStructure, bitSetFreqTable, newFragmentTableFreq);
        BigDecimal prob = pc.getProb();
        pc.extractNewFragmentFrequencies();
        return prob;
    }

    private static void getCFGSetCoveringFragment(TSNodeLabelIndex t, TSNodeLabel fragment, BigDecimal fragmentFreq, NodeSetCollector setCollector, HashMap<BitSet, TSNodeLabel_BigDecimal> bitSetFreqTable) {
        BitSet set;
        if (t.isLexical) {
            return;
        }
        if (t.sameLabel(fragment) && DOP_EM_BigDecimal.getCFGSetCoveringFragmentNonRecursive(t, fragment, set = new BitSet()) && !set.isEmpty()) {
            setCollector.add(set);
            bitSetFreqTable.put(set, new TSNodeLabel_BigDecimal(fragment, fragmentFreq));
        }
        TSNodeLabel[] tSNodeLabelArray = t.daughters;
        int n = t.daughters.length;
        int n2 = 0;
        while (n2 < n) {
            TSNodeLabel d = tSNodeLabelArray[n2];
            TSNodeLabelIndex di = (TSNodeLabelIndex)d;
            DOP_EM_BigDecimal.getCFGSetCoveringFragment(di, fragment, fragmentFreq, setCollector, bitSetFreqTable);
            ++n2;
        }
    }

    private static boolean getCFGSetCoveringFragmentNonRecursive(TSNodeLabelIndex t, TSNodeLabel fragment, BitSet set) {
        if (t.isLexical || fragment.isTerminal()) {
            return true;
        }
        if (!t.sameDaughtersLabel(fragment)) {
            return false;
        }
        int prole = t.prole();
        int i = 0;
        while (i < prole) {
            TSNodeLabel thisDaughter = t.daughters[i];
            TSNodeLabelIndex thisDaughterIndex = (TSNodeLabelIndex)thisDaughter;
            TSNodeLabel otherDaughter = fragment.daughters[i];
            if (!DOP_EM_BigDecimal.getCFGSetCoveringFragmentNonRecursive(thisDaughterIndex, otherDaughter, set)) {
                return false;
            }
            ++i;
        }
        set.set(t.index);
        return true;
    }

    public static void main(String[] args) throws Exception {
        workingDir = new String(String.valueOf(Parameters.resultsPath) + "TSG/DOP_EM/");
        System.out.println("Working Dir: " + workingDir);
        String fragmentFileDir = String.valueOf(Parameters.resultsPath) + "TSG/TSGkernels/Wsj/KenelFragments/SemTagOff_Top/all/";
        File fragmentFile = new File(String.valueOf(fragmentFileDir) + "fragments_MUB_freq_all.txt");
        File corpusFile = new File(String.valueOf(Wsj.WsjOriginalCleanedTop) + "wsj-02-21.mrg");
        ArrayList<TSNodeLabel> treebank = TSNodeLabel.getTreebank(corpusFile);
        TSNodeLabel.removeSemanticTagsInTreebank(treebank);
        endCycle = 10;
        DOP_EM_BigDecimal.readTreeBank(treebank);
        DOP_EM_BigDecimal.readFragmentsFile(fragmentFile);
        DOP_EM_BigDecimal.addCFGfragments();
        DOP_EM_BigDecimal.getRootFreq();
        DOP_EM_BigDecimal.runEM();
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class DerivationsNode {
        ArrayList<PartialDerivation> partialDerivations = new ArrayList();
        BigDecimal totalProb = BigDecimal.ZERO;
        BigDecimal newProbMass = BigDecimal.ZERO;

        public void addDerivation(TSNodeLabel intialFragment, ArrayList<Integer> subSites, BigDecimal derivationProb) {
            this.partialDerivations.add(new PartialDerivation(intialFragment, subSites, derivationProb));
            this.totalProb = this.totalProb.add(derivationProb);
        }

        public void addProbMass(BigDecimal probMass) {
            this.newProbMass = this.newProbMass.add(probMass);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class PartialDerivation {
        TSNodeLabel intialFragment;
        ArrayList<Integer> subSites;
        BigDecimal partialDerivProb;

        public PartialDerivation(TSNodeLabel intialFragment, ArrayList<Integer> subSites, BigDecimal partialDerivProb) {
            this.intialFragment = intialFragment;
            this.subSites = subSites;
            this.partialDerivProb = partialDerivProb;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class ProbChart {
        NodeSetCollectorSimple setCollector;
        TSNodeLabelStructure t;
        int totalNodes;
        DerivationsNode[] derivationsNodes;
        NodeSetCollectorStandard[] nodesCollector;
        HashMap<BitSet, TSNodeLabel_BigDecimal> bitSetFreqTable;
        Hashtable<TSNodeLabel, BigDecimal[]> newFragmentTableFreq;

        public ProbChart(NodeSetCollectorSimple setCollector, TSNodeLabelStructure t, HashMap<BitSet, TSNodeLabel_BigDecimal> bitSetFreqTable, Hashtable<TSNodeLabel, BigDecimal[]> newFragmentTableFreq) {
            this.setCollector = setCollector;
            this.t = t;
            this.newFragmentTableFreq = newFragmentTableFreq;
            this.totalNodes = t.length;
            this.derivationsNodes = new DerivationsNode[this.totalNodes];
            this.nodesCollector = new NodeSetCollectorStandard[this.totalNodes];
            this.bitSetFreqTable = bitSetFreqTable;
        }

        public BigDecimal getProb() {
            for (BitSet bs : this.setCollector.bitSetSet) {
                int firstIndex = bs.nextSetBit(0);
                if (this.nodesCollector[firstIndex] == null) {
                    this.nodesCollector[firstIndex] = new NodeSetCollectorStandard();
                }
                this.nodesCollector[firstIndex].add(bs);
            }
            return this.getProbRecursive(0);
        }

        private BigDecimal getProbRecursive(int index) {
            if (this.derivationsNodes[index] != null) {
                return this.derivationsNodes[index].totalProb;
            }
            NodeSetCollectorStandard setCollector = this.nodesCollector[index];
            if (setCollector == null) {
                this.derivationsNodes[index].totalProb = BigDecimal.ZERO;
                return this.derivationsNodes[index].totalProb;
            }
            TSNodeLabelIndex root = this.t.structure[index];
            BigDecimal rootFreq = rootTableFreq.get(root.label)[0];
            DerivationsNode derivation = new DerivationsNode();
            for (BitSet initialSubTree : setCollector.bitSetArray) {
                ArrayList<Integer> subSitesIndexes = new ArrayList<Integer>();
                this.collectSubSites(root, initialSubTree, subSitesIndexes);
                BigDecimal partialProb = BigDecimal.ONE;
                for (int subSiteIndex : subSitesIndexes) {
                    BigDecimal subSiteProb = this.getProbRecursive(subSiteIndex);
                    if (subSiteProb.compareTo(BigDecimal.ZERO) == 0) {
                        partialProb = BigDecimal.ZERO;
                        break;
                    }
                    partialProb = partialProb.multiply(subSiteProb);
                }
                if (partialProb.compareTo(BigDecimal.ZERO) == 0) continue;
                TSNodeLabel_BigDecimal treeDouble = this.bitSetFreqTable.get(initialSubTree);
                BigDecimal initialSubTreeFreq = treeDouble.d;
                if (initialSubTreeFreq.compareTo(BigDecimal.ZERO) == 0) {
                    System.out.println();
                }
                TSNodeLabel initialFragment = treeDouble.tree;
                partialProb = partialProb.multiply(initialSubTreeFreq.divide(rootFreq, MathContext.DECIMAL128));
                derivation.addDerivation(initialFragment, subSitesIndexes, partialProb);
            }
            this.derivationsNodes[index] = derivation;
            if (this.derivationsNodes[index].totalProb.compareTo(BigDecimal.ZERO) == 0) {
                System.out.println();
            }
            return this.derivationsNodes[index].totalProb;
        }

        private void collectSubSites(TSNodeLabelIndex root, BitSet initialSubTree, ArrayList<Integer> subSitesIndexes) {
            TSNodeLabel[] tSNodeLabelArray = root.daughters;
            int n = root.daughters.length;
            int n2 = 0;
            while (n2 < n) {
                TSNodeLabel d = tSNodeLabelArray[n2];
                if (d.isLexical) {
                    return;
                }
                TSNodeLabelIndex di = (TSNodeLabelIndex)d;
                int index = di.index;
                if (!initialSubTree.get(index)) {
                    subSitesIndexes.add(index);
                } else {
                    this.collectSubSites(di, initialSubTree, subSitesIndexes);
                }
                ++n2;
            }
        }

        public void extractNewFragmentFrequencies() {
            this.derivationsNodes[0].newProbMass = BigDecimal.ONE;
            this.extractNewFragmentFrequenciesRecursive(0);
        }

        private void extractNewFragmentFrequenciesRecursive(int index) {
            DerivationsNode derivations = this.derivationsNodes[index];
            BigDecimal totalMass = derivations.newProbMass;
            BigDecimal derivationsTotProb = derivations.totalProb;
            TreeSet<Integer> allEncounteredSubSites = new TreeSet<Integer>();
            for (PartialDerivation pd : derivations.partialDerivations) {
                TSNodeLabel initialFragment = pd.intialFragment;
                BigDecimal pdProb = pd.partialDerivProb;
                BigDecimal partialMass = pdProb.divide(derivationsTotProb, MathContext.DECIMAL128).multiply(totalMass);
                Utility.increaseInTableBigDecimalArray(this.newFragmentTableFreq, initialFragment, partialMass);
                for (int subSite : pd.subSites) {
                    allEncounteredSubSites.add(subSite);
                    DerivationsNode subSiteDerivation = this.derivationsNodes[subSite];
                    subSiteDerivation.addProbMass(partialMass);
                }
            }
            Iterator<PartialDerivation> iterator = allEncounteredSubSites.iterator();
            while (iterator.hasNext()) {
                int subSite = (Integer)((Object)iterator.next());
                this.extractNewFragmentFrequenciesRecursive(subSite);
            }
        }
    }

    static class TSNodeLabel_BigDecimal {
        TSNodeLabel tree;
        BigDecimal d;

        public TSNodeLabel_BigDecimal(TSNodeLabel tree, BigDecimal d) {
            this.tree = tree;
            this.d = d;
        }
    }
}

