One of the hottest tech disciplines in 2017 in the tech industry was Deep Learning. Due to Deep Learning, many startups placed AI emphasis and many frameworks have been developed to make implementing these algorithms easier. Google’s DeepMind was even able to create AlphaGo Zero that didn’t rely on data to master the game of Go. However, the analysis is much more basic than anything that was recently developed. In fact, the dataset is the popular MNIST database dataset. In other words, the dataset consists of hand written digits to test out computer vision.
The reason I wanted to work on this dataset is to get my feet wet on neural networks. In all of my dataset analysis posting, I have yet to work on a dataset that needed a neural network. It’s true that I could have just used an SVM or logistic regression, but images are more complex than what the other models can deal with.
Each image is 28×28 pixels, is grayscaled, and consists of a digit from 0 to 9. The input is each individual pixel (784 pixels in total) and the output is a digit ranging from 0 to 9. When the network is trained, a file consisting of test cases is fed into the model to predict the digit.
In this analysis, I used Keras to train a neural network. I could have used Tensorflow to create my model, but Tensorflow requires low level work in order to implement complex models. With Keras, I can leverage Tensorflow to create an efficient neural network without writing out every detail.
The first time I attempted to train the network, the images were not wrangled in any way. This created a poor trained network, achieving only 20% accuracy. It’s possible that there was too much variation in each photo. Thus, the second attempt first normalized the images by dividing by 255 and then uses those images to train on the network. This modification allowed for a huge increase in accuracy to about 97-99%.
It’s definitely possible to fine tune the neural network to squeeze out more accuracy, but at ~99% accuracy, the amount of effort wouldn’t be justified.
The notebook can be found at my Github page. How would you build a neural network on this dataset? Let me know in the comments section.