|
| 1 | +import numpy as np |
| 2 | +import logging |
| 3 | +from jax import numpy as jnp |
| 4 | +import jax |
| 5 | +from jax import grad, jit, vmap, pmap |
| 6 | +from typing import List, Tuple, Union |
| 7 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 8 | +import optax # For advanced optimization algorithms |
| 9 | +import haiku as hk # For model building and parameter management |
| 10 | + |
| 11 | +class SimpleMind: |
| 12 | + def __init__(self, input_size, hidden_sizes, output_size, activation='relu', optimizer='adam', learning_rate=0.001, regularization=None, reg_lambda=0.01): |
| 13 | + """ |
| 14 | + Initialize the SimpleMind neural network. |
| 15 | + |
| 16 | + :param input_size: Number of input neurons. |
| 17 | + :param hidden_sizes: List of the number of neurons in each hidden layer. |
| 18 | + :param output_size: Number of output neurons. |
| 19 | + :param activation: Activation function to use ('sigmoid', 'tanh', 'relu'). |
| 20 | + :param optimizer: Optimizer to use ('sgd', 'adam'). |
| 21 | + :param learning_rate: Learning rate for training. |
| 22 | + :param regularization: Regularization method ('l2'). |
| 23 | + :param reg_lambda: Regularization strength. |
| 24 | + """ |
| 25 | + self.input_size = input_size |
| 26 | + self.hidden_sizes = hidden_sizes |
| 27 | + self.output_size = output_size |
| 28 | + self.learning_rate = learning_rate |
| 29 | + self.regularization = regularization |
| 30 | + self.reg_lambda = reg_lambda |
| 31 | + |
| 32 | + self.params = self._initialize_parameters() |
| 33 | + |
| 34 | + self.activation = activation |
| 35 | + self.optimizer = optimizer |
| 36 | + self.opt_state = self._setup_optimizer() |
| 37 | + self._setup_logging() |
| 38 | + |
| 39 | + def _setup_logging(self): |
| 40 | + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| 41 | + |
| 42 | + def _activation_function(self, s): |
| 43 | + if self.activation == 'sigmoid': |
| 44 | + return 1 / (1 + jnp.exp(-s)) |
| 45 | + elif self.activation == 'tanh': |
| 46 | + return jnp.tanh(s) |
| 47 | + elif self.activation == 'relu': |
| 48 | + return jnp.maximum(0, s) |
| 49 | + else: |
| 50 | + raise ValueError("Unsupported activation function.") |
| 51 | + |
| 52 | + def _activation_derivative(self, s): |
| 53 | + if self.activation == 'sigmoid': |
| 54 | + return s * (1 - s) |
| 55 | + elif self.activation == 'tanh': |
| 56 | + return 1 - jnp.power(s, 2) |
| 57 | + elif self.activation == 'relu': |
| 58 | + return jnp.where(s > 0, 1, 0) |
| 59 | + else: |
| 60 | + raise ValueError("Unsupported activation function.") |
| 61 | + |
| 62 | + def _initialize_parameters(self): |
| 63 | + params = {} |
| 64 | + layer_sizes = [self.input_size] + self.hidden_sizes + [self.output_size] |
| 65 | + for i in range(len(layer_sizes) - 1): |
| 66 | + params[f'W{i}'] = jnp.random.randn(layer_sizes[i], layer_sizes[i+1]) * 0.01 |
| 67 | + params[f'b{i}'] = jnp.zeros(layer_sizes[i+1]) |
| 68 | + return params |
| 69 | + |
| 70 | + def forward(self, X, params): |
| 71 | + activations = X |
| 72 | + for i in range(len(self.hidden_sizes) + 1): |
| 73 | + z = jnp.dot(activations, params[f'W{i}']) + params[f'b{i}'] |
| 74 | + activations = self._activation_function(z) if i < len(self.hidden_sizes) else z |
| 75 | + return activations |
| 76 | + |
| 77 | + @jit |
| 78 | + def backpropagate(self, X, y, params, opt_state): |
| 79 | + def loss_fn(params): |
| 80 | + predictions = self.forward(X, params) |
| 81 | + loss = jnp.mean(jnp.square(y - predictions)) |
| 82 | + if self.regularization == 'l2': |
| 83 | + l2_penalty = sum(jnp.sum(jnp.square(params[f'W{i}'])) for i in range(len(self.hidden_sizes) + 1)) |
| 84 | + loss += self.reg_lambda * l2_penalty / 2 |
| 85 | + return loss |
| 86 | + |
| 87 | + grads = grad(loss_fn)(params) |
| 88 | + updates, opt_state = self.optimizer.update(grads, opt_state) |
| 89 | + new_params = optax.apply_updates(params, updates) |
| 90 | + return new_params, opt_state |
| 91 | + |
| 92 | + def _setup_optimizer(self): |
| 93 | + if self.optimizer == 'adam': |
| 94 | + self.optimizer = optax.adam(self.learning_rate) |
| 95 | + elif self.optimizer == 'sgd': |
| 96 | + self.optimizer = optax.sgd(self.learning_rate) |
| 97 | + else: |
| 98 | + raise ValueError("Unsupported optimizer.") |
| 99 | + return self.optimizer.init(self.params) |
| 100 | + |
| 101 | + def train(self, X, y, epochs): |
| 102 | + for epoch in range(epochs): |
| 103 | + self.params, self.opt_state = self._parallel_backpropagate(X, y, self.params, self.opt_state) |
| 104 | + if epoch % 100 == 0: |
| 105 | + loss = self._calculate_loss(X, y, self.params) |
| 106 | + logging.info(f"Epoch {epoch}, Loss: {loss}") |
| 107 | + |
| 108 | + def _parallel_backpropagate(self, X, y, params, opt_state): |
| 109 | + with ThreadPoolExecutor() as executor: |
| 110 | + futures = [executor.submit(self.backpropagate, X[i], y[i], params, opt_state) for i in range(len(X))] |
| 111 | + for future in as_completed(futures): |
| 112 | + params, opt_state = future.result() |
| 113 | + return params, opt_state |
| 114 | + |
| 115 | + @jit |
| 116 | + def _calculate_loss(self, X, y, params): |
| 117 | + output = self.forward(X, params) |
| 118 | + loss = jnp.mean(jnp.square(y - output)) |
| 119 | + if self.regularization == 'l2': |
| 120 | + loss += self.reg_lambda / 2 * sum(jnp.sum(jnp.square(params[f'W{i}'])) for i in range(len(self.hidden_sizes) + 1)) |
| 121 | + return loss |
| 122 | + |
| 123 | +# Example Usage |
| 124 | +if __name__ == "__main__": |
| 125 | + input_size = 3 |
| 126 | + hidden_sizes = [5, 5] |
| 127 | + output_size = 1 |
| 128 | + learning_rate = 0.001 |
| 129 | + epochs = 1000 |
| 130 | + |
| 131 | + X = jnp.array([[0.1, 0.2, 0.3]]) |
| 132 | + y = jnp.array([[0.5]]) |
| 133 | + |
| 134 | + mind = SimpleMind(input_size, hidden_sizes, output_size, activation='relu', optimizer='adam', learning_rate=learning_rate, regularization='l2', reg_lambda=0.01) |
| 135 | + |
| 136 | + mind.train(X, y, epochs) |
| 137 | + print("Final Output:", mind.forward(X, mind.params)) |
0 commit comments