1

I'm trying to write a test to one of my applications, which uses Spark Structured Streaming to write incoming data to an HBase table. The idea is that I mocked an HBaseClient instance, that creates a BufferedMutator. The TestingBufferedMutator doesn't do any inserting, it should just keep the appends it receives and saves them in a list, which can be accessed later on by the test. My problem is, the execution never reaches the content of forEachBatch{}, from what I understand because it didn't parse the incoming from the MemoryStream, thus even when the trigger is called, no data to be streamed is ready and nothing is being done. I have already checked that the JSON string is parsed correctly (println() statement prints it correctly to console). I also tried adding the empty afterSessionStart() method for overriding, because I thought the issue was, that I added the data before the streaming process was set up.

I was trying to implement a modified version of the code in this guide: https://medium.com/swlh/unit-testing-apache-spark-structured-streaming-using-memorystream-8e77e97c5f5d

My question would be how the MemoryStream exactly works and what am I doing incorrectly while trying to write the JSON data to the stream? Any help would be appreciated!

This is the test code:

class StreamingLoaderOutputSpec extends FunSpec with MockFactory with BeforeAndAfter {

  private val mockHBaseClient = mock[HBaseClient]
  private val testingBufferedMutator = new TestingBufferedMutator()

  private val streamingConfig = StreamingConfig(
    hbaseTableName = "sampleTable",
    sourceKafkaTopics = List(),
    checkpointDirectory="checkpointDir",
    batchInterval = 1,
    batchIntervalTimeUnit = TimeUnit.SECONDS,
    kerberosPrincipal = "kerberosPrincipal",
    kerberosKeyTabFile = "kerberosKeyTabFile",
    dryRun = false
  )
  private var jsonString: String = ""

  before {

    val fileSrc = scala.io.Source.fromFile("src/test/resources/sample_kafka_msg.json")
    jsonString = fileSrc.mkString
    fileSrc.close()

    (mockHBaseClient.getSerializableBufferedMutator _).expects(*).onCall((tableName: String) => {
      testingBufferedMutator
    }).once()
  }

  describe("StreamingLoader") {

    it("should create appends with correct content"){

      val spark: SparkSession = SparkSession.builder()
        .appName("StreamingLoader Test")
        .master("local[1]")
        .getOrCreate()

      implicit val sqlCtx: SQLContext = spark.sqlContext
      import spark.implicits._

      val events: MemoryStream[String] = MemoryStream[String]
      val sessions = events.toDS
      val sessionsDF = sessions.toDF

      assert(sessions.isStreaming, "sessions must be a streaming Dataset")

      println(s"Test JSON content: $jsonString")

      val cb = new CyclicBarrier(1, new Runnable {
        override def run(): Unit = {
          println("Adding dataSession has been started")
          val currentOffset = events.addData(jsonString)
          events.commit(currentOffset.asInstanceOf[LongOffset])
        }
      })

      val loader = new StreamingLoader(
        spark = spark, hbaseClient = mockHBaseClient, streamedDataFrame = sessionsDF, config = streamingConfig) {
        override def afterSessionStart(): Unit = {
          cb.await()
        }
      }

      assert(testingBufferedMutator.getTestList.nonEmpty)

      testingBufferedMutator.getTestList.foreach(println(_))

    }

  }

}

This is the application I'm trying to test:

class StreamingLoader(spark: SparkSession,
                      hbaseClient: HBaseClient,
                      streamedDataFrame: DataFrame,
                      config: StreamingConfig) extends Serializable {

  def execute(): Unit = {

    val labelSerializer = new LabelSerializer()

    val bufferedMutator = hbaseClient.getSerializableBufferedMutator(config.hbaseTableName)

    try {
      val query = streamedDataFrame
        .writeStream
        .outputMode(OutputMode.Append())
        .option("checkpointLocation", config.checkpointDirectory)
        .trigger(Trigger.ProcessingTime(config.batchInterval, config.batchIntervalTimeUnit))
        .foreachBatch { (receivedDataFrame: DataFrame, _: Long) =>

          val dataToWrite = Mapper(receivedDataFrame)

          dataToWrite.foreach(row => {

            val result = labelSerializer.process(row)

            if (result.isDefined) {
              val (rowKey, columnNameBytes, payload) = result.get

              // Append instead of Put HBaseClient Object
              val append = new Append(rowKey)
              append.addColumn(TargetColumnFamily, columnNameBytes, payload)

              try {
                bufferedMutator.mutate(append)
              } catch {
                case exception: Exception =>
                  // Swallow the exception and resume processing
                  logger.error("BufferedMutator.mutate raised Exception", exception)
                case throwable: Throwable =>
                  // Throwables are re-thrown, causing the application to fail
                  logger.error("BufferedMutator.mutate failed with Throwable", throwable)
                  throw throwable
              }
            }
          })

          try {
            bufferedMutator.flush()
          } catch {
            case exception: Exception =>
              // Swallow the exception and resume processing
              logger.error("BufferedMutator.flush raised Exception", exception)
            case throwable: Throwable =>
              // Throwables are re-thrown, causing the application to fail
              logger.error("BufferedMutator.flush failed with Throwable", throwable)
              throw throwable
          }
        }
        .start

      logger.info("Session has been started")
      afterSessionStart()

      query.awaitTermination()
      closeQuietly(bufferedMutator)

    } catch {
      case throwable: Throwable =>
        closeQuietly(bufferedMutator)

        throw throwable
    }
  }

  def afterSessionStart(): Unit = {

  }

  private def closeQuietly(resource: Closeable): Unit = {
    try {
      resource.close()
    } catch {
      case exception: Throwable => exception.printStackTrace()
    }
  }
}

EDIT/UPDATE:

I have realized, that I was missing the processAllAvailable() call from my code, when compared to the guide I linked below. I changed the method call for afterSessionStart() to get the streaming query as a parameter and moved the awaitTermination() statement in there:

def afterSessionStart(query: StreamingQuery): Unit = {
  query.awaitTermination()
}

I have also modified the test code, so that it calls the method I missed:

val loader = new StreamingLoader(
    spark = spark, hbaseClient = mockHBaseClient, streamedDataFrame = sessionsDF, config = streamingConfig) {
    override def afterSessionStart(query: StreamingQuery): Unit = {
        println("Adding dataSession has been started")
        val currentOffset = memoryStream.addData(jsonString)
        query.processAllAvailable()
        memoryStream.commit(currentOffset.asInstanceOf[LongOffset])
    }
}

Unfortunately this didn't solve my problem yet, but this refactoring is essential if someone wants to write a test to a different file, which only uses the transforming streamer class, that is being tested.

Also another thread I'm trying to use to fix my code: How to perform Unit testing on Spark Structured Streaming?

0 Answers0