3

I am trying to use UDTF in snowpark but not able to do partition by column. what I want the sql query is something like this :

select mcount.result from CUSTOMER, table(map_count(name) over (partition by name)) mcount;

Here "map_count" is my JavaScript UDTF. Below is the code snippet in Snowpark :

val session = Session.builder.configs(configs).create
val df = session.table("CUSTOMER")
val window = Window.partitionBy(col("name"))
val result = df.join(TableFunction("map_count"), col("name"))
//result.show()

Any suggestion how to use window partition by with table function? Is this even supported in snowpark?

Raphael Roth
  • 26,751
  • 15
  • 88
  • 145
Sella
  • 31
  • 6
  • If you are trying to create a UDTF inside of the Snowflake database, then it does not have the "Session" and "Window" objects at all. If this is code outside of Snowflake called through Snowpark it could be different, but since you mention UDTF it appears that you'll be limited to the available objects in the JavaScript engine running for UDTFs in Snowflake. – Greg Pavlik Oct 13 '21 at 19:46
  • I already have UDTF inside the Snowflake database. I want to use that UDTF in Snowpark with dataframes, which I could able to do using df.join(TableFunction("map_count"), col("name")) But this generates snowflake query as : SELECT * FROM ( SELECT * FROM ( SELECT * FROM (CUSTOMER)) JOIN TABLE (map_count("NAME"))) LIMIT 10; where as I want the query : SELECT * FROM ( SELECT * FROM ( SELECT * FROM (CUSTOMER)) JOIN TABLE (map_count("NAME") over (partition by NAME) )) LIMIT 10 I couldn't find a way to use window inside join with TableFunction in Snowpark. – Sella Oct 14 '21 at 04:48
  • it is possible now, see example below – Raphael Roth Aug 25 '22 at 06:30

3 Answers3

2

Unfortunately, this is not currently supported in Snowpark. But we are working on it.

Khush Bhatia
  • 498
  • 1
  • 4
  • 9
2

Today (version Python 0.8.0) it works as follows (example is calculating the median of a group/partition), i.e. acts as an UDAF:

from statistics import median
from snowflake.snowpark.types import *

class MyMedian:
    values = []

    def __init__(self):
        self.values = []

    def process(self, value: float):
        self.values.append(value)
        #no return value
        for _ in range(0):
            yield

    def end_partition(self):
       yield ("partition_summary",median(self.values))

output_schema = StructType([
    StructField("label", StringType()),
    StructField("median", FloatType())
])

my_median = udtf(
    MyMedian,
    output_schema=output_schema,
    input_types=[FloatType()]
)

example_df = session.create_dataframe(
    [["A", 2.0],
     ["A", 2.0],
     ["A", 4.0],
     ["B", -1.0],
     ["B", 0.0],
     ["B", 1.0]],
    StructType([
        StructField("Key", StringType()),
        StructField("Value", FloatType())
    ])
)
example_df.show()

-------------------
|"KEY"  |"VALUE"  |
-------------------
|A      |2.0      |
|A      |2.0      |
|A      |4.0      |
|B      |-1.0     |
|B      |0.0      |
|B      |1.0      |
-------------------

Now the usage uf my_median:

example_df.join_table_function(my_median("VALUE").over(partition_by=col("KEY")))\
    .show()

------------------------------------------------
|"KEY"  |"VALUE"  |"LABEL"          |"MEDIAN"  |
------------------------------------------------
|A      |NULL     |partition_total  |2.0       |
|B      |NULL     |partition_total  |0.0       |
------------------------------------------------
Raphael Roth
  • 26,751
  • 15
  • 88
  • 145
  • Do you know how to avoid the null `"VALUE"` column? Should that be manually dropped afterwards? – Pietro Jan 17 '23 at 13:44
0

I think that for now the workaround will be to use sql to do the invocation, like in the example below.

I created a dummy customer table and a dummy Javascript Table UDF.

And then I invoked it using SQL.

Obviously when the DF API is ready this will be unnecessary and the DataFrame API is cleaner.

import com.snowflake.snowpark.functions._
session.sql("ALTER SESSION SET QUERY_TAG='TEST_1'")
session.sql(""" 
CREATE OR REPLACE FUNCTION MAP_COUNT(NAME STRING) RETURNS TABLE (NUM FLOAT)
LANGUAGE JAVASCRIPT AS 
$$
    {
      processRow: function (row, rowWriter, context) {
        this.ccount = this.ccount + 1;
      },
      finalize: function (rowWriter, context) {
       rowWriter.writeRow({NUM: this.ccount});
      },
      initialize: function(argumentInfo, context) {
       this.ccount = 0;
      }
    }
$$;
""").show()

session.sql("""
CREATE OR REPLACE TABLE CUSTOMER (
CUST_ID INTEGER,
CUST_NAME TEXT
)
""").show()
session.sql("INSERT INTO CUSTOMER  SELECT 1, 'John'").show()
session.sql("INSERT INTO CUSTOMER  SELECT 2, 'John'").show()
session.sql("INSERT INTO CUSTOMER  SELECT 3, 'John'").show()
session.sql("INSERT INTO CUSTOMER  SELECT 4, 'Mary'").show() 
session.sql("INSERT INTO CUSTOMER  SELECT 5, 'Mary'").show()  
import com.snowflake.snowpark.functions._
val df = session.table("CUSTOMER")
val window = Window.partitionBy(col("CUST_NAME"))
val res = session.sql("select CUST_NAME,NUM FROM CUSTOMER, TABLE(MAP_COUNT(CUST_NAME) OVER (PARTITION BY CUST_NAME ORDER BY CUST_NAME))")
res.show()
// Output will be
//-----------------------
//|"CUST_NAME"  |"NUM"  |
//-----------------------
//|Mary         |2.0    |
//|John         |3.0    |
orellabac
  • 2,077
  • 2
  • 26
  • 34