package uk.ac.cam.cl.gfxintro.gd355.tick2;

import static org.lwjgl.opengl.GL11.GL_COLOR_BUFFER_BIT;
import static org.lwjgl.opengl.GL11.GL_DEPTH_BUFFER_BIT;
import static org.lwjgl.opengl.GL11.GL_FLOAT;
import static org.lwjgl.opengl.GL11.GL_TEXTURE_2D;
import static org.lwjgl.opengl.GL11.GL_TRIANGLES;
import static org.lwjgl.opengl.GL11.GL_UNSIGNED_INT;
import static org.lwjgl.opengl.GL11.glBindTexture;
import static org.lwjgl.opengl.GL11.glClear;
import static org.lwjgl.opengl.GL11.glClearColor;
import static org.lwjgl.opengl.GL11.glDrawElements;
import static org.lwjgl.opengl.GL15.GL_ARRAY_BUFFER;
import static org.lwjgl.opengl.GL15.GL_ELEMENT_ARRAY_BUFFER;
import static org.lwjgl.opengl.GL15.GL_STATIC_DRAW;
import static org.lwjgl.opengl.GL15.glBindBuffer;
import static org.lwjgl.opengl.GL15.glBufferData;
import static org.lwjgl.opengl.GL15.glGenBuffers;
import static org.lwjgl.opengl.GL20.glEnableVertexAttribArray;
import static org.lwjgl.opengl.GL20.glGetAttribLocation;
import static org.lwjgl.opengl.GL20.glGetUniformLocation;
import static org.lwjgl.opengl.GL20.*;
import static org.lwjgl.opengl.GL20.glVertexAttribPointer;
import static org.lwjgl.opengl.GL30.glBindVertexArray;
import static org.lwjgl.opengl.GL30.glGenVertexArrays;

import java.nio.FloatBuffer;
import java.nio.IntBuffer;

import org.joml.AxisAngle4f;
import org.joml.Matrix4f;
import org.joml.Vector3f;
import org.joml.Vector4f;
import org.lwjgl.BufferUtils;

/**
 * Abstract class encapsulating a 3D mesh object
 * Mesh object must have 3D position, UV texture coordiantes and normals
 *
 */
public abstract class Mesh {

	// shape/rendering properties
	protected int vertexArrayObj;
	protected int no_of_triangles;
	protected ShaderProgram shaders;

	
	// time for animated meshes
	private long startTime = System.currentTimeMillis();
	protected float runTimeInSeconds;

	// abstract methods that all subclasses should implement
	abstract float[]  initializeVertexPositions(); 
	abstract int[]  initializeVertexIndices();
	abstract float[]  initializeVertexNormals();
	abstract float[]  initializeTextureCoordinates();
	
	/**
	 * allow the subclass to change the rendering state before draw (e.g. bind textures)
	 * @param c
	 */
	abstract void preRender(Camera c); 
	
	/**
	 * allow the subclass to reset the rendering state (e.g. unbind textures)
	 * @param c
	 */
	abstract void postRender(Camera c);


	/**
	 * Create mesh object.
	 * @param drawShaders
	 */
	public Mesh(ShaderProgram drawShaders) {

		shaders = drawShaders;
	}

	/**
	 * Initialise. Make sure this is called before you start using the mesh
	 */
	public void iniatialise() {

		float vertPositions[] = initializeVertexPositions();
		int indices[] = initializeVertexIndices();
		float vertNormals[] = initializeVertexNormals();
		float textureCoordinates[] = initializeTextureCoordinates();
		no_of_triangles = indices.length;

		loadDataOntoGPU( vertPositions, indices, vertNormals, textureCoordinates );
	}

	/**
	 * Draw mesh from a camera perspective
	 * @param camera
	 */
	public void render(Camera camera) {
		// If shaders modified on disk, reload them
		shaders.reloadIfNeeded(); 
		shaders.startUsing();

		// update time
		runTimeInSeconds = ((float)(System.currentTimeMillis() - startTime)) / 1000.f;

		// if the object moved, we may need to recompute an updated MVP matrix
		if (modelNeedsRecompute) {
			recomputeModel();
		}

		// compute and upload MVP
		Matrix4f mvp_matrix; // Model-view-projection matrix
		mvp_matrix = new Matrix4f(camera.getProjectionMatrix()).mul(camera.getViewMatrix()).mul(modelMatrix);
		int mvp_location = glGetUniformLocation(shaders.getHandle(), "mvp_matrix");
		FloatBuffer mvp_buffer = BufferUtils.createFloatBuffer(16);
		mvp_matrix.get(mvp_buffer);
		glUniformMatrix4fv(mvp_location, false, mvp_buffer);

		// if shader has an m_matrix uniform, send it the model matrix
		int m_location = glGetUniformLocation(shaders.getHandle(), "m_matrix");
		if (m_location >= 0) {
			FloatBuffer m_buffer = BufferUtils.createFloatBuffer(16);
			modelMatrix.get(m_buffer);
			glUniformMatrix4fv(m_location, false, m_buffer);
		}

		// if shader has a camera_location uniform, send it the camera location
		int camera_location = glGetUniformLocation(shaders.getHandle(), "camera_location");
		if(camera_location >= 0) {
			Vector3f cameraPos = camera.getCameraPosition();
			glUniform3f(camera_location, cameraPos.x, cameraPos.y, cameraPos.z);
		}

		// if shader has a clip_location uniform, send it the clipping plane info
		int clip_location = glGetUniformLocation(shaders.getHandle(), "clip_plane");
		if (clip_location >= 0) {
			Vector4f clipPlane = camera.getClipPlane();
			glUniform4f(clip_location, clipPlane.x, clipPlane.y, clipPlane.z, clipPlane.w);
		}

		// subclass's chance to modify the render state
		preRender(camera);

		// draw
		glBindVertexArray(vertexArrayObj); // Bind the existing VertexArray object
		glDrawElements(GL_TRIANGLES, no_of_triangles, GL_UNSIGNED_INT, 0); // Draw it as triangles
		glBindVertexArray(0);              // Remove the binding

		// subclass's chance to undo any changes it made
		postRender(camera);
	}

	
	/**
	 * Move the data from Java arrays to OpenGL buffers (these are most likely on the GPU)
	 * @param vertPositions
	 * @param indices
	 * @param vertNormals
	 * @param textureCoordinates
	 */
	protected void loadDataOntoGPU( float[] vertPositions, int[] indices, float[] vertNormals, float[] textureCoordinates ) {

		int shaders_handle = shaders.getHandle();

		vertexArrayObj = glGenVertexArrays(); // Get a OGL "name" for a vertex-array object
		glBindVertexArray(vertexArrayObj); // Create a new vertex-array object with that name

		// ---------------------------------------------------------------
		// LOAD VERTEX POSITIONS
		// ---------------------------------------------------------------

		// Construct the vertex buffer in CPU memory
		FloatBuffer vertex_buffer = BufferUtils.createFloatBuffer(vertPositions.length);
		vertex_buffer.put(vertPositions); // Put the vertex array into the CPU buffer
		vertex_buffer.flip(); // "flip" is used to change the buffer from read to write mode

		int vertex_handle = glGenBuffers(); // Get an OGL name for a buffer object
		glBindBuffer(GL_ARRAY_BUFFER, vertex_handle); // Bring that buffer object into existence on GPU
		glBufferData(GL_ARRAY_BUFFER, vertex_buffer, GL_STATIC_DRAW); // Load the GPU buffer object with data

		// Get the locations of the "position" vertex attribute variable in our ShaderProgram
		int position_loc = glGetAttribLocation(shaders_handle, "position");

		// If the vertex attribute does not exist, position_loc will be -1, so we should not use it
		if (position_loc != -1) {

			// Specifies where the data for "position" variable can be accessed
			glVertexAttribPointer(position_loc, 3, GL_FLOAT, false, 0, 0);

			// Enable that vertex attribute variable
			glEnableVertexAttribArray(position_loc);
		}

		// ---------------------------------------------------------------
		// LOAD VERTEX NORMALS
		// ---------------------------------------------------------------
		FloatBuffer normal_buffer = BufferUtils.createFloatBuffer(vertNormals.length);
		normal_buffer.put(vertNormals); // Put the normal array into the CPU buffer
		normal_buffer.flip(); // "flip" is used to change the buffer from read to write mode

		int normal_handle = glGenBuffers(); // Get an OGL name for a buffer object
		glBindBuffer(GL_ARRAY_BUFFER, normal_handle); // Bring that buffer object into existence on GPU
		glBufferData(GL_ARRAY_BUFFER, normal_buffer, GL_STATIC_DRAW); // Load the GPU buffer object with data

		// Get the locations of the "normal" vertex attribute variable in our ShaderProgram
		int normal_loc = glGetAttribLocation(shaders_handle, "normal");

		// If the vertex attribute does not exist, normal_loc will be -1, so we should not use it
		if (normal_loc != -1) {

			// Specifies where the data for "normal" variable can be accessed
			glVertexAttribPointer(normal_loc, 3, GL_FLOAT, false, 0, 0);

			// Enable that vertex attribute variable
			glEnableVertexAttribArray(normal_loc);
		}


		// ---------------------------------------------------------------
		// LOAD VERTEX INDICES
		// ---------------------------------------------------------------

		IntBuffer index_buffer = BufferUtils.createIntBuffer(indices.length);
		index_buffer.put(indices).flip();
		int index_handle = glGenBuffers();
		glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, index_handle);
		glBufferData(GL_ELEMENT_ARRAY_BUFFER, index_buffer, GL_STATIC_DRAW);

		// ---------------------------------------------------------------
		// LOAD Texture coordinates
		// ---------------------------------------------------------------

		// Put texture coordinate array into a buffer in CPU memory
		FloatBuffer tex_buffer = BufferUtils.createFloatBuffer(textureCoordinates.length);
		tex_buffer.put(textureCoordinates).flip();

		// Create an OpenGL buffer and load it with texture coordinate data
		int tex_handle = glGenBuffers();
		glBindBuffer(GL_ARRAY_BUFFER, tex_handle);
		glBufferData(GL_ARRAY_BUFFER, tex_buffer, GL_STATIC_DRAW);

		// Get the location of the "texcoord" variable in the shader
		int tex_loc = glGetAttribLocation(shaders.getHandle(), "texcoord" );

		// Specify how to access the variable, and enable it
		if (tex_loc != -1) {
			glVertexAttribPointer(tex_loc, 2, GL_FLOAT, false, 0, 0);
			glEnableVertexAttribArray(tex_loc);
		}
	}

	// object location/rotation/scale
	protected Vector3f position = new Vector3f(0, 0, 0);
	protected Vector3f scale = new Vector3f(1, 1, 1);
	protected Vector3f rotation = new Vector3f(0, 0, 0);
	protected boolean modelNeedsRecompute = true;
	protected Matrix4f modelMatrix;

	// setters and getters for location/rotation and scale
	public void setScale(Vector3f scale) {
		this.scale = scale;
		modelNeedsRecompute = true;
	}

	public void setPosition(Vector3f position) {
		this.position = position;
		modelNeedsRecompute = true;
	}

	public void setRotation(Vector3f rotation) {
		this.rotation = rotation;
		modelNeedsRecompute = true;
	}
	
	public Vector3f getPosition() {
		return new Vector3f(position);
	}
	
	public Vector3f getRotation() {
		return new Vector3f(rotation);
	}
	
	protected void recomputeModel() {
		Matrix4f scaleMatric = new Matrix4f();
		scaleMatric.scale(scale);

		Matrix4f rotationMatrix = new Matrix4f();
		rotationMatrix.rotateAffineXYZ(rotation.x, rotation.y, rotation.z);

		Matrix4f translateMatrix = new Matrix4f();
		translateMatrix.translate(position);

		modelMatrix = translateMatrix.mul(rotationMatrix.mul(scaleMatric));
		modelNeedsRecompute = false;
	}
}
