Classification is a supervised machine learning task where a model learns from labeled examples and predicts one label from a fixed set of classes. In Apache Spark MLlib, logistic regression is commonly used for binary classification, and it can also be configured for multiclass classification.
Logistic regression does not predict a continuous number like linear regression. It estimates class probabilities from input features and then assigns the record to the class with the highest score. For example, a record may be classified as spam or not spam, fraudulent or valid, or one category among several numeric labels.
In Spark MLlib, the training data is usually prepared as labeled feature vectors. Each row contains a label and a vector of numeric features. The label is the known class during training, and the features are the measured values used by the algorithm to learn the relationship between inputs and classes.
This tutorial uses the RDD-based MLlib API in Java, specifically org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS. Spark also provides the newer DataFrame-based ML API under org.apache.spark.ml. If you are starting a new Spark machine learning project, also review the official Spark ML classification and regression guide at spark.apache.org/docs/latest/ml-classification-regression.html.
Classification using Logistic Regression in Apache Spark MLlib with Java
In this Apache Spark Tutorial, we shall look into an example, with step by step explanation, in generating a Logistic Regression Model for classification using Spark MLlib.
What the Spark MLlib logistic regression example does
The program loads a sample LIBSVM file, splits the records into training and test datasets, trains a logistic regression classifier, evaluates the prediction accuracy, and saves the trained model to disk. The example is useful for understanding the Java workflow before moving to larger datasets or a production pipeline.
- Input format: labeled points loaded from a LIBSVM text file.
- Training step: logistic regression with LBFGS optimization.
- Evaluation step: predictions are compared with actual labels using
MulticlassMetrics. - Output: model accuracy and a saved classifier model directory.
Data format expected by MLUtils.loadLibSVMFile
The sample file used here is data/mllib/sample_libsvm_data.txt from the Apache Spark distribution. LIBSVM format stores each record as a label followed by sparse feature entries. A simplified line looks like the following.
label index1:value1 index2:value2 index3:value3
For logistic regression, labels must be numeric. In a binary classification problem, labels are commonly 0 and 1. In a multiclass problem, labels are commonly indexed as 0, 1, 2, and so on.
Step 1: Configure Spark for the logistic regression classifier
1. Configure Spark.
SparkConf conf = new SparkConf().setAppName("LogisticRegressionClassifier")
.setMaster("local[2]").set("spark.executor.memory","2g");
The application name helps identify the job in Spark logs or the Spark UI. The master value local[2] runs Spark locally with two worker threads, which is enough for a small tutorial dataset.
Step 2: Start the Spark context in Java
2. Start a spark context.
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
The Spark context is the entry point for RDD operations. It connects your Java program to the Spark execution environment and is required for loading the input data file.
Step 3: Load LIBSVM data and split it into training and test sets
3. Load Data and Split the data to be used for training and testing. The data file used in this example is present in the folder “data” in “apache spark“, downloaded from official website.
// provide path to data transformed as [feature vectors]
String path = "data/mllib/sample_libsvm_data.txt";
JavaRDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
// Split initial RDD into two... [80% training data, 20% testing data].
JavaRDD[] splits = data.randomSplit(new double[] {0.8, 0.2}, 11L);
JavaRDD training = splits[0].cache();
JavaRDD test = splits[1];
The split uses 80% of the labeled records for training and 20% for testing. The seed value 11L makes the random split reproducible, so repeated runs produce the same split as long as the input data is the same.
The training RDD is cached because the algorithm may reuse it during model training. Caching avoids recomputing the same RDD lineage repeatedly for this small example and becomes more useful with larger datasets.
Step 4: Train the Spark MLlib logistic regression model
4. Train a Logistic Regression model.
LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.run(training.rdd());
The setNumClasses(10) call tells the algorithm that the classification problem has ten possible class labels. For a binary classification dataset, this value would usually be 2. Always set the number of classes to match the labels in your dataset.
LogisticRegressionWithLBFGS uses the LBFGS optimization method to fit the logistic regression model. After the model is trained, it can take a feature vector and return the predicted class label.
Step 5: Predict labels and calculate classification accuracy
5. Use the model to predict on the test data, and calculate accuracy.
// Compute raw scores on the test set.
JavaPairRDD<Object, Object> predictionAndLabels = test.mapToPair(p ->
new Tuple2<>(model.predict(p.features()), p.label()));
// get evaluation metrics
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
double accuracy = metrics.accuracy();
System.out.println("Accuracy = " + accuracy);
The predictionAndLabels RDD stores pairs of predicted label and actual label. MulticlassMetrics then calculates evaluation values from those pairs. Accuracy is the fraction of test records where the predicted label matches the true label.
Accuracy is useful for a first check, but it may not be enough when the dataset is imbalanced. For practical classification work, also inspect confusion matrix, precision, recall, and F-measure when those metrics are relevant to your problem.
Step 6: Save the trained Spark logistic regression classifier
6. Save the trained classifier model to local for future use.
model.save(jsc, "LogisticRegressionClassifier");
This creates a model directory named LogisticRegressionClassifier. If the directory already exists, Spark may throw an error, so delete or rename the existing directory before saving again.
Step 7: Stop the Spark context after training and evaluation
7. Stop the spark context.
jsc.stop();
Stopping the Spark context releases the resources used by the Spark application. For local runs, this closes the local Spark execution environment used by the Java program.
Complete Java program for Logistic Regression classification in Spark MLlib
Complete example program is given below.
LogisticRegressionClassifierExample.java
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
/**
* Example for Logistic Regression Classifier
*/
public class LogisticRegressionClassifierExample {
public static void main(String[] args) {
// configure spark
SparkConf conf = new SparkConf().setAppName("LogisticRegressionClassifier")
.setMaster("local[2]").set("spark.executor.memory","2g");
// start a spark context
SparkContext jsc = new SparkContext(conf);
// provide path to data transformed as [feature vectors]
String path = "data/mllib/sample_libsvm_data.txt";
JavaRDD data = MLUtils.loadLibSVMFile(jsc, path).toJavaRDD();
// Split initial RDD into two... [80% training data, 20% testing data].
JavaRDD[] splits = data.randomSplit(new double[] {0.8, 0.2}, 11L);
JavaRDD training = splits[0].cache();
JavaRDD test = splits[1];
// Run training algorithm to build the model.
LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.run(training.rdd());
// Compute raw scores on the test set.
JavaPairRDD<Object, Object> predictionAndLabels = test.mapToPair(p ->
new Tuple2<>(model.predict(p.features()), p.label()));
// get evaluation metrics
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
double accuracy = metrics.accuracy();
System.out.println("Accuracy = " + accuracy);
// After training, save model to local for prediction in future
model.save(jsc, "LogisticRegressionClassifier");
// stop the spark context
jsc.stop();
}
}
Output
Accuracy = 0.9523809523809523
How to read the accuracy result from this Spark classifier
The output value 0.9523809523809523 means that the model predicted about 95.24% of the test records correctly for this particular split of the sample data. Because the dataset is small and the split is controlled by a random seed, the value should be treated as a tutorial result, not as a general benchmark for logistic regression.
If you change the input data, the random split seed, the train-test ratio, the number of classes, or model parameters, the accuracy can change. For real projects, evaluate the model on data that was not used during training and choose metrics that match the cost of wrong predictions in your use case.
Common mistakes in Spark MLlib Logistic Regression with Java
- Wrong number of classes:
setNumClasses()must match the labels in the dataset. Use2for binary classification and the actual class count for multiclass classification. - Non-numeric labels: MLlib expects numeric labels. Convert string classes such as
yes,no,cat, ordoginto numeric labels before training. - Unscaled features: Logistic regression can be sensitive to feature scales. Consider feature normalization or standardization when your feature values have very different ranges.
- Training and testing on the same data: Always keep separate test data, otherwise the reported accuracy may be misleading.
- Existing model output directory: Saving to an existing path can fail. Use a fresh model directory or remove the previous one before saving.
RDD-based MLlib logistic regression versus Spark ML DataFrame API
The example above uses the older RDD-based MLlib package name org.apache.spark.mllib. This is still helpful when you are maintaining existing Spark Java code that already uses RDDs and LabeledPoint.
For new applications, the DataFrame-based API under org.apache.spark.ml is usually preferred because it works with pipelines, transformers, estimators, and DataFrame columns. The official Java API documentation for LogisticRegression is available at spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/LogisticRegression.html.
FAQ on Classification using Logistic Regression in Apache Spark MLlib with Java
Can Logistic Regression in Spark MLlib be used for multiclass classification?
Yes. The RDD-based MLlib example uses LogisticRegressionWithLBFGS with setNumClasses(10), which configures the classifier for ten classes. For binary classification, set the number of classes to 2.
What is the role of LabeledPoint in this Spark MLlib Java example?
LabeledPoint represents one training or testing record. It contains the known label and the feature vector for that record. The logistic regression model learns from these labeled feature vectors.
Why does the example use LIBSVM data for logistic regression?
LIBSVM is a simple text format for labeled sparse feature vectors. Spark MLlib provides MLUtils.loadLibSVMFile, which makes it convenient for a small classification example without writing extra parsing code.
Is accuracy enough to evaluate a Spark logistic regression classifier?
Accuracy is a useful starting point, but it is not always enough. If the dataset is imbalanced or some mistakes are more costly than others, review additional metrics such as precision, recall, F-measure, and the confusion matrix.
Which Spark API should I use for new logistic regression projects in Java?
For new projects, consider the DataFrame-based org.apache.spark.ml.classification.LogisticRegression API because it integrates with Spark ML pipelines. Use the RDD-based org.apache.spark.mllib API when maintaining older RDD-based code or learning from legacy examples.
Editorial QA checklist for this Spark MLlib logistic regression tutorial
- Confirm that the tutorial clearly distinguishes classification from regression.
- Verify that every Java code block uses a PrismJS-compatible
language-javaclass. - Check that output-only examples use the
outputclass. - Ensure that the sample data path points to the Spark distribution’s
data/mllib/sample_libsvm_data.txtfile. - Confirm that the explanation states when to use the RDD-based MLlib API and when to consider the DataFrame-based Spark ML API.
TutorialKart.com