Skip to content

Commit 9a47e26

Browse files
Create SimpleMind.py
import numpy as np import logging from jax import numpy as jnp import jax from jax import grad, jit, vmap, pmap from typing import List, Tuple, Union from concurrent.futures import ThreadPoolExecutor, as_completed import optax # For advanced optimization algorithms import haiku as hk # For model building and parameter management class SimpleMind: def __init__(self, input_size, hidden_sizes, output_size, activation='relu', optimizer='adam', learning_rate=0.001, regularization=None, reg_lambda=0.01): """ Initialize the SimpleMind neural network. :param input_size: Number of input neurons. :param hidden_sizes: List of the number of neurons in each hidden layer. :param output_size: Number of output neurons. :param activation: Activation function to use ('sigmoid', 'tanh', 'relu'). :param optimizer: Optimizer to use ('sgd', 'adam'). :param learning_rate: Learning rate for training. :param regularization: Regularization method ('l2'). :param reg_lambda: Regularization strength. """ self.input_size = input_size self.hidden_sizes = hidden_sizes self.output_size = output_size self.learning_rate = learning_rate self.regularization = regularization self.reg_lambda = reg_lambda self.params = self._initialize_parameters() self.activation = activation self.optimizer = optimizer self.opt_state = self._setup_optimizer() self._setup_logging() def _setup_logging(self): logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def _activation_function(self, s): if self.activation == 'sigmoid': return 1 / (1 + jnp.exp(-s)) elif self.activation == 'tanh': return jnp.tanh(s) elif self.activation == 'relu': return jnp.maximum(0, s) else: raise ValueError("Unsupported activation function.") def _activation_derivative(self, s): if self.activation == 'sigmoid': return s * (1 - s) elif self.activation == 'tanh': return 1 - jnp.power(s, 2) elif self.activation == 'relu': return jnp.where(s > 0, 1, 0) else: raise ValueError("Unsupported activation function.") def _initialize_parameters(self): params = {} layer_sizes = [self.input_size] + self.hidden_sizes + [self.output_size] for i in range(len(layer_sizes) - 1): params[f'W{i}'] = jnp.random.randn(layer_sizes[i], layer_sizes[i+1]) * 0.01 params[f'b{i}'] = jnp.zeros(layer_sizes[i+1]) return params def forward(self, X, params): activations = X for i in range(len(self.hidden_sizes) + 1): z = jnp.dot(activations, params[f'W{i}']) + params[f'b{i}'] activations = self._activation_function(z) if i < len(self.hidden_sizes) else z return activations @jit def backpropagate(self, X, y, params, opt_state): def loss_fn(params): predictions = self.forward(X, params) loss = jnp.mean(jnp.square(y - predictions)) if self.regularization == 'l2': l2_penalty = sum(jnp.sum(jnp.square(params[f'W{i}'])) for i in range(len(self.hidden_sizes) + 1)) loss += self.reg_lambda * l2_penalty / 2 return loss grads = grad(loss_fn)(params) updates, opt_state = self.optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, opt_state def _setup_optimizer(self): if self.optimizer == 'adam': self.optimizer = optax.adam(self.learning_rate) elif self.optimizer == 'sgd': self.optimizer = optax.sgd(self.learning_rate) else: raise ValueError("Unsupported optimizer.") return self.optimizer.init(self.params) def train(self, X, y, epochs): for epoch in range(epochs): self.params, self.opt_state = self._parallel_backpropagate(X, y, self.params, self.opt_state) if epoch % 100 == 0: loss = self._calculate_loss(X, y, self.params) logging.info(f"Epoch {epoch}, Loss: {loss}") def _parallel_backpropagate(self, X, y, params, opt_state): with ThreadPoolExecutor() as executor: futures = [executor.submit(self.backpropagate, X[i], y[i], params, opt_state) for i in range(len(X))] for future in as_completed(futures): params, opt_state = future.result() return params, opt_state @jit def _calculate_loss(self, X, y, params): output = self.forward(X, params) loss = jnp.mean(jnp.square(y - output)) if self.regularization == 'l2': loss += self.reg_lambda / 2 * sum(jnp.sum(jnp.square(params[f'W{i}'])) for i in range(len(self.hidden_sizes) + 1)) return loss # Example Usage if __name__ == "__main__": input_size = 3 hidden_sizes = [5, 5] output_size = 1 learning_rate = 0.001 epochs = 1000 X = jnp.array([[0.1, 0.2, 0.3]]) y = jnp.array([[0.5]]) mind = SimpleMind(input_size, hidden_sizes, output_size, activation='relu', optimizer='adam', learning_rate=learning_rate, regularization='l2', reg_lambda=0.01) mind.train(X, y, epochs) print("Final Output:", mind.forward(X, mind.params))
1 parent bbfa90f commit 9a47e26

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

