/*
 * Decompiled with CFR 0.152.
 */
package tsg;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Scanner;
import kernels.NodeSetCollector;
import kernels.NodeSetCollectorMUB;
import settings.Parameters;
import tsg.TSNodeLabel;
import tsg.TSNodeLabelIndex;
import tsg.corpora.Wsj;
import tsg.parseEval.EvalF;
import util.FileUtil;
import util.PrintProgress;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class DOP_SD_reranker {
    public static ArrayList<TSNodeLabel> fragmentBag;

    public static ArrayList<TSNodeLabelIndex> nextNBest(int nBest, Scanner s) throws Exception {
        ArrayList<TSNodeLabelIndex> result = new ArrayList<TSNodeLabelIndex>(nBest);
        int count = 0;
        while (s.hasNextLine() && count < nBest) {
            String line = s.nextLine();
            if (line.equals("")) {
                return result;
            }
            TSNodeLabelIndex tree = new TSNodeLabelIndex(line);
            result.add(tree);
            ++count;
        }
        while (s.hasNextLine() && !s.nextLine().equals("")) {
        }
        return result;
    }

    public static void readFragmentsFile(File fragmentFile) throws Exception {
        fragmentBag = new ArrayList();
        Scanner scan = FileUtil.getScanner(fragmentFile);
        int countFragments = 0;
        while (scan.hasNextLine()) {
            String line = scan.nextLine();
            if (line.equals("")) continue;
            ++countFragments;
            String fragmentString = line.split("\t")[0];
            fragmentString = fragmentString.replaceAll("\\\\", "");
            TSNodeLabel fragment = new TSNodeLabel(fragmentString, false);
            fragmentBag.add(fragment);
        }
        System.out.println("Read " + countFragments + " fragments");
        scan.close();
    }

    public static void addCFGfragments(File trainingCorpus) throws Exception {
        ArrayList<TSNodeLabel> corpus = TSNodeLabel.getTreebank(trainingCorpus);
        HashSet<String> ruleSet = new HashSet<String>();
        for (TSNodeLabel t : corpus) {
            t.addTop();
            ArrayList<TSNodeLabel> nodes = t.collectAllNodes();
            for (TSNodeLabel n : nodes) {
                if (n.isLexical) continue;
                String rule = n.cfgRule();
                ruleSet.add(rule);
            }
        }
        for (String rule : ruleSet) {
            TSNodeLabel ruleFragment = new TSNodeLabel("( " + rule + ")", false);
            fragmentBag.add(ruleFragment);
        }
        System.out.println("Read " + ruleSet.size() + " CFG fragments");
    }

    private static int getMinDerivationSize(TSNodeLabelIndex t) {
        NodeSetCollectorMUB setCollector = new NodeSetCollectorMUB();
        for (TSNodeLabel fragment : fragmentBag) {
            DOP_SD_reranker.getCFGSetCoveringFragment(t, fragment, (NodeSetCollector)setCollector);
        }
        BitSet union = setCollector.uniteSubGraphs();
        ArrayList<TSNodeLabel> internalNodes = t.collectInternalNodes();
        if (union.cardinality() < internalNodes.size()) {
            return Integer.MAX_VALUE;
        }
        for (TSNodeLabel iN : internalNodes) {
            TSNodeLabelIndex iNI = (TSNodeLabelIndex)iN;
            if (union.get(iNI.index)) continue;
            return Integer.MAX_VALUE;
        }
        int lexNonCovered = 0;
        ArrayList<TSNodeLabel> lexNodes = t.collectLexicalItems();
        for (TSNodeLabel lN : lexNodes) {
            TSNodeLabelIndex iNPI = (TSNodeLabelIndex)lN.parent;
            if (union.get(iNPI.index)) continue;
            ++lexNonCovered;
        }
        int minCover = DOP_SD_reranker.getMinCover(setCollector, union);
        return minCover + lexNonCovered;
    }

    private static void getCFGSetCoveringFragment(TSNodeLabelIndex t, TSNodeLabel fragment, NodeSetCollector setCollector) {
        if (t.isLexical) {
            return;
        }
        if (t.sameLabel(fragment)) {
            BitSet set = new BitSet();
            DOP_SD_reranker.getCFGSetCoveringFragmentNonRecursive(t, fragment, set);
            if (!set.isEmpty()) {
                setCollector.add(set);
            }
        }
        TSNodeLabel[] tSNodeLabelArray = t.daughters;
        int n = t.daughters.length;
        int n2 = 0;
        while (n2 < n) {
            TSNodeLabel d = tSNodeLabelArray[n2];
            TSNodeLabelIndex di = (TSNodeLabelIndex)d;
            DOP_SD_reranker.getCFGSetCoveringFragment(di, fragment, setCollector);
            ++n2;
        }
    }

    private static void getCFGSetCoveringFragmentNonRecursive(TSNodeLabelIndex t, TSNodeLabel fragment, BitSet set) {
        if (t.isLexical || fragment.isTerminal()) {
            return;
        }
        if (!t.sameDaughtersLabel(fragment)) {
            return;
        }
        int prole = t.prole();
        int i = 0;
        while (i < prole) {
            TSNodeLabel thisDaughter = t.daughters[i];
            TSNodeLabel otherDaughter = fragment.daughters[i];
            TSNodeLabelIndex thisDaughterIndex = (TSNodeLabelIndex)thisDaughter;
            DOP_SD_reranker.getCFGSetCoveringFragmentNonRecursive(thisDaughterIndex, otherDaughter, set);
            ++i;
        }
        set.set(t.index);
    }

    private static int getMinCover(NodeSetCollectorMUB setCollector, BitSet union) {
        int result = 0;
        int unionCardinality = union.cardinality();
        BitSet currentCover = new BitSet();
        int nextElement = union.nextSetBit(0);
        while (nextElement != -1) {
            BitSet uniqueSetContainingElement;
            if (!currentCover.get(nextElement) && (uniqueSetContainingElement = setCollector.getUniqueBitSetContainingElement(nextElement)) != null) {
                currentCover.or(uniqueSetContainingElement);
                setCollector.removeSet(uniqueSetContainingElement);
                ++result;
            }
            nextElement = union.nextSetBit(nextElement + 1);
        }
        while (currentCover.cardinality() != unionCardinality) {
            BitSet nextSet = setCollector.getSetWithMaxUncoveredElements(currentCover);
            currentCover.or(nextSet);
            setCollector.removeSet(nextSet);
            ++result;
        }
        return result;
    }

    public static void rerank(int nBest) throws Exception {
        String fragmentFileDir = String.valueOf(Parameters.resultsPath) + "TSG/TSGkernels/Wsj/KenelFragments/SemTagOff_Top/all/correctCount/";
        File fragmentFile = new File(String.valueOf(fragmentFileDir) + "fragments_MUB_freq_all_correctCount.txt");
        DOP_SD_reranker.readFragmentsFile(fragmentFile);
        File trainingCorpus = new File(String.valueOf(Wsj.WsjOriginalCleanedSemTagsOff) + "wsj-02-21.mrg");
        DOP_SD_reranker.addCFGfragments(trainingCorpus);
        String baseDir = String.valueOf(Parameters.resultsPath) + "TSG/DOP_SD_Reranker/";
        File nBestFile = new File(String.valueOf(baseDir) + "wsj-22_chiarniak_parsed1000_cleaned.mrg");
        Scanner nBestScanner = FileUtil.getScanner(nBestFile);
        File goldFile = new File(String.valueOf(baseDir) + "wsj-22_gold.mrg");
        ArrayList<TSNodeLabel> goldTreebank = TSNodeLabel.getTreebank(goldFile);
        File rerankedFile = new File(String.valueOf(baseDir) + "wsj-22_reranked_" + nBest + "best.mrg");
        File rerankedFileEvalF = new File(String.valueOf(baseDir) + "wsj-22_reranked_" + nBest + "best.evalF");
        PrintWriter pw = FileUtil.getPrintWriter(rerankedFile);
        int size = goldTreebank.size();
        int totalActivelyReranked = 0;
        int totalNonCovered = 0;
        System.out.println("Rerankin n = " + nBest);
        PrintProgress.start("Sentence ");
        int i = 0;
        while (i < size) {
            PrintProgress.next();
            ArrayList<TSNodeLabelIndex> nBestTrees = DOP_SD_reranker.nextNBest(nBest, nBestScanner);
            Iterator<TSNodeLabelIndex> iter = nBestTrees.iterator();
            TSNodeLabelIndex bestReranked = iter.next();
            int minDerivSize = DOP_SD_reranker.getMinDerivationSize(bestReranked);
            boolean reranked = false;
            while (iter.hasNext()) {
                TSNodeLabelIndex t = iter.next();
                int derSize = DOP_SD_reranker.getMinDerivationSize(t);
                if (derSize >= minDerivSize) continue;
                minDerivSize = derSize;
                bestReranked = t;
                reranked = true;
            }
            if (reranked) {
                ++totalActivelyReranked;
            }
            if (minDerivSize == Integer.MAX_VALUE) {
                ++totalNonCovered;
            }
            pw.println(bestReranked.toString());
            ++i;
        }
        pw.close();
        PrintProgress.end();
        float[] rerankedFScore = EvalF.staticEvalF(goldFile, rerankedFile, rerankedFileEvalF, true);
        System.out.println("Actively reranked: " + totalActivelyReranked);
        System.out.println("Non covered: " + totalNonCovered);
        System.out.println("Reranked Recall Precision FScore: " + Arrays.toString(rerankedFScore));
    }

    public static void main(String[] args) throws Exception {
        int[] nBest;
        int[] nArray = nBest = new int[]{5};
        int n = nBest.length;
        int n2 = 0;
        while (n2 < n) {
            int n3 = nArray[n2];
            DOP_SD_reranker.rerank(n3);
            ++n2;
        }
    }
}

