Skip to content

Commit cdeae2c

Browse files
committed
GH-39519:[Swift] fix null count when using reader
1 parent 1f42e6d commit cdeae2c

File tree

4 files changed

+96
-47
lines changed

4 files changed

+96
-47
lines changed

swift/Arrow/Sources/Arrow/ArrowReader.swift

+8-4
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,17 @@ public class ArrowReader {
5757
private func loadPrimitiveData(_ loadInfo: DataLoadInfo) -> Result<ArrowArrayHolder, ArrowError> {
5858
do {
5959
let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
60+
let nullLength = UInt(ceil(Double(node.length) / 8))
6061
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex)
6162
let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)!
6263
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
63-
length: UInt(node.nullCount), messageOffset: loadInfo.messageOffset)
64+
length: nullLength, messageOffset: loadInfo.messageOffset)
6465
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1)
6566
let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)!
6667
let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData,
6768
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
68-
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowValueBuffer])
69+
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowValueBuffer],
70+
nullCount: UInt(node.nullCount))
6971
} catch let error as ArrowError {
7072
return .failure(error)
7173
} catch {
@@ -76,10 +78,11 @@ public class ArrowReader {
7678
private func loadVariableData(_ loadInfo: DataLoadInfo) -> Result<ArrowArrayHolder, ArrowError> {
7779
let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
7880
do {
81+
let nullLength = UInt(ceil(Double(node.length) / 8))
7982
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex)
8083
let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)!
8184
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
82-
length: UInt(node.nullCount), messageOffset: loadInfo.messageOffset)
85+
length: nullLength, messageOffset: loadInfo.messageOffset)
8386
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1)
8487
let offsetBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)!
8588
let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: loadInfo.fileData,
@@ -88,7 +91,8 @@ public class ArrowReader {
8891
let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 2)!
8992
let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData,
9093
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
91-
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer])
94+
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer],
95+
nullCount: UInt(node.nullCount))
9296
} catch let error as ArrowError {
9397
return .failure(error)
9498
} catch {

swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift

+49-33
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
import FlatBuffers
1919
import Foundation
2020

21-
private func makeBinaryHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
21+
private func makeBinaryHolder(_ buffers: [ArrowBuffer],
22+
nullCount: UInt) -> Result<ArrowArrayHolder, ArrowError> {
2223
do {
2324
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBinary), buffers: buffers,
24-
nullCount: buffers[0].length, stride: MemoryLayout<Int8>.stride)
25+
nullCount: nullCount, stride: MemoryLayout<Int8>.stride)
2526
return .success(ArrowArrayHolder(BinaryArray(arrowData)))
2627
} catch let error as ArrowError {
2728
return .failure(error)
@@ -30,10 +31,11 @@ private func makeBinaryHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHold
3031
}
3132
}
3233

33-
private func makeStringHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
34+
private func makeStringHolder(_ buffers: [ArrowBuffer],
35+
nullCount: UInt) -> Result<ArrowArrayHolder, ArrowError> {
3436
do {
3537
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers,
36-
nullCount: buffers[0].length, stride: MemoryLayout<Int8>.stride)
38+
nullCount: nullCount, stride: MemoryLayout<Int8>.stride)
3739
return .success(ArrowArrayHolder(StringArray(arrowData)))
3840
} catch let error as ArrowError {
3941
return .failure(error)
@@ -43,30 +45,32 @@ private func makeStringHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHold
4345
}
4446

4547
private func makeFloatHolder(_ floatType: org_apache_arrow_flatbuf_FloatingPoint,
46-
buffers: [ArrowBuffer]
48+
buffers: [ArrowBuffer],
49+
nullCount: UInt
4750
) -> Result<ArrowArrayHolder, ArrowError> {
4851
switch floatType.precision {
4952
case .single:
50-
return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat)
53+
return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat, nullCount: nullCount)
5154
case .double:
52-
return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble)
55+
return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble, nullCount: nullCount)
5356
default:
5457
return .failure(.unknownType("Float precision \(floatType.precision) currently not supported"))
5558
}
5659
}
5760

5861
private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date,
59-
buffers: [ArrowBuffer]
62+
buffers: [ArrowBuffer],
63+
nullCount: UInt
6064
) -> Result<ArrowArrayHolder, ArrowError> {
6165
do {
6266
if dateType.unit == .day {
6367
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers,
64-
nullCount: buffers[0].length, stride: MemoryLayout<Date>.stride)
68+
nullCount: nullCount, stride: MemoryLayout<Date>.stride)
6569
return .success(ArrowArrayHolder(Date32Array(arrowData)))
6670
}
6771

6872
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers,
69-
nullCount: buffers[0].length, stride: MemoryLayout<Date>.stride)
73+
nullCount: nullCount, stride: MemoryLayout<Date>.stride)
7074
return .success(ArrowArrayHolder(Date64Array(arrowData)))
7175
} catch let error as ArrowError {
7276
return .failure(error)
@@ -76,19 +80,20 @@ private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date,
7680
}
7781

7882
private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time,
79-
buffers: [ArrowBuffer]
83+
buffers: [ArrowBuffer],
84+
nullCount: UInt
8085
) -> Result<ArrowArrayHolder, ArrowError> {
8186
do {
8287
if timeType.unit == .second || timeType.unit == .millisecond {
8388
let arrowUnit: ArrowTime32Unit = timeType.unit == .second ? .seconds : .milliseconds
8489
let arrowData = try ArrowData(ArrowTypeTime32(arrowUnit), buffers: buffers,
85-
nullCount: buffers[0].length, stride: MemoryLayout<Time32>.stride)
90+
nullCount: nullCount, stride: MemoryLayout<Time32>.stride)
8691
return .success(ArrowArrayHolder(FixedArray<Time32>(arrowData)))
8792
}
8893

8994
let arrowUnit: ArrowTime64Unit = timeType.unit == .microsecond ? .microseconds : .nanoseconds
9095
let arrowData = try ArrowData(ArrowTypeTime64(arrowUnit), buffers: buffers,
91-
nullCount: buffers[0].length, stride: MemoryLayout<Time64>.stride)
96+
nullCount: nullCount, stride: MemoryLayout<Time64>.stride)
9297
return .success(ArrowArrayHolder(FixedArray<Time64>(arrowData)))
9398
} catch let error as ArrowError {
9499
return .failure(error)
@@ -97,10 +102,11 @@ private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time,
97102
}
98103
}
99104

100-
private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
105+
private func makeBoolHolder(_ buffers: [ArrowBuffer],
106+
nullCount: UInt) -> Result<ArrowArrayHolder, ArrowError> {
101107
do {
102108
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBool), buffers: buffers,
103-
nullCount: buffers[0].length, stride: MemoryLayout<UInt8>.stride)
109+
nullCount: nullCount, stride: MemoryLayout<UInt8>.stride)
104110
return .success(ArrowArrayHolder(BoolArray(arrowData)))
105111
} catch let error as ArrowError {
106112
return .failure(error)
@@ -111,11 +117,12 @@ private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder
111117

112118
private func makeFixedHolder<T>(
113119
_: T.Type, buffers: [ArrowBuffer],
114-
arrowType: ArrowType.Info
120+
arrowType: ArrowType.Info,
121+
nullCount: UInt
115122
) -> Result<ArrowArrayHolder, ArrowError> {
116123
do {
117124
let arrowData = try ArrowData(ArrowType(arrowType), buffers: buffers,
118-
nullCount: buffers[0].length, stride: MemoryLayout<T>.stride)
125+
nullCount: nullCount, stride: MemoryLayout<T>.stride)
119126
return .success(ArrowArrayHolder(FixedArray<T>(arrowData)))
120127
} catch let error as ArrowError {
121128
return .failure(error)
@@ -124,9 +131,10 @@ private func makeFixedHolder<T>(
124131
}
125132
}
126133

127-
func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
134+
func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity function_body_length
128135
_ field: org_apache_arrow_flatbuf_Field,
129-
buffers: [ArrowBuffer]
136+
buffers: [ArrowBuffer],
137+
nullCount: UInt
130138
) -> Result<ArrowArrayHolder, ArrowError> {
131139
let type = field.typeType
132140
switch type {
@@ -135,45 +143,53 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
135143
let bitWidth = intType.bitWidth
136144
if bitWidth == 8 {
137145
if intType.isSigned {
138-
return makeFixedHolder(Int8.self, buffers: buffers, arrowType: ArrowType.ArrowInt8)
146+
return makeFixedHolder(Int8.self, buffers: buffers,
147+
arrowType: ArrowType.ArrowInt8, nullCount: nullCount)
139148
} else {
140-
return makeFixedHolder(UInt8.self, buffers: buffers, arrowType: ArrowType.ArrowUInt8)
149+
return makeFixedHolder(UInt8.self, buffers: buffers,
150+
arrowType: ArrowType.ArrowUInt8, nullCount: nullCount)
141151
}
142152
} else if bitWidth == 16 {
143153
if intType.isSigned {
144-
return makeFixedHolder(Int16.self, buffers: buffers, arrowType: ArrowType.ArrowInt16)
154+
return makeFixedHolder(Int16.self, buffers: buffers,
155+
arrowType: ArrowType.ArrowInt16, nullCount: nullCount)
145156
} else {
146-
return makeFixedHolder(UInt16.self, buffers: buffers, arrowType: ArrowType.ArrowUInt16)
157+
return makeFixedHolder(UInt16.self, buffers: buffers,
158+
arrowType: ArrowType.ArrowUInt16, nullCount: nullCount)
147159
}
148160
} else if bitWidth == 32 {
149161
if intType.isSigned {
150-
return makeFixedHolder(Int32.self, buffers: buffers, arrowType: ArrowType.ArrowInt32)
162+
return makeFixedHolder(Int32.self, buffers: buffers,
163+
arrowType: ArrowType.ArrowInt32, nullCount: nullCount)
151164
} else {
152-
return makeFixedHolder(UInt32.self, buffers: buffers, arrowType: ArrowType.ArrowUInt32)
165+
return makeFixedHolder(UInt32.self, buffers: buffers,
166+
arrowType: ArrowType.ArrowUInt32, nullCount: nullCount)
153167
}
154168
} else if bitWidth == 64 {
155169
if intType.isSigned {
156-
return makeFixedHolder(Int64.self, buffers: buffers, arrowType: ArrowType.ArrowInt64)
170+
return makeFixedHolder(Int64.self, buffers: buffers,
171+
arrowType: ArrowType.ArrowInt64, nullCount: nullCount)
157172
} else {
158-
return makeFixedHolder(UInt64.self, buffers: buffers, arrowType: ArrowType.ArrowUInt64)
173+
return makeFixedHolder(UInt64.self, buffers: buffers,
174+
arrowType: ArrowType.ArrowUInt64, nullCount: nullCount)
159175
}
160176
}
161177
return .failure(.unknownType("Int width \(bitWidth) currently not supported"))
162178
case .bool:
163-
return makeBoolHolder(buffers)
179+
return makeBoolHolder(buffers, nullCount: nullCount)
164180
case .floatingpoint:
165181
let floatType = field.type(type: org_apache_arrow_flatbuf_FloatingPoint.self)!
166-
return makeFloatHolder(floatType, buffers: buffers)
182+
return makeFloatHolder(floatType, buffers: buffers, nullCount: nullCount)
167183
case .utf8:
168-
return makeStringHolder(buffers)
184+
return makeStringHolder(buffers, nullCount: nullCount)
169185
case .binary:
170-
return makeBinaryHolder(buffers)
186+
return makeBinaryHolder(buffers, nullCount: nullCount)
171187
case .date:
172188
let dateType = field.type(type: org_apache_arrow_flatbuf_Date.self)!
173-
return makeDateHolder(dateType, buffers: buffers)
189+
return makeDateHolder(dateType, buffers: buffers, nullCount: nullCount)
174190
case .time:
175191
let timeType = field.type(type: org_apache_arrow_flatbuf_Time.self)!
176-
return makeTimeHolder(timeType, buffers: buffers)
192+
return makeTimeHolder(timeType, buffers: buffers, nullCount: nullCount)
177193
default:
178194
return .failure(.unknownType("Type \(type) currently not supported"))
179195
}

swift/Arrow/Tests/ArrowTests/IPCTests.swift

+33-7
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,16 @@ func makeSchema() -> ArrowSchema {
6464
return schemaBuilder.addField("col1", type: ArrowType(ArrowType.ArrowUInt8), isNullable: true)
6565
.addField("col2", type: ArrowType(ArrowType.ArrowString), isNullable: false)
6666
.addField("col3", type: ArrowType(ArrowType.ArrowDate32), isNullable: false)
67+
.addField("col4", type: ArrowType(ArrowType.ArrowInt32), isNullable: false)
68+
.addField("col5", type: ArrowType(ArrowType.ArrowFloat), isNullable: false)
6769
.finish()
6870
}
6971

7072
func makeRecordBatch() throws -> RecordBatch {
7173
let uint8Builder: NumberArrayBuilder<UInt8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
7274
uint8Builder.append(10)
73-
uint8Builder.append(22)
74-
uint8Builder.append(33)
75+
uint8Builder.append(nil)
76+
uint8Builder.append(nil)
7577
uint8Builder.append(44)
7678
let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
7779
stringBuilder.append("test10")
@@ -85,13 +87,28 @@ func makeRecordBatch() throws -> RecordBatch {
8587
date32Builder.append(date2)
8688
date32Builder.append(date1)
8789
date32Builder.append(date2)
88-
let intHolder = ArrowArrayHolder(try uint8Builder.finish())
90+
let int32Builder: NumberArrayBuilder<Int32> = try ArrowArrayBuilders.loadNumberArrayBuilder()
91+
int32Builder.append(1)
92+
int32Builder.append(2)
93+
int32Builder.append(3)
94+
int32Builder.append(4)
95+
let floatBuilder: NumberArrayBuilder<Float> = try ArrowArrayBuilders.loadNumberArrayBuilder()
96+
floatBuilder.append(211.112)
97+
floatBuilder.append(322.223)
98+
floatBuilder.append(433.334)
99+
floatBuilder.append(544.445)
100+
101+
let uint8Holder = ArrowArrayHolder(try uint8Builder.finish())
89102
let stringHolder = ArrowArrayHolder(try stringBuilder.finish())
90103
let date32Holder = ArrowArrayHolder(try date32Builder.finish())
104+
let int32Holder = ArrowArrayHolder(try int32Builder.finish())
105+
let floatHolder = ArrowArrayHolder(try floatBuilder.finish())
91106
let result = RecordBatch.Builder()
92-
.addColumn("col1", arrowArray: intHolder)
107+
.addColumn("col1", arrowArray: uint8Holder)
93108
.addColumn("col2", arrowArray: stringHolder)
94109
.addColumn("col3", arrowArray: date32Holder)
110+
.addColumn("col4", arrowArray: int32Holder)
111+
.addColumn("col5", arrowArray: floatHolder)
95112
.finish()
96113
switch result {
97114
case .success(let recordBatch):
@@ -182,15 +199,20 @@ final class IPCFileReaderTests: XCTestCase {
182199
XCTAssertEqual(recordBatches.count, 1)
183200
for recordBatch in recordBatches {
184201
XCTAssertEqual(recordBatch.length, 4)
185-
XCTAssertEqual(recordBatch.columns.count, 3)
186-
XCTAssertEqual(recordBatch.schema.fields.count, 3)
202+
XCTAssertEqual(recordBatch.columns.count, 5)
203+
XCTAssertEqual(recordBatch.schema.fields.count, 5)
187204
XCTAssertEqual(recordBatch.schema.fields[0].name, "col1")
188205
XCTAssertEqual(recordBatch.schema.fields[0].type.info, ArrowType.ArrowUInt8)
189206
XCTAssertEqual(recordBatch.schema.fields[1].name, "col2")
190207
XCTAssertEqual(recordBatch.schema.fields[1].type.info, ArrowType.ArrowString)
191208
XCTAssertEqual(recordBatch.schema.fields[2].name, "col3")
192209
XCTAssertEqual(recordBatch.schema.fields[2].type.info, ArrowType.ArrowDate32)
210+
XCTAssertEqual(recordBatch.schema.fields[3].name, "col4")
211+
XCTAssertEqual(recordBatch.schema.fields[3].type.info, ArrowType.ArrowInt32)
212+
XCTAssertEqual(recordBatch.schema.fields[4].name, "col5")
213+
XCTAssertEqual(recordBatch.schema.fields[4].type.info, ArrowType.ArrowFloat)
193214
let columns = recordBatch.columns
215+
XCTAssertEqual(columns[0].nullCount, 2)
194216
let dateVal =
195217
"\((columns[2].array as! AsString).asString(0))" // swiftlint:disable:this force_cast
196218
XCTAssertEqual(dateVal, "2014-09-10 00:00:00 +0000")
@@ -227,13 +249,17 @@ final class IPCFileReaderTests: XCTestCase {
227249
case .success(let result):
228250
XCTAssertNotNil(result.schema)
229251
let schema = result.schema!
230-
XCTAssertEqual(schema.fields.count, 3)
252+
XCTAssertEqual(schema.fields.count, 5)
231253
XCTAssertEqual(schema.fields[0].name, "col1")
232254
XCTAssertEqual(schema.fields[0].type.info, ArrowType.ArrowUInt8)
233255
XCTAssertEqual(schema.fields[1].name, "col2")
234256
XCTAssertEqual(schema.fields[1].type.info, ArrowType.ArrowString)
235257
XCTAssertEqual(schema.fields[2].name, "col3")
236258
XCTAssertEqual(schema.fields[2].type.info, ArrowType.ArrowDate32)
259+
XCTAssertEqual(schema.fields[3].name, "col4")
260+
XCTAssertEqual(schema.fields[3].type.info, ArrowType.ArrowInt32)
261+
XCTAssertEqual(schema.fields[4].name, "col5")
262+
XCTAssertEqual(schema.fields[4].type.info, ArrowType.ArrowFloat)
237263
case.failure(let error):
238264
throw error
239265
}

0 commit comments

Comments
 (0)