0

I would like to search my text column in a pyspark data frame for phrases. Here is an example to show you what I mean.

sentenceData = spark.createDataFrame([
(0, "Hi I heard about Spark"),
(4, "I wish Java could use case classes"),
(11, "Logistic regression models are neat")], 
["id", "sentence"])

If the sentence contains "heard about spark" then categorySpark=1 and categoryHeard=1.

If the sentence contains "java OR regression" then categoryCool=1.

I have about 28 booleans (or maybe better if I use regex) to check for.

sentenceData.withColumn('categoryCool',sentenceData['sentence'].rlike('Java | regression')).show()

returns:

+---+--------------------+------------+
| id|            sentence|categoryCool|
+---+--------------------+------------+
|  0|Hi I heard about ...|       false|
|  4|I wish Java could...|        true|
| 11|Logistic regressi...|        true|
+---+--------------------+------------+

This is what I want, but I'd like to add it to a pipeline as a transformation step.

Climbs_lika_Spyder
  • 6,004
  • 3
  • 39
  • 53

1 Answers1

0

I found this nice Medium article and this S.O. answer which I combined to answer my own question! I hope someone finds this helpful someday.

    from pyspark.ml.pipeline import Transformer
    from pyspark.ml import Pipeline
    from pyspark.sql.types import *
    from pyspark.ml.util import Identifiable
    
    sentenceData = spark.createDataFrame([
        (0, "Hi I heard about Spark"),
        (4, "I wish Java could use case classes"),
        (11, "Logistic regression models are neat")
    ], ["id", "sentence"])
    
    class OneSearchMultiLabelExtractor(Transformer):
        def __init__(self, rlikeSearch, outputCols, inputCol = 'fullText'):
            self.inputCol = inputCol
            self.outputCols = outputCols
            self.rlikeSearch = rlikeSearch
            self.uid = str(Identifiable())
        def copy(extra):
            defaultCopy(extra)
        def check_input_type(self, schema):
            field = schema[self.inputCol]
            if (field.dataType != StringType()):
                raise Exception('OneSearchMultiLabelExtractor input type %s did not match input type StringType' % field.dataType)
        def check_output_type(self):
            if not (isinstance(self.outputCols,list)):
                raise Exception('OneSearchMultiLabelExtractor output columns must be a list')
        def _transform(self, df):
            self.check_input_type(df.schema)
            self.check_output_type()
            df = df.withColumn("searchResult", df[self.inputCol].rlike(self.rlikeSearch)).cache()
            for outputCol in self.outputCols:
                df = df.withColumn(outputCol, df["searchResult"])
            return df.drop("searchResult")
            
    dex = CoolExtractor(inputCol='sentence',rlikeSearch='Java | regression',outputCols=['coolCategory'])
    FeaturesPipeline =  Pipeline(stages=[dex])
    Featpip = FeaturesPipeline.fit(sentenceData)
    Featpip.transform(sentenceData).show()
Climbs_lika_Spyder
  • 6,004
  • 3
  • 39
  • 53