package framework;

import static java.lang.Math.acos;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

@SuppressWarnings("serial")
public class Mesh extends LinkedList<Mesh.Face> {
  
  //////////////////////////////////////////////
  
  public static class Vertex extends M3d {

    Map<Face, Integer> faces = new HashMap<Face, Integer>();
    public M3d normal = new M3d();

    public Vertex(M3d src) {
      super(src);
    }
    
    public void addFace(Face face, int index) {
      faces.put(face, index);
    }
    
    public void computeNormal() {
      normal = new M3d();
      for (Face face : faces.keySet()) {
        normal = normal.plus(face.getNormal().times(face.getFaceAngle(this)));
      }
      normal = normal.normalized();
    }
    
    int getFaceIndex(Face f) {
      return faces.get(f);
    }
    
    public M3d getNormal() {
      return normal;
    }
  }
  
  //////////////////////////////////////////////
  
  public static class Face {

    Vertex[] arrVerts = null;
    M3d normal;
    
    public Face(Vertex... verts) {
      arrVerts = verts;
      for (int i = 0; i <arrVerts.length; i++) {
        arrVerts[i].addFace(this, i);
      }
      normal = getVertex(1).minus(getVertex(0)).cross(getVertex(-1).minus(getVertex(0))).normalized();
    }

    public M3d getNormal() {
      return normal;
    }
    
    public int getNumVerts() {
      return arrVerts.length;
    }
    
    public Vertex getVertex(int i) {
      while (i < 0) {
        i += arrVerts.length;
      }
      return arrVerts[i % arrVerts.length];
    }
    
    public Vertex[] getVertices() {
      return arrVerts;
    }
    
    public double getFaceAngle(Vertex v) {
      int i = v.getFaceIndex(this);
      
      return acos(getVertex(i-1).minus(v).normalized().dot(getVertex(i+1).minus(v).normalized()));
    }
  }

  //////////////////////////////////////////////
  
  public Mesh() {
  }
  
  public void computeAllNormals() {
    Map<Vertex, Integer> visited = new HashMap<Vertex, Integer>();
    
    for (Face face : this) {
      for (Vertex v : face.getVertices()) {
        if (!visited.containsKey(v)) {
          visited.put(v, 1);
          v.computeNormal();
        }
      }
    }
  }
  
  public Mesh scaled(M3d scale) {
    Mesh newMesh = new Mesh();
    Map<Vertex, Vertex> newVerts = new HashMap<Vertex, Vertex>();
    
    for (Face face : this) {
      for (Vertex v : face.getVertices()) {
        if (!newVerts.containsKey(v)) {
          newVerts.put(v, new Vertex(new M3d(v.getX() * scale.getX(), v.getY() * scale.getY(), v.getZ() * scale.getZ())));
        }
      }
    }
    
    for (Face face : this) {
      int n = face.getVertices().length;
      Vertex arr[] = new Vertex[n];
      
      for (int i = 0; i < n; i++) {
        arr[i] = newVerts.get(face.getVertices()[i]);
      }
      
      newMesh.add(new Face(arr));
    }
    newMesh.computeAllNormals();
    return newMesh;
  }
  
  public Mesh flipped() {
    Mesh newMesh = new Mesh();
    Map<Vertex, Vertex> newVerts = new HashMap<Vertex, Vertex>();
    
    for (Face face : this) {
      for (Vertex v : face.getVertices()) {
        if (!newVerts.containsKey(v)) {
          newVerts.put(v, new Vertex(v));
        }
      }
    }
    
    for (Face face : this) {
      int n = face.getVertices().length;
      Vertex arr[] = new Vertex[n];
      
      for (int i = 0; i < n; i++) {
        arr[n-1-i] = newVerts.get(face.getVertices()[i]);
      }
      
      newMesh.add(new Face(arr));
    }
    newMesh.computeAllNormals();
    return newMesh;
  }
}
