/*

Copyright © 2023-24 Sean Holden. All rights reserved.

*/
/*

This file is part of Connect++.

Connect++ is free software: you can redistribute it and/or modify it 
under the terms of the GNU General Public License as published by the 
Free Software Foundation, either version 3 of the License, or (at your 
option) any later version.

Connect++ is distributed in the hope that it will be useful, but WITHOUT 
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for 
more details.

You should have received a copy of the GNU General Public License along 
with Connect++. If not, see <https://www.gnu.org/licenses/>. 

*/

#include "Matrix.hpp"

std::mt19937 Matrix::d(params::random_seed);

//----------------------------------------------------------------------
Matrix::Matrix(size_t num_predicates)
: clauses()
, index((num_predicates * 2), vector<MatrixPairType>())
, literal_clause_index((num_predicates * 2), vector<LiteralClausePairType>())
, positive_clauses()
, negative_clauses()
, ground_clauses()
, clause_roles()
, num_equals(0)
, clauses_copy()
, roles_copy()
, copy_saved(false)
{}
//----------------------------------------------------------------------
void Matrix::set_num_preds(size_t num_predicates) {
    index = vector<vector<MatrixPairType> >((num_predicates * 2), vector<MatrixPairType>());
    literal_clause_index = vector<vector<LiteralClausePairType> >((num_predicates * 2), vector<LiteralClausePairType>());
}
//----------------------------------------------------------------------
bool Matrix::is_conjecture(size_t i) const {
  return clause_roles[i] == "negated_conjecture" ||
         clause_roles[i] == "conjecture";
}
//----------------------------------------------------------------------
pair<bool,size_t> Matrix::find_start() const {
  bool found_positive = false;
  bool found_negative = false;
  size_t first_positive = 0;
  size_t first_negative = 0;
  bool found_candidate = false;
  size_t candidate = 0;
  size_t i = 0;
  for (bool positive : positive_clauses) {
    if (positive_clauses[i]) {
      if (!found_positive) {
        first_positive = i;
        found_positive = true;
      }
    }
    if (negative_clauses[i]) {
      if (!found_negative) {
        first_negative = i;
        found_negative = true;
      }
    }
    if (is_conjecture(i) && !found_candidate) {
      if ((params::positive_representation && positive_clauses[i]) ||
         (!params::positive_representation && negative_clauses[i])) {
            found_candidate = true;
            candidate = i;
      }
    }
    i++;
  }
  if (!(found_positive && found_negative))
    return pair<bool,size_t>(false, 0);
  if (found_candidate)
    return pair<bool, size_t>(true, candidate);
  else if (params::positive_representation)
    return pair<bool, size_t>(true, first_positive);
  else
    return pair<bool, size_t>(true, first_negative);
}
//----------------------------------------------------------------------
// Note: we're assuming here that we don't have to use arity to 
// distinguish predicates.
//----------------------------------------------------------------------
void Matrix::add_clause(Clause& clause, string role) {
    ClauseNum clause_number = clauses.size();
    LitNum literal_number = 0;
    for (size_t j = 0; j < clause.size(); j++) {
        size_t i = clause[j].get_pred_as_index();
        index[i].push_back(MatrixPairType(clause_number, literal_number));
        Clause new_clause = clause;
        new_clause.drop_literal(literal_number);
        literal_clause_index[i].push_back(LiteralClausePairType(clause[j], new_clause));
        literal_number++;
    }
    clauses.push_back(clause);
    positive_clauses.push_back(clause.is_positive());
    negative_clauses.push_back(clause.is_negative());
    ground_clauses.push_back(clause.is_ground());
    clause_roles.push_back(role);
}
//----------------------------------------------------------------------
void Matrix::rebuild_index(vector<Clause>& cs, vector<string>& ss) {
  size_t s = clauses.size();
  clauses.clear();
  positive_clauses.clear();
  negative_clauses.clear();
  ground_clauses.clear();
  clause_roles.clear();
  size_t s2 = index.size();
  size_t s3 = literal_clause_index.size();
  index.clear();
  literal_clause_index.clear();
  index = vector<vector<MatrixPairType> >(s2, vector<MatrixPairType>());
  literal_clause_index = vector<vector<LiteralClausePairType> >(s3, vector<LiteralClausePairType>());
  for (size_t i = 0; i < s; i++) {
    add_clause(cs[i], ss[i]);
  }
}
//----------------------------------------------------------------------
void Matrix::deterministic_reorder(size_t n) {
  /**
  * Only store a copy of the clauses the first time you do this.
  */
  if (!copy_saved) {
    copy_saved = true;
    make_clauses_copy();
  }
  /*
  * Find a suitably reordered set of indices.
  */
  vector<uint32_t> new_order;
  size_t s = clauses.size();
  for (size_t i = 0; i < s; i++)
    new_order.push_back(i);
  new_order = deterministic_reorder_n_times<uint32_t>(new_order, n);
  /*
  * Do the reordering.
  */
  vector<Clause> cs;
  vector<string> ss;
  for (size_t i = 0; i < s; i++) {
    cs.push_back(clauses_copy[new_order[i]]);
    ss.push_back(roles_copy[new_order[i]]);
  }
  /*
  * Clear and rebuild as necessary.
  */
  rebuild_index(cs, ss);
}
//----------------------------------------------------------------------
void Matrix::random_reorder() {
  vector<Clause> saved_clauses(clauses);
  vector<string> saved_clause_roles(clause_roles);
  /*
  * Find a suitably reordered set of indices.
  */
  vector<uint32_t> new_order;
  size_t s = clauses.size();
  for (size_t i = 0; i < s; i++)
    new_order.push_back(i);
  std::shuffle(new_order.begin(), new_order.end(), d);
  /*
  * Do the reordering.
  */
  vector<Clause> cs;
  vector<string> ss;
  for (size_t i = 0; i < s; i++) {
    cs.push_back(saved_clauses[new_order[i]]);
    ss.push_back(saved_clause_roles[new_order[i]]);
  }
  /*
  * Clear and rebuild as necessary.
  */
  rebuild_index(cs, ss);
}
//----------------------------------------------------------------------
void Matrix::random_reorder_literals() {
  vector<Clause> saved_clauses(clauses);
  vector<string> saved_clause_roles(clause_roles);
  size_t s = clauses.size();
  /*
  * Reorder.
  */
  for (size_t i = 0; i < s; i++) {
    Clause c(saved_clauses[i]);
    c.random_reorder();
    saved_clauses[i] = c;
  }
  /*
  * Clear and rebuild as necessary.
  */
  rebuild_index(saved_clauses, saved_clause_roles);
}
//----------------------------------------------------------------------
void Matrix::move_equals_to_start() {
  if (num_equals == 0) {
    cerr << "Why are you trying to move equality axioms - there aren't any?" << endl;
    return;
  }
  make_clauses_copy();
  clauses.clear();
  positive_clauses.clear();
  negative_clauses.clear();
  ground_clauses.clear();
  clause_roles.clear();
  size_t s2 = index.size();
  size_t s3 = literal_clause_index.size();
  index.clear();
  literal_clause_index.clear();
  index = vector<vector<MatrixPairType> >(s2, vector<MatrixPairType>());
  literal_clause_index = vector<vector<LiteralClausePairType> >(s3, vector<LiteralClausePairType>());
  size_t n_clauses = clauses_copy.size();
  for (size_t i = n_clauses - num_equals; i < n_clauses; i++) {
    add_clause(clauses_copy[i], roles_copy[i]);
  }
  for (size_t i = 0; i < n_clauses - num_equals; i++) {
    add_clause(clauses_copy[i], roles_copy[i]);
  }
  clauses_copy.clear();
  roles_copy.clear();
}
//----------------------------------------------------------------------
string Matrix::to_string() const {
  string result;
  colour_string::ColourString cs(params::use_colours);
  size_t i = 0;
  for (const Clause& c : clauses) {
    if (is_conjecture(i))
      result += cs("*").orange();
    else
      result += " ";
    if (positive_clauses[i])
      result += cs("+").orange();
    else if (negative_clauses[i])
      result += cs("-").orange();
    else
      result += " ";
    result += " ";
    i++;
    result += c.to_string();
    result += "\n";
  }
  return result;
}
//----------------------------------------------------------------------
string Matrix::make_LaTeX(bool subbed) const {
  string s ("\\[\n\\begin{split}\n");
  s += "\\textcolor{magenta}{M} = ";
  for (const Clause& c : clauses) {
    s += "&\\,";
    s += c.make_LaTeX(subbed);
    s += "\\\\";
  }
  s += "\n\\end{split}\n\\]\n";

  return s;
}
//----------------------------------------------------------------------
void Matrix::write_to_prolog_file(const path& path_to_file) const {
  std::ofstream file(path_to_file);
  size_t matrix_i = 0;
  // Nasty hack needed to stop the proof checker from printing 
  // a bunch of warnings when reading the matrix.
  file << ":- style_check(-singleton)." << std::endl;
  for (const Clause& c : clauses) {
    file << "matrix(";
    file << std::to_string(matrix_i++);
    file << ", ";
    file << c.to_prolog_string();
    file << ").";
    file << std::endl;
  }
  file.close();
}
//----------------------------------------------------------------------
void Matrix::show_tptp() const {
  size_t matrix_i = 0;
  for (const Clause& c : clauses) {
    cout << "cnf(matrix-";
    cout << std::to_string(matrix_i++);
    cout << ", plain, ";
    cout << c.to_tptp_string();
    cout << ").";
    cout << std::endl;
  }
}
//----------------------------------------------------------------------
void Matrix::get_literal_clause_pair(LitNum _l, size_t _i, Literal& _lit, Clause& _clause) const {
  LiteralClausePairType result = (literal_clause_index[_l])[_i];
  _lit = result.first;
  _clause = result.second;
}
//----------------------------------------------------------------------
ostream& operator<<(ostream& out, const Matrix& m) {
    out << "Clauses in matrix:" << endl;
    out << "------------------" << endl;
    size_t i = 0;
    for (const Clause& c : m.clauses) {
        out << setw(params::output_width) << i
        << ": " << c;
        if (m.ground_clauses[i++]) {
          out << " (ground)";
        }
        out << endl;
    }
    vector<string> index_lits(m.index.size(), string(""));
    for (Clause c : m.clauses)
      for (size_t i = 0; i < c.size(); i++)
        index_lits[c[i].get_pred_as_index()] = c[i].get_small_lit();
    i = 0;
    out << endl << "Index: (Clause, Literal):" << endl;
    out <<         "-------------------------" << endl;
    for (vector<MatrixPairType> v : m.index) {
        out << setw(params::output_width + 20) << index_lits[i++] << ": ";
        for (MatrixPairType p : v) {
            out << "(" << p.first << ", " << p.second << ") ";
        }
        out << endl;
    }
    out << endl << "Index: (Literal, Clause without Literal):" << endl;
    out <<         "-----------------------------------------" << endl;
    i = 0;
    for (vector<LiteralClausePairType> v : m.literal_clause_index) {
        out << setw(params::output_width + 20) << index_lits[i++] << ": "  << endl;;
        for (LiteralClausePairType p : v) {
            out << p.first.to_string() << " --- " << p.second.to_string() << endl;
        }
        out << endl;
    }
    return out;
}

