#include "CMatrix.h"
#include "CDiagonalMatrix.h"
#include <iostream>
using namespace std;

CDiagonalMatrix::CDiagonalMatrix()
	: CMatrix()
{
}

CDiagonalMatrix::~CDiagonalMatrix()
{
}

CDiagonalMatrix::CDiagonalMatrix(const CDiagonalMatrix& right)
{
	numberOfCols = right.numberOfCols;
	numberOfRows = right.numberOfRows;
	rows = new CVector[numberOfRows + 1];
	for (int i = 0; i < numberOfRows + 1; i++)
		rows[i] = right.rows[i];
}

CDiagonalMatrix::CDiagonalMatrix(int row, int col, double init)
					    throw (SquareMatrixExpectedException)
{
	if (row!=col ) // insert a condition here
	   throw SquareMatrixExpectedException("Square Matrix expected in Diagonal Matrix generation",row,col);// something
	numberOfRows = row;
	numberOfCols = col;
	rows = new CVector[numberOfRows + 1];
	CVector v(1, init);
	for (int i = 0; i < numberOfRows; i++)
		rows[i] = v; 
	CVector zero(1);
	rows[numberOfRows] = zero;
}

CDiagonalMatrix& CDiagonalMatrix::operator~() const
{
	CDiagonalMatrix* v = new CDiagonalMatrix(numberOfCols, numberOfRows, 0.);

	v->rows = rows;	

	return *v;
} 

CVector& CDiagonalMatrix::operator[](int index) const
						    throw (OutOfBoundsException,NegativeIndexException)
{
	if (index < 0) // insert a condition here
	   throw NegativeIndexException("Negative index in column vector access from Diagonal Matrix",index);// something
	if (index > numberOfRows) // insert a condition here
	   throw OutOfBoundsException("Out of bounds in column vector access from Diagonal Matrix",0,index);// something
	CVector* v = new CVector(numberOfRows, 0.0);

	(*v)[index] = (*this) (index, index);

	return *v;
} 

CDiagonalMatrix& CDiagonalMatrix::operator=(const CDiagonalMatrix& right)
{
	if (rows)
		delete[] rows;  	
	numberOfRows = right.numberOfRows;
	numberOfCols = right.numberOfCols;
	rows = new CVector[numberOfRows + 1];
	for (int i = 0; i < numberOfRows + 1; i++)
		rows[i] = right.rows[i];
	return *this;
}

CDiagonalMatrix& CDiagonalMatrix::operator+(const CDiagonalMatrix& right) const
							      throw (SizeMismatchException)
{
	if (getRowSize()!=right.getRowSize() ) // insert a condition here
	   throw SizeMismatchException("Size Mismatch in Diagonal Matrix addition",getRowSize(),getRowSize(),right.getRowSize(),right.getRowSize());// something

	CDiagonalMatrix* m = new CDiagonalMatrix(*this);
	for (int i = 0; i < m->numberOfRows; i++)
		m->rows[i] = m->rows[i] + right.rows[i];
	return *m;
}

CDiagonalMatrix& CDiagonalMatrix::operator-(const CDiagonalMatrix& right) const
							     throw (SizeMismatchException)
{
	if (getRowSize()!=right.getRowSize()) // insert a condition here
	   throw SizeMismatchException("Size mismatch in Diagonal Matrix subtraction",getRowSize(),getRowSize(),right.getRowSize(),right.getRowSize());// something

	CDiagonalMatrix* m = new CDiagonalMatrix(*this);
	for (int i = 0; i < m->numberOfRows; i++)
		m->rows[i] = m->rows[i] - right.rows[i];
	return *m;
}

CDiagonalMatrix& CDiagonalMatrix::operator-() const
{
	CDiagonalMatrix* m = new CDiagonalMatrix(*this);
	for (int i = 0; i < m->numberOfRows; i++)
		m->rows[i] = -m->rows[i];
	return *m;
}

CDiagonalMatrix& CDiagonalMatrix::operator*(const CDiagonalMatrix& right) const
							     throw (SizeMismatchException)
{
	if (getRowSize()!=right.getRowSize()) // insert a condition here
	   throw SizeMismatchException("Size mismatch in Diagonal Matrix multiplication",getRowSize(),getRowSize(),right.getRowSize(),right.getRowSize());// something

	CDiagonalMatrix* m = new CDiagonalMatrix(numberOfRows, right.numberOfCols);
	for (int i = 0; i < m->numberOfRows; i++)
		(*m) (i, i) = (*this) (i, i) * right(i, i);

	return *m;
}

CDiagonalMatrix& CDiagonalMatrix::operator/(double divisor) const
							      throw (DivideByZeroException)
{
	if (divisor==0) // insert a condition here
		throw DivideByZeroException("Divide by zero in Diagonal Matrix division"); // something
	
	CDiagonalMatrix* m = new CDiagonalMatrix(*this);
	for (int i = 0; i < m->numberOfRows; i++)
		m->rows[i] = m->rows[i] / divisor;
	return *m;
}

CDiagonalMatrix& CDiagonalMatrix::clone() const
{
	CDiagonalMatrix* v = new CDiagonalMatrix(*this);
	return *v;
}

double& CDiagonalMatrix::operator()(int i, int j) const
						    throw (OutOfBoundsException,NegativeIndexException)
{
	if (i < 0) // insert a condition here
	   throw NegativeIndexException("Negative index in Diagonal Matrix access",i);// something
	if (j < 0) // insert a condition here
	   throw NegativeIndexException("Negative index in Diagonal Matrix access",j);// something
	if (i > getRowSize() || j> getRowSize()) // insert a condition here
	   throw OutOfBoundsException("Out of bounds in Diagonal Matrix access",i,j);// something
	
	if (i == j)
		return rows[i][0] ;
	else
		return rows[numberOfRows][0];
}

ostream& operator<<(ostream& out, const CDiagonalMatrix& m)
{
	CVector v(m.numberOfCols);
	for (int i = 0; i < m.numberOfRows; i++) {
		v[i] = m.rows[i][0];
		out << "| " << v << " |" << endl ;
		v[i] = 0;
	}	
	out << endl ;
	return out ;
}
