Skip to content

How to Convert PySpark DataFrame Column to Python List?

If you are working with PySpark data frames, you may need to extract a column from the data frame and use it as a python list for further analysis. In this article, we will explore various ways to extract a PySpark data frame column into a python list. We will start with the basics of the PySpark data frame, and then we will dive into several methods to extract columns.

Want to quickly create Data Visualization from Python Pandas Dataframe with No code?

PyGWalker is a Python library for Exploratory Data Analysis with Visualization. PyGWalker (opens in a new tab) can simplify your Jupyter Notebook data analysis and data visualization workflow, by turning your pandas dataframe (and polars dataframe) into a tableau-alternative User Interface for visual exploration.

PyGWalker for Data visualization (opens in a new tab)

Introduction to PySpark Data Frame

A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R or Python. PySpark DataFrame is a Resilient Distributed Dataset (RDD) of Rows that has a schema. PySpark SQL provides a programming interface to work with structured data using Spark. PySpark supports most of the common data sources such as CSV, JSON, Avro, Parquet, etc.

To understand better, let's create a simple PySpark data frame and check its schema using PySpark SQL.

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
 
# Create SparkSession
spark = SparkSession.builder.appName("PySpark_Examples").getOrCreate()
 
#define schema
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("name", StringType(), True)])
 
# create data frame
data = [(1,"John"),(2,"Mary"),(3,"Smith"),(4,"James")]
df = spark.createDataFrame(data, schema=schema)
 
# show data frame
df.show()

The output will look like:

+---+-----+
| id| name|
+---+-----+
|  1| John|
|  2| Mary|
|  3|Smith|
|  4|James|
+---+-----+

Extracting a Single Column as a List

There are various ways to extract a column from the PySpark data frame. We will explore a few of them in this section.

Method 1: Using Collect Function

The collect() function in PySpark is used to return all the elements of the RDD (Resilient Distributed Datasets) to the driver program as an array. We can use collect() to convert a PySpark data frame column into a python list. Here's how:

# extract name column using collect()
name_list = df.select('name').rdd.flatMap(lambda x: x).collect()
 
# print the list
print(name_list)

The output will look like:

[u'John', u'Mary', u'Smith', u'James']

Here, we used the select() function to select the "name" column from the data frame. We then used rdd.flatMap(lambda x: x) to convert the column into an RDD and then used collect() function of RDD to get data in the form of a Python list.

Method 2: Using List Comprehension

Another way to extract a column from a PySpark data frame as a python list is to use list comprehension. Here's how:

# extract the name column using list comprehension
name_list = [row.name for row in df.select('name').collect()]
 
# print the list
print(name_list)

The output will look like:

[u'John', u'Mary', u'Smith', u'James']

Here, we used list comprehension to extract the "name" column from the data frame as a python list. We first used select() function to extract the column and then used collect() function to retrieve the data back to the driver.

Method 3: Using toPandas() Function

We can also extract a column from PySpark data frame as a python list using the toPandas() function. However, this method is not recommended for large data frames as it can cause out-of-memory errors. Here's how:

# extract name column using toPandas()
name_list = df.select('name').toPandas()['name'].tolist()
 
# print the list
print(name_list)

The output will look like:

['John', 'Mary', 'Smith', 'James']

Here, we used the select() function to select the "name" column from the data frame and then converted the data frame into a Pandas dataframe using the toPandas() function. Finally, we used the tolist() function to convert the Pandas series to a python list.

Conclusion

In this article, we explored various methods to extract a PySpark data frame column into a python list such as using collect() function, list comprehension, and toPandas() function. PySpark provides an efficient way to process large datasets in parallel by distributing the work across multiple nodes in a cluster. Understanding the PySpark data frame and how to extract data from it is a valuable skill for any data scientist or engineer working with large datasets.

We hope this article was helpful, and you now know how to extract a PySpark dataframe column to a python list. If you want to learn more about PySpark and Pandas, check out our other tutorials.


Links: