-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmatrix.h
112 lines (95 loc) · 2.83 KB
/
matrix.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#ifndef MATRIX_H
#define MATRIX_H
#include <functional>
template <class T, int rows, int cols>
class matrix
{
private:
std::array<T, unsigned(rows*cols)> data;
size_t index(const int row, const int col) const
{
assert(row >= 0 && col >= 0 && row < rows && col < cols);
return size_t(row*cols + col);
}
public:
matrix() : data({{0}})
{}
matrix(const std::array<T, unsigned(rows*cols)> &m) : data({{0}})
{
data = m;
}
int get_rows() const
{
return rows;
}
int get_cols() const
{
return cols;
}
T & at(const int row, const int col)
{
return data[index(row, col)];
}
T at(const int row, const int col) const
{
return data[index(row, col)];
}
bool operator ==(const matrix<T, rows, cols> &m) const
{
for(int row = 0; row < rows; ++row)
for(int col = 0; col < cols; ++col)
if(data[index(row, col)] != m.at(row, col))
return false;
return true;
}
bool roughly_equal(const matrix<T, rows, cols> &m, T tolerance) const
{
for(int row = 0; row < rows; ++row)
for(int col = 0; col < cols; ++col)
if(std::abs(data[index(row, col)] -
m.at(row, col)) > tolerance)
return false;
return true;
}
bool operator !=(const matrix<T, rows, cols> &m) const
{
return !(*this == m);
}
matrix operator +(const matrix<T, rows, cols> &m) const
{
matrix<T, rows, cols> ans;
for(int row = 0; row < rows; ++row)
for(int col = 0; col < cols; ++col)
ans.at(row, col) = data[index(row, col)] + m.at(row, col);
return ans;
}
matrix operator -(const matrix<T, rows, cols> &m) const
{
matrix<T, rows, cols> ans;
for(int row = 0; row < rows; ++row)
for(int col = 0; col < cols; ++col)
ans.at(row, col) = data[index(row, col)] - m.at(row, col);
return ans;
}
template<int N>
matrix<T, rows, N> operator *(const matrix<T, cols, N> &m) const
{
matrix<T, rows, N> ans;
for(int col = 0; col < N; ++col)
for(int row = 0; row < rows; ++row)
{
T sum = 0;
for(int i = 0; i < cols; ++i)
sum += data[index(row, i)]*m.at(i, col);
ans.at(row, col) = sum;
}
return ans;
}
void transform_each(const std::function<T(T)> &foo)
{
for(int row = 0; row < rows; ++row)
for(int col = 0; col < cols; ++col)
data[index(row, col)] = foo(data[index(row, col)]);
}
};
#endif