171

I work on a dataframe with two column, mvv and count.

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

i would like to obtain two list containing mvv values and count value. Something like

mvv = [1,2,3,4]
count = [5,9,3,1]

So, I tried the following code: The first line should return a python list of row. I wanted to see the first value:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

But I get an error message with the second line:

AttributeError: getInt

nha
  • 17,623
  • 13
  • 87
  • 133
a.moussa
  • 2,977
  • 7
  • 34
  • 56
  • 2
    As of Spark 2.3, this code is the fastest and least likely to cause OutOfMemory exceptions: `list(df.select('mvv').toPandas()['mvv'])`. [Arrow was integrated into PySpark](https://arrow.apache.org/blog/2017/07/26/spark-arrow/) which sped up `toPandas` significantly. Don't use the other approaches if you're using Spark 2.3+. See my answer for more benchmarking details. – Powers Jul 28 '20 at 14:57

11 Answers11

220

See, why this way that you are doing is not working. First, you are trying to get integer from a Row Type, the output of your collect is like this:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

If you take something like this:

>>> firstvalue = mvv_list[0].mvv
Out: 1

You will get the mvv value. If you want all the information of the array you can take something like this:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

But if you try the same for the other column, you get:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

This happens because count is a built-in method. And the column has the same name as count. A workaround to do this is change the column name of count to _count:

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

But this workaround is not needed, as you can access the column using the dictionary syntax:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

And it will finally work!

roschach
  • 8,390
  • 14
  • 74
  • 124
Thiago Baldim
  • 7,362
  • 3
  • 29
  • 51
  • it works great for the first column, but it does not work for the column count i think because of (the function count of spark) – a.moussa Jul 27 '16 at 12:16
  • Can you add what are you doing with the count? Add here in the comments. – Thiago Baldim Jul 27 '16 at 12:19
  • thanks for your response So this line work mvv_list = [int(i.mvv) for i in mvv_count.select('mvv').collect()] but not this one count_list = [int(i.count) for i in mvv_count.select('count').collect()] return invalid syntax – a.moussa Jul 27 '16 at 12:19
  • Don't need to add this `select('count')` use like this: `count_list = [int(i.count) for i in mvv_list.collect()]` I will add the example to the response. – Thiago Baldim Jul 27 '16 at 12:28
  • this line count_list = [int(i.count) for i in mvv_list.collect()] also return: int() argument must be a string or a number, not 'builtin_function_or_method' – a.moussa Jul 27 '16 at 12:34
  • How do you instanciate your `mvv_list`? – Thiago Baldim Jul 27 '16 at 12:39
  • mvv_list result of many transformation of an initial dataframe. The tab print in my question result of mvv_list.show(). – a.moussa Jul 27 '16 at 12:52
  • #Thiago Baldim i found a solution but it not elegant count_list = [i[1] for i in mvv_list.collect()] it work because i know that count at index = 2. Do youknow a solution which allow you to precise the column name 'count' despite of the index ? – a.moussa Jul 27 '16 at 13:02
  • It should work with i.count. Really strange... Do something... Add here in the comments the `mvv_list.collect()` that I will be able to understande the issue. – Thiago Baldim Jul 27 '16 at 13:05
  • here is the line where i instantiate mvv_list: mvv_list = mvv_logs.groupBy('mvv').count() – a.moussa Jul 27 '16 at 13:16
  • I added a solution above in the answer. – Thiago Baldim Jul 27 '16 at 16:49
  • 1
    @a.moussa `[i.['count'] for i in mvv_list.collect()]` works to make it explicit to use the column named 'count' and not the `count` function – user989762 Aug 28 '18 at 10:21
  • @ThiagoBaldim I want to access column value where column is month like 202111 which is variable and whose name is coming from another list. I am trying something like 'list[0].i.Month' which is not working. So i.Month is inside for loop whose value is changing in each iteration like 202111, 202112 and so on. can you please help me how to access it – Pardeep Naik Nov 11 '21 at 15:37
180

Following one liner gives the list you want.

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()
Neo
  • 4,200
  • 5
  • 21
  • 27
  • 7
    Performance wise this solution is much faster than your solution mvv_list = [int(i.mvv) for i in mvv_count.select('mvv').collect()] – Chanaka Fernando Dec 21 '18 at 19:29
  • 2
    Wouldn't this just work for OP's question?: mvv = mvv_count_df.select("mvv").rdd.flatMap(list).collect() – eemilk Nov 05 '20 at 10:57
43

This will give you all the elements as a list.

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)
Muhammad Raihan Muhaimin
  • 5,559
  • 7
  • 47
  • 68
  • 4
    This is the fastest and most efficient solution for Spark 2.3+. See the benchmarking results in my answer. – Powers Jul 28 '20 at 14:59
38

I ran a benchmarking analysis and list(mvv_count_df.select('mvv').toPandas()['mvv']) is the fastest method. I'm very surprised.

I ran the different approaches on 100 thousand / 100 million row datasets using a 5 node i3.xlarge cluster (each node has 30.5 GBs of RAM and 4 cores) with Spark 2.4.5. Data was evenly distributed on 20 snappy compressed Parquet files with a single column.

Here's the benchmarking results (runtimes in seconds):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

Golden rules to follow when collecting data on the driver node:

  • Try to solve the problem with other approaches. Collecting data to the driver node is expensive, doesn't harness the power of the Spark cluster, and should be avoided whenever possible.
  • Collect as few rows as possible. Aggregate, deduplicate, filter, and prune columns before collecting the data. Send as little data to the driver node as you can.

toPandas was significantly improved in Spark 2.3. It's probably not the best approach if you're using a Spark version earlier than 2.3.

See here for more details / benchmarking results.

Powers
  • 18,150
  • 10
  • 103
  • 108
  • 4
    This really is surprising since I would imagined `toPandas` to perform one of the worst since we are doing an additional data structure transformation. Spark team must have really done good job with optimization. Thanks for the benchmark! – THIS USER NEEDS HELP Feb 17 '22 at 21:53
  • 1
    Could you also test the @phgui answer? It also looks quite efficient. `mvv_list = df.select(collect_list("mvv")).collect()[0][0]` – Bohdan Pylypenko Jun 28 '22 at 20:09
23

On my data I got these benchmarks:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0.52 sec

>>> [row[col] for row in data.collect()]

0.271 sec

>>> list(data.select(col).toPandas()[col])

0.427 sec

The result is the same

luminousmen
  • 1,971
  • 1
  • 18
  • 24
  • 2
    If you use `toLocalIterator` instead of `collect` it should even be more memory efficient `[row[col] for row in data.toLocalIterator()]` – oglop May 19 '20 at 07:37
21

The following code will help you

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()
Itachi
  • 2,817
  • 27
  • 35
  • 3
    This should be the accepted answer. the reason is that you are staying in a spark context throughout the process and then you collect at the end as opposed to getting out of the spark context earlier which may cause a larger collect depending on what you are doing. – AntiPawn79 Jan 18 '19 at 17:30
8

A possible solution is using the collect_list() function from pyspark.sql.functions. This will aggregate all column values into a pyspark array that is converted into a python list when collected:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 
phgui
  • 81
  • 1
  • 2
6

If you get the error below :

AttributeError: 'list' object has no attribute 'collect'

This code will solve your issues :

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]
LaSul
  • 2,231
  • 1
  • 20
  • 36
anirban sen
  • 61
  • 1
  • 2
  • I got that error too and this solution solved the problem. But why did I get the error? (Many others don't seem to get that!) – Bikash Gyawali May 01 '19 at 12:23
6

Let's create the dataframe in question

df_test = spark.createDataFrame(
    [
        (1, 5),
        (2, 9),
        (3, 3),
        (4, 1),
    ],
    ['mvv', 'count']
)
df_test.show()

Which gives

+---+-----+
|mvv|count|
+---+-----+
|  1|    5|
|  2|    9|
|  3|    3|
|  4|    1|
+---+-----+

and then apply rdd.flatMap(f).collect() to get the list

test_list = df_test.select("mvv").rdd.flatMap(list).collect()
print(type(test_list))
print(test_list)

which gives

<type 'list'>
[1, 2, 3, 4]
eemilk
  • 1,375
  • 13
  • 17
6

you can first collect the df with will return list of Row type

row_list = df.select('mvv').collect()

iterate over row to convert to list

sno_id_array = [ int(row.mvv) for row in row_list]

sno_id_array 
[1,2,3,4]

using flatmap

sno_id_array = df.select("mvv").rdd.flatMap(lambda x: x).collect()
Strick
  • 1,512
  • 9
  • 15
4

Despite many answeres, some of them wont work when you need a list to be used in combination with when and isin commands. The simplest yet effective approach resulting a flat list of values is by using list comprehension and [0] to avoid row names:

flatten_list_from_spark_df=[i[0] for i in df.select("your column").collect()]

The other approach is to use panda data frame and then use the list function but it is not convenient and as effective as this.a

ashkan
  • 59
  • 5