/*
	A pretty dumb implementation of the Strassen algorithm (matrix multiplication)
	
	(C) 2004 by Andrej Krutak
	Released under the GPL 2.0 or later :o)
	andrek.wz.cz; andrek(at)inmail.sk
*/


#include <stdio.h>
#include <time.h>
#include <string>
#include <fstream>
#include <iostream>

using namespace std;

typedef short DTYPE;

enum {
	COPYFROM_4REAL=0,
	STRASSEN_TYPE=1, //STRASSEN_TYPE==1   =>    COPYFROM_4REAL=0 recommended
};

int MINW=128, MINH=128;

class Matrix {
public:
	Matrix(int h, int w) { sub_=0; data_=0; create(h, w); }
	Matrix() { sub_=0; data_=0; }
	Matrix(Matrix &m) { data_=0; copyfrom(m); }
	Matrix(Matrix &m, int yo, int xo) { create(m, yo, xo, -1, -1); }
	Matrix(char *s);
	~Matrix() {destroy();}

	void create(const Matrix &m, int yo, int xo, int h, int w);
	void create(int h, int w);
	void destroy();

	void getQuarters(Matrix &m11, Matrix &m12, Matrix &m21, Matrix &m2) const;
	void setQuarters(const Matrix &m11, const Matrix &m12, const Matrix &m21, const Matrix &m2);

	void copyfrom(Matrix &m);
	inline void setData(int i, int j, DTYPE val) { data_[i][j]=val; }
	inline DTYPE getData(int i, int j) const { return data_[i][j]; }

	friend Matrix multiplyo3(const Matrix &a, const Matrix &b);
	friend Matrix multiplystrassen(const Matrix &a, const Matrix &b);

	Matrix &operator =(Matrix &m) { copyfrom(m); return *this; }
	Matrix operator*(Matrix &m) {return multiplystrassen(*this, m);}
	Matrix operator+(Matrix &m);
	Matrix operator-(Matrix &m);
	friend ostream& operator<< ( ostream& os, Matrix& m) { return operator <<(os, static_cast<const Matrix &>(m)); }
	friend ostream& operator<< ( ostream& os, const Matrix& m);

	int getW() {return w_;}
	int getH() {return h_;}
private:
	int w_, h_;
	DTYPE **data_;

	int sub_;
};

Matrix::Matrix(char *s)
{
	sub_=0;
	data_=0;

	FILE *f;
	int d;

	f=fopen(s, "r");
	fscanf(f, "%d %d", &h_, &w_);
	create(h_, w_);

	int i, j;
	for (i=0; i<h_; i++) {
	
		for (j=0; j<w_; j++) {
			fscanf(f, "%d ", &d);
			data_[i][j]=d;
		}
	}
	fclose(f);
}

void Matrix::create(int h, int w)
{
	destroy();
	
	sub_=0;
	data_=new DTYPE*[h];
	for (int i=0; i<h; i++)
		data_[i]=new DTYPE[w];
	//data_=new DTYPE[h*w];
	w_=w;
	h_=h;
}

void Matrix::create(const Matrix &m, int yo, int xo, int h, int w)
{
	sub_=1;
	
	if (w==-1)
		w_=m.w_-xo;
	else {
		if (m.w_-xo<w)
			throw "too_big_width";
		w_=w;
	}
	
	if (h==-1)
		h_=m.h_-yo;
	else {
		if (m.h_-yo<h)
			throw "too_big_height";
		h_=h;
	}

	data_=new DTYPE*[h_];
	for (int i=0; i<h_; i++)
		data_[i]=&(m.data_[i+yo][xo]);
}

void Matrix::destroy()
{
	if (!data_)
		return;

	if (!sub_) {
		if (COPYFROM_4REAL) {
			for (int i=0; i<h_; i++)
				delete[] data_[i];
			delete[] data_;
			data_=0;
		} else {
			delete[] data_;
			data_=0;
		}
	} else {
		delete[] data_;
		data_=0;
	}
}

void Matrix::copyfrom(Matrix &m)
{
	sub_=0;

	if (COPYFROM_4REAL) {
		create(m.h_, m.w_);
		for (int i=0; i<h_; i++)
			memcpy(data_[i], m.data_[i], w_*sizeof(DTYPE));
	} else {
		w_=m.w_;
		h_=m.h_;

		data_=new DTYPE*[h_];
		for (int i=0; i<h_; i++)
			data_[i]=m.data_[i];

		m.data_=0;
		m.w_=0;
		m.h_=0;
	}
}

ostream &operator<<(ostream &os, const Matrix &m)
{
	int i, j;

	for (i=0; i<m.h_; i++) {
		for (j=0; j<m.w_; j++) {
			os << m.getData(i, j) << "\t";
		}
		os << endl;
	}
	return os;
}

Matrix Matrix::operator+(Matrix &mb)
{
	if (w_!=mb.w_ || h_!=mb.h_)
		throw "BAD_SIZES";

	Matrix m1(h_, w_);

	DTYPE *p1, *p2, *p3;
	int i, j;
	for (j=0; j<h_; j++) {
		p1=m1.data_[j];
		p2=data_[j];
		p3=mb.data_[j];
		for (i=0; i<w_; i++) {
			p1[i]=p2[i]+p3[i];
		}
	}

	return m1;
}

Matrix Matrix::operator-(Matrix &mb)
{
	if (w_!=mb.w_ || h_!=mb.h_)
		throw "BAD_SIZES";

	Matrix m1(h_, w_);
		
	DTYPE *p1, *p2, *p3;
	int i, j;
	for (j=0; j<h_; j++) {
		p1=m1.data_[j];
		p2=data_[j];
		p3=mb.data_[j];
		for (i=0; i<w_; i++) {
			p1[i]=p2[i]-p3[i];
		}
	}

	return m1;
}

void Matrix::getQuarters(Matrix &m11, Matrix &m12, Matrix &m21, Matrix &m22) const
{
	int w1=w_/2;
	int h1=h_/2;

	m11.create(*this, 0, 0, h1, w1);
	m12.create(*this, 0, w1, h1, w_-w1);
	m21.create(*this, h1, 0, h_-h1, w1);
	m22.create(*this, h1, w1, h_-h1, w_-w1);
}

void Matrix::setQuarters(const Matrix &m11, const Matrix &m12, const Matrix &m21, const Matrix &m22)
{
	if (m11.w_!=m21.w_ || m12.w_!=m22.w_ ||
		m11.h_!=m12.h_ || m21.h_!=m22.h_)
			throw "Bad sub-matrixes";

	int i, j;

	create(m11.h_+m21.h_, m11.w_+m12.w_);

	for (i=0; i<m11.h_; i++) {
		for (j=0; j<m11.w_; j++) {
			data_[i][j]=m11.data_[i][j];
		}
	}

	for (i=0; i<m12.h_; i++) {
		for (j=0; j<m12.w_; j++) {
			data_[i][j+m11.w_]=m12.data_[i][j];
		}
	}

	for (i=0; i<m21.h_; i++) {
		for (j=0; j<m21.w_; j++) {
			data_[i+m11.h_][j]=m21.data_[i][j];
		}
	}

	for (i=0; i<m22.h_; i++) {
		for (j=0; j<m22.w_; j++) {
			data_[i+m12.h_][j+m21.w_]=m22.data_[i][j];
		}
	}
}

Matrix multiplyo3(const Matrix &a, const Matrix &b)
{
	if (a.w_!=b.h_)
		throw "BAD_SIZES";

	Matrix m1(a.h_, b.w_);

	int i, j, k;
	DTYPE d;

	for (i=0; i<a.h_; i++) {
		for (j=0; j<b.w_; j++) {
			d=0;

			for (k=0; k<a.w_; k++) {
				d+=a.data_[i][k]*b.data_[k][j];
			}
			m1.data_[i][j]=d; //m1.setData(i, j, d);
		}
	}
	return m1;
}

Matrix multiplystrassen(const Matrix &a, const Matrix &b)
{
	if (a.w_<=MINW && b.w_<=MINW && a.h_<=MINH && b.h_<=MINH)
		return multiplyo3(a, b);

	if (a.w_!=b.h_)
		throw "BAD_SIZES";

	Matrix a11, a12, a21, a22;
	Matrix b11, b12, b21, b22;

	a.getQuarters(a11, a12, a21, a22);
	b.getQuarters(b11, b12, b21, b22);

	Matrix r1;
	if (STRASSEN_TYPE==0) {
		Matrix m1, m2, m3, m4, m5, m6, m7;

		m1=(a11+a22)*(b11+b22);
		m2=(a21+a22)*b11;
		m3=a11*(b12-b22);
		m4=a22*(b21-b11);
		m5=(a11+a12)*b22;
		m6=(a21-a11)*(b11+b12);
		m7=(a12-a22)*(b21+b22);

		r1.setQuarters(
			m1+m4-m5+m7,	m3+m5,
			m2+m4,			m1-m2+m3+m6
			);
	} else { //a more sophisticated algorithm to also reduce the + and - operations...
		Matrix s1, s2, s3, s4;
		Matrix t1, t2, t3, t4;
		Matrix p1, p2, p3, p4, p5, p6, p7;
		Matrix u1, u2, u3, u4, u5, u6, u7;

		s1=a21+a22;
		s2=s1-a11;
		s3=a11-a21;
		s4=a12-s2;

		t1=b12-b11;
		t2=b22-t1;
		t3=b22-b12;
		t4=b21-t2;

		p1=a11*b11;
		p2=a12*b21;
		p3=t1*s1;
		p4=t2*s2;
		p5=s3*t3;
		p6=s4*b22;
		p7=a22*t4;

		u1=p1+p2;
		u2=p1+p4;
		u3=u2+p5;
		u4=u3+p7;
		u5=u3+p3;
		u6=u2+p3;
		u7=u6+p6;

		r1.setQuarters(u1, u7, u4, u5);
	}
	return r1;
}

int main(int argc, char **argv)
{
	if (argc<3)
		return 1;
	MINW=atol(argv[1]);
	MINH=MINW;
	
	Matrix a(argv[2]), b(argv[3]), c, d;

	clock_t time0, time1;

	cerr << "MINW/H: " << MINW << endl << "SIZE: " << a.getW() << endl;
	cout << "MINW/H: " << MINW << endl << "SIZE: " << a.getW() << endl;

	/*
	cout << "O3: ";
	time0=clock();
	c=multiplyo3(a, b);
	time1=clock();
	cout << (double)(time1-time0)/CLOCKS_PER_SEC << endl;
	*/

	cout << "Strassen: ";
	time0=clock();
	d=multiplystrassen(a, b);
	time1=clock();
	cout << (double)(time1-time0)/CLOCKS_PER_SEC << endl;

	cout << endl;
	//cout << endl << a << "*" << endl << b << "=" << endl << "[O(N3)]" << endl << c << "==" << endl << "[Strassen]" << endl << d;
	return 0;
}