RandomForest Classification Example using Spark MLlib

RandomForest Classification Example using Spark MLlib – In this tutorial, we shall see how to train and generate a model using RandomForest classifier. And use this generated model on test to predict the categories and calculate Test Error and Accuracy of the model.

This example uses the older RDD-based spark.mllib API with Java. It is still useful when you are maintaining existing MLlib code, learning how RandomForest.trainClassifier works, or working with LibSVM files. For new Spark machine learning applications, also review the DataFrame-based Spark RandomForestClassifier API because modern Spark pipelines generally use spark.ml.

What Random Forest Classification Does in Spark MLlib

A random forest classifier builds many decision trees and combines their predictions. For classification problems, the model predicts a class label such as 0, 1, or 2. This is different from random forest regression, where the model predicts a numeric value.

In Spark MLlib, the RDD-based random forest classifier can work with binary and multiclass labels. It can also use both continuous features and categorical features, provided the categorical feature information is supplied correctly.

Training using Random Forest classifier

Spark MLlib understands only numbers. So, the training data should be prepared in a way that MLlib understands. Preparing the training data is the most important step that decides the accuracy a model. And this includes the following

  1. Identify the categories. And index the categories.
  2. Identify the features. And index the features.
  3. Transform the experiments/observations/examples using indexes of categories and features

Note: Feature values could be discrete or continuous. Comments have been provided in the program to make some of the features discrete and others as continuous. With this as reference, features could be configured as per your requirement.

Download the source code of the ongoing example here, RandomForestExampleAttachment. For setting up java project to work with spark MLlib , please refer Create Java Project with Apache Spark.

Random Forest Classifier Input Format in this Spark MLlib Example

The training and test files in this tutorial use LibSVM-style input. Each line contains one labeled example. The first value is the class label, followed by indexed feature-value pairs.

Part of rowMeaning in this example
0, 1, 2Class label to be predicted by the classifier.
1:1Feature index 1 has value 1 in the input file.
2:4Feature index 2 has value 4 in the input file.
One lineOne training or test observation.

In the Java program, categorical feature information is supplied through categoricalFeaturesInfo. Be careful with feature indexing: LibSVM files usually show feature indices starting from 1, while Spark vectors are handled with zero-based feature positions inside the program.

Sample Training Data for Random Forest

Below is the sample of transformed and ready to be fed, to the RandomForest, to train on. Each row represents an experiment/observation/example. The format of each row is [category feature1:value feature2:value ..]

Training data: trainingValues.txt

0 1:1 2:1 3:1 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:1 3:5 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:1 2:3 3:1 4:1 5:1 6:1

Spark MLlib RandomForest Training Parameters Used in the Java Program

The random forest model is controlled by several hyperparameters. These settings affect how many trees are trained, how deep each tree can grow, how categorical features are handled, and how repeatable the model is.

ParameterUsed valuePurpose
numClasses3Number of possible class labels in the training data.
categoricalFeaturesInfoMap of feature index to category countTells Spark which features are categorical and how many distinct values each categorical feature can take.
numTrees3Number of decision trees in the forest.
featureSubsetStrategyautoLets Spark decide how many features are considered at each split.
impurityginiCriterion used to choose classification splits.
maxDepth30Maximum depth allowed for each decision tree.
maxBins10Maximum number of bins used while splitting continuous and categorical features.
seed12345Controls repeatability for the random parts of training.

For a small learning dataset, numTrees = 3 keeps the output easy to read. For real data, you usually test several values for numTrees, maxDepth, and maxBins instead of relying on one fixed setting.

Below is the java class, RandomForestTrainerExample.java, that trains a model and saves it to local.

Trainer Class: RandomForestTrainerExample.java

</>
Copy
package com.tut;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;

import java.util.HashMap;

import org.apache.commons.io.FileUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;

/** RandomForest Classification Example using Spark MLlib
 * @author tutorialkart.com
 */
public class RandomForestTrainerExample {
	
public static void main(String[] args) {
	// hadoop home dir [path to bin folder containing winutils.exe]
	System.setProperty("hadoop.home.dir", "D:\\Arjun\\ml\\hadoop\\");
	  
    // Configuring spark
    SparkConf sparkConf = new SparkConf().setAppName("RandomForestExample")
    		.setMaster("local[2]")
    		.set("spark.executor.memory","3g")
    		.set("spark.driver.memory", "3g");
    
    // initializing the spark context
    JavaSparkContext jsc = new JavaSparkContext(sparkConf);
    
    // Load and parse the data file.
    String datapath = "data"+File.separator+"trainingValues.txt";
    JavaRDD trainingData;
	try {
		trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
	} catch (Exception e1) {
		System.out.println("No training data available.");
		e1.printStackTrace();
		return;
	}
	
    // Configuration/Hyper parameters to train random forest model
	Integer numClasses = 3;
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    HashMap<Integer, Integer> categoricalFeaturesInfo =new HashMap<Integer, Integer>(){{
    		put(0,3);	// feature 0 is considered discrete, with values from 0 to 9
    		put(1,7);  // feature 1 is considered discrete, with values from 0 to 6
    		put(2,10);  // feature 2 is considered discrete, with values from 0 to 9
    		// feature 3 is considered continuous valued
    		put(4,10);  // feature 4 is considered discrete, with values from 0 to 9
    		// feature 5 is considered continuous valued
    }};
    Integer numTrees = 3; // number of decision trees to be included in the Random Forest
    String featureSubsetStrategy = "auto"; // Let the algorithm choose, which set of features to be made as subsets
    String impurity = "gini";	// adds impurity to the experiments/samples in the training data : gini is a choice
    Integer maxDepth = 30;	// maximum depth of a decision tree that can grow
    Integer maxBins = 10;	// classifier first splits the training data into number of bins, and this parameter decides the maximum number of bins
    Integer seed = 12345;	// classifier introduces some randomization, and for this randomization to be same across iterations, same seed is used in all the iterations inside classifier 

    // training the classifier with all the hyper-parameters defined above
    final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
      categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
      seed);

    System.out.print("");
    // Delete if model already present, and Save the new model
	try {
		FileUtils.forceDelete(new File("RandForestClsfrMdl"));
		System.out.println("\nDeleting old model completed.");
	} catch (FileNotFoundException e1) {
	} catch (IOException e) {
	}
	// saving the random forest model that is generated
    model.save(jsc.sc(), "RandForestClsfrMdl"+File.separator+"model");
    System.out.println("\nRandForestClsfrMdl/model has been created and successfully saved.");
    
    // printing the random forest model (collection of decision trees)
    System.out.println(model.toDebugString());
    
    jsc.stop();
    
  }
}

When the above java class is run, a model is generated, with three decision trees which are shown in the below output.

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".                
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
                                                                                
RandForestClsfrMdl/model has been created and successfully saved.
TreeEnsembleModel classifier with 3 trees

  Tree 0:
    If (feature 5 <= 6.0)
     If (feature 0 in {1.0})
      If (feature 1 in {3.0})
       Predict: 1.0
      Else (feature 1 not in {3.0})
       If (feature 5 <= 2.0) If (feature 2 in {1.0}) Predict: 0.0 Else (feature 2 not in {1.0}) Predict: 1.0 Else (feature 5 > 2.0) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 5 > 6.0)
     Predict: 2.0
  Tree 1:
    If (feature 5 <= 6.0) If (feature 0 in {1.0}) If (feature 2 in {6.0}) Predict: 1.0 Else (feature 2 not in {6.0}) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 5 > 6.0)
     If (feature 3 <= 1.0) Predict: 0.0 Else (feature 3 > 1.0)
      Predict: 2.0
  Tree 2:
    If (feature 3 <= 1.0) If (feature 2 in {5.0,6.0}) Predict: 1.0 Else (feature 2 not in {5.0,6.0}) If (feature 0 in {1.0}) If (feature 1 in {1.0}) Predict: 0.0 Else (feature 1 not in {1.0}) If (feature 1 in {3.0}) Predict: 1.0 Else (feature 1 not in {3.0}) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 3 > 1.0)
     Predict: 2.0

From the above random forest, following observation could be made:
. features : 0,1,2,4 are considered discrete as [feature 2 not in {5.0,6.0}]
. features : 3,5 are considered continuous as [feature 5 > 6.0]

How to Read the Spark MLlib Random Forest Model Output

The printed model is a collection of decision trees. Each tree contains split rules and a final predicted label. A rule such as feature 0 in {1.0} is a categorical split, while a rule such as feature 5 <= 6.0 is a continuous numeric split.

During prediction, each tree gives a prediction and the random forest combines the tree predictions. The final output is the predicted class label for the input feature vector.

Possible exceptions during training:

One might come across some of the exceptions below, which has to be taken care of

java.lang.IllegalArgumentException – requirement failed – DecisionTree requires maxBins

When  maxBins = 2   and
maximum number of discrete values for a feature in our training data is : 10
Exception in thread “main” java.lang.IllegalArgumentException: requirement failed: DecisionTree requires maxBins (=2) to be at least as large as the number of values in each categorical feature, but categorical feature 2 has 10 values. Considering remove this and other categorical features with a large number of values, or add more training examples.

Solution : Provide maxBins with value >= max(maximum discrete value + 1) among all the features with discrete values.

java.lang.IllegalArgumentException: GiniAggregator given label

When numClasses = 2    and
training data has three categories [0,1,2]
Caused by: java.lang.IllegalArgumentException: GiniAggregator given label 2.0 but requires label < numClasses (= 2).

Solution : Provide numClasses with value >= number of categories in the training data.

Spark Random Forest Training Troubleshooting Checks

  • Check that labels in the training file are numeric and start from 0 for this classifier setup.
  • Check that numClasses covers every label present in the training data.
  • For categorical features, keep maxBins at least as large as the highest category count used in categoricalFeaturesInfo.
  • Confirm that the same feature order is used in both training and test files.
  • Use a fixed seed when you want repeatable training output for debugging.

Prediction using the saved model from the above Random Forest Classification Example using Spark MLlib – Training part:

Sample of the test data is shown below. Little observation reveals that the format of the test data is same as that of training data.

0 1:1 2:4 3:1 4:1 5:1 6:3
0 1:1 2:1 3:1 4:1 5:1 6:6
1 1:2 2:1 3:5 4:1 5:1 6:6
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:3 3:1 4:1 5:1 6:1
2 1:2 2:6 3:9 4:6 5:1 6:8
2 1:2 2:6 3:9 4:4 5:1 6:8

Prediction using the model generated during training :

Predictor Class : RandomForestPredictor.java

</>
Copy
package com.tut;

import scala.Tuple2;

import java.io.File;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;

/** RandomForest Classification Example using Spark MLlib
 * @author tutorialkart.com
 */
public class RandomForestPredictor {
	static RandomForestModel model;

	public static void main(String[] args) {
		// hadoop home dir [path to bin folder containing winutils.exe]
		System.setProperty("hadoop.home.dir", "D:\\Arjun\\ml\\hadoop\\");

		// Configuring spark
		SparkConf sparkConf1 = new SparkConf().setAppName("RandomForestExample")
				.setMaster("local[2]")
				.set("spark.executor.memory","3g")
				.set("spark.driver.memory", "3g");
		
		// initializing the spark context
		JavaSparkContext jsc = new JavaSparkContext(sparkConf1);
		
		// loading the model, that is generated during training
		model = RandomForestModel.load(jsc.sc(),"RandForestClsfrMdl"+File.separator+"model");
		
		// Load and parse the test data file.
		String datapath = "data"+File.separator+"testValues.txt";
		JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
		
		System.out.println("\nPredicted : Expected");
		
		// Evaluate model on test instances and compute test error
		JavaPairRDD<Double, Double> predictionAndLabel =
				data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
					@Override
					public Tuple2<Double, Double> call(LabeledPoint p) {
						System.out.println(model.predict(p.features())+" : "+p.label());
						return new Tuple2<>(model.predict(p.features()), p.label());
					}
				});
		
		// compute error of the model to predict the categories for test samples/experiments 
		Double testErr =
				1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
					@Override
					public Boolean call(Tuple2<Double, Double> pl) {
						return !pl._1().equals(pl._2());
					}
				}).count() / data.count();
		System.out.println("Test Error: " + testErr);

		jsc.stop();
	}

	private static PairFunction<LabeledPoint, Double, Double> pf =  new PairFunction<LabeledPoint, Double, Double>() {
		@Override
		public Tuple2<Double, Double> call(LabeledPoint p) {
			Double prediction= null;
			try {
				prediction = model.predict(p.features());
			} catch (Exception e) {
				//logger.error(ExceptionUtils.getStackTrace(e));
				e.printStackTrace();
			}
			System.out.println(prediction+" : "+p.label());
			return new Tuple2<>(prediction, p.label());
		}
	};
	
	private static Function<Tuple2<Double, Double>, Boolean> f = new Function<Tuple2<Double, Double>, Boolean>() {
		@Override
		public Boolean call(Tuple2<Double, Double> pl) {
			return !pl._1().equals(pl._2());
		}
	};
}

