//!
//! \file RayTracer.cpp
//! \brief Implementation of the RayTracer class
//!
//! \author Austin McGee amcgee@digipen.edu
//! \date February 17, 2006
//!
//! Copyright 2006 DigiPen (USA) Corporation, all rights reserved.
//!
////////////////////////////////////////////////////////////////////////////////

#include "RayTracer.h"
#include <assert.h>
#include "Sphere.h"
#include <math.h>
#include "resource.h"

// our instance of the class
RayTracer *WINDOW = 0;

/*------------------------------------------------------------------------------
RayTracer
------------------------------------------------------------------------------*/
/*!	Default constructor.  Does default stuff.
*/
RayTracer::RayTracer( void )
: camera_( Vector3( 0.f, 0.f, 0.f ), Vector3( 1.f, 0.f, 0.f ), Vector3( 0.f, 0.f, 1.f ), Vector3( 0.f, 1.f, 0.f ) ),
	redraw_( false ), drawing_( false )
{
	className_ = L"Austin McGee Class";
	windowName_ = L"CS 400 Assignment 2";
	SetFullscreen( false );
}

/*------------------------------------------------------------------------------
~RayTracer
------------------------------------------------------------------------------*/
/*!	Default destructor.  Gets rid of the window.
*/
RayTracer::~RayTracer( void )
{
	UnregisterClass( className_, windowClass_.hInstance );
	Clear();
}

/*------------------------------------------------------------------------------
InitWindow
------------------------------------------------------------------------------*/
/*!	Initializes the window, registers the class, and does all the stuff
	a window needs in the beginning of its life

	\param DefaultWindowProc	the process that handles the default windows messages
*/
void RayTracer::InitWindow( LRESULT CALLBACK DefaultWindowProc( HWND, UINT, WPARAM, LPARAM ) )
{
	HINSTANCE app = GetModuleHandle( NULL );

	windowClass_.style          = CS_OWNDC;								// make GDI work with d3d
	windowClass_.cbClsExtra     = 0;									// allocate extra memory for class
	windowClass_.cbWndExtra     = 0;									// allocate extra memory for window
	windowClass_.hInstance      = app;									// a handle to the instance of the window
	windowClass_.hIcon          = LoadIcon(NULL,IDI_APPLICATION);		// handle to the window's default icon
	windowClass_.hCursor        = LoadCursor(NULL,IDC_ARROW);			// handle to the window's default cursor
	windowClass_.hbrBackground  = (HBRUSH)GetStockObject(BLACK_BRUSH);	// handle to the window's default background
	windowClass_.lpszMenuName   = MAKEINTRESOURCE(IDR_MENU1);			// used for creating menus
	windowClass_.lpszClassName  = className_;							// public name for the class
	windowClass_.lpfnWndProc    = DefaultWindowProc;					// pointer to message handler procedure

	// register the class with windows
	if(!RegisterClass(&windowClass_))
	{
		// handle error message here
	}

	window_ = CreateWindow(	className_,						// registered class name
							windowName_,					// window name
							style_,							// style flags
							40,								// starting x point of window
							40,								// starting y point of window
							width_,							// width of window
							height_,						// height of window
							NULL,							// pointer to parent window
							NULL,							// pointer to menu
							windowClass_.hInstance,			// application instance handle
							NULL							// pointer to window creation data
							);
	if (!window_)
	{
		// handle error message here
	}

	CreateBackBuffer();
}

/*------------------------------------------------------------------------------
SetFullscreen
------------------------------------------------------------------------------*/
/*!	Sets the full screen bool for whether we want the window to be
	full screen or not.

	\param value true for full screen, false for windowed
*/
void RayTracer::SetFullscreen( bool value )
{
	if ( value )
	{
		fullscreen_ = true;
		style_ = WS_POPUP | WS_VISIBLE;
		width_ = GetSystemMetrics(SM_CXSCREEN);
		height_ = GetSystemMetrics(SM_CYSCREEN);
	}
	else
	{
		fullscreen_ = false;
		style_ = WS_OVERLAPPED | WS_SYSMENU | WS_VISIBLE;
		width_ = 500;
		height_ = 500 + eHeight_;
	}
}

