Unit testing spark 2.x applications in Java

When you specific spark questions you might get the impression that java is the black sheep of the languages supported. Nearly all answers refer to scala or python. So it is with unit testing hence I am writing this post. I will show how to create a local context (what is pretty well documented) and how to read parquet files (or other formats like csv or json – the process is the same) from a source directory within your project. This way you can unit test your classes containing spark functions without connection to another file system or resource negotiator.

In the following example it is important to register the folder src/test/resources as class path / source code folder.

The annotations beforeClass and afterClass define methods that are called once the class is loaded the first respectively the last time.

package de.jofre.spark.tests;
 
import static org.assertj.core.api.Assertions.assertThat;
 
import org.apache.spark.SparkConf;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.feature.SQLTransformer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
 
import de.jofre.test.categories.UnitTest;
 
@Category(UnitTest.class)
public class SparkSessionAndFileReadTest {
	private static Dataset<Row> df;
	private static SparkSession spark;
	private static final Logger logger = LoggerFactory.getLogger(SparkSessionAndFileReadTest.class);
 
	@BeforeClass
	public static void beforeClass() {
		spark = SparkSession.builder().master("local[*]").config(new SparkConf().set("fs.defaultFS", "file:///"))
				.appName(SparkSessionAndFileReadTest.class.getName()).getOrCreate();
		df = spark.read().parquet("src/test/resources/tests/part-r-00001-myfile.snappy.parquet");
		logger.info("Created spark context and dataset with {} rows.", df.count());
	}
 
	@Test
	public void buildClippingTransformerTest() {
		logger.info("Testing the spark sorting function.");
		Dataset<Row> sorted = df.sort("id");
		assertThat(sorted.count()).isEqualTo(df.count());
	}
 
	@AfterClass
	public static void afterClass() {
		if (spark != null) {
			spark.stop();
		}
	}
}

Leave a Reply

Your email address will not be published. Required fields are marked *

*