Machine Learning – Linear Regression using gradient descent vs. neural network
Machine learning or Supervised Learning broadly encompasses two classes of problems – regression and classification. Regression deals with predicting a continous variable while a classification problem deals with categorical prediction. As an example, predicting house price is a regression problem which can take any real number. While, email spam detection is a classification problem as the outcome is a special kind of categorical variable ie. binary (spam 1/ non-spam 0). Numerous algorithms have been built over the last few decades which falls under one of these two classes. The focus of this post will be regression techniques and will reserve another post for classification techniques. Feel free to subscribe/bookmark this page for upcoming posts.
- Introduction – Regression
- Oridinary Least Squares (OLS)
- Gradient Descent
- Neural Networks
- Comparison of Algorithms
- Conclusions & Inferences
Regression Techniques helps to understand relationship of continous variable as a function of one of more independent aka explanatory variables. It can be considered as a curve fitting problem where we are interested in learning a function y = f(X) where X belongs to x1,x2,x3..xi which best fits the y. Now, how do we quantify best fit ? Most techniques uses some measure of error – SSE (sum-of-squared error) also called cost function aka objective function to quantify the quality of fit. We desire to learn parameters which gives the least error. So, formally we can define this as an optimization problem to minimize SSE given the parameters of function f(x).
Linear vs. Non-Linear : When y = f(x) takes the shape of y = a0 + a1 * x1 + a2 * x2 + .. + an * xn, we call this a linear regression where we are trying to fit a straight line to the curve. In non-linear we learn y = f(x) of more complex forms involving log, exponents, higher order polynomials of independent variables. For example y = a0 + a1 * log(x1) + a2 * e^x + a3* x^3.
Single vs. Multiple Regression : When y = f(x) has a more than one explanatory variable then called multiple regression.
Lets take a simple linear regression problem(dataset) and try to apply couple of algorithms and compare accuracy, complexity, run times. We will conclude this post with in-depth understanding of techniques. Following three techniques have been explored in detail using R libraries. Feel free to download notebook and try running yourself, playing with parameters and plots available on my github. The notebooks are also available in html so you can explore the charts & documentation here. A high level summary is presented below, highly recommend checking the notebook here.
Link to jupyter notebook (code, charts & data) here
- Ordinary Least Squares
A closed form solution uses matrix multiplication & inverse to find parameters which minimize error (SSE). This is an anlaytic solution always finds a minima (of error). Looking the (x,y) plot we attempt fitting non-linear model using linear regression exploiting variable transformation techniques.
- Gradient Descent
A open form solution, which uses mathematical optimization by initializing parameters randomly and iteratively adjusting the parameters depending on the gradient that reduces SSE. We start with learning rate of 0.01 and iterate for 120000 times.
- Neural Networks
Now, generally popular among AI (Artificial Intelligence) practitioners which mimics working of human brain, modelling hidden layers with weights propagating across layers between input & output nodes. Internally, uses gradient descent with back-propagation to learn the weights. We use a 3 hidden layers, each with 3 nodes and train a neural network model. Input & Output are single nodes as we have a single predictor and a single explanatory variable.
Lets compare which of the models performs well in terms of best fitting the data. Please note, here I am not using train/test approach as idea here is emphasis on technique. So we are talking about training error here. Lets use RMSE (root mean square error) to compare the three models
Ordinary Least Squares
Looking at RMSE, neural network seems to be doing a great job in learning the data. NN is known to have good memorizing capability and causes over fitting leading high variance system. Also we must note, Neural Network did not involve any kind of feature engineering we just passed x values unlike in other methods, we had x^2 as feature. So NN is known to be great at identifying latent features without explicit need for feature engineering usually done by domain experts of respective problem spaces.
Table below explains more about each of the 3 techniques, how they differ, when to use-what etc.
|Ordinary Least Squares||Gradient Descent||Neural Network|
|Slow, involves matrix inverse computation||Fast for large datasets||Slow, training usually done on GPUs which can handle matrix computations easily|
|–||Hyperparameters – epoch(iterations), alpha-learning rate||Hyperparameters – hidden layers, nodes, activation fn., learning rate|
|Feature Engineering Required||Feature Engineering Required||Little/No Feature Engineering Needed|
|More data is good||More data is good||While more data is good, NN requires a lot of data for good generalization|
|Good Interpretability. Easy to communicate findings||Good Interpretability. Easy to communicate findings||Blackbox, difficult to explain suffers from interpretability|
|Most commonly used for building offline models on small datasets||Commonly used for large scale learning involving thousands of features||Used commonly in text/vision/speech – mostly uses two variants CNN and RNN architectures.|
|Tools: R,SAS,Python||Tools: Python,R||Tools: Deeplearning4j, Theano, Python, R, Tensorflow|
|–||Can get stuck in local minima, due to bad initialization/learning rate||Can get stuck in local minima, due to bad initialization/learning rate|
What we did not talk about, but important in this context ?
- Cross Validation
- Feature engineering and reduction(PCA)
- Hyperparameter tuning
- Objective Functions
- Regularization L1/L2
- Interaction Effects
A lot of content to digest here, feel free to share any feedback/comment you have about any part of the blog post – would love to chat around. I will be back with a similar post on classification in coming days comparing logistic regression, decision trees, random forest(ensembles) and neural networks.