Machine Learning – Linear Regression using gradient descent vs. neural network

Machine learning or Supervised Learning broadly encompasses two classes of problems – regression and classification. Regression deals with predicting a continous variable while a classification problem deals with categorical prediction. As an example, predicting house price is a regression problem which can take any real number. While, email spam detection is a classification problem as the outcome is a special kind of categorical variable ie. binary (spam 1/ non-spam 0). Numerous algorithms have been built over the last few decades which falls under one of these two classes. The focus of this post will be regression techniques and will reserve another post for classification techniques. Feel free to subscribe/bookmark this page for upcoming posts.

  • Introduction – Regression
  • Algorithms
    • Oridinary Least Squares (OLS)
    • Gradient Descent
    • Neural Networks
  • Comparison of Algorithms
  • Conclusions & Inferences

Regression Techniques helps to understand relationship of continous variable as a function of one of more independent aka explanatory variables. It can be considered as a curve fitting problem where we are interested in learning a function y = f(X) where X belongs to x1,x2,x3..xi which best fits the y. Now, how do we quantify best fit ? Most techniques uses some measure of error – SSE (sum-of-squared error) also called cost function aka objective function to quantify the quality of fit. We desire to learn parameters which gives the least error. So, formally we can define this as an optimization problem to minimize SSE given the parameters of function f(x).

Linear vs. Non-Linear : When y = f(x) takes the shape of y = a0 + a1 * x1 + a2 * x2 + .. + an * xn, we call this a linear regression where we are trying to fit a straight line to the curve. In non-linear we learn y = f(x) of more complex forms involving log, exponents, higher order polynomials of independent variables. For example y = a0 + a1 * log(x1) + a2 * e^x + a3* x^3.

Single vs. Multiple Regression : When y = f(x) has a more than one explanatory variable then called multiple regression.

Lets take a simple linear regression problem(dataset) and try to apply couple of algorithms and compare accuracy, complexity, run times.  We will conclude this post with in-depth understanding of techniques. Following three techniques have been explored in detail using R libraries. Feel free to download notebook and try running yourself, playing with parameters and plots available on my github. The notebooks are also available in html so you can explore the charts & documentation here. A high level summary is presented below, highly recommend checking the notebook here.

Link to jupyter notebook (code, charts & data) here

  1. Ordinary Least Squares
    A closed form solution uses matrix multiplication & inverse to find parameters which minimize error (SSE). This is an anlaytic solution always finds a minima (of error). Looking the (x,y) plot we attempt fitting non-linear model using linear regression exploiting variable transformation techniques.
  2. Gradient Descent
    A open form solution, which uses mathematical optimization by initializing parameters randomly and iteratively adjusting the parameters depending on the gradient that reduces SSE. We start with learning rate of 0.01 and iterate for 120000 times.
  3. Neural Networks
    Now, generally popular among AI (Artificial Intelligence) practitioners which mimics working of human brain, modelling hidden layers with weights propagating across layers between input & output nodes. Internally, uses gradient descent with back-propagation to learn the weights. We use a 3 hidden layers, each with 3 nodes and train a neural network model. Input & Output are single nodes as we have a single predictor and a single explanatory variable.



Lets compare which of the models performs well in terms of best fitting the data. Please note, here I am not using train/test approach as idea here is emphasis on technique. So we are talking about training error here. Lets use RMSE (root mean square error) to compare the three models

Ordinary Least Squares
Gradient Descent
Neural Network
0.000534 0.000984 0.000278

Looking at RMSE, neural network seems to be doing a great job in learning the data. NN is known to have good memorizing capability and causes over fitting leading high variance system. Also we must note, Neural Network did not involve any kind of feature engineering we just passed x values unlike in other methods, we had x^2 as feature. So NN is known to be great at identifying latent features without explicit need for feature engineering usually done by domain experts of respective problem spaces.

Table below explains more about each of the 3 techniques, how they differ, when to use-what etc.

Ordinary Least Squares Gradient Descent Neural Network
Closed Form
Analytic Solution
Open Form
Iterative Optimization
Slow, involves matrix inverse computation Fast for large datasets Slow, training usually done on GPUs which can handle matrix computations easily
Hyperparameters – epoch(iterations), alpha-learning rate Hyperparameters – hidden layers, nodes, activation fn., learning rate
Feature Engineering Required Feature Engineering Required Little/No Feature Engineering Needed
More data is good More data is good While more data is good, NN requires a lot of data for good generalization
Good Interpretability. Easy to communicate findings Good Interpretability. Easy to communicate findings Blackbox, difficult to explain suffers from interpretability
Most commonly used for building offline models on small datasets Commonly used for large scale learning involving thousands of features Used commonly in text/vision/speech – mostly uses two variants CNN and RNN architectures.
Tools: R,SAS,Python Tools: Python,R Tools: Deeplearning4j, Theano, Python, R, Tensorflow
Can get stuck in local minima, due to bad initialization/learning rate Can get stuck in local minima, due to bad initialization/learning rate

What we did not talk about, but important in this context ?

  • Cross Validation
  • Feature engineering and reduction(PCA)
  • Hyperparameter tuning
  • Sampling
  • Objective Functions
  • Regularization L1/L2
  • Interaction Effects

A lot of content to digest here, feel free to share any feedback/comment you have about any part of the blog post – would love to chat around. I will be back with a similar post on classification in coming days comparing logistic regression, decision trees, random forest(ensembles) and neural networks.

   We talk about big data quite often these days, wanted to put some fundas about basics around data. Do you know the singular form of data ? How data differs from information vs. knowledge? How insights convert to actions? Here is my attempt towards answering some of these.
      Data is often raw in binary represented as a number or a character or a string. Data is a plural version of datum. Information is anything which puts context to data. For example: The number 89 itself doesn’t mean anything unless we fit a context to say – the car’s speed is 89 kmph. Knowledge on other hand is about knowing how things around us work & the larger world is interconnected. It is obtained based on experiences, experiment, research, etc. In the below infographic I have tried to explain these terms using a simplified version of intelligence that can be embedded in cars. More on the infographic follows.
