Skip to content

Commit f9094ff

Browse files
authored
Added DecisionTree
1 parent 1ddc0f7 commit f9094ff

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed

Decision Tree.ipynb

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 2,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"def entropy(y):\n",
19+
" hist = np.bincount(y)\n",
20+
" ps = hist / len(y)\n",
21+
" return -np.sum([p * np.log2(p) for p in ps if p > 0])"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 3,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"class Node:\n",
31+
" def __init__(self, feature=None, left=None, right=None, threshold=None, value=None):\n",
32+
" self.feature = feature\n",
33+
" self.left = left\n",
34+
" self.right = right\n",
35+
" self.threshold = threshold\n",
36+
" self.value = value\n",
37+
" \n",
38+
" def is_leaf_node(self):\n",
39+
" return self.value is not None"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": 4,
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"class DecisionTree:\n",
49+
" def __init__(self, min_sample_split=2, max_depth=100, n_feature=None):\n",
50+
" self.n_feature = n_feature\n",
51+
" self.min_sample_split = min_sample_split\n",
52+
" self.max_depth = max_depth\n",
53+
" \n",
54+
" def fit(self, X, y):\n",
55+
" n_samples, n_features = X.shape\n",
56+
" self.n_feature = X.shape[1] if not self.n_feature else min(self.n_feature, n_features)\n",
57+
" self.root = self.growTree(X, y)\n",
58+
" \n",
59+
" def predict(self, X):\n",
60+
" ypred = [self.traverseTree(x, self.root) for x in X]\n",
61+
" return np.array(ypred)\n",
62+
" \n",
63+
" def growTree(self, X, y, depth=0):\n",
64+
" n_samples, n_features = X.shape\n",
65+
" n_labels = np.unique(y)\n",
66+
" \n",
67+
" if (depth >= max_depth or n_samples < self.min_sample_split \n",
68+
" or n_labels == 1):\n",
69+
" leaf_value = self.most_common_label(y)\n",
70+
" return Node(value=leaf_value)\n",
71+
" \n",
72+
" \n",
73+
" feature_idxs = np.random.choice(n_features, self.n_feature, replace=False)\n",
74+
" bestFeature, bestThreshold = self.bestCriteria(X, y, feature_idxs)\n",
75+
" \n",
76+
" left_idxs, right_idxs = self.split(X[:, bestFeature], bestThreshold)\n",
77+
" left = self.growTree(X[left_idxs, :], y[left_idxs], depth+1)\n",
78+
" right = self.growTree(X[right_idxs, :], y[right_idxs], depth+1)\n",
79+
" \n",
80+
" return Node(feature=bestFeature, left = left, right=right, threshold=bestThreshold)\n",
81+
" \n",
82+
" def bestCriteria(X, y, feature_idxs):\n",
83+
" best_gain = -1\n",
84+
" splitIdx, splitThreshold = None, None\n",
85+
" \n",
86+
" for f in feature_idxs:\n",
87+
" X_column = X[:, f]\n",
88+
" thresholds = np.unique(X_column)\n",
89+
" for t in thresholds:\n",
90+
" gain = self.infoGain(X_column, y, t)\n",
91+
" \n",
92+
" if gain > best_gain:\n",
93+
" best_gain = gain\n",
94+
" split_idx = f\n",
95+
" split_thresh = t\n",
96+
" \n",
97+
" return split_idx, split_thresh\n",
98+
" \n",
99+
" def infoGain(self, X, y, t):\n",
100+
" pE = entropy(y)\n",
101+
" \n",
102+
" leftIdx, rightIdx = self.split(X, t)\n",
103+
" if len(leftIdx)==0 or len(rightIdx)==0:\n",
104+
" return 0\n",
105+
" \n",
106+
" n = len(y)\n",
107+
" n_l, n_r = len(leftIdx), len(rightIdx)\n",
108+
" e_l, e_r = entropy(y[leftIdx]), entropy(y[rightIdx])\n",
109+
" \n",
110+
" child_entropy = (n_l * e_l) / n + (n_r * e_r) / n\n",
111+
" ig = pE - child_entropy\n",
112+
" return ig\n",
113+
" \n",
114+
" def split(self, X, threshold):\n",
115+
" leftIdxs = np.argwhere(X>= threshold).flatten()\n",
116+
" rightIdxs = np.argwhere(X<threshold).flatten()\n",
117+
" return leftIdxs, rightIdxs\n",
118+
" \n",
119+
" def traverseTree(self, X, node):\n",
120+
" if node.is_leaf_node():\n",
121+
" return node.value\n",
122+
" \n",
123+
" if X[node.feature] <= node.threshold:\n",
124+
" return traverseTree(X, node.left)\n",
125+
" return traverseTree(X, node.right)\n",
126+
" \n",
127+
" def most_common_label(self, label):\n",
128+
" mostCommon = Counter(label)\n",
129+
" mc = mostCommon(1)[0][0]\n",
130+
" return mc"
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"execution_count": null,
136+
"metadata": {},
137+
"outputs": [],
138+
"source": []
139+
}
140+
],
141+
"metadata": {
142+
"kernelspec": {
143+
"display_name": "Python 3",
144+
"language": "python",
145+
"name": "python3"
146+
},
147+
"language_info": {
148+
"codemirror_mode": {
149+
"name": "ipython",
150+
"version": 3
151+
},
152+
"file_extension": ".py",
153+
"mimetype": "text/x-python",
154+
"name": "python",
155+
"nbconvert_exporter": "python",
156+
"pygments_lexer": "ipython3",
157+
"version": "3.7.6"
158+
}
159+
},
160+
"nbformat": 4,
161+
"nbformat_minor": 4
162+
}

0 commit comments

Comments
 (0)