Linear Discriminant Analysis in R: An Introduction
How does Linear Discriminant Analysis (LDA) work and how do you use it in R? This post answers these questions and provides an introduction to Linear Discriminant Analysis.
Linear Discriminant Analysis (LDA) is a well-established machine learning technique and classification method for predicting categories. Its main advantages, compared to other classification algorithms such as neural networks and random forests, are that the model is interpretable and that prediction is easy. Linear Discriminant Analysis is frequently used as a dimensionality reduction technique for pattern recognition or classification and machine learning.
If you want to quickly do your own linear discriminant analysis, use this handy template!
The intuition behind Linear Discriminant Analysis
Linear Discriminant Analysis takes a data set of cases (also known as observations) as input. For each case, you need to have a categorical variable to define the class and several predictor variables (which are numeric). We often visualize this input data as a matrix, such as shown below, with each case being a row and each variable a column. In this example, the categorical variable is called "class" and the predictive variables (which are numeric) are the other columns.
This example, discussed below, relates to classes of motor vehicles based on images of those vehicles.
Think of each case as a point in N-dimensional space, where N is the number of predictor variables. Every point is labeled by its category. (Although it focuses on t-SNE, this video neatly illustrates what we mean by dimensional space).
The LDA algorithm uses this data to divide the space of predictor variables into regions. The regions are labeled by categories and have linear boundaries, hence the "L" in LDA. The model predicts the category of a new unseen case according to which region it lies in. The model predicts that all cases within a region belong to the same category.
The linear boundaries are a consequence of assuming that the predictor variables for each category have the same multivariate Gaussian distribution. Although in practice this assumption may not be 100% true, if it is approximately valid then LDA can still perform well.
Mathematically, LDA uses the input data to derive the coefficients of a scoring function for each category. Each function takes as arguments the numeric predictor variables of a case. It then scales each variable according to its category-specific coefficients and outputs a score. The LDA model looks at the score from each function and uses the highest score to allocate a case to a category (prediction). We call these scoring functions the discriminant functions.
I am going to stop with the model described here and go into some practical examples. If you would like more detail, I suggest one of my favorite reads, Elements of Statistical Learning (section 4.3).
Linear Discriminant Analysis Example
Predicting the type of vehicle
Even though my eyesight is far from perfect, I can normally tell the difference between a car, a van, and a bus. I might not distinguish a Saab 9000 from an Opel Manta though. They are cars made around 30 years ago (I can't remember!). Despite my unfamiliarity, I would hope to do a decent job if given a few examples of both.
I will demonstrate Linear Discriminant Analysis by predicting the type of vehicle in an image. The 4 vehicle categories are a double-decker bus, Chevrolet van, Saab 9000 and Opel Manta 400. The input features are not the raw image pixels but are 18 numerical features calculated from silhouettes of the vehicles. You can read more about the data behind this LDA example here.
To start, I load the 846 instances into a data.frame called vehicles. The columns are labeled by the variables, with the target outcome column called class. The earlier table shows this data.
flipMultivariates: A new R package
The package I am going to use is called flipMultivariates (click on the link to get it). It is based on the MASS package, but extends it in the following ways:
- Handling of weighted data
- Graphical outputs
- Options for missing data
- Ability to output discriminant functions
The package is installed with the following R code.
library(devtools) install_github("Displayr/flipMultivariates")
Then the model is created with the following two lines of code.
library(flipMultivariates) lda <- LDA(class ~ ., data = vehicles)
The output is shown below. The subtitle shows that the model identifies buses and vans well but struggles to tell the difference between the two car models. The first four columns show the means for each variable by category. High values are shaded in blue and low values in red, with values significant at the 5% level in bold. The R-Squared column shows the proportion of variance within each row that is explained by the categories. On this measure, ELONGATEDNESS is the best discriminator.
Customizing the LDA model with alternative inputs in the code
The LDA function in flipMultivariates has a lot more to offer than just the default. Consider the code below:
lda.2 <- LDA(class ~ COMPACTNESS + CIRCULARITY + DISTANCE.CIRCULARITY + RADIUS.RATIO, data = vehicles, output = "Scatterplot", prior = "Equal", subset = vehicles$ELONGATEDNESS < 50, weight = ifelse(vehicles$class == "saab", 2, 1))
I've set a few new arguments, which include;
- output
- The default output is Means, which is what you got in the very first output. In the code immediately above, I changed this to be Scatterplot. I explain how to interpret the scatterplot in the next section. Also available are Prediction-Accuracy Table, Detail and Discriminant Functions. The latter produces a table of the coefficients of the discriminant functions described earlier.
- prior
- This argument sets the prior probabilities of category membership. Observed is the default, which uses the frequencies of the input data. It is also possible to specify values for each category as a vector (which naturally sum to 1). Equal implies that vehicles are equally distributed across the categories.
- subset
- This is a vector of TRUE/FALSE, which in the code above limits the analysis to cases with lower ELONGATEDNESS (less than scores of 50 on that numeric predictor variable)
- weight
- This is a vector of values used to scale the cases. In this example, I overweight the saab cars.
It is also possible to control treatment of missing variables with the missing argument (not shown in the code example above). The options are Exclude cases with missing data (default), Error if missing data and Imputation (replace missing values with estimates). Imputation allows the user to specify additional variables (which the model uses to estimate replacements for missing data points).
The R command ?LDA gives more information on all of the arguments.
Interpreting the Linear Discriminant Analysis output
The previous block of code above produces the following scatterplot. (Note: I am no longer using all the predictor variables in the example below, for the sake of clarity). I am going to talk about two aspects of interpreting the scatterplot: how each dimension separates the categories, and how the predictor variables correlate with the dimensions.
I said above that I would stop writing about the model. However, to explain the scatterplot I am going to have to mention a few more points about the algorithm. If you prefer to gloss over this, please skip ahead.
An alternative view of linear discriminant analysis is that it projects the data into a space of (number of categories - 1) dimensions. In this example that space has 3 dimensions (4 vehicle categories minus one). While this aspect of dimension reduction has some similarity to Principal Components Analysis (PCA), there is a difference. The difference from PCA is that LDA chooses dimensions that maximally separate the categories (in the transformed space). The LDA model orders the dimensions in terms of how much separation each achieves (the first dimensions achieves the most separation, and so forth). Hence the scatterplot shows the means of each category plotted in the first two dimensions of this space. So in our example here, the first dimension (the horizontal axis) distinguishes the cars (right) from the bus and van categories (left). However, the same dimension does not separate the cars well.
Also shown are the correlations between the predictor variables and these new dimensions. Because DISTANCE.CIRCULARITY has a high value along the first linear discriminant it positively correlates with this first dimension. It has a value of almost zero along the second linear discriminant, hence is virtually uncorrelated with the second dimension. Note the scatterplot scales the correlations to appear on the same scale as the means. So you can't just read their values from the axis. In other words, the means are the primary data, whereas the scatterplot adjusts the correlations to "fit" on the chart.
The Prediction-Accuracy Table
Finally, I will leave you with this chart to consider the model's accuracy. Changing the output argument in the code above to Prediction-Accuracy Table produces the following:
So from this, you can see what the model gets right and wrong (in terms of correctly predicting the class of vehicle). The ideal is for all the cases to lie on the diagonal of this matrix (and so the diagonal is a deep color in terms of shading). But here we are getting some misallocations (no model is ever perfect). For instance, 19 cases that the model predicted as Opel are actually in the bus category (observed). Given the shades of red and the numbers that lie outside this diagonal (particularly with respect to the confusion between Opel and saab) this LDA model is far from perfect.
Try it yourself
I created the analyses in this post with R in Displayr. You can review the underlying data and code or run your own LDA analyses here. I used the flipMultivariates package (available on GitHub).
Displayr also makes Linear Discriminant Analysis and other machine learning tools available through menus, alleviating the need to write code. There's even a template custom made for Linear Discriminant Analysis, so you can just add your data and go.
This dataset originates from the Turing Institute, Glasgow, Scotland, which closed in 1994 so I doubt they care, but I'm crediting the source anyway.