Applying Stochastic Gradient Descent (SGD) to a single-feature linear regression problem
SGD is an optimisation algorithm that updates model parameters using the gradient of the loss function w.r.t. the parameters on a small and random subset of data.
In this blog post, I'll explain how SGD works by applying it to a basic ML scenario. In the future I hope to apply it to more complex scenarios such as CNNs and reinforcement learning.
1. Single-feature Linear Regression (Easy)
Our goal is to predict the price of a car based on its age. Our single feature is the age.
Let age of car and actual price.
We guess that there is a linear relationship between price and age, and hence:
where is the model's prediction of the price, is weight, and is bias.
We define our "loss" function to measure how wrong our model is with
Our goal is to minimise this loss function.
Step 1: Randomly initialise parameters.
Let's say .
Step 2: Randomly pick one training example
It's important that it's random and a small part of the data (this is the 'stochastic' part).
Suppose and .
Our model would predict the price to be (check this using the formula before).
Step 3: Compute the loss
The loss is 28.125
Step 4: Compute the gradients
We need to calculate the derivative of the loss function L with respect to the parameters and . For you get the gradient is , and for , it is .
The gradient tells us the direction and magnitude we should change our parameters to reduce the loss.
If the gradient is positive, then the parameter is too high, and we need to reduce it.
If it is negative, then we need to increase it.
It might be helpful to think of the loss function as if it were to help visualise why the gradient means what it means (since we want to get to the minima of the loss function).
Step 5: Update the parameters
Let be a parameter we are trying to optimise and let be our learning rate.
Then
Use that formula to update and .
Step 6: Repeat
Repeat the above steps until the loss converges to a minimum.
This example of Stochastic Gradient Descent (SGD) was designed to help you understand the intuition behind the mathematics of deep learning. In the future, I'll make more posts on DL with fastai and PyTorch and explain key concepts like backpropagation.