/*------------------------------------------------------------------------------
CreateBackBuffer
------------------------------------------------------------------------------*/
/*  Creates the backbuffer for the scene.
*/
void RayTracer::CreateBackBuffer( void )
{
	RECT rect;

	GetClientRect( window_, &rect );

	// create a new bitmap
	BITMAPINFOHEADER bitmapInfo = { 0 };
	bitmapInfo.biSize = sizeof(BITMAPINFOHEADER);
	bitmapInfo.biBitCount = 32;
	bitmapInfo.biCompression = BI_RGB;
	bitmapInfo.biWidth = rect.right;
	bitmapInfo.biHeight = rect.bottom;
	bitmapInfo.biPlanes = 1;

	// Create a temp display DC and bitmap
	HDC tempDC = CreateDC(L"Display", 0, 0, 0);
	backBuffer_ = CreateCompatibleDC(tempDC);

	backBitmap_ = CreateDIBSection(backBuffer_, reinterpret_cast<BITMAPINFO *>(&bitmapInfo), 
		DIB_RGB_COLORS, reinterpret_cast<void **>(&basePtr_), 0, 0);

	savedBitmap_ = static_cast<HBITMAP>(::SelectObject(backBuffer_, backBitmap_));

	width_ = rect.right;
	height_ = rect.bottom;

	// Cleanup
	::DeleteDC(tempDC);
}

/*------------------------------------------------------------------------------
SetPixel
------------------------------------------------------------------------------*/
/*  Sets a pixel to a specific color.

	param: x the horizontal location
	param: y the vertical location
	param: color the color
*/
void RayTracer::SetPixel(int x, int y, int color)
{
	assert( x >= 0 && x < width_ && y >= 0 && y < height_ );

	*( basePtr_ + ( y * width_ ) + x ) = color;
}

/*------------------------------------------------------------------------------
GetReflectionCoefficient
------------------------------------------------------------------------------*/
/*  Gets the reflection coefficient for the collision.  For details on the
	algorithm, read Hanson's CS 400 notes.

	param: n index of refraction
	param: dot the dot product of the incident and normal vectors
	param: square cos( theta_t )

	return: 1 for total internal reflection, otherwise 0 <= result < 1
*/
float RayTracer::GetReflectionCoefficient( const float n, const float dot, const float square )
{
	// find the perpendicular
	float perp = ( n * dot - square ) / ( n * dot + square );

	// find the parallel
	float parallel = ( dot - n * square ) / ( dot + n * square );

	// return the reflection coefficient value
	return 0.5f * ( perp * perp + parallel * parallel );
}

/*------------------------------------------------------------------------------
GetTransmissionVector
------------------------------------------------------------------------------*/
/*  Gets the transmission color for the collision.  For details on the
	algorithm, read Hanson's CS 400 notes.

	param: incident the incident vector
	param: normal the normal vector
	param: n the index of refraction
	param: dot the dot product of the incident and normal vector
	param: square cos( theta_t )

	return: the transmission color
*/
Vector3 RayTracer::GetTransmissionVector( const Vector3 &incident, const Vector3 &normal, const float n, const float dot, const float square )
{
	// find out what Ti * normal is
	float tiDotN = square;
	if ( dot >= 0.f )
		tiDotN = -square;

	// return the transmission vector
	return ( tiDotN + n * dot ) * normal - n * incident;
}

/*------------------------------------------------------------------------------
GetCollision
------------------------------------------------------------------------------*/
/*  Checks the list of objects to see if the ray has any collisions with it.

	param: ray the ray
	param: objHit changes to the index of the object closest hit

	return: the time in which the closest object was hit, -1 for no hits
*/
float RayTracer::GetCollision( const Ray &ray, unsigned &objHit )
{
	float time = -1.f;
	for ( unsigned obj = 0; obj < objectList_.size(); ++obj )
	{
		float hit = objectList_[obj]->IntersectionTest( ray );
		if ( hit >= 0.f && ( time == -1.f || hit < time ) )
		{
			time = hit;
			objHit = obj;
		}
	}
	return time;
}

/*------------------------------------------------------------------------------
RayCast
------------------------------------------------------------------------------*/
/*  Casts the specified ray into the scene and does the appropriate calculations
	on it.  This includes lighting, reflection, and transmission.

	param: ray the ray
	param: depth how deep we want to go
	param: ni the last index of refraction

	return: the color for that ray trace
*/
Vector3 RayTracer::RayCast( const Ray &ray, int depth, float ni )
{
	// if we're as far as we want to go depth wise, return
	if ( depth < 0 )
	{
		return Vector3::Zero;
	}

	// find the closest object along the ray
	unsigned objHit;
	float time = GetCollision( ray, objHit );

	Vector3 color = Vector3::Zero;

	// if we hit something, fill in the pixel
	if ( time >= 0.f )
	{
		Vector3 intersection = ray.pos + time * ray.dir;
		Vector3 normal = objectList_[objHit]->GetNormal();

		// get our shadow intersection point (since it may be different)
		Vector3 shadowInt = intersection + normal * 0.001f;

		// get our index of refraction
		float nt = 1.f;
		float intDir = 0.001f;
		if ( ni == 1.f )
		{
			nt = sqrt( objectList_[objHit]->er_ * objectList_[objHit]->ur_ );
		}
		else
		{
			normal *= -1.f;
		}

		// get the eye position
		Vector3 incident = (-ray.dir).Normal();

		// precompute some of the stuff for the reflection and transmission coefficient
		float indexOfRefraction = ni / nt;
		float dot = fabs( incident * normal );
		float square = 1.f - indexOfRefraction * indexOfRefraction * ( 1.f - dot * dot );

		float reflectCoeff = 1.f;

		// check for total internal reflection before calculating it
		if ( square > 0.f )
		{
			square = sqrt( square );
			reflectCoeff = GetReflectionCoefficient( indexOfRefraction, dot, square );
		}
		float transCoeff = 1.f - reflectCoeff;

		// multiply in the light loss for the scene
		reflectCoeff *= objectList_[objHit]->sc_;
		transCoeff *= objectList_[objHit]->sc_;

		// get the local illumination
		color = GetLocalIllumination( incident, shadowInt, objHit, normal, reflectCoeff );

		// if we're under the threshold, calculate in the reflection
		if ( reflectCoeff >= 1.f / 255.f )
		{
			Ray reflect( intersection + normal * intDir, 2.f * ( incident * normal ) * normal - incident );
			color += reflectCoeff * RayCast( reflect, depth - 1, ni );
		}

		// if we're under the threshold, calculate in the transmission
		if ( transCoeff >= 1.f / 255.f )
		{
			Ray transmission( intersection + normal * -intDir, GetTransmissionVector( incident, normal, indexOfRefraction, dot, square ) );
			color += transCoeff * RayCast( transmission, depth - 1, nt );
		}
	}

	return color;
}

/*------------------------------------------------------------------------------
GetLocalIllumination
------------------------------------------------------------------------------*/
/*  Gets the local illumination for the current depth.  Uses ambient, diffuse,
	and phong lighting.

	param: incident the incident vector
	param: shadowInt the shadow intersection point (moved outwards on the normal)
	param: objHit the object hit
	param: normal the normal
	param: reflectCoeff the reflection coefficient

	return: the color for the scene
*/
Vector3 RayTracer::GetLocalIllumination( const Vector3 &incident, const Vector3 &shadowInt, const unsigned objHit, const Vector3 &normal, const float reflectCoeff )
{
	// find the ambient light term
	Vector3 ambient = objectList_[objHit]->diffuse_;
	ambient.x *= ambient_.x;
	ambient.y *= ambient_.y;
	ambient.z *= ambient_.z;

	// find the color based off of the lights in the world
	Vector3 color = ambient;
	for ( unsigned light = 0; light < lightList_.size(); ++light )
	{
		// find the light vector
		Vector3 lightVec = ( lightList_[light].pos - shadowInt );

		// find out if an object is blocking the light
		unsigned spareHit;
		float hitTime = GetCollision( Ray( shadowInt, lightVec ), spareHit );
		if ( hitTime >= 0.f && hitTime <= 1.f )
			continue;

		lightVec.Normalize();

		// diffuse * |n * l| * light
		// find the diffuse reflection term
		float dot = fabs( normal * lightVec );

		Vector3 diffuse = objectList_[objHit]->diffuse_ * dot;
		diffuse.x *= lightList_[light].rgb.x;
		diffuse.y *= lightList_[light].rgb.y;
		diffuse.z *= lightList_[light].rgb.z;

		// find the specular reflection coefficient
		Vector3 ri = 2.f * ( lightVec * normal ) * normal - lightVec;
		Vector3 specular = Vector3::Zero;
		if ( ri * incident > 0.f )
			specular = reflectCoeff * pow( ri * incident, objectList_[objHit]->se_ ) * lightList_[light].rgb;

		// combine all the values
		color += diffuse + specular;
	}
	return color;
}

/*------------------------------------------------------------------------------
Render
------------------------------------------------------------------------------*/
/*  Renders the scene (if necessary).
*/
void RayTracer::Render( void )
{
	// make sure we need to redraw the scene in the first place
	if ( redraw_ )
	{
		// if we're not in the middle of drawing, reset everything
		if ( !drawing_ )
		{
			iStep_ = -1.f;
			iWidth_ = 0;
			BitBlt(backBuffer_, 0, 0, width_, height_, NULL, 0, 0, BLACKNESS);
			drawing_ = true;
		}


		// draw stuff here
		Vector3 camPos = camera_.c + camera_.e;

		int height = 0;
		for ( float j = -1.f; j <= 1.f; j += 2.f / height_, ++height )
		{
			Vector3 pos = camera_.c + iStep_ * camera_.u + j * camera_.v;

			// cast the ray out into the scene and find out what color
			// we want for this particular pixel
			Ray ray( pos, Normal(pos - camPos) );
			Vector3 color = RayCast( ray, 10, 1.f );
			SetPixel( iWidth_, height, color.Color() );
		}

		// increase our step and also what pixel horizontally we're at
		iStep_ += 2.f / width_;
		++ iWidth_;

		// end the drawing session
		if ( iStep_ > 1.f )
		{
			redraw_ = false;
			drawing_ = false;
		}
	}

	HDC temp = GetDC( window_ );
	BitBlt(temp, 0, 0, width_, height_, backBuffer_, 0, 0, SRCCOPY);
	ReleaseDC( window_, temp);
}

/*------------------------------------------------------------------------------
SetCamera
------------------------------------------------------------------------------*/
/*  Sets the camera for the scene.  Resizes the window based off of the
	camera ratio so the scene isn't skewed.

	param: c center of view plane
	param: u horizontal view vector
	param: v vertical view vector
	param: e eye of the camera
*/
void RayTracer::SetCamera( const Vector3 &c, const Vector3 &u, const Vector3 &v, const Vector3 &e )
{
	camera_ = Camera( c, u, v, e );

	float camRatio = u.Length() / v.Length();

	width_ = static_cast<int>( width_ * camRatio ) + 1;
	SetWindowPos( window_, HWND_TOP, 40, 40, width_ + eWidth_, height_ + eHeight_, 0 );
	CreateBackBuffer();
}

/*------------------------------------------------------------------------------
Clear
------------------------------------------------------------------------------*/
/*  Clears everything regarding the scene.
*/
void RayTracer::Clear( void )
{
	// clear the object list
	for ( unsigned i = 0; i < objectList_.size(); ++i )
	{
		if ( objectList_[i] )
			delete objectList_[i];
	}
	objectList_.clear();

	// clear the light list
	lightList_.clear();

	// reset the ambient to nothing
	ambient_ = Vector3::Zero;

	// make sure we draw everything again
	redraw_ = true;
	drawing_ = false;

	// resize the window back to normal
	SetFullscreen( fullscreen_ );
}
