|
| 1 | +# dataClassifier.py |
| 2 | +# ----------------- |
| 3 | + |
| 4 | +import mostFrequent |
| 5 | +import perceptron |
| 6 | +import svm |
| 7 | +import mlp |
| 8 | +import samples |
| 9 | +import sys |
| 10 | +import util |
| 11 | + |
| 12 | +TRAINING_SET_SIZE = 5000 |
| 13 | +TEST_SET_SIZE = 1000 |
| 14 | +DIGIT_DATUM_WIDTH = 28 |
| 15 | +DIGIT_DATUM_HEIGHT = 28 |
| 16 | + |
| 17 | + |
| 18 | +def basicFeatureExtractorDigit(datum): |
| 19 | + """ |
| 20 | + Returns a set of pixel features indicating whether |
| 21 | + each pixel in the provided datum is white (0) or gray/black (1) |
| 22 | + """ |
| 23 | + features = util.Counter() |
| 24 | + for x in range(DIGIT_DATUM_WIDTH): |
| 25 | + for y in range(DIGIT_DATUM_HEIGHT): |
| 26 | + if datum.getPixel(x, y) > 0: |
| 27 | + features[(x, y)] = 1 |
| 28 | + else: |
| 29 | + features[(x, y)] = 0 |
| 30 | + return features |
| 31 | + |
| 32 | +def analysis(classifier, guesses, testLabels, testData, rawTestData, printImage): |
| 33 | + """ |
| 34 | + This function is called after learning. |
| 35 | + Include any code that you want here to help you analyze your results. |
| 36 | +
|
| 37 | + Use the printImage(<list of pixels>) function to visualize features. |
| 38 | +
|
| 39 | + An example of use has been given to you. |
| 40 | +
|
| 41 | + - classifier is the trained classifier |
| 42 | + - guesses is the list of labels predicted by your classifier on the test set |
| 43 | + - testLabels is the list of true labels |
| 44 | + - testData is the list of training datapoints (as util.Counter of features) |
| 45 | + - rawTestData is the list of training datapoints (as samples.Datum) |
| 46 | + - printImage is a method to visualize the features |
| 47 | + (see its use in the odds ratio part in runClassifier method) |
| 48 | +
|
| 49 | + This code won't be evaluated. It is for your own optional use |
| 50 | + (and you can modify the signature if you want). |
| 51 | + """ |
| 52 | + |
| 53 | + # Put any code here... |
| 54 | + # Example of use: |
| 55 | + for i in range(len(guesses)): |
| 56 | + prediction = guesses[i] |
| 57 | + truth = testLabels[i] |
| 58 | + if (prediction != truth): |
| 59 | + print "===================================" |
| 60 | + print "Mistake on example %d" % i |
| 61 | + print "Predicted %d; truth is %d" % (prediction, truth) |
| 62 | + print "Image: " |
| 63 | + print rawTestData[i] |
| 64 | + break |
| 65 | + |
| 66 | + |
| 67 | +class ImagePrinter: |
| 68 | + def __init__(self, width, height): |
| 69 | + self.width = width |
| 70 | + self.height = height |
| 71 | + |
| 72 | + def printImage(self, pixels): |
| 73 | + """ |
| 74 | + Prints a Datum object that contains all pixels in the |
| 75 | + provided list of pixels. This will serve as a helper function |
| 76 | + to the analysis function you write. |
| 77 | +
|
| 78 | + Pixels should take the form |
| 79 | + [(2,2), (2, 3), ...] |
| 80 | + where each tuple represents a pixel. |
| 81 | + """ |
| 82 | + image = samples.Datum(None, self.width, self.height) |
| 83 | + for pix in pixels: |
| 84 | + try: |
| 85 | + # This is so that new features that you could define which |
| 86 | + # which are not of the form of (x,y) will not break |
| 87 | + # this image printer... |
| 88 | + x, y = pix |
| 89 | + image.pixels[x][y] = 2 |
| 90 | + except: |
| 91 | + print "new features:", pix |
| 92 | + continue |
| 93 | + print image |
| 94 | + |
| 95 | + |
| 96 | +def default(str): |
| 97 | + return str + ' [Default: %default]' |
| 98 | + |
| 99 | + |
| 100 | +def readCommand(argv): |
| 101 | + "Processes the command used to run from the command line." |
| 102 | + from optparse import OptionParser |
| 103 | + parser = OptionParser(USAGE_STRING) |
| 104 | + |
| 105 | + parser.add_option('-c', '--classifier', help=default('The type of classifier'), |
| 106 | + choices=['mostFrequent', 'perceptron', 'mlp', 'svm'], default='mostFrequent') |
| 107 | + parser.add_option('-t', '--training', help=default('The size of the training set'), default=TRAINING_SET_SIZE, |
| 108 | + type="int") |
| 109 | + parser.add_option('-w', '--weights', help=default('Whether to print weights'), default=False, action="store_true") |
| 110 | + parser.add_option('-i', '--iterations', help=default("Maximum iterations to run training"), default=3, type="int") |
| 111 | + parser.add_option('-s', '--test', help=default("Amount of test data to use"), default=TEST_SET_SIZE, type="int") |
| 112 | + |
| 113 | + options, otherjunk = parser.parse_args(argv) |
| 114 | + if len(otherjunk) != 0: raise Exception('Command line input not understood: ' + str(otherjunk)) |
| 115 | + args = {} |
| 116 | + |
| 117 | + # Set up variables according to the command line input. |
| 118 | + print "Doing classification" |
| 119 | + print "--------------------" |
| 120 | + print "classifier:\t\t" + options.classifier |
| 121 | + print "training set size:\t" + str(options.training) |
| 122 | + |
| 123 | + printImage = ImagePrinter(DIGIT_DATUM_WIDTH, DIGIT_DATUM_HEIGHT).printImage |
| 124 | + featureFunction = basicFeatureExtractorDigit |
| 125 | + legalLabels = range(10) |
| 126 | + |
| 127 | + if options.training <= 0: |
| 128 | + print "Training set size should be a positive integer (you provided: %d)" % options.training |
| 129 | + print USAGE_STRING |
| 130 | + sys.exit(2) |
| 131 | + |
| 132 | + if (options.classifier == "mostFrequent"): |
| 133 | + classifier = mostFrequent.MostFrequentClassifier(legalLabels) |
| 134 | + elif (options.classifier == "mlp"): |
| 135 | + classifier = mlp.MLPClassifier(legalLabels, options.iterations) |
| 136 | + elif (options.classifier == "perceptron"): |
| 137 | + classifier = perceptron.PerceptronClassifier(legalLabels, options.iterations) |
| 138 | + elif (options.classifier == "svm"): |
| 139 | + classifier = svm.SVMClassifier(legalLabels) |
| 140 | + else: |
| 141 | + print "Unknown classifier:", options.classifier |
| 142 | + print USAGE_STRING |
| 143 | + |
| 144 | + sys.exit(2) |
| 145 | + |
| 146 | + args['classifier'] = classifier |
| 147 | + args['featureFunction'] = featureFunction |
| 148 | + args['printImage'] = printImage |
| 149 | + |
| 150 | + return args, options |
| 151 | + |
| 152 | + |
| 153 | +USAGE_STRING = """ |
| 154 | + USAGE: python dataClassifier.py <options> |
| 155 | + EXAMPLES: (1) python dataClassifier.py |
| 156 | + - trains the default mostFrequent classifier on the digit dataset |
| 157 | + using the default 100 training examples and |
| 158 | + then test the classifier on test data |
| 159 | + (2) python dataClassifier.py -c perceptron -t 1000 -s 500 |
| 160 | + - would run the perceptron classifier on 1000 training examples, would |
| 161 | + test the classifier on 500 test data points |
| 162 | + """ |
| 163 | + |
| 164 | + |
| 165 | +# Main harness code |
| 166 | + |
| 167 | +def runClassifier(args, options): |
| 168 | + featureFunction = args['featureFunction'] |
| 169 | + classifier = args['classifier'] |
| 170 | + printImage = args['printImage'] |
| 171 | + |
| 172 | + # Load data |
| 173 | + numTraining = options.training |
| 174 | + numTest = options.test |
| 175 | + |
| 176 | + rawTrainingData = samples.loadDataFile("data/digitdata/trainingimages", numTraining, DIGIT_DATUM_WIDTH, |
| 177 | + DIGIT_DATUM_HEIGHT) |
| 178 | + trainingLabels = samples.loadLabelsFile("data/digitdata/traininglabels", numTraining) |
| 179 | + completeRawTrainingData = samples.loadDataFile("data/digitdata/trainingimages", 5000, DIGIT_DATUM_WIDTH, |
| 180 | + DIGIT_DATUM_HEIGHT) |
| 181 | + completeTrainingLabels = samples.loadLabelsFile("data/digitdata/traininglabels", 5000) |
| 182 | + rawValidationData = samples.loadDataFile("data/digitdata/validationimages", numTest, DIGIT_DATUM_WIDTH, |
| 183 | + DIGIT_DATUM_HEIGHT) |
| 184 | + validationLabels = samples.loadLabelsFile("data/digitdata/validationlabels", numTest) |
| 185 | + rawTestData = samples.loadDataFile("data/digitdata/testimages", numTest, DIGIT_DATUM_WIDTH, DIGIT_DATUM_HEIGHT) |
| 186 | + testLabels = samples.loadLabelsFile("data/digitdata/testlabels", numTest) |
| 187 | + |
| 188 | + # Extract features |
| 189 | + print "Extracting features..." |
| 190 | + trainingData = map(featureFunction, rawTrainingData) |
| 191 | + completeTrainingData = map(featureFunction, completeRawTrainingData) |
| 192 | + validationData = map(featureFunction, rawValidationData) |
| 193 | + testData = map(featureFunction, rawTestData) |
| 194 | + |
| 195 | + # Conduct training and testing |
| 196 | + print "Training..." |
| 197 | + classifier.train(trainingData, trainingLabels, validationData, validationLabels) |
| 198 | + print "Validating..." |
| 199 | + guesses = classifier.classify(validationData) |
| 200 | + correct = [guesses[i] == validationLabels[i] for i in range(len(validationLabels))].count(True) |
| 201 | + print str(correct), ("correct out of " + str(len(validationLabels)) + " (%.1f%%).") % ( |
| 202 | + 100.0 * correct / len(validationLabels)) |
| 203 | + print "Testing..." |
| 204 | + guesses = classifier.classify(testData) |
| 205 | + correct = [guesses[i] == testLabels[i] for i in range(len(testLabels))].count(True) |
| 206 | + print str(correct), ("correct out of " + str(len(testLabels)) + " (%.1f%%).") % (100.0 * correct / len(testLabels)) |
| 207 | + print "Testing training data..." |
| 208 | + guesses = classifier.classify(completeTrainingData) |
| 209 | + correct = [guesses[i] == completeTrainingLabels[i] for i in range(len(completeTrainingLabels))].count(True) |
| 210 | + print str(correct), ("correct out of " + str(len(completeTrainingLabels)) + " (%.1f%%).") % ( |
| 211 | + 100.0 * correct / len(completeTrainingLabels)) |
| 212 | + |
| 213 | + analysis(classifier, guesses, testLabels, testData, rawTestData, printImage) |
| 214 | + |
| 215 | + if ((options.classifier == "perceptron")): |
| 216 | + for l in classifier.legalLabels: |
| 217 | + features_weights = classifier.findHighWeightFeatures(l) |
| 218 | + print ("=== Features with high weight for label %d ===" % l) |
| 219 | + printImage(features_weights) |
| 220 | + |
| 221 | + |
| 222 | +if __name__ == '__main__': |
| 223 | + # Read input |
| 224 | + args, options = readCommand(sys.argv[1:]) |
| 225 | + # Run classifier |
| 226 | + runClassifier(args, options) |
0 commit comments