Classification is a task of identifying the features of an entity and assigning the entity to one of the predefined classes or categories. In a supervised machine learning problem, the model learns this mapping from labelled training data and then predicts labels for new data.

A decision tree has a structure like tree. It has a root which denotes a decision node and also the start of classifying a problem instance. A node can branch out. Each branch represents a possible outcome from the decision block. Each branch can end up with another node or a class label terminating the classification and ending up with the result – class label.

In this Apache Spark Tutorial, we shall build a decision tree, like the one said above, from the training data using Decision Tree Algorithm in Apache Spark MLlib.

Classification using Decision Trees in Apache Spark MLlib with Java

This example uses the RDD-based MLlib API from the org.apache.spark.mllib package. It is useful when you are maintaining older Spark MLlib Java code or learning how Spark’s original decision tree classifier works. For new Spark machine learning projects, also review the DataFrame-based org.apache.spark.ml API, because Spark’s newer machine learning pipelines are built around DataFrames, transformers, estimators, and pipeline stages.

Decision tree classification flow in Spark MLlib Java

A decision tree classifier in Spark MLlib follows the same basic machine learning flow used by many supervised learning algorithms:

  1. Load labelled data as LabeledPoint records.
  2. Split the data into training and test sets.
  3. Set decision tree parameters such as number of classes, impurity, maximum depth, and maximum bins.
  4. Train the decision tree classifier on the training data.
  5. Predict labels for the test data.
  6. Compare predicted labels with actual labels and calculate accuracy.
  7. Save the trained model if it needs to be reused.

The sample program below uses Spark’s sample_libsvm_data.txt file. LIBSVM format stores each row as a label followed by sparse feature index and value pairs. Spark MLlib can load this format directly with MLUtils.loadLibSVMFile().

Following is a step by step process to build a classifier using Decision Tree algorithm of MLLib :
Setup Java Project with Apache Spark

Step 1: Configure Spark for the Java decision tree example

1. Configure Spark.

</>
Copy
SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeExample");

In local development, the complete program later uses setMaster("local[2]") so that Spark can run with two local worker threads. In a cluster, the master is normally supplied through spark-submit instead of hard-coding it in the application.

2. Start a spark context.

</>
Copy
JavaSparkContext jsc = new JavaSparkContext(sparkConf);

JavaSparkContext is the entry point for RDD-based Spark code in Java. It is used here to load data, create RDDs, train the model, save the model, and stop the Spark application.

Step 2: Load LIBSVM training data for Spark MLlib classification

3. Load Data and Split the data to be used for training and testing. The data file used in this example is present in the folder “data” in “apache spark“, downloaded from official website.

</>
Copy
// provide path to data transformed as [feature vectors]
String path = "data/mllib/sample_libsvm_data.txt";
JavaRDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
 
// split the data for training (60%) and testing (40%)
JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4});
JavaRDD trainingData = tmp[0]; // training set
JavaRDD testData = tmp[1]; // test set

The training set is used to build the classifier. The test set is held back so that the trained model can be evaluated on records it did not use during training. The split ratio in this example is 60% training and 40% testing. For repeatable results in production examples, you can use the overloaded randomSplit() method that accepts a random seed.

Step 3: Set DecisionTree classifier parameters in Spark MLlib

4. Set the hyper parameters required by Decision Tree. impurity : impurity introduced into the feature values, to avoid Decision Tree model over-fitting the training data. maxDepth : Maximum number of node levels that can be created from root node by the Decision Tree algorithm during training. maxBins : Before even starting with the training a model, the training data is shuffled into bins. maxBins sets a limit on the number of data bins that could be created.

</>
Copy
int numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
int maxDepth = 5;
int maxBins = 32;

The parameter values should match the data and the classification problem:

Spark MLlib decision tree parameterMeaning in this Java examplePractical note
numClassesNumber of possible output labels.The sample data is a binary classification problem, so the value is 2.
categoricalFeaturesInfoMap that tells Spark which features are categorical and how many categories each feature has.An empty map means all features are treated as continuous.
impurityCriterion used to choose splits.For classification, common options include gini and entropy.
maxDepthMaximum tree depth allowed during training.A larger value may fit training data better but can increase overfitting.
maxBinsMaximum number of bins used for splitting continuous features.For categorical features, this value must be large enough to handle the number of categories.

Step 4: Train the Spark MLlib DecisionTreeModel classifier

5. Train a Decision Tree model.

