/* Matrix.C: * matrix of doubles * implementation */ #include "Matrix.H" #include <strstream.h> // creates a new matrix with given number of rows and columns Matrix::Matrix(int rows, int cols) { // error handling if (rows < 0 || cols < 0) { ostrstream myString; myString << "Invalid array size: " << rows << " x " << cols ; throw myString.str(); } nRows = rows; nCols = cols; val = new double[nRows * nCols]; } Matrix::~Matrix() { delete[] val; } // returns shape (= number of rows and columns) of the matrix int* Matrix::shape() { int * sizes = new int[2]; sizes[0] = nRows; sizes[1] = nCols; return sizes; } // sets Matrix(i,j) to v void Matrix::set(int i, int j, double v) { // check indices if (i > 0 && i <= nRows && j > 0 && j <= nCols) { val[i - 1 + nRows * (j - 1)] = v; } else { ostrstream myString; myString << "Invalid index: (" << i << "," << j << ")"; throw myString.str(); } } // returns Matrix(i,j) double Matrix::get(int i, int j) { // check indices if (i > 0 && i <= nRows && j > 0 && j <= nCols) { return val[i - 1 + nRows * (j - 1)]; } else { ostrstream myString; myString << "Invalid index: (" << i << "," << j << ")"; throw myString.str(); } } // result = matmul(this, b) void Matrix::mult(Matrix &b, Matrix &result) { int *dimA, *dimB, *dimC; dimA = shape(); dimB = b.shape(); dimC = result.shape(); // check dimensions if (dimA[1] != dimB[0]) { ostrstream myString; myString << "Invalid matrix dimensions for multiplication: (" << dimA[0] << "," << dimA[1] << ") x (" << dimB[0] << "," << dimB[1] << ")"; throw myString.str(); } else if (dimA[0] != dimC[0] || dimB[1] != dimC[1]) { ostrstream myString; myString << "Invalid matrix dimensions for result: is (" << dimC[0] << "," << dimC[1] << "), should be (" << dimA[0] << "," << dimB[1] << ")"; throw myString.str(); } int m = dimA[0]; int l = dimA[1]; int n = dimB[1]; for (int i = 1; i <= m; i++) { for (int j = 1; j <= n; j++) { result.val[i - 1 + m * (j - 1)] = 0.0; for (int k = 1; k <= l; k++) { result.val[i - 1 + m * (j - 1)] += val[i - 1 + m * (k - 1)] * b.val[k - 1 + l * (j - 1)]; } } } }