Training Neural Nets in the Browser
Two months ago, Google announced its new TensorflowJS, for training and inference in the browser. This opens the door to a lot of amazing web applications.
With TensorflowJS, we can take a pretrained model and personalize it with the user’s own data (keeping that data local and private!). We can also make fast predictions, without needing to wait to send data to a model up in the cloud.
I’m interested in how to teach deep learning well, and so I decided to take a week and explore TensorflowJS thoroughly. As Tensorflow Playground demonstrated, it’s so helpful for a deep learning newcomer to have the chance to tinker easily.
Continuing the idea of experimenting with hyperparameter and architectural choices, I created a demo page where you can try out training different kinds of models on the MNIST set and see in real time how your parameter choices affect the final accuracy. This is roughly based off of one of the TFJS tutorials, but I added in several different model and parameter options, and I fixed the bias in their train/test split. Hop over to my demo page to try it out.
Like everything in deep learning, TFJS is moving fast. Keep an eye on ml5.js. They’re building a wrapper around TensorflowJS that aims to be a “friendly machine learning library” for artists, hobbyists, and students. They have some beautiful looking demos. Personally I wouldn’t call it “friendly” just yet — they’re still missing a lot of documentation, and at this point nothing I tried out worked easily. I suspect in another month or two it’ll be in great shape.
Getting to know TensorflowJS
I came to TensorflowJS with a fair amount of Tensorflow experience, but only a little web design experience. I suspect the path would actually be easier coming from the opposite direction. TensorflowJS itself is quite easy and a very natural extension of regular Tensorflow. I had no problems creating fun local extensions of the tutorials, though I then had a trickier time sorting out deploying to production (long story short, tfjs-angular-firebase seems a good way to go).
I began the week by working through the TensorflowJS Tutorials. Much like Tensorflow itself, you can work with TFJS at different levels of abstraction. Each tutorial focuses on one of these levels:
- Math: Polynomial Regressions introduces the lowest level. Here you can do math operations directly (adding, matrix multiplications, etc.). It’s not the level where you’d normally implement a model, but the neat thing is here we don’t need to be doing neural nets at all. In fact, check out the Realtime tSNE Visualization. It uses TFJS to create real-time interactive visualizations of high-dimensional data.
- Layers: MNIST Training moves one level up. This will feel very familiar to anyone who has used Keras. We start with model=tf.sequential() and then add layers (convolutional, fully connected, pooling, relu, etc.) to this model. Then we can compile the model and train it using model.fit()
- Pretrained Model: The Pacman Tutorial is by far the most fun. This introduces importing a pretrained model (here we use mobilenet) and then we finetune that on a set of webcam images. We start with four categories (up/down/left/right), although it’s trivial to change this in the code. After fine tuning, we switch to prediction mode and feed these streaming results directly to the Pacman game, so we can now control the game with our webcam.
Personally, I like to work through tutorials by looking over the code, then switching to a blank page and trying to recreate it myself. It’s more painful than just reading through the code samples, but I highly recommend this for code that you want to understand well — it’s amazing how many little details you notice by taking this extra step. In my case, the pacman tutorial took me a morning to recreate from scratch (initially it involved a fair amount of glancing back to look at the original code, but I soon felt increasingly independent).
Based on my few days of playing with TensorflowJS – it works best when you have a straightforward pretrained model that needs fine tuning. TFJS isn’t geared for customized loss functions, lambda layers, or other personalizations. Things are changing fast though, and this may soon be easier to do. My main annoyance with TFJS is the lack of guidance – the tutorials are nice, but I couldn’t find good documentation beyond that. I found often I’d try to use a Tensorflow function, and later find that it doesn’t exist in TFJS.
Between TFJS and TFLite (Tensorflow geared for mobile devices), new deep learning web and mobile apps will be very exciting to watch. In particular I think it’ll be a fantastic tool for artists, musicians, and educators.
As part of my tinkering this week, I created a TensorflowJS demo page that lets you try out several different models for training on MNIST. You can try convolutional nets with different filter sizes, or you can go the fully connected route and see how that compares. You can try removing the pooling layers or the reLU layers. Most variations eventually reach good accuracy, although some train more slowly than others. You can also play around with the learning rate and the number of batches to train.
I’m excited for next week as I begin my dive into Reinforcement Learning. I’m planning to study the Deep RL Bootcamp and UC Berkeley 294 over the next three weeks and will continue to tell the tales here.