</>
Copy
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
          categoricalFeaturesInfo, impurity, maxDepth, maxBins);

DecisionTree.trainClassifier() returns a DecisionTreeModel. The model stores the learned tree structure and can predict a label for each feature vector. The result can also be printed with toDebugString() to inspect the splits selected by the algorithm.

Step 5: Predict labels and calculate decision tree accuracy

6. Use the model to predict on the test data, and calculate accuracy. Decision Tree that is generated could be visualized by converting the tree to a readable string.

</>
Copy
// Predict for the test data using the model trained
JavaPairRDD<Double, Double> predictionAndLabel =
        testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
// calculate the accuracy
double accuracy =
        predictionAndLabel.filter(pl -> pl._1().equals(pl._2())).count() / (double) testData.count();
 
System.out.println("Accuracy is : "+accuracy);
System.out.println("Trained Decision Tree model:\n" + model.toDebugString());

The predictionAndLabel RDD contains pairs in the form (predictedLabel, actualLabel). The accuracy calculation counts how many predictions match the actual labels and divides that count by the number of test records.

Accuracy is easy to understand, but it is not always enough. If one class appears much more often than another class, also inspect class-level metrics such as precision, recall, F1 score, or a confusion matrix. For this small sample program, accuracy is sufficient to show that the model can be trained and used for prediction.

Step 6: Save and reuse the Spark MLlib decision tree model

7. Save the trained classifier model to local for future use.

</>
Copy
model.save(jsc.sc(), "myDecisionTreeClassificationModel");

Saving the model is useful when training is expensive or when the same classifier must be reused later for batch scoring. The save path must not already exist, otherwise Spark may fail with an output path already exists error. Delete or rename the previous model directory before saving again during repeated local tests.

</>
Copy
DecisionTreeModel sameModel = DecisionTreeModel.load(
        jsc.sc(),
        "myDecisionTreeClassificationModel"
);

The loading snippet above shows how a saved MLlib decision tree model can be loaded again from the same model path.

8. Stop the spark context.

</>
Copy
jsc.stop();

Always stop the Spark context at the end of the program. This releases resources used by the local Spark application or cluster connection.

Complete Java program for Decision Tree classification in Apache Spark MLlib

Complete Java program is given below.

DecisionTreeClassifierExample.java

</>
Copy
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.SparkConf;

/**
 * Classification using Decision Trees in Apache Spark MLlib with Java Example 
 */
public class DecisionTreeClassifierExample {

	public static void main(String[] args) {

		// configure spark
		SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeExample")
										.setMaster("local[2]").set("spark.executor.memory","2g");
		// start a spark context
		JavaSparkContext jsc = new JavaSparkContext(sparkConf);
		
		// provide path to data transformed as [feature vectors]
		String path = "data/mllib/sample_libsvm_data.txt";
		JavaRDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
		
		// split the data for training (60%) and testing (40%)
		JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4});
		JavaRDD trainingData = tmp[0]; // training set
		JavaRDD testData = tmp[1]; // test set
		
	    // Set hyper parameters for Decision Tree algorithm
	    //  Empty categoricalFeaturesInfo indicates all features are continuous.
	    int numClasses = 2;
	    Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
	    String impurity = "gini";
	    int maxDepth = 5;
	    int maxBins = 32;
		
		// Train a Decision Tree model
	    DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
	    	      categoricalFeaturesInfo, impurity, maxDepth, maxBins);

		
		// Predict for the test data using the model trained
		JavaPairRDD<Double, Double> predictionAndLabel =
				testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
		// calculate the accuracy
		double accuracy =
				predictionAndLabel.filter(pl -> pl._1().equals(pl._2())).count() / (double) testData.count();
		
		System.out.println("Accuracy is : "+accuracy);
		System.out.println("Trained Decision Tree model:\n" + model.toDebugString());

		// Save model to local for future use
		model.save(jsc.sc(), "myDecisionTreeClassificationModel");

		// stop the spark context
		jsc.stop();
	}
}

Output of the Spark MLlib DecisionTree classifier example

Output

Accuracy is : 0.9787234042553191
Trained Decision Tree model:
DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 434 <= 0.0) Predict: 0.0 Else (feature 434 > 0.0)
   Predict: 1.0

The hyper parameters set only the limits. The Decision Tree algorithm may optimize the tree by reducing the number of nodes and branches. And in this example, despite the maxDepth=5, the tree has been optimized to a depth of 1.