Data 2 Actions
           I have tried to take a very simplified version here. Now just imagine, when we talk about intelligent cars – we are usually talking about 100s of such parameters instead of just a single varaible (speed of car) ,all collected from multiple sensors obtained in realtime streaming format to make such decisions. Knowledge is obtained through predictive algorithms which continuously learns [in AI terms to say “adapts”] from data and help in making recommendation about safety of the vehicle. Now imagine these 100s of parameters collected every millisecond from 1000s of connected cars around you – this is what forms “Big Data”.
              Hope you liked this post, I will be writing up more articles – as have been getting requests from friends around the globe. Stay tuned !

Clustering is an unsupervised classification (learning) technique, where the objective is to maximize inter-cluster distance while minimizing the intra-cluster distance.  By unsupervised, we mean clustering or segmenting or classifying data based on all the available attributes and specifically there is no availability of class information. A supervised classification on other hand uses class information.
As usual, before we jump into ‘how’ let’s answer the ‘why’. Clustering is applied to solve variety of problems ranging from biological systems to using it for exploratory analysis of data ( as a pre-processing technique).  Many of the predictive analytics algorithms use clustering solutions as one of their components. It is used in all major brands for CRM, to understand their customer better. Another use of clustering is in outlier detection or fraud transaction identification.  If you have heard about a site called, it extensively works on clustering algorithms where the sites are segmented/clustered based on website attributes like category of domain, number of users, traffic, content type, corporate or personal, blog, image blog, video blog,etc. For example, if you entered INMOBI, you would get a list of companies which are in this space mainly its competitors – mojiva, Millenialmedia, Admob, Quattro, Mobclix,etc. If you are looking for image hosting site and want to know alternatives/options, this will be helpful.

We talk about similarity in terms of distance measures like

(i)                  Euclidean Distance

(ii)                Manhattan Distance

Read More

Turning raw data into insights often involves integrating data from multiple disparate sources (not just limited structured one), analyzing the data, visualizing it and socializing the results/insights to a broader audience to whom the results are of interest. In this cycle of turning data into insights, Visualization plays a vital role and hence would be the topic of my discussion in this blog post . Visualization could aid in analyzing huge data by identifying patterns which are easily interpretable visually as compared to tabular layout of numbers.Second, Visualization could help  represent the numbers using visuals which are easy for everyone to read and understand. One could easily convey the insights of the analysis by visuals, grasped in a minute or two, which might have possibly took 3-4 mins using textual aid/table of numbers.This is a important factor to consider especially when are you delivering the findings to the CEO/CFO/CXO/CIO of a company, as often they have limited time.

London Cholera Outbreak visualized

London Cholera Outbreak visualized

Going back to history of visualization. The most famous, early example mapping epidemiological data was Dr. John Snow’s map of deaths from a cholera outbreak in London, 1854, in relation to the locations of public water pumps. The original (high-res PDF copies from UCLA), spawned many imitators including this simplified version by Gilbert in 1958. Tufte (1983, p. 24) says,”Snow observed that cholera occurred almost entirely among those who lived near (and drank from) the Broad Street water pump. He had the handle of the contaminated pump removed, ending the neighborhood epidemic which had taken more than 500 lives.” Read More

For the last 6 months, i have been closely following trends in information management. Below are few of my observations.

  • Data source explosion: Business Problems are gaining complexity day by day, hence there is a huge demand for analyzing data from multitude of sources to help companies frame strategies for growth.  GPS data accumulated by Telecom companies offer insights into customers current location and provide context aware recomendations. Infact, some of the telecom companies have introduced location based pricing. Sensor data helps identify security threats to secure networks. Social network data has opened up as a channel for marketing services/product. Analysis of such closely knit data leads to behavioral & Contextual targeting. Traditional data analysis tools/algorithms fail to perform efficiently because such data are of huge sizes and needs newer datastructures for efficient analysis.
  • Databases going beyond relational is gaining popularity. NoSQL dbs and Graph/Tree/XML based databases.
  • Open Source tools continue to emerge.(R, RapidMiner, Weka)
  • Growing need for massive dataset analysis.
  • Artificial Intelligence(AI) and NLP gaining popularity among data analysts( in additional to ML techniques)
  • Multimedia Analytics: Need for gathering critical metrics like customer footfalls, quantifying customers satisfaction by using facial expressions. All these applications demand high end signal processing( both Image & Video). There is a lot of scope for innovation in this area.
  • Privacy preserving techniques for data analysis. This in turn encourages companies to outsource some of the critical data analysis to third parties.
  • Agile Methodologies for Analytics Project to cope up with rapidly changing customer/business needs.
  • Bio-Inspiration/Bio-Imitation: To learn from nature/natural processes and develop analogous techniques which could potentially solve a real-world problem. Some classic examples are development of Neural network inspired by working of a human brain, solving path optimization problem from Ant colonies, 280 degree view of honey bee(vision) etc.
  • More and more data are made publicly available.
  • Real Time data integration, insight generation and business decision.
  • Complex visualization techniques through new technology like Adobe Flex , MS Silverlight,etc which are known for generating RIA.(Rich Internet Applications)

And I am sure these are just few items in the list and really not exhaustive. Feel free to share your comments.

%d bloggers like this: