/*
 * Decompiled with CFR 0.152.
 */
package tsg;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;
import java.util.Scanner;
import kernels.NodeSetCollector;
import kernels.NodeSetCollectorSimple;
import kernels.NodeSetCollectorStandard;
import settings.Parameters;
import tsg.Label;
import tsg.TSNodeLabel;
import tsg.TSNodeLabelIndex;
import tsg.TSNodeLabelStructure;
import tsg.corpora.Wsj;
import util.FileUtil;
import util.PrintProgress;
import util.Utility;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class DOP_IO_Log {
    public static Hashtable<TSNodeLabel, double[]> fragmentTableLogFreq;
    public static Hashtable<Label, double[]> rootTableLogFreq;
    public static ArrayList<TSNodeLabelIndex> trainingCorpus;
    public static int minFreqFragment;
    public static int endCycle;
    public static double deltaLogLikelihoodThreshold;
    public static String workingDir;

    static {
        minFreqFragment = 0;
        endCycle = 10;
        deltaLogLikelihoodThreshold = 1.0E-5;
    }

    public static void readFragmentsFile(File fragmentFile) throws Exception {
        fragmentTableLogFreq = new Hashtable();
        Scanner scan = FileUtil.getScanner(fragmentFile);
        int countFragments = 0;
        int discarded = 0;
        while (scan.hasNextLine()) {
            String line = scan.nextLine();
            if (line.equals("")) continue;
            ++countFragments;
            String[] fragmentFreq = line.split("\t");
            String fragmentString = fragmentFreq[0];
            int freq = Integer.parseInt(fragmentFreq[1]);
            if (freq < minFreqFragment) {
                ++discarded;
                continue;
            }
            double logFreq = Math.log(freq);
            TSNodeLabel fragment = new TSNodeLabel(fragmentString, false);
            fragmentTableLogFreq.put(fragment, new double[]{logFreq});
        }
        System.out.println("Read " + countFragments + " fragments");
        System.out.println("Discarded " + discarded + " (freq < " + minFreqFragment + ")");
        scan.close();
    }

    private static void readTreeBank(ArrayList<TSNodeLabel> treebank) throws Exception {
        trainingCorpus = new ArrayList();
        for (TSNodeLabel t : treebank) {
            trainingCorpus.add(new TSNodeLabelIndex(t));
        }
    }

    public static void printFragmentFreq(File outputFile) {
        PrintWriter pw = FileUtil.getPrintWriter(outputFile);
        for (Map.Entry<TSNodeLabel, double[]> e : fragmentTableLogFreq.entrySet()) {
            String fragmentString = e.getKey().toString(false, true);
            double freq = Math.exp(e.getValue()[0]);
            if (freq == 0.0) continue;
            pw.println(String.valueOf(fragmentString) + "\t" + freq);
        }
        pw.close();
    }

    /*
     * WARNING - void declaration
     */
    public static void addCFGfragments() throws Exception {
        void var1_4;
        Hashtable ruleTable = new Hashtable();
        for (TSNodeLabel tSNodeLabel : trainingCorpus) {
            ArrayList<TSNodeLabel> nodes = tSNodeLabel.collectAllNodes();
            for (TSNodeLabel n : nodes) {
                if (n.isLexical) continue;
                String rule = n.cfgRule();
                Utility.increaseInTableInt(ruleTable, rule);
            }
        }
        System.out.println("Read " + ruleTable.size() + " CFG fragments");
        boolean bl = false;
        for (Map.Entry e : ruleTable.entrySet()) {
            TSNodeLabel ruleFragment = new TSNodeLabel("( " + (String)e.getKey() + ")", false);
            if (fragmentTableLogFreq.containsKey(ruleFragment)) continue;
            double logFreq = Math.log(((int[])e.getValue())[0]);
            fragmentTableLogFreq.put(ruleFragment, new double[]{logFreq});
            ++var1_4;
        }
        System.out.println("Added " + (int)var1_4 + " CFG fragments");
    }

    public static void getRootFreq() {
        rootTableLogFreq = new Hashtable();
        for (Map.Entry<TSNodeLabel, double[]> e : fragmentTableLogFreq.entrySet()) {
            Label rootLabel = e.getKey().label;
            double logFreq = e.getValue()[0];
            Utility.increaseInTableDoubleLogArray(rootTableLogFreq, rootLabel, logFreq);
        }
        System.out.println("Built root freq. table: " + rootTableLogFreq.size() + " entries.");
    }

    public static void runEM() {
        int cycle = 0;
        double previousLogLikelihood = Double.NEGATIVE_INFINITY;
        DOP_IO_Log.printFragmentFreq(new File(String.valueOf(workingDir) + "kernelsMUB_CFG_freq_EM_cycle_" + cycle + ".txt"));
        do {
            Hashtable<TSNodeLabel, double[]> newFragmentTableLogFreq = new Hashtable<TSNodeLabel, double[]>();
            double currentLogLikelihood = 0.0;
            PrintProgress.start("Iterating Training Corpus:");
            int index = -1;
            for (TSNodeLabelIndex t : trainingCorpus) {
                ++index;
                PrintProgress.next();
                double logInsideProb = DOP_IO_Log.updateNewFragmentTableFreq(t, newFragmentTableLogFreq);
                currentLogLikelihood += logInsideProb;
            }
            PrintProgress.end();
            double deltaLogLikelihood = currentLogLikelihood - previousLogLikelihood;
            System.out.println("EM cyle " + ++cycle + ". Log-Likelihood: " + currentLogLikelihood + " Delta: " + deltaLogLikelihood);
            if (deltaLogLikelihood <= deltaLogLikelihoodThreshold) break;
            previousLogLikelihood = currentLogLikelihood;
            fragmentTableLogFreq = newFragmentTableLogFreq;
            DOP_IO_Log.getRootFreq();
            DOP_IO_Log.printFragmentFreq(new File(String.valueOf(workingDir) + "kernelsMUB_CFG_freq_EM_cycle_" + cycle + ".txt"));
        } while (cycle != endCycle);
    }

    private static double updateNewFragmentTableFreq(TSNodeLabelIndex t, Hashtable<TSNodeLabel, double[]> newFragmentTableFreq) {
        NodeSetCollectorSimple setCollector = new NodeSetCollectorSimple();
        HashMap<BitSet, TSNodeLabel_Double> bitSetFreqTable = new HashMap<BitSet, TSNodeLabel_Double>();
        for (Map.Entry<TSNodeLabel, double[]> e : fragmentTableLogFreq.entrySet()) {
            DOP_IO_Log.getCFGSetCoveringFragment(t, e.getKey(), e.getValue()[0], (NodeSetCollector)setCollector, bitSetFreqTable);
        }
        TSNodeLabelStructure tStructure = new TSNodeLabelStructure(t);
        IOChart pc = new IOChart(setCollector, tStructure, bitSetFreqTable);
        pc.buildChart();
        pc.updateNewFragmentLogFreq(newFragmentTableFreq);
        return pc.getInsideLogProb();
    }

    private static void getCFGSetCoveringFragment(TSNodeLabelIndex t, TSNodeLabel fragment, double fragmentFreq, NodeSetCollector setCollector, HashMap<BitSet, TSNodeLabel_Double> bitSetFreqLogTable) {
        BitSet set;
        if (t.isLexical) {
            return;
        }
        if (t.sameLabel(fragment) && DOP_IO_Log.getCFGSetCoveringFragmentNonRecursive(t, fragment, set = new BitSet()) && !set.isEmpty()) {
            setCollector.add(set);
            bitSetFreqLogTable.put(set, new TSNodeLabel_Double(fragment, fragmentFreq));
        }
        TSNodeLabel[] tSNodeLabelArray = t.daughters;
        int n = t.daughters.length;
        int n2 = 0;
        while (n2 < n) {
            TSNodeLabel d = tSNodeLabelArray[n2];
            TSNodeLabelIndex di = (TSNodeLabelIndex)d;
            DOP_IO_Log.getCFGSetCoveringFragment(di, fragment, fragmentFreq, setCollector, bitSetFreqLogTable);
            ++n2;
        }
    }

    private static boolean getCFGSetCoveringFragmentNonRecursive(TSNodeLabelIndex t, TSNodeLabel fragment, BitSet set) {
        if (t.isLexical || fragment.isTerminal()) {
            return true;
        }
        if (!t.sameDaughtersLabel(fragment)) {
            return false;
        }
        int prole = t.prole();
        int i = 0;
        while (i < prole) {
            TSNodeLabel thisDaughter = t.daughters[i];
            TSNodeLabelIndex thisDaughterIndex = (TSNodeLabelIndex)thisDaughter;
            TSNodeLabel otherDaughter = fragment.daughters[i];
            if (!DOP_IO_Log.getCFGSetCoveringFragmentNonRecursive(thisDaughterIndex, otherDaughter, set)) {
                return false;
            }
            ++i;
        }
        set.set(t.index);
        return true;
    }

    public static void main(String[] args) throws Exception {
        workingDir = new String(String.valueOf(Parameters.resultsPath) + "TSG/DOP_IO_SMALL/");
        System.out.println("Working Dir: " + workingDir);
        String fragmentFileDir = String.valueOf(Parameters.resultsPath) + "TSG/TSGkernels/Wsj/KenelFragments/SemTagOff_Top/all/correctCount/";
        File fragmentFile = new File(String.valueOf(fragmentFileDir) + "fragments_MUB_freq_all_correctCount.txt");
        File corpusFile = new File(String.valueOf(Wsj.WsjOriginalCleanedTop) + "wsj-02-21.mrg");
        ArrayList<TSNodeLabel> treebank = TSNodeLabel.getTreebank(corpusFile);
        TSNodeLabel.removeSemanticTagsInTreebank(treebank);
        minFreqFragment = 5;
        endCycle = 50;
        deltaLogLikelihoodThreshold = 0.0;
        DOP_IO_Log.readTreeBank(treebank);
        DOP_IO_Log.readFragmentsFile(fragmentFile);
        DOP_IO_Log.addCFGfragments();
        DOP_IO_Log.getRootFreq();
        DOP_IO_Log.runEM();
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class IOChart {
        NodeSetCollectorSimple setCollector;
        TSNodeLabelStructure t;
        int totalNodes;
        IOSubNode[] IOSubNodesChart;
        NodeSetCollectorStandard[] nodesCollector;
        HashMap<BitSet, TSNodeLabel_Double> bitSetFreqTable;

        public IOChart(NodeSetCollectorSimple setCollector, TSNodeLabelStructure t, HashMap<BitSet, TSNodeLabel_Double> bitSetFreqTable) {
            this.setCollector = setCollector;
            this.t = t;
            this.totalNodes = t.length;
            this.IOSubNodesChart = new IOSubNode[this.totalNodes];
            this.nodesCollector = new NodeSetCollectorStandard[this.totalNodes];
            this.bitSetFreqTable = bitSetFreqTable;
        }

        public void buildChart() {
            for (BitSet bs : this.setCollector.bitSetSet) {
                int firstIndex = bs.nextSetBit(0);
                if (this.nodesCollector[firstIndex] == null) {
                    this.nodesCollector[firstIndex] = new NodeSetCollectorStandard();
                }
                this.nodesCollector[firstIndex].add(bs);
            }
            this.buildInsideProb();
            this.buildOutsideProb();
        }

        public double getInsideLogProb() {
            return this.IOSubNodesChart[0].insideLogProb;
        }

        private void buildInsideProb() {
            int i = this.totalNodes - 1;
            while (i >= 0) {
                if (!this.t.structure[i].isLexical) {
                    this.buildInsideProb(i);
                }
                --i;
            }
        }

        private void buildInsideProb(int index) {
            NodeSetCollectorStandard setCollector = this.nodesCollector[index];
            TSNodeLabelIndex root = this.t.structure[index];
            double rootLogFreq = rootTableLogFreq.get(root.label)[0];
            IOSubNode IOSubNodeIndex = new IOSubNode();
            for (BitSet initialSubTree : setCollector.bitSetArray) {
                ArrayList<Integer> subSitesIndexes = new ArrayList<Integer>();
                this.collectSubSites(root, initialSubTree, subSitesIndexes);
                double initialSubTreeInsideLogProb = 0.0;
                for (int subSiteIndex : subSitesIndexes) {
                    double subSiteInsideLogProb = this.IOSubNodesChart[subSiteIndex].insideLogProb;
                    initialSubTreeInsideLogProb += subSiteInsideLogProb;
                }
                TSNodeLabel_Double treeDouble = this.bitSetFreqTable.get(initialSubTree);
                double initialSubTreeFreq = treeDouble.logFreq;
                TSNodeLabel initialFragment = treeDouble.fragment;
                IOSubNodeIndex.addDerivation(initialFragment, subSitesIndexes, initialSubTreeInsideLogProb += initialSubTreeFreq - rootLogFreq);
            }
            this.IOSubNodesChart[index] = IOSubNodeIndex;
        }

        private void collectSubSites(TSNodeLabelIndex root, BitSet initialSubTree, ArrayList<Integer> subSitesIndexes) {
            TSNodeLabel[] tSNodeLabelArray = root.daughters;
            int n = root.daughters.length;
            int n2 = 0;
            while (n2 < n) {
                TSNodeLabel d = tSNodeLabelArray[n2];
                if (d.isLexical) {
                    return;
                }
                TSNodeLabelIndex di = (TSNodeLabelIndex)d;
                int index = di.index;
                if (!initialSubTree.get(index)) {
                    subSitesIndexes.add(index);
                } else {
                    this.collectSubSites(di, initialSubTree, subSitesIndexes);
                }
                ++n2;
            }
        }

        private void buildOutsideProb() {
            this.IOSubNodesChart[0].outsideLogProb = 0.0;
            int i = 0;
            while (i < this.totalNodes) {
                if (!this.t.structure[i].isLexical) {
                    this.buildOutsideProb(i);
                }
                ++i;
            }
        }

        private void buildOutsideProb(int index) {
            IOSubNode IOSubNodeIndex = this.IOSubNodesChart[index];
            double ousideLogProb = IOSubNodeIndex.outsideLogProb;
            for (InitFragmentDerivations ifd : IOSubNodeIndex.partialDerivations) {
                double initialFragmInsideLogProb = ifd.initFragmentInsideLogProb;
                for (int subSite : ifd.subSites) {
                    double subSiteInsideLogProb = this.IOSubNodesChart[subSite].insideLogProb;
                    double outsideLogProbToAdd = ousideLogProb + initialFragmInsideLogProb - subSiteInsideLogProb;
                    IOSubNode subSiteDerivation = this.IOSubNodesChart[subSite];
                    subSiteDerivation.addOutisdeProb(outsideLogProbToAdd);
                }
            }
        }

        public void updateNewFragmentLogFreq(Hashtable<TSNodeLabel, double[]> newFragmentTableFreq) {
            double insideLogProbTOP = this.IOSubNodesChart[0].insideLogProb;
            int i = 0;
            while (i < this.totalNodes) {
                if (!this.t.structure[i].isLexical) {
                    IOSubNode IOSubNodeIndex = this.IOSubNodesChart[i];
                    double ousideLogProb = IOSubNodeIndex.outsideLogProb;
                    for (InitFragmentDerivations ifd : IOSubNodeIndex.partialDerivations) {
                        TSNodeLabel initialFragment = ifd.intialFragment;
                        double initialFragmInsideLogProb = ifd.initFragmentInsideLogProb;
                        double newFreqToAdd = ousideLogProb + initialFragmInsideLogProb - insideLogProbTOP;
                        Utility.increaseInTableDoubleLogArray(newFragmentTableFreq, initialFragment, newFreqToAdd);
                    }
                }
                ++i;
            }
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class IOSubNode {
        ArrayList<InitFragmentDerivations> partialDerivations = new ArrayList();
        double insideLogProb = Double.NEGATIVE_INFINITY;
        double outsideLogProb = Double.NEGATIVE_INFINITY;

        public void addDerivation(TSNodeLabel intialFragment, ArrayList<Integer> subSites, double initFragmInsideLogProb) {
            this.partialDerivations.add(new InitFragmentDerivations(intialFragment, subSites, initFragmInsideLogProb));
            this.insideLogProb = Utility.logSum(this.insideLogProb, initFragmInsideLogProb);
        }

        public void addOutisdeProb(double outsideLogProbToAdd) {
            this.outsideLogProb = Utility.logSum(this.outsideLogProb, outsideLogProbToAdd);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class InitFragmentDerivations {
        TSNodeLabel intialFragment;
        ArrayList<Integer> subSites;
        double initFragmentInsideLogProb;

        public InitFragmentDerivations(TSNodeLabel intialFragment, ArrayList<Integer> subSites, double initFragmentInsideLogProb) {
            this.intialFragment = intialFragment;
            this.subSites = subSites;
            this.initFragmentInsideLogProb = initFragmentInsideLogProb;
        }
    }

    static class TSNodeLabel_Double {
        TSNodeLabel fragment;
        double logFreq;

        public TSNodeLabel_Double(TSNodeLabel fragment, double logFreq) {
            this.fragment = fragment;
            this.logFreq = logFreq;
        }
    }
}

