RandomForest Classification Example using Spark MLlib
RandomForest Classification Example using Spark MLlib – In this tutorial, we shall see how to train and generate a model using RandomForest classifier. And use this generated model on test to predict the categories and calculate Test Error and Accuracy of the model.
This example uses the older RDD-based spark.mllib API with Java. It is still useful when you are maintaining existing MLlib code, learning how RandomForest.trainClassifier works, or working with LibSVM files. For new Spark machine learning applications, also review the DataFrame-based Spark RandomForestClassifier API because modern Spark pipelines generally use spark.ml.
What Random Forest Classification Does in Spark MLlib
A random forest classifier builds many decision trees and combines their predictions. For classification problems, the model predicts a class label such as 0, 1, or 2. This is different from random forest regression, where the model predicts a numeric value.
In Spark MLlib, the RDD-based random forest classifier can work with binary and multiclass labels. It can also use both continuous features and categorical features, provided the categorical feature information is supplied correctly.
Training using Random Forest classifier
Spark MLlib understands only numbers. So, the training data should be prepared in a way that MLlib understands. Preparing the training data is the most important step that decides the accuracy a model. And this includes the following
- Identify the categories. And index the categories.
- Identify the features. And index the features.
- Transform the experiments/observations/examples using indexes of categories and features
Note: Feature values could be discrete or continuous. Comments have been provided in the program to make some of the features discrete and others as continuous. With this as reference, features could be configured as per your requirement.
Download the source code of the ongoing example here, RandomForestExampleAttachment. For setting up java project to work with spark MLlib , please refer Create Java Project with Apache Spark.
Random Forest Classifier Input Format in this Spark MLlib Example
The training and test files in this tutorial use LibSVM-style input. Each line contains one labeled example. The first value is the class label, followed by indexed feature-value pairs.
| Part of row | Meaning in this example |
0, 1, 2 | Class label to be predicted by the classifier. |
1:1 | Feature index 1 has value 1 in the input file. |
2:4 | Feature index 2 has value 4 in the input file. |
| One line | One training or test observation. |
In the Java program, categorical feature information is supplied through categoricalFeaturesInfo. Be careful with feature indexing: LibSVM files usually show feature indices starting from 1, while Spark vectors are handled with zero-based feature positions inside the program.
Sample Training Data for Random Forest
Below is the sample of transformed and ready to be fed, to the RandomForest, to train on. Each row represents an experiment/observation/example. The format of each row is [category feature1:value feature2:value ..]
Training data: trainingValues.txt
0 1:1 2:1 3:1 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:1 3:5 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:1 2:3 3:1 4:1 5:1 6:1
Spark MLlib RandomForest Training Parameters Used in the Java Program
The random forest model is controlled by several hyperparameters. These settings affect how many trees are trained, how deep each tree can grow, how categorical features are handled, and how repeatable the model is.
| Parameter | Used value | Purpose |
numClasses | 3 | Number of possible class labels in the training data. |
categoricalFeaturesInfo | Map of feature index to category count | Tells Spark which features are categorical and how many distinct values each categorical feature can take. |
numTrees | 3 | Number of decision trees in the forest. |
featureSubsetStrategy | auto | Lets Spark decide how many features are considered at each split. |
impurity | gini | Criterion used to choose classification splits. |
maxDepth | 30 | Maximum depth allowed for each decision tree. |
maxBins | 10 | Maximum number of bins used while splitting continuous and categorical features. |
seed | 12345 | Controls repeatability for the random parts of training. |
For a small learning dataset, numTrees = 3 keeps the output easy to read. For real data, you usually test several values for numTrees, maxDepth, and maxBins instead of relying on one fixed setting.
Below is the java class, RandomForestTrainerExample.java, that trains a model and saves it to local.
Trainer Class: RandomForestTrainerExample.java
package com.tut;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.HashMap;
import org.apache.commons.io.FileUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;
/** RandomForest Classification Example using Spark MLlib
* @author tutorialkart.com
*/
public class RandomForestTrainerExample {
public static void main(String[] args) {
// hadoop home dir [path to bin folder containing winutils.exe]
System.setProperty("hadoop.home.dir", "D:\\Arjun\\ml\\hadoop\\");
// Configuring spark
SparkConf sparkConf = new SparkConf().setAppName("RandomForestExample")
.setMaster("local[2]")
.set("spark.executor.memory","3g")
.set("spark.driver.memory", "3g");
// initializing the spark context
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "data"+File.separator+"trainingValues.txt";
JavaRDD trainingData;
try {
trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
} catch (Exception e1) {
System.out.println("No training data available.");
e1.printStackTrace();
return;
}
// Configuration/Hyper parameters to train random forest model
Integer numClasses = 3;
// Empty categoricalFeaturesInfo indicates all features are continuous.
HashMap<Integer, Integer> categoricalFeaturesInfo =new HashMap<Integer, Integer>(){{
put(0,3); // feature 0 is considered discrete, with values from 0 to 9
put(1,7); // feature 1 is considered discrete, with values from 0 to 6
put(2,10); // feature 2 is considered discrete, with values from 0 to 9
// feature 3 is considered continuous valued
put(4,10); // feature 4 is considered discrete, with values from 0 to 9
// feature 5 is considered continuous valued
}};
Integer numTrees = 3; // number of decision trees to be included in the Random Forest
String featureSubsetStrategy = "auto"; // Let the algorithm choose, which set of features to be made as subsets
String impurity = "gini"; // adds impurity to the experiments/samples in the training data : gini is a choice
Integer maxDepth = 30; // maximum depth of a decision tree that can grow
Integer maxBins = 10; // classifier first splits the training data into number of bins, and this parameter decides the maximum number of bins
Integer seed = 12345; // classifier introduces some randomization, and for this randomization to be same across iterations, same seed is used in all the iterations inside classifier
// training the classifier with all the hyper-parameters defined above
final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
seed);
System.out.print("");
// Delete if model already present, and Save the new model
try {
FileUtils.forceDelete(new File("RandForestClsfrMdl"));
System.out.println("\nDeleting old model completed.");
} catch (FileNotFoundException e1) {
} catch (IOException e) {
}
// saving the random forest model that is generated
model.save(jsc.sc(), "RandForestClsfrMdl"+File.separator+"model");
System.out.println("\nRandForestClsfrMdl/model has been created and successfully saved.");
// printing the random forest model (collection of decision trees)
System.out.println(model.toDebugString());
jsc.stop();
}
}
When the above java class is run, a model is generated, with three decision trees which are shown in the below output.
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
RandForestClsfrMdl/model has been created and successfully saved.
TreeEnsembleModel classifier with 3 trees
Tree 0:
If (feature 5 <= 6.0)
If (feature 0 in {1.0})
If (feature 1 in {3.0})
Predict: 1.0
Else (feature 1 not in {3.0})
If (feature 5 <= 2.0) If (feature 2 in {1.0}) Predict: 0.0 Else (feature 2 not in {1.0}) Predict: 1.0 Else (feature 5 > 2.0) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 5 > 6.0)
Predict: 2.0
Tree 1:
If (feature 5 <= 6.0) If (feature 0 in {1.0}) If (feature 2 in {6.0}) Predict: 1.0 Else (feature 2 not in {6.0}) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 5 > 6.0)
If (feature 3 <= 1.0) Predict: 0.0 Else (feature 3 > 1.0)
Predict: 2.0
Tree 2:
If (feature 3 <= 1.0) If (feature 2 in {5.0,6.0}) Predict: 1.0 Else (feature 2 not in {5.0,6.0}) If (feature 0 in {1.0}) If (feature 1 in {1.0}) Predict: 0.0 Else (feature 1 not in {1.0}) If (feature 1 in {3.0}) Predict: 1.0 Else (feature 1 not in {3.0}) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 3 > 1.0)
Predict: 2.0
From the above random forest, following observation could be made:
. features : 0,1,2,4 are considered discrete as [feature 2 not in {5.0,6.0}]
. features : 3,5 are considered continuous as [feature 5 > 6.0]
How to Read the Spark MLlib Random Forest Model Output
The printed model is a collection of decision trees. Each tree contains split rules and a final predicted label. A rule such as feature 0 in {1.0} is a categorical split, while a rule such as feature 5 <= 6.0 is a continuous numeric split.
During prediction, each tree gives a prediction and the random forest combines the tree predictions. The final output is the predicted class label for the input feature vector.
Possible exceptions during training:
One might come across some of the exceptions below, which has to be taken care of
java.lang.IllegalArgumentException – requirement failed – DecisionTree requires maxBins
When maxBins = 2 and
maximum number of discrete values for a feature in our training data is : 10
Exception in thread “main” java.lang.IllegalArgumentException: requirement failed: DecisionTree requires maxBins (=2) to be at least as large as the number of values in each categorical feature, but categorical feature 2 has 10 values. Considering remove this and other categorical features with a large number of values, or add more training examples.
Solution : Provide maxBins with value >= max(maximum discrete value + 1) among all the features with discrete values.
java.lang.IllegalArgumentException: GiniAggregator given label
When numClasses = 2 and
training data has three categories [0,1,2]
Caused by: java.lang.IllegalArgumentException: GiniAggregator given label 2.0 but requires label < numClasses (= 2).
Solution : Provide numClasses with value >= number of categories in the training data.
Spark Random Forest Training Troubleshooting Checks
- Check that labels in the training file are numeric and start from
0for this classifier setup. - Check that
numClassescovers every label present in the training data. - For categorical features, keep
maxBinsat least as large as the highest category count used incategoricalFeaturesInfo. - Confirm that the same feature order is used in both training and test files.
- Use a fixed
seedwhen you want repeatable training output for debugging.
Prediction using the saved model from the above Random Forest Classification Example using Spark MLlib – Training part:
Sample of the test data is shown below. Little observation reveals that the format of the test data is same as that of training data.
0 1:1 2:4 3:1 4:1 5:1 6:3
0 1:1 2:1 3:1 4:1 5:1 6:6
1 1:2 2:1 3:5 4:1 5:1 6:6
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:3 3:1 4:1 5:1 6:1
2 1:2 2:6 3:9 4:6 5:1 6:8
2 1:2 2:6 3:9 4:4 5:1 6:8
Prediction using the model generated during training :
Predictor Class : RandomForestPredictor.java
package com.tut;
import scala.Tuple2;
import java.io.File;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;
/** RandomForest Classification Example using Spark MLlib
* @author tutorialkart.com
*/
public class RandomForestPredictor {
static RandomForestModel model;
public static void main(String[] args) {
// hadoop home dir [path to bin folder containing winutils.exe]
System.setProperty("hadoop.home.dir", "D:\\Arjun\\ml\\hadoop\\");
// Configuring spark
SparkConf sparkConf1 = new SparkConf().setAppName("RandomForestExample")
.setMaster("local[2]")
.set("spark.executor.memory","3g")
.set("spark.driver.memory", "3g");
// initializing the spark context
JavaSparkContext jsc = new JavaSparkContext(sparkConf1);
// loading the model, that is generated during training
model = RandomForestModel.load(jsc.sc(),"RandForestClsfrMdl"+File.separator+"model");
// Load and parse the test data file.
String datapath = "data"+File.separator+"testValues.txt";
JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
System.out.println("\nPredicted : Expected");
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
System.out.println(model.predict(p.features())+" : "+p.label());
return new Tuple2<>(model.predict(p.features()), p.label());
}
});
// compute error of the model to predict the categories for test samples/experiments
Double testErr =
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / data.count();
System.out.println("Test Error: " + testErr);
jsc.stop();
}
private static PairFunction<LabeledPoint, Double, Double> pf = new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
Double prediction= null;
try {
prediction = model.predict(p.features());
} catch (Exception e) {
//logger.error(ExceptionUtils.getStackTrace(e));
e.printStackTrace();
}
System.out.println(prediction+" : "+p.label());
return new Tuple2<>(prediction, p.label());
}
};
private static Function<Tuple2<Double, Double>, Boolean> f = new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
};
}
Output
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
Predicted : Expected
1.0 : 1.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
1.0 : 1.0
1.0 : 1.0
2.0 : 2.0
0.0 : 0.0
2.0 : 1.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
1.0 : 1.0
0.0 : 0.0
Test Error: 0.047619047619047616
For the test data, we provided, the model has a test error of approximately 4.76%. Accuracy is calculated as (1 - testErr) * 100, so the model accuracy is approximately 95.24% for this sample test data.
How Test Error and Accuracy are Computed in the Spark Java Predictor
The predictor compares each predicted label with the actual label from the test file. If the predicted label is different from the actual label, the row is counted as an error. The test error is the number of wrong predictions divided by the total number of test rows.
| Metric | Formula used in this example | Meaning |
| Test Error | Wrong predictions / Total test rows | Lower value is better. |
| Accuracy | (1 – Test Error) * 100 | Percentage of correctly predicted test rows. |
RDD-based spark.mllib RandomForest vs DataFrame-based spark.ml RandomForestClassifier
This tutorial uses org.apache.spark.mllib.tree.RandomForest, which belongs to the RDD-based MLlib API. Spark also provides org.apache.spark.ml.classification.RandomForestClassifier, which works with DataFrames and Spark ML pipelines.
| API | Typical input | Typical use |
spark.mllib | RDD of LabeledPoint | Older Java, Scala, or Python MLlib examples and existing RDD-based code. |
spark.ml | DataFrame with label and features columns | Modern Spark ML pipelines with transformers, estimators, and evaluators. |
If you are starting a new project, the DataFrame-based API is usually easier to combine with feature transformers, parameter grids, cross-validation, and evaluators. If you are maintaining the Java code in this tutorial, keep the input format, categorical feature map, and model path consistent across training and prediction.
Spark MLlib Random Forest Classifier QA Checklist
- Verify that the training and test files use the same LibSVM feature order.
- Confirm that every label in the data is covered by
numClasses. - Check that every categorical feature in
categoricalFeaturesInfohas the correct category count. - Use a separate test dataset instead of measuring only on the training rows.
- Review test error, accuracy, and per-class prediction behavior before using the model output.
Spark MLlib RandomForest Classification FAQs
Which Spark API is used in this RandomForest Java example?
This tutorial uses the RDD-based spark.mllib API, specifically org.apache.spark.mllib.tree.RandomForest and RandomForestModel. The newer DataFrame-based API uses org.apache.spark.ml.classification.RandomForestClassifier.
What input format is used for Spark MLlib random forest training?
The example uses LibSVM-style input. Each row starts with a numeric label, followed by feature-value pairs such as 1:1, 2:4, and 3:1. The same feature structure should be used for training and test data.
How do I fix the DecisionTree requires maxBins exception?
Increase maxBins so that it is at least as large as the number of values in every categorical feature. Also check whether a high-cardinality feature should really be treated as categorical.
How is accuracy calculated from test error in this Spark example?
Accuracy is calculated as (1 - testErr) * 100. For example, when the test error is 0.047619047619047616, the approximate accuracy is 95.24%.
Can Spark MLlib RandomForest handle multiclass classification?
Yes. This example uses three classes, represented by labels 0, 1, and 2. Set numClasses to cover the number of labels present in your training data.
What this Spark MLlib RandomForest classifier example demonstrated
In this Apache Spark Tutorial – RandomForest Classification Example using Spark MLlib, we have learned how to train and predict for a classification problem using RandomForest Classification Example in Apache Spark MLlib.
TutorialKart.com