0

I have a dataframe as a result of validation codes:

df=\
(['c_1',  'c_1',  'c_1',    'c_2',    'c_3',    'c_1',  'c_2',  'c_2'],\
['valid','valid', 'invalid','missing','invalid','valid','valid', 'missing'],\
['missing','valid','invalid','invalid','valid', 'valid','missing','missing'],\ 
['invalid','valid','valid', 'missing', 'missing','valid','invalid','missing'])\
.toDF('clinic_id','name','phone','city')

I counted the number of valids, invalids, and missing using aggregated code grouped by clinic_id in pyspark

agg_table = (
  df
        .groupBy('clinic_id') 
        .agg(
          # name
             sum(when(col('name') == 'valid',1).otherwise(0)).alias('validname')
             ,sum(when(col('name') == 'invalid',1).otherwise(0)).alias('invalidname')
             ,sum(when(col('name') == 'missing',1).otherwise(0)).alias('missingname')
          # phone
             ,sum(when(col('phone') == 'valid',1).otherwise(0)).alias('validphone')
             ,sum(when(col('phone') == 'invalid',1).otherwise(0)).alias('invalidphone')
             ,sum(when(col('phone') == 'missing',1).otherwise(0)).alias('missingphone')
          # city
             ,sum(when(col('city') == 'valid',1).otherwise(0)).alias('validcity')
             ,sum(when(col('city') == 'invalid',1).otherwise(0)).alias('invalidcity')
             ,sum(when(col('city') == 'missing',1).otherwise(0)).alias('missingcity')
         ))
display(agg_table)

output:
clinic_id   validname  invalidname  missingname ... invalidcity  missingcity
---------   ---------  -----------  ----------- ... -----------  -----------
c_1         3          1            0           ...  1           0
c_2         1          0            2           ...  1           0
c_3         0          1            0           ...  0           1

the resulting aggregated table is just fine, but is not ideal for further analysis. I tried the pivoting within pyspark trying to get something below:

#note: counts below are just made up, not the actual count from above, but I hope you get what I mean.

clinic_id  category name   phone   city
--------   -------  ----  -------  ----
c_1        valid    3     1         3
c_1        invalid  1     0         2
c_1        missing  0     2         3
c_2        valid    3     1         3
c_2        invalid  1     0         2
c_2        missing  0     2         3
c_3        valid    3     1         3
c_3        invalid  1     0         2
c_3        missing  0     2         3

I initially searched pivot/unpivot, but I learned it is called unstack in pyspark and I also came across mapping.

I tried the suggested approach in How to unstack dataset (using pivot)? but it is showing me only one column and I cannot get the desired result when I try applying it to my dataframe of 30 columns.

I also tried the following using the validated table/dataframe

expression = ""
cnt=0
for column in agg_table.columns:
    if column!='clinc_id':
        cnt +=1
        expression += f"'{column}' , {column},"
exprs = f"stack({cnt}, {expression[:-1]}) as (Type,Value)"

unpivoted = agg_table.select('clinic_id',expr(exprs))

I get an error just pointing to the line that may be referring to a return value.

I also tried grouping the results by id and the category but that is where I am stuck at finding solution. If I group by an aggregated variable, say the values of the validname, the agggregated function only counts the values in that column and would not apply to every count columns. So I thought of inserting a column using .withColumn function assigning the three categories to each ID so that each aggregated counts will be grouped by id and category as in the prior table, but I am not feeling lucky in finding solution to this.

Also, maybe a sql approach will be easier?

budding pro
  • 135
  • 1
  • 2
  • 10

1 Answers1

0

I found the right search phrase: "column to row in pyspark" One of the suggested answer that fit my dataframe is this function:

def to_long(df, by):

    # Filter dtypes and split into column names and type description
    cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))
    # Spark SQL supports only homogeneous columns
    assert len(set(dtypes)) == 1, "All columns have to be of the same type"

    # Create and explode an array of (column_name, column_value) structs
    kvs = explode(array([
      struct(lit(c).alias("key"), col(c).alias("val")) for c in cols
    ])).alias("kvs")

    return df.select(by + [kvs]).select(by + ["kvs.key", "kvs.val"])

to_long(df, ["clinic_id"])

This created a dataframe of three columns: clinic_id, column_names, status (valid, invalid, missing)

Then I created my aggregated table grouped by clinic_id, status:

display(long_df.groupBy('Clinic_id','Status') 
        .agg(
          sum(when(col('column_names') == 'name',1).otherwise(0)).alias('name')
          ,sum(when(col('column_names') == 'phone',1).otherwise(0)).alias('phone')
          ,sum(when(col('column_names') == 'city',1).otherwise(0)).alias('city')
           ).show

I got my intended table.

budding pro
  • 135
  • 1
  • 2
  • 10