Uber’s Petastorm library provides a data reader for PyTorch that reads files generated by PySpark. Clone the project from Github for more information.
As a data scientist, I spend much time wrangling data than making models, and the data scale from hundreds of millions to billions of pieces of records. Spark has been of great use to me due to its capability to process big data. Usually, I run all the dirty jobs with Spark and generated the-ready-to-use-files for the downstream model training process.
However, there is a gap between Spark and PyTorch, which is the data reader. As Spark runs in parallel, it writes multiple partitions of files by default. It’s painful that PyTorch doesn’t provide a data loader with the support of multiple files out of the box. While PyTorch gives a proper level of customisation, writing a high-efficiency data loader is not easy. Then, I found Petastorm.
Petastorm is an open-source data access library developed by Uber that provides more than a data loader. It supports multiple machine learning frameworks, such as TensorFlow, PyTorch, and PySpark. To install Petastorm, run
pip install petastorm
Check out their repository for more details.
In this post, I would like to demonstrate how I process data with PySpark, train a model with PyTorch and fill in the gap in between with Petastorm.
We use the famous Iris flower data set as part of the demonstration. We load the CSV file with PySpark and create two new columns: the
features column by assembling all the feature vectors and the
label column by changing the class string to an integer.
def transform_iris_data(data: DataFrame):
After that, we split the data frame into the train set and test set.
Finally, we save the two data frames in parquet format.
We build a simple three-layer network for this easy training task.
Now, it’s time to build the data reader. Building a data reader is simple; what we have to do is to call
make_batch_reader, passing in the data path, batch size, and the number of the epoch.
Keep in mind that the
dataset_url must consist of the URL scheme. For example, if the data path is
/some/localpath/a_dataset, you should pass in
file:///some/localpath/a_dataset. Petastorm also supports
s3://, but I haven’t tried them yet.
The final step is to train our model.
Using a Petastorm data loader is just like using a PyTorch data loader: wrap it in a for loop.
Remember that the data frame previously generated by PySpark consists of only two columns? We can retrieve the values by writing
data['label']. If everything goes smoothly, you get a model with the following test performance:
precision recall f1-score support
Petastorm is a reliable and helpful library. Be sure to check out their latest features!