/*
 * Decompiled with CFR 0.152.
 */
package tsg.LTSG;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import settings.Parameters;
import tsg.LTSG.LDQQueue;
import tsg.LTSG.LTSG;
import tsg.LTSG.LexDerivationQueue;
import tsg.LTSG.LexicalDerivation;
import tsg.TSNode;
import tsg.corpora.ConstCorpus;
import tsg.parser.Parser;
import util.FileUtil;
import util.PrintProgress;
import util.Utility;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LTSG_EM_nBest
extends LTSG {
    Hashtable<String, Double> template_prob;
    Hashtable<String, Double> root_prob;

    private void extractRootProb() {
        this.root_prob = new Hashtable();
        for (Map.Entry<String, Double> e : this.template_prob.entrySet()) {
            String root = TSNode.get_unique_root(e.getKey());
            Double count = e.getValue();
            Utility.increaseStringDouble(this.root_prob, root, count);
        }
    }

    private void initializeUniformTreeProb() {
        this.extractAllLexTrees();
        this.template_prob = new Hashtable();
        for (Map.Entry e : this.template_freq.entrySet()) {
            String tree = (String)e.getKey();
            double count = ((Integer)e.getValue()).intValue();
            this.template_prob.put(tree, count);
        }
        this.normalizeTemplateProb();
    }

    private void extractUniformTreeProbExcluding(TSNode treeToExclude) {
        this.decreaseElementayTreesFrom(treeToExclude);
        this.template_prob = new Hashtable();
        for (Map.Entry e : this.template_freq.entrySet()) {
            String tree = (String)e.getKey();
            double count = ((Integer)e.getValue()).intValue();
            this.template_prob.put(tree, count);
        }
        this.normalizeTemplateProb();
        this.increaseElementaryTreeesFrom(treeToExclude);
    }

    private void normalizeTemplateProb() {
        this.extractRootProb();
        for (Map.Entry<String, Double> e : this.template_prob.entrySet()) {
            String tree = e.getKey();
            double treeProb = e.getValue();
            String root = TSNode.get_unique_root(tree);
            double rootCount = this.root_prob.get(root);
            double newTreeProb = treeProb / rootCount;
            e.setValue(newTreeProb);
        }
    }

    public void reportMaxLexicalDerivations() {
        long maxDerivation = 0L;
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            long lexDerivations = inputTree.lexDerivations();
            if (lexDerivations <= maxDerivation) continue;
            maxDerivation = lexDerivations;
        }
        FileUtil.append("Max number of derivation per tree: " + maxDerivation, Parameters.logFile);
    }

    public void checkEMCoverage() {
        ConstCorpus originalTrainingCorpus = null;
        if (Parameters.posTagConversion) {
            originalTrainingCorpus = Parameters.trainingCorpus.deepClone();
            Parameters.trainingCorpus.makePosTagsLexicon();
        }
        this.extractAllLexTrees();
        int covered = 0;
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            this.decreaseElementayTreesFrom(inputTree);
            if (this.checkCoverageTree(inputTree)) {
                ++covered;
            }
            this.increaseElementaryTreeesFrom(inputTree);
        }
        if (Parameters.posTagConversion) {
            Parameters.trainingCorpus.unMakePosTagsLexicon(originalTrainingCorpus);
        }
        int treebankSize = Parameters.trainingCorpus.size();
        float ratio = (float)covered / (float)treebankSize;
        System.out.println("Covered tree in training corpus: " + covered + " / " + treebankSize + " (" + ratio + ")");
    }

    private boolean checkCoverageTree(TSNode inputTree) {
        ArrayList<ArrayList<TSNode>> levelsSubTrees = inputTree.getNodesInDepthLevels();
        IdentityHashMap<TSNode, Boolean> coveredSubTrees = new IdentityHashMap<TSNode, Boolean>();
        for (ArrayList arrayList : levelsSubTrees) {
            for (TSNode TN : arrayList) {
                if (TN.isLexical || TN.isUniqueDaughter()) continue;
                boolean covered = false;
                if (TN.isPrelexical() || !TN.hasMoreThanNBranching(1)) {
                    if (this.template_freq.keySet().contains(TN.toString(false, true))) {
                        covered = true;
                    }
                } else {
                    List<TSNode> lexicon = TN.collectLexicalItems();
                    for (TSNode anchor : lexicon) {
                        TN.markHeadPathToAnchor(anchor);
                        TSNode lexTemplate = TN.lexicalizedTreeCopy();
                        lexTemplate.applyAllConversions();
                        List<TSNode> subSites = TN.collectSubstitutionSites();
                        TN.unmarkHeadPathToAnchor(anchor);
                        if (!this.template_freq.keySet().contains(lexTemplate.toString(false, true))) continue;
                        covered = true;
                        for (TSNode SS : subSites) {
                            if (coveredSubTrees.keySet().contains(SS)) continue;
                            covered = false;
                            break;
                        }
                        if (covered) break;
                    }
                }
                if (!covered) continue;
                coveredSubTrees.put(TN, true);
            }
        }
        return coveredSubTrees.keySet().contains(inputTree);
    }

    public void EMHeldOutAlgorithm() {
        ConstCorpus originalTrainingCorpus = null;
        if (Parameters.posTagConversion) {
            originalTrainingCorpus = Parameters.trainingCorpus.deepClone();
            Parameters.trainingCorpus.makePosTagsLexicon();
        }
        this.extractAllLexTrees();
        PrintProgress.start("Estimating EM param. sentence:");
        block0: for (TSNode observedTree : Parameters.trainingCorpus.treeBank) {
            PrintProgress.next();
            this.extractUniformTreeProbExcluding(observedTree);
            int cycle = 0;
            double previousLikelihood = -1.7976931348623157E308;
            double delta = 0.0;
            do {
                ++cycle;
                observedTree.removeHeadAnnotations();
                Hashtable<String, Double> new_template_prob = new Hashtable<String, Double>();
                Double actualLikelihood = this.getNBestHeadAnnotations(observedTree, new_template_prob);
                if (actualLikelihood == null) {
                    FileUtil.append("No coverage for " + observedTree, Parameters.logFile);
                    observedTree.assignRandomHeads();
                    continue block0;
                }
                delta = actualLikelihood - previousLikelihood;
                previousLikelihood = actualLikelihood;
                this.template_prob = new_template_prob;
                this.normalizeTemplateProb();
            } while (delta > 0.0 && delta > Parameters.EM_deltaThreshold && cycle < Parameters.EM_maxCycle);
        }
        PrintProgress.end();
        if (Parameters.posTagConversion) {
            Parameters.trainingCorpus.unMakePosTagsLexicon(originalTrainingCorpus);
        }
    }

    public void EMalgorithm() {
        ConstCorpus originalTrainingCorpus = null;
        if (Parameters.posTagConversion) {
            originalTrainingCorpus = Parameters.trainingCorpus.deepClone();
            Parameters.trainingCorpus.makePosTagsLexicon();
        }
        this.initializeUniformTreeProb();
        int cycle = 0;
        double previousLikelihood = -1.7976931348623157E308;
        double delta = 0.0;
        do {
            double actualLikelihood = this.emSteps();
            delta = actualLikelihood - previousLikelihood;
            previousLikelihood = actualLikelihood;
            String line = "EM cycle: " + ++cycle + "\tActual LikeLihood: " + actualLikelihood + "\tDelta LikeLihood: " + delta;
            FileUtil.append(line, Parameters.logFile);
        } while (delta > 0.0 && delta > Parameters.EM_deltaThreshold && cycle < Parameters.EM_maxCycle);
        if (Parameters.posTagConversion) {
            Parameters.trainingCorpus.unMakePosTagsLexicon(originalTrainingCorpus);
        }
    }

    private double emSteps() {
        double likelihood = 0.0;
        Hashtable<String, Double> new_template_prob = new Hashtable<String, Double>();
        int treeIndex = 0;
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            ++treeIndex;
            inputTree.removeHeadAnnotations();
            likelihood += this.getNBestHeadAnnotations(inputTree, new_template_prob).doubleValue();
            if (!inputTree.hasWrongHeadAssignment()) continue;
            System.err.println("Wrong Head Assignment: " + inputTree.toString(true, true));
        }
        this.template_prob = new_template_prob;
        this.normalizeTemplateProb();
        return likelihood;
    }

    private void assignHeadAnnotation(TSNode inputTree, IdentityHashMap<TSNode, LexicalDerivation[]> nBestDerivationsSubTrees, int index) {
        LexicalDerivation lexD = nBestDerivationsSubTrees.get(inputTree)[index];
        int[] indexes = lexD.subSiteDerivationsIndexes;
        TSNode anchor = lexD.anchor;
        TSNode lexiconPath = anchor.parent;
        TSNode lexiconPathDaughter = anchor;
        int substitutionIndex = 0;
        while (lexiconPath != inputTree) {
            lexiconPath.headMarked = true;
            lexiconPath = lexiconPath.parent;
            lexiconPathDaughter = lexiconPathDaughter.parent;
            TSNode[] tSNodeArray = lexiconPath.daughters;
            int n = lexiconPath.daughters.length;
            int n2 = 0;
            while (n2 < n) {
                TSNode D = tSNodeArray[n2];
                if (D != lexiconPathDaughter) {
                    this.assignHeadAnnotation(D, nBestDerivationsSubTrees, indexes[substitutionIndex++]);
                }
                ++n2;
            }
            if (lexiconPath != inputTree) continue;
        }
    }

    private LexicalDerivation[] getNBestTable(TSNode TN, IdentityHashMap<TSNode, LexicalDerivation[]> nBestDerivationsSubTrees) {
        TSNode[] lexicalsTN = TN.collectTerminals().toArray(new TSNode[0]);
        double[] lexTreeWeights = new double[lexicalsTN.length];
        LexicalDerivation[][][] lex_SubSite_NBestTable = new LexicalDerivation[lexicalsTN.length][][];
        int nonNullLex = 0;
        int lexicalIndex = -1;
        TSNode[] tSNodeArray = lexicalsTN;
        int n = lexicalsTN.length;
        int n2 = 0;
        while (n2 < n) {
            TSNode anchor = tSNodeArray[n2];
            ++lexicalIndex;
            int substitutionSites = 0;
            TSNode lexiconPath = anchor.parent;
            while (lexiconPath != TN) {
                substitutionSites += lexiconPath.prole() - 1;
                lexiconPath.headMarked = true;
                lexiconPath = lexiconPath.parent;
            }
            substitutionSites += lexiconPath.prole() - 1;
            TSNode lexicalTree = TN.lexicalizedTreeCopy();
            lexicalTree.applyAllConversions();
            Double weight = this.template_prob.get(lexicalTree.toString(false, true));
            if (weight == null) {
                TN.unmarkHeadPathToAnchor(anchor);
                lexTreeWeights[lexicalIndex] = -1.0;
                lex_SubSite_NBestTable[lexicalIndex] = null;
            } else {
                double logWeight;
                lexTreeWeights[lexicalIndex] = logWeight = Math.log(weight);
                lexiconPath = anchor.parent;
                lexiconPath.headMarked = false;
                TSNode lexiconPathDaughter = anchor;
                lex_SubSite_NBestTable[lexicalIndex] = new LexicalDerivation[substitutionSites][];
                int substitutionIndex = 0;
                boolean nullSubSide = false;
                do {
                    lexiconPath = lexiconPath.parent;
                    lexiconPathDaughter = lexiconPathDaughter.parent;
                    lexiconPath.headMarked = false;
                    TSNode[] tSNodeArray2 = lexiconPath.daughters;
                    int n3 = lexiconPath.daughters.length;
                    int n4 = 0;
                    while (n4 < n3) {
                        TSNode D = tSNodeArray2[n4];
                        if (D != lexiconPathDaughter) {
                            lex_SubSite_NBestTable[lexicalIndex][substitutionIndex] = nBestDerivationsSubTrees.get(D);
                            if (lex_SubSite_NBestTable[lexicalIndex][substitutionIndex] == null) {
                                nullSubSide = true;
                            }
                            ++substitutionIndex;
                        }
                        ++n4;
                    }
                } while (lexiconPath != TN);
                if (nullSubSide) {
                    lexTreeWeights[lexicalIndex] = -1.0;
                    lex_SubSite_NBestTable[lexicalIndex] = null;
                } else {
                    ++nonNullLex;
                }
            }
            ++n2;
        }
        return LTSG_EM_nBest.computeNBestTable(nBestDerivationsSubTrees, lex_SubSite_NBestTable, nonNullLex, lexicalsTN, lexTreeWeights);
    }

    private static LexicalDerivation[] computeNBestTable(IdentityHashMap<TSNode, LexicalDerivation[]> nBestDerivationsSubTrees, LexicalDerivation[][][] lex_SubSite_NBestTable, int nonNullLex, TSNode[] lexicalsTN, double[] lexTreeWeights) {
        if (nonNullLex == 0) {
            return null;
        }
        LexDerivationQueue[] lexQueues = new LexDerivationQueue[lexicalsTN.length];
        int totalCombination = 0;
        int i = 0;
        while (i < lexicalsTN.length) {
            if (lex_SubSite_NBestTable[i] != null) {
                lexQueues[i] = new LexDerivationQueue(lex_SubSite_NBestTable[i], lexicalsTN[i], i, lexTreeWeights[i]);
                totalCombination += lexQueues[i].combinations;
            }
            ++i;
        }
        LDQQueue superQueue = new LDQQueue(lexQueues);
        int nBestTableSize = Math.min(Parameters.EM_nBest, totalCombination);
        if (nBestTableSize < 0) {
            nBestTableSize = Parameters.EM_nBest;
        }
        LexicalDerivation[] nBestTable = new LexicalDerivation[nBestTableSize];
        nBestTable[0] = superQueue.pollFirst();
        int i2 = 1;
        while (i2 < nBestTableSize) {
            nBestTable[i2] = superQueue.addNeighboursAndPoll();
            ++i2;
        }
        return nBestTable;
    }

    private Double getNBestHeadAnnotations(TSNode inputTree, Hashtable<String, Double> new_template_prob) {
        ArrayList<ArrayList<TSNode>> levelsSubTrees = inputTree.getNodesInDepthLevels();
        IdentityHashMap<TSNode, LexicalDerivation[]> nBestDerivationsSubTrees = new IdentityHashMap<TSNode, LexicalDerivation[]>();
        for (ArrayList arrayList : levelsSubTrees) {
            for (TSNode TN : arrayList) {
                LexicalDerivation[] nBestTable;
                if (TN.isLexical || TN.isUniqueDaughter()) continue;
                if (TN.isPrelexical() || !TN.hasMoreThanNBranching(1)) {
                    Double weight = this.template_prob.get(TN.toString(false, true));
                    if (weight == null) continue;
                    double logWeight = Math.log(weight);
                    nBestTable = new LexicalDerivation[]{new LexicalDerivation(TN.getAnchor(), 0, logWeight, null, null)};
                } else {
                    nBestTable = this.getNBestTable(TN, nBestDerivationsSubTrees);
                }
                nBestDerivationsSubTrees.put(TN, nBestTable);
            }
        }
        double d = 0.0;
        LexicalDerivation[] nBestTable = (LexicalDerivation[])nBestDerivationsSubTrees.get(inputTree);
        if (nBestTable == null) {
            return null;
        }
        int i = 0;
        while (i < nBestTable.length) {
            d += Math.exp(nBestTable[i].logDerivationProb);
            ++i;
        }
        --i;
        while (i > -1) {
            inputTree.removeHeadAnnotations();
            this.assignHeadAnnotation(inputTree, nBestDerivationsSubTrees, i);
            List<TSNode> eTrees = inputTree.lexicalizedTreesFromHeadAnnotation();
            double newWeight = Math.exp(nBestTable[i].logDerivationProb) / d;
            for (TSNode lexTree : eTrees) {
                lexTree.applyAllConversions();
                Utility.increaseStringDouble(new_template_prob, lexTree.toString(false, true), newWeight);
            }
            --i;
        }
        return Math.log(d);
    }

    public static int[][] bestNcombinations(Double[][] substitutionSiteWeightLists, int n) {
        int elements = substitutionSiteWeightLists.length;
        int[] indexes = new int[elements];
        Arrays.fill(indexes, 1);
        int indexMax = 0;
        while (Utility.product(indexes) < n && indexMax != -1) {
            double max = 0.0;
            indexMax = -1;
            int i = 0;
            while (i < elements) {
                double nextWeightElementDouble;
                Double nextWeightElement = substitutionSiteWeightLists[i][indexes[i]];
                if (nextWeightElement != null && (nextWeightElementDouble = nextWeightElement.doubleValue()) > max) {
                    max = nextWeightElementDouble;
                    indexMax = i;
                }
                ++i;
            }
            if (indexMax == -1) continue;
            int n2 = indexMax;
            indexes[n2] = indexes[n2] + 1;
        }
        return Utility.combinations(indexes);
    }

    public static void coverage() {
        Parameters.setDefaultParam();
        Parameters.lengthLimitTraining = 40;
        Parameters.posTagConversion = true;
        Parameters.spineConversion = true;
        Parameters.LTSGtype = "EM";
        Parameters.outputPath = "/home/fsangati/PROJECTS/TSG/RESULTS/LTSG/" + Parameters.LTSGtype + "/";
        LTSG_EM_nBest Grammar2 = new LTSG_EM_nBest();
        Grammar2.checkEMCoverage();
    }

    public static void EmStandard(String[] args) {
        Parameters.setDefaultParam();
        Parameters.lengthLimitTraining = 10;
        Parameters.lengthLimitTest = 10;
        Parameters.LTSGtype = "EM";
        Parameters.outputPath = "/home/fsangati/PROJECTS/TSG/RESULTS/LTSG/" + Parameters.LTSGtype + "/";
        Parameters.EM_nBest = 400;
        Parameters.EM_deltaThreshold = 0.1;
        Parameters.EM_maxCycle = Integer.MAX_VALUE;
        Parameters.parserName = "bitPar";
        Parameters.nBest = 1;
        Parameters.cachingActive = false;
        Parameters.posTagConversion = false;
        Parameters.spineConversion = false;
        LTSG_EM_nBest Grammar2 = new LTSG_EM_nBest();
        Grammar2.reportMaxLexicalDerivations();
        Grammar2.EMalgorithm();
        Grammar2.readTreesFromCorpus();
        Grammar2.printTemplatesToFile();
        Grammar2.treatTreeBank();
        Grammar2.toPCFG();
        Parameters.printTrainingCorpusToFile();
        Grammar2.printLexiconAndGrammarFiles();
        new Parser(Grammar2);
    }

    public static void EMHeldOut() {
        Parameters.setDefaultParam();
        Parameters.smoothing = false;
        Parameters.LTSGtype = "EM";
        Parameters.outputPath = "/home/fsangati/PROJECTS/TSG/RESULTS/LTSG/" + Parameters.LTSGtype + "/";
        Parameters.EM_nBest = 100;
        Parameters.EM_deltaThreshold = 0.1;
        Parameters.EM_maxCycle = 1;
        Parameters.parserName = "bitPar";
        Parameters.nBest = 1;
        Parameters.cachingActive = false;
        Parameters.posTagConversion = false;
        Parameters.spineConversion = true;
        LTSG_EM_nBest Grammar2 = new LTSG_EM_nBest();
        Grammar2.reportMaxLexicalDerivations();
        Grammar2.EMHeldOutAlgorithm();
        Grammar2.readTreesFromCorpus();
        Grammar2.printTemplatesToFile();
        Grammar2.treatTreeBank();
        Grammar2.toPCFG();
        Grammar2.printTrainingCorpusToFile();
        Grammar2.printLexiconAndGrammarFiles();
        new Parser(Grammar2);
    }

    public static void main(String[] args) {
        LTSG_EM_nBest.EmStandard(args);
    }
}