Output

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
                                                                                
Predicted : Expected
1.0 : 1.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
1.0 : 1.0
1.0 : 1.0
2.0 : 2.0
0.0 : 0.0
2.0 : 1.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
1.0 : 1.0
0.0 : 0.0
Test Error: 0.047619047619047616

For the test data, we provided, the model has a test error of approximately 4.76%. Accuracy is calculated as (1 - testErr) * 100, so the model accuracy is approximately 95.24% for this sample test data.

How Test Error and Accuracy are Computed in the Spark Java Predictor

The predictor compares each predicted label with the actual label from the test file. If the predicted label is different from the actual label, the row is counted as an error. The test error is the number of wrong predictions divided by the total number of test rows.

MetricFormula used in this exampleMeaning
Test ErrorWrong predictions / Total test rowsLower value is better.
Accuracy(1 – Test Error) * 100Percentage of correctly predicted test rows.

RDD-based spark.mllib RandomForest vs DataFrame-based spark.ml RandomForestClassifier

This tutorial uses org.apache.spark.mllib.tree.RandomForest, which belongs to the RDD-based MLlib API. Spark also provides org.apache.spark.ml.classification.RandomForestClassifier, which works with DataFrames and Spark ML pipelines.

APITypical inputTypical use
spark.mllibRDD of LabeledPointOlder Java, Scala, or Python MLlib examples and existing RDD-based code.
spark.mlDataFrame with label and features columnsModern Spark ML pipelines with transformers, estimators, and evaluators.

If you are starting a new project, the DataFrame-based API is usually easier to combine with feature transformers, parameter grids, cross-validation, and evaluators. If you are maintaining the Java code in this tutorial, keep the input format, categorical feature map, and model path consistent across training and prediction.

Spark MLlib Random Forest Classifier QA Checklist

  • Verify that the training and test files use the same LibSVM feature order.
  • Confirm that every label in the data is covered by numClasses.
  • Check that every categorical feature in categoricalFeaturesInfo has the correct category count.
  • Use a separate test dataset instead of measuring only on the training rows.
  • Review test error, accuracy, and per-class prediction behavior before using the model output.

Spark MLlib RandomForest Classification FAQs

Which Spark API is used in this RandomForest Java example?

This tutorial uses the RDD-based spark.mllib API, specifically org.apache.spark.mllib.tree.RandomForest and RandomForestModel. The newer DataFrame-based API uses org.apache.spark.ml.classification.RandomForestClassifier.

What input format is used for Spark MLlib random forest training?

The example uses LibSVM-style input. Each row starts with a numeric label, followed by feature-value pairs such as 1:1, 2:4, and 3:1. The same feature structure should be used for training and test data.

How do I fix the DecisionTree requires maxBins exception?

Increase maxBins so that it is at least as large as the number of values in every categorical feature. Also check whether a high-cardinality feature should really be treated as categorical.

How is accuracy calculated from test error in this Spark example?

Accuracy is calculated as (1 - testErr) * 100. For example, when the test error is 0.047619047619047616, the approximate accuracy is 95.24%.

Can Spark MLlib RandomForest handle multiclass classification?

Yes. This example uses three classes, represented by labels 0, 1, and 2. Set numClasses to cover the number of labels present in your training data.

What this Spark MLlib RandomForest classifier example demonstrated

In this Apache Spark Tutorial – RandomForest Classification Example using Spark MLlib, we have learned how to train and predict for a classification problem using RandomForest Classification Example in Apache Spark MLlib.