SimpleMind.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import numpy as np
2+
import logging
3+
from jax import numpy as jnp
4+
import jax
5+
from jax import grad, jit, vmap, pmap
6+
from typing import List, Tuple, Union
7+
from concurrent.futures import ThreadPoolExecutor, as_completed
8+
import optax # For advanced optimization algorithms
9+
import haiku as hk # For model building and parameter management
10+
11+
class SimpleMind:
12+
def __init__(self, input_size, hidden_sizes, output_size, activation='relu', optimizer='adam', learning_rate=0.001, regularization=None, reg_lambda=0.01):
13+
"""
14+
Initialize the SimpleMind neural network.
15+
16+
:param input_size: Number of input neurons.
17+
:param hidden_sizes: List of the number of neurons in each hidden layer.
18+
:param output_size: Number of output neurons.
19+
:param activation: Activation function to use ('sigmoid', 'tanh', 'relu').
20+
:param optimizer: Optimizer to use ('sgd', 'adam').
21+
:param learning_rate: Learning rate for training.
22+
:param regularization: Regularization method ('l2').
23+
:param reg_lambda: Regularization strength.
24+
"""
25+
self.input_size = input_size
26+
self.hidden_sizes = hidden_sizes
27+
self.output_size = output_size
28+
self.learning_rate = learning_rate
29+
self.regularization = regularization
30+
self.reg_lambda = reg_lambda
31+
32+
self.params = self._initialize_parameters()
33+
34+
self.activation = activation
35+
self.optimizer = optimizer
36+
self.opt_state = self._setup_optimizer()
37+
self._setup_logging()
38+
39+
def _setup_logging(self):
40+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
41+
42+
def _activation_function(self, s):
43+
if self.activation == 'sigmoid':
44+
return 1 / (1 + jnp.exp(-s))
45+
elif self.activation == 'tanh':
46+
return jnp.tanh(s)
47+
elif self.activation == 'relu':
48+
return jnp.maximum(0, s)
49+
else:
50+
raise ValueError("Unsupported activation function.")
51+
52+
def _activation_derivative(self, s):
53+
if self.activation == 'sigmoid':
54+
return s * (1 - s)
55+
elif self.activation == 'tanh':
56+
return 1 - jnp.power(s, 2)
57+
elif self.activation == 'relu':
58+
return jnp.where(s > 0, 1, 0)
59+
else:
60+
raise ValueError("Unsupported activation function.")
61+
62+
def _initialize_parameters(self):
63+
params = {}
64+
layer_sizes = [self.input_size] + self.hidden_sizes + [self.output_size]
65+
for i in range(len(layer_sizes) - 1):
66+
params[f'W{i}'] = jnp.random.randn(layer_sizes[i], layer_sizes[i+1]) * 0.01
67+
params[f'b{i}'] = jnp.zeros(layer_sizes[i+1])
68+
return params
69+
70+
def forward(self, X, params):
71+
activations = X
72+
for i in range(len(self.hidden_sizes) + 1):
73+
z = jnp.dot(activations, params[f'W{i}']) + params[f'b{i}']
74+
activations = self._activation_function(z) if i < len(self.hidden_sizes) else z
75+
return activations
76+
77+
@jit
78+
def backpropagate(self, X, y, params, opt_state):
79+
def loss_fn(params):
80+
predictions = self.forward(X, params)
81+
loss = jnp.mean(jnp.square(y - predictions))
82+
if self.regularization == 'l2':
83+
l2_penalty = sum(jnp.sum(jnp.square(params[f'W{i}'])) for i in range(len(self.hidden_sizes) + 1))
84+
loss += self.reg_lambda * l2_penalty / 2
85+
return loss
86+
87+
grads = grad(loss_fn)(params)
88+
updates, opt_state = self.optimizer.update(grads, opt_state)
89+
new_params = optax.apply_updates(params, updates)
90+
return new_params, opt_state
91+
92+
def _setup_optimizer(self):
93+
if self.optimizer == 'adam':
94+
self.optimizer = optax.adam(self.learning_rate)
95+
elif self.optimizer == 'sgd':
96+
self.optimizer = optax.sgd(self.learning_rate)
97+
else:
98+
raise ValueError("Unsupported optimizer.")
99+
return self.optimizer.init(self.params)
100+
101+
def train(self, X, y, epochs):
102+
for epoch in range(epochs):
103+
self.params, self.opt_state = self._parallel_backpropagate(X, y, self.params, self.opt_state)
104+
if epoch % 100 == 0:
105+
loss = self._calculate_loss(X, y, self.params)
106+
logging.info(f"Epoch {epoch}, Loss: {loss}")
107+
108+
def _parallel_backpropagate(self, X, y, params, opt_state):
109+
with ThreadPoolExecutor() as executor:
110+
futures = [executor.submit(self.backpropagate, X[i], y[i], params, opt_state) for i in range(len(X))]
111+
for future in as_completed(futures):
112+
params, opt_state = future.result()
113+
return params, opt_state
114+
115+
@jit
116+
def _calculate_loss(self, X, y, params):
117+
output = self.forward(X, params)
118+
loss = jnp.mean(jnp.square(y - output))
119+
if self.regularization == 'l2':
120+
loss += self.reg_lambda / 2 * sum(jnp.sum(jnp.square(params[f'W{i}'])) for i in range(len(self.hidden_sizes) + 1))
121+
return loss
122+
123+
# Example Usage
124+
if __name__ == "__main__":
125+
input_size = 3
126+
hidden_sizes = [5, 5]
127+
output_size = 1
128+
learning_rate = 0.001
129+
epochs = 1000
130+
131+
X = jnp.array([[0.1, 0.2, 0.3]])
132+
y = jnp.array([[0.5]])
133+
134+
mind = SimpleMind(input_size, hidden_sizes, output_size, activation='relu', optimizer='adam', learning_rate=learning_rate, regularization='l2', reg_lambda=0.01)
135+
136+
mind.train(X, y, epochs)
137+
print("Final Output:", mind.forward(X, mind.params))

0 commit comments

Comments
 (0)