Your output may not match the accuracy value exactly because randomSplit() can create a different training and test split when no seed is supplied. The learned tree can also differ if Spark version, data split, or parameter values change. This is normal for a machine learning example.

How to run this DecisionTreeClassifierExample Java program

After creating a Java project with Apache Spark dependencies, place the program in your source folder and make sure the sample data path is correct. If you run the program from a Spark distribution folder, the path data/mllib/sample_libsvm_data.txt usually points to Spark’s bundled sample dataset.

</>
Copy
spark-submit \
  --class DecisionTreeClassifierExample \
  --master local[2] \
  target/your-spark-java-project.jar

The exact JAR name depends on your build tool and project configuration. If the program cannot find the sample data file, provide an absolute path or copy the data file into the expected project location.

Common errors in Spark MLlib decision tree classification with Java

Issue in the Java Spark decision tree exampleLikely causeHow to fix it
FileNotFoundException for sample_libsvm_data.txtThe program is running from a directory where the relative data path is not valid.Use the correct absolute path or run from the Spark distribution directory that contains the data/mllib folder.
Model save path already existsThe folder myDecisionTreeClassificationModel was created by an earlier run.Delete the old folder or save to a new path before running again.
Accuracy changes between runsrandomSplit() is used without a fixed seed.Use a deterministic seed when you need repeatable local examples.
Unexpected poor accuracyData split, label quality, feature quality, or tree parameters may not fit the problem.Check labels, tune maxDepth, review categorical features, and evaluate with more than one metric.
Overfitted decision treeThe tree is allowed to grow too deep for the amount of data.Reduce maxDepth, increase validation, and compare training accuracy with test accuracy.

RDD-based MLlib DecisionTree vs DataFrame-based Spark ML DecisionTreeClassifier

This tutorial uses the RDD-based DecisionTree class from spark.mllib. Spark also provides a newer DataFrame-based DecisionTreeClassifier in the spark.ml package. The two APIs solve similar classification problems, but their project style is different.

Comparison pointRDD-based Spark MLlib APIDataFrame-based Spark ML API
Package styleUses classes under org.apache.spark.mllib.Uses classes under org.apache.spark.ml.
Data structureWorks mainly with RDDs and LabeledPoint.Works with DataFrames and columns such as label and features.
Pipeline supportManual flow is common.Designed for ML pipelines with transformers and estimators.
Best use caseMaintaining older MLlib examples and RDD-based Spark applications.Newer Spark machine learning projects that use DataFrames and pipeline stages.

If you are learning from older Java examples, the RDD-based version is still useful. If you are starting a new Spark machine learning application, compare this tutorial with Spark’s DataFrame-based decision tree documentation before choosing an API.

Decision tree classification FAQs for Apache Spark MLlib Java

What does categoricalFeaturesInfo mean in Spark MLlib decision trees?

categoricalFeaturesInfo tells Spark which feature indexes are categorical and how many categories each of those features has. In this example, the map is empty, so Spark treats all features as continuous.

Why is the trained decision tree depth 1 when maxDepth is 5?

maxDepth is only an upper limit. The algorithm can stop earlier when additional splits are not useful for the data. That is why the printed model can have a smaller depth than the configured maximum depth.

Why does the accuracy change when I run the Java Spark program again?

The example uses randomSplit() without a fixed seed. This can create a different training and test split in different runs. Use a fixed seed when you need repeatable results for debugging or documentation.

Should I use spark.mllib or spark.ml for a new decision tree classifier?

Use spark.ml for most new Spark machine learning projects because it supports DataFrame-based pipelines. Use spark.mllib when maintaining existing RDD-based code or following older examples like the one shown in this tutorial.

Can Spark MLlib decision trees handle multiclass classification?

Yes. Set numClasses to the number of output classes in the labelled data. The labels in the training data should match the expected class range for the classification problem.

Editorial QA checklist for this Spark MLlib Java decision tree tutorial

  • The tutorial clearly states that the example uses the RDD-based org.apache.spark.mllib API.
  • The Java code keeps the original DecisionTree.trainClassifier() flow unchanged.
  • The explanation covers numClasses, categoricalFeaturesInfo, impurity, maxDepth, and maxBins.
  • The output note explains why accuracy and tree structure can vary between runs.
  • The model-saving note warns that the output path must not already exist.
  • The tutorial distinguishes older MLlib RDD examples from newer Spark ML DataFrame pipelines.