SSL based Masked Image Prediction
In this post, I wanted to talk about self-supervised learning, which is useful when we want to learn from unlabelled data. Given the cost of labeled data, it’s often not possible to curate that type of data. Due to preexisting data, the cost of just getting data from the internet is not as high, while labeling data is time-consuming and has a high cost.
Learning from unsupervised data can be useful from a scalability, breadth, and generalized knowledge perspective. Now I wanted to get started in this direction. So one of the first examples my professor, Yuyin Zhou at UCSC, introduced was masked image prediction. The goal of masked image prediction is to, given a masked image, train a deep neural network to reconstruct the full image. Self-supervised learning obtains supervisory signals from the data itself, and its success depends on learning objectives.
For masked image prediction, the methodology I used was:
- Take MNIST data, load it with dataloader
- Training algorithm with 10 epochs - In each epoch, for every image in MNIST - Generate a random mask
- Apply the mask to the image
- Pass this corrupted image to CNN-based encoder and decoder-based model
- Loss is calculated between the generated output and the original MNIST image.
- Backpropagate
 
- Also evaluate the performance on the test set
 
- In each epoch, for every image in MNIST 
 
 During inference phase, we just give the model the masked image and it has learned how to predict the reconstructed image. I have linked the code and demo in the references.
There are several other interesting self-supervised learning objectives like image colorization, rotation prediction, next token prediction, dense feature prediction with student-teacher model, jigsaw image fixing, 3D point cloud, and 2D-3D feature learning, which I want to explore in the future.
References
Enjoy Reading This Article?
Here are some more articles you might like to read next: