package uk.ac.cam.cl.gfxintro.gd355.tick2;

import static org.lwjgl.opengl.GL11.GL_TEXTURE_2D;
import static org.lwjgl.opengl.GL11.glBindTexture;
import static org.lwjgl.opengl.GL13.*;
import static org.lwjgl.opengl.GL20.*;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;

import javax.imageio.ImageIO;

import org.joml.Vector3f;
import org.lwjgl.opengl.GL13;

public class Terrain extends Mesh {

	private static final float HEIGHTMAP_SCALE = 3.0f;
	private float[][] heightmap;
	private static float MAP_SIZE = 10;
	private Texture terrainTexture;
    
    // Filenames for vertex and fragment shader source code
    private final static String VSHADER_FN = "resources/vertex_shader.glsl";
    private final static String FSHADER_FN = "resources/fragment_shader.glsl";
	
	public Terrain(String heightmapFilename, Texture texture)
	{       
		super(new ShaderProgram(new Shader(GL_VERTEX_SHADER, VSHADER_FN), new Shader(GL_FRAGMENT_SHADER, FSHADER_FN), "colour"));
	    terrainTexture = texture;
	    initializeHeightmap(heightmapFilename);
	}
	
	@Override
	void preRender(Camera c) {	
		
		// bind terrain texture before we draw
		glActiveTexture(GL_TEXTURE0);
		glBindTexture(GL_TEXTURE_2D, terrainTexture.getTexId());
		int texture_location = glGetUniformLocation(shaders.getHandle(), "tex");
		glUniform1i(texture_location, 0);		
	}

	@Override
	void postRender(Camera c) {

		// unbind texture
        glBindTexture(GL_TEXTURE_2D, 0);		
	}
	
    public void initializeHeightmap(String heightmapFilename) {

        try {
            BufferedImage heightmapImg = ImageIO.read(new File(heightmapFilename));
            int heightmap_width_px = heightmapImg.getWidth();
            int heightmap_height_px = heightmapImg.getHeight();

            heightmap = new float[heightmap_height_px][heightmap_width_px];

            for (int row = 0; row < heightmap_height_px; row++) {
                for (int col = 0; col < heightmap_width_px; col++) {
                    float height = (float) (heightmapImg.getRGB(col, row) & 0xFF) / 0xFF;
                    heightmap[row][col] = (float) Math.pow(height, 2.2);
                }
            }
        } catch (IOException e) {
            throw new RuntimeException("Error loading heightmap");
        }
    }
	
	@Override
	float[] initializeVertexPositions() {
	      //generate and upload vertex data

        int heightmap_width_px = heightmap[0].length;
        int heightmap_height_px = heightmap.length;

        float start_x = -MAP_SIZE / 2;
        float start_z = -MAP_SIZE / 2;
        float delta_x = MAP_SIZE / heightmap_width_px;
        float delta_z = MAP_SIZE / heightmap_height_px;

        float[] vertPositions = new float[heightmap_width_px * heightmap_height_px * 3];
        for (int row = 0; row < heightmap_height_px; row++) {
            for (int col = 0; col < heightmap_width_px; col++) {
            	vertPositions[3 * (row + col * heightmap_width_px)] = start_x + delta_x * row;
            	vertPositions[3*(row + col * heightmap_width_px) +1] = HEIGHTMAP_SCALE * heightmap[col][row];
            	vertPositions[3*(row + col * heightmap_width_px) +2] = start_z + delta_z * col;
            }
        }
        return vertPositions;
	}

	@Override
	int[] initializeVertexIndices() {
        int heightmap_width_px = heightmap[0].length;
        int heightmap_height_px = heightmap.length;
        
        //generate and upload index data
        int[] indices = new int[6 * (heightmap_width_px - 1) * (heightmap_height_px - 1)];

        int count = 0;
        for (int row = 0; row < heightmap_height_px - 1; row++) {
            for (int col = 0; col < heightmap_width_px - 1; col++) {
            	int vert_index = heightmap_width_px * row + col;
            	indices[count++] = vert_index;
            	indices[count++] = vert_index + heightmap_width_px;
            	indices[count++] = vert_index + heightmap_width_px + 1;
            	indices[count++] = vert_index;
            	indices[count++] = vert_index + heightmap_width_px + 1;
            	indices[count++] = vert_index + 1;
            }
        }
        return indices;
	}

	@Override
	float[] initializeVertexNormals() {
		int heightmap_width_px = heightmap[0].length;
        int heightmap_height_px = heightmap.length;

        int num_verts = heightmap_width_px * heightmap_height_px;
        float[] vertNormals = new float[3*num_verts];

        float delta_x = MAP_SIZE / heightmap_width_px;
        float delta_z = MAP_SIZE / heightmap_height_px;

        int count = 0;

        count = 0;
        for(int i = 0; i < heightmap_height_px;i++) {
            for(int j = 0; j < heightmap_width_px;j++) {
                vertNormals[count++] = 0;
                vertNormals[count++] = 1;
                vertNormals[count++] = 0;
            }
        }

        for (int row = 1; row < heightmap_height_px - 1; row++) {
            for (int col = 1; col < heightmap_width_px - 1; col++) {
            	Vector3f Tx = new Vector3f();
            	Vector3f Tz = new Vector3f();
                
            	Tx.x = delta_x * 2;
            	Tx.y = heightmap[col][row+1]-heightmap[col][row-1];
            	Tx.z = 0;
            	
            	Tz.x = 0;
            	Tz.y = heightmap[col+1][row]-heightmap[col-1][row];
            	Tz.z = delta_z * 2;
            	
            	Vector3f D = Tz.cross(Tx).normalize();
            	vertNormals[3*(row + col * heightmap_width_px)] = D.x;
            	vertNormals[3*(row + col * heightmap_width_px)+1] = D.y;
            	vertNormals[3*(row + col * heightmap_width_px)+2] = D.z;
            }
        }

        return vertNormals;
	}

	@Override
	float[] initializeTextureCoordinates() {
        int heightmapWidthPx = heightmap[0].length;
        int heightmapHeightPx = heightmap.length;

        int numVerts = heightmapWidthPx * heightmapHeightPx;
        float[] texcoords = new float[numVerts*2];
        for (int row = 0; row < heightmapWidthPx; row++) {
            for (int col = 0; col < heightmapWidthPx; col++) {
            	texcoords[2*(col + row * heightmapWidthPx)] = col/(float)heightmapWidthPx;
            	texcoords[2*(col + row * heightmapWidthPx)+1] = row/(float)heightmapHeightPx;
            }
        }

        return texcoords;
	}

	

		
}
