原文:https://animeshtrivedi.github.io/spark-parquet-reading
Spark 如何讀取Parquet文件
Apache Parquet 是一種流行的列式存儲格式,它把數據存儲爲一堆文件。
Spark讀取parquet依賴以下API:
val parquetFileDF = spark.read.parquet("test.parquet")
test.parquet文件格式爲<int, Array[Byte]>。
關鍵對象
在 Spark SQL 中,各種操作都在各自的類中實現,其名稱都以Exec作爲後綴。
1.DataSourceScanExec類掌管的是對數據源的讀取。讀取Parquet文件的相關代碼從這裏開始,在ParquetFileFormat類中結束。
2.ParquetFileFormat中有一個buildReader函數,返回一個(PartitionedFile => Iterator[InternalRow])。此函數中生成了一個迭代器:
val iter = new RecordReaderIterator(parquetReader)
這裏parquetReader是一個VectorizedParquetRecordReader。RecordReaderIterator包裝了一個scala迭代器,以Hadoop RecordReader<K,V>風格。它由 VectorizedParquetRecordReader(及其基類 SpecificParquetRecordReaderBase<Object>)實現。
- VectorizedParquetRecordReader做了什麼?根據文件中的comment:一個專門的RecordReader,直接使用Parquet column API 讀入InternalRows或ColumnarBatches,基於parquet-mr的ColumnReader。VectorizedParquetRecordReader 對象分配後,調用initialize(split, hadoopAttemptContext)函數和initBatch(partitionSchema, file.partitionValues)函數。
- initialize調用父類SpecificParquetRecordReaderBase的initialize函數。在這個函數中,會讀取文件schema,推斷請求的schema,並且實例化一個ParquetFileReader的讀取器。在initialize結束時,我們知道讀取的InputFileSplit中有多少行,這存儲在totalRawCount變量中。
- initBatch主要工作是分配columnarBatch對象,後面會詳細討論。
4.VectorizedParquetRecordReader 中 RecordReader 接口的實現需要多加關注,它在使用步驟 2 中的迭代器時調用的是什麼?在調用nextKeyValue()時,該函數首先調用了resultBatch(),然後調用nextBatch()。請記住,我們總是在Batch Mode下操作(returnColumnarBatch 設置爲 true),nextBatch用數據填充columnarBatch,且這個變量會在getCurrentValue函數中返回。getCurrentKey 在 SpecificParquetRecordReaderBase 的基類中實現,且始終返回null。
現在,我們知道了迭代器中返回了什麼變量。從這開始有兩個方向,首先我們描述ColumnarBatch是怎麼被Parquet數據填充。然後我們描述誰使用了步驟2中生成的iter迭代器。
ColumnarBatch 如何被填充?
在 VectorizedParquetRecordReader.nextBatch() 函數中,如果尚未讀取所有行,則調用 checkEndOfRowGroup() 函數。然後,checkEndOfRowGroup 函數讀取一個rowGroup(可以將rowGroup視爲以列格式存儲的一定數量行的集合),然後爲requestedSchema 中的每個請求列分配一個VectorizedColumnReader 對象。VectorizedColumnReader 構造函數接受一個 ColumnDescriptor(可以在schema中找到)和一個 PageReader(可以從 rowGroup 中找到,一個 Parquet API 調用)。
另外,missingColumns是確實列的一個bitmap(可能是缺失的列或 Spark 不打算讀取的列)。然後,在nextBatch中調用readBatch(num, columnarBatch.column(i)),會在之前checkEndOfRowGroup(基本上是每列)函數分配的所有VectorizedColumnReader對象上調用。(因此,ColumnarBatch 和 ColumnVector 只是 VectorizedColumnReader 使用的原始內存)。所以在 readBatch 中,傳遞了行數和 ColumnVector(存儲在 ColumnarBatch 中)。什麼是ColumnVector?我們可以將其視爲一個類型數組,由 rowId 索引。
/**
* An interface representing in-memory columnar data in Spark. This interface defines the main APIs
* to access the data, as well as their batched versions. The batched versions are considered to be
* faster and preferable whenever possible.
*
* Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values
* in this ColumnVector.
*
* Spark only calls specific `get` method according to the data type of this {@link ColumnVector},
* e.g. if it's int type, Spark is guaranteed to only call {@link #getInt(int)} or
* {@link #getInts(int, int)}.
*
* ColumnVector supports all the data types including nested types. To handle nested types,
* ColumnVector can have children and is a tree structure. Please refer to {@link #getStruct(int)},
* {@link #getArray(int)} and {@link #getMap(int)} for the details about how to implement nested
* types.
*
* ColumnVector is expected to be reused during the entire data loading process, to avoid allocating
* memory again and again.
*
* ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint.
* Implementations should prefer computing efficiency over storage efficiency when design the
* format. Since it is expected to reuse the ColumnVector instance while loading data, the storage
* footprint is negligible.
*/
@Evolving
public abstract class ColumnVector implements AutoCloseable {
總之,原始數據存儲在 ColumnVector 中,ColumnVector 本身存儲在 ColumnBatch 對象中。ColumnVector 是在 readBatch 函數中作爲存儲空間傳遞的。 在 readBatch 函數內部,它首先調用 readPage() 函數,該函數查看我們正在讀取哪個版本的 parquet 文件(v1 或 v2,我不知道區別),然後初始化一堆對象,即 defColumn: VectorizedRleValuesReader、replicationLevelColumn:ValuesReaderIntIterator、definitionLevelColumn:ValuesReaderIntIterator 和 dataColumn:VectorizedRleValuesReader。這些變量中的 ValuesReaderIntIterator 來自 parquet-mr,而 VectorizedRleValuesReader 來自 Spark。接下來,有一堆 read[Type]Batch() 函數被調用,這些函數又調用 defColumn.read[Type]s() 函數。 (這裏的 [Type] 是一些類型,如 Int、Short、Binary 等)。 在 VectorizedRleValuesReader 上的這些函數中,數據被讀取、解碼(可能來自 RLE),然後插入到此處傳遞的 ColumnVector 中。
Scala[ColumnBatch] 迭代器在哪裏被消費?
迭代器根據 reader 是否處於批處理模式返回兩種不同的類型,code如下:
@Override
public Object getCurrentValue() {
if (returnColumnarBatch) return columnarBatch;
return columnarBatch.getRow(batchIdx - 1);
}
其中,columnarBatch的類型是ColumnarBatch,columnarBatch.getRow 返回一個 ColumnarBatch.Row 類型的嵌套類。這個迭代器以某種方式傳遞給wholestage code generation。消費這個迭代器並且實例化UnsafeRow的code示例如下:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator scan_input;
/* 009 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows;
/* 010 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_scanTime;
/* 011 */ private long scan_scanTime1;
/* 012 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch scan_batch;
/* 013 */ private int scan_batchIdx;
/* 014 */ private org.apache.spark.sql.execution.vectorized.ColumnVector scan_colInstance0;
/* 015 */ private org.apache.spark.sql.execution.vectorized.ColumnVector scan_colInstance1;
/* 016 */ private UnsafeRow scan_result;
/* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder scan_holder;
/* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter scan_rowWriter;
/* 019 */
/* 020 */ public GeneratedIterator(Object[] references) {
/* 021 */ this.references = references;
/* 022 */ }
/* 023 */
/* 024 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */ partitionIndex = index;
/* 026 */ this.inputs = inputs;
/* 027 */ scan_input = inputs[0];
/* 028 */ this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 029 */ this.scan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 030 */ scan_scanTime1 = 0;
/* 031 */ scan_batch = null;
/* 032 */ scan_batchIdx = 0;
/* 033 */ scan_colInstance0 = null;
/* 034 */ scan_colInstance1 = null;
/* 035 */ scan_result = new UnsafeRow(2);
/* 036 */ this.scan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(scan_result, 32);
/* 037 */ this.scan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(scan_holder, 2);
/* 038 */
/* 039 */ }
/* 040 */
/* 041 */ private void scan_nextBatch() throws java.io.IOException {
/* 042 */ long getBatchStart = System.nanoTime();
/* 043 */ if (scan_input.hasNext()) {
/* 044 */ scan_batch = (org.apache.spark.sql.execution.vectorized.ColumnarBatch)scan_input.next();
/* 045 */ scan_numOutputRows.add(scan_batch.numRows());
/* 046 */ scan_batchIdx = 0;
/* 047 */ scan_colInstance0 = scan_batch.column(0);
/* 048 */ scan_colInstance1 = scan_batch.column(1);
/* 049 */
/* 050 */ }
/* 051 */ scan_scanTime1 += System.nanoTime() - getBatchStart;
/* 052 */ }
/* 053 */
/* 054 */ protected void processNext() throws java.io.IOException {
/* 055 */ if (scan_batch == null) {
/* 056 */ scan_nextBatch();
/* 057 */ }
/* 058 */ while (scan_batch != null) {
/* 059 */ int numRows = scan_batch.numRows();
/* 060 */ while (scan_batchIdx < numRows) {
/* 061 */ int scan_rowIdx = scan_batchIdx++;
/* 062 */ boolean scan_isNull = scan_colInstance0.isNullAt(scan_rowIdx);
/* 063 */ int scan_value = scan_isNull ? -1 : (scan_colInstance0.getInt(scan_rowIdx));
/* 064 */ boolean scan_isNull1 = scan_colInstance1.isNullAt(scan_rowIdx);
/* 065 */ byte[] scan_value1 = scan_isNull1 ? null : (scan_colInstance1.getBinary(scan_rowIdx));
/* 066 */ scan_holder.reset();
/* 067 */
/* 068 */ scan_rowWriter.zeroOutNullBytes();
/* 069 */
/* 070 */ if (scan_isNull) {
/* 071 */ scan_rowWriter.setNullAt(0);
/* 072 */ } else {
/* 073 */ scan_rowWriter.write(0, scan_value);
/* 074 */ }
/* 075 */
/* 076 */ if (scan_isNull1) {
/* 077 */ scan_rowWriter.setNullAt(1);
/* 078 */ } else {
/* 079 */ scan_rowWriter.write(1, scan_value1);
/* 080 */ }
/* 081 */ scan_result.setTotalSize(scan_holder.totalSize());
/* 082 */ append(scan_result);
/* 083 */ if (shouldStop()) return;
/* 084 */ }
/* 085 */ scan_batch = null;
/* 086 */ scan_nextBatch();
/* 087 */ }
/* 088 */ scan_scanTime.add(scan_scanTime1 / (1000 * 1000));
/* 089 */ scan_scanTime1 = 0;
/* 090 */ }
/* 091 */ }
在scan_nextBatch方法中,我們通過調用next()讀取一個新的ColumnarBatch。然後我們獲取ColumnVectors對象(變量 scan_colInstance0/scan_colInstance1)。通過numRows()方法,我們可以得到ColumnarBatch的行數,通過調用ColumnVector對象的get[Type](rowId: Int)獲取最終的值。
這些值在BufferHolder和UnsafeRowWriter對象的幫助下表示爲UnsafeRow:
/* 035 */ scan_result = new UnsafeRow(2);
/* 036 */ this.scan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(scan_result, 32);
/* 037 */ this.scan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(scan_holder, 2);