Skip to content

Commit 1e3736c

Browse files
committed
#828 Add support for structs and arrays in EBCDIC writer.
1 parent 5e31b62 commit 1e3736c

File tree

6 files changed

+528
-2
lines changed

6 files changed

+528
-2
lines changed

spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ package za.co.absa.cobrix.spark.cobol.utils
1919
import com.fasterxml.jackson.databind.ObjectMapper
2020
import org.apache.hadoop.fs.{FileSystem, Path}
2121
import org.apache.spark.SparkContext
22+
import org.apache.spark.sql.expressions.UserDefinedFunction
2223
import org.apache.spark.sql.functions._
2324
import org.apache.spark.sql.types._
24-
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
25+
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
2526
import za.co.absa.cobrix.cobol.internal.Logging
2627
import za.co.absa.cobrix.spark.cobol.parameters.MetadataFields.MAX_ELEMENTS
2728
import za.co.absa.cobrix.spark.cobol.utils.impl.HofsWrapper.transform
@@ -550,5 +551,44 @@ object SparkUtils extends Logging {
550551
fields.toList
551552
}
552553

554+
/**
555+
* A UDF that receives the entire record as a [[Row]] and returns a
556+
* human-readable string representation of its contents.
557+
*
558+
* Usage (after columns are combined into a struct):
559+
* {{{
560+
* df.withColumn("record_dump", printRowUdf(struct(df.columns.map(col): _*)))
561+
* }}}
562+
*/
563+
val printRowUdf: UserDefinedFunction = udf { row: Row =>
564+
def rowToString(r: Row): String = {
565+
val schema = r.schema
566+
val fields = schema.fields.zipWithIndex.map { case (field, idx) =>
567+
val value = if (r.isNullAt(idx)) {
568+
"null"
569+
} else {
570+
r.get(idx) match {
571+
case nestedRow: Row =>
572+
s"{${rowToString(nestedRow)}}"
573+
case seq: Seq[_] =>
574+
val items = seq.map {
575+
case nestedRow: Row => s"{${rowToString(nestedRow)}}"
576+
case other => String.valueOf(other)
577+
}
578+
s"[${items.mkString(", ")}]"
579+
case other =>
580+
String.valueOf(other)
581+
}
582+
}
583+
s"${field.name}=$value"
584+
}
585+
fields.mkString(", ")
586+
}
587+
588+
val result = rowToString(row)
589+
// Side-effect: print to stdout so the content is visible during tests
590+
//println(s"[printRowUdf] $result")
591+
result
592+
}
553593

554594
}
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
/*
2+
* Copyright 2018 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.cobrix.spark.cobol.writer
18+
19+
import org.apache.spark.rdd.RDD
20+
import org.apache.spark.sql.types.{ArrayType, StructType}
21+
import org.apache.spark.sql.{DataFrame, Row}
22+
import za.co.absa.cobrix.cobol.parser.Copybook
23+
import za.co.absa.cobrix.cobol.parser.ast.datatype.{Decimal, Integral}
24+
import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive}
25+
import za.co.absa.cobrix.cobol.parser.recordformats.RecordFormat
26+
import za.co.absa.cobrix.cobol.reader.parameters.ReaderParameters
27+
import za.co.absa.cobrix.cobol.reader.schema.CobolSchema
28+
import za.co.absa.cobrix.spark.cobol.writer.WriterAst._
29+
30+
import scala.collection.mutable
31+
32+
class NestedRecordCombiner extends RecordCombiner {
33+
import NestedRecordCombiner._
34+
35+
override def combine(df: DataFrame, cobolSchema: CobolSchema, readerParameters: ReaderParameters): RDD[Array[Byte]] = {
36+
val hasRdw = readerParameters.recordFormat == RecordFormat.VariableLength
37+
val isRdwBigEndian = readerParameters.isRdwBigEndian
38+
val adjustment1 = if (readerParameters.isRdwPartRecLength) 4 else 0
39+
val adjustment2 = readerParameters.rdwAdjustment
40+
41+
val size = if (hasRdw) {
42+
cobolSchema.getRecordSize + 4
43+
} else {
44+
cobolSchema.getRecordSize
45+
}
46+
47+
val startOffset = if (hasRdw) 4 else 0
48+
49+
val recordLengthLong = cobolSchema.getRecordSize.toLong + adjustment1.toLong + adjustment2.toLong
50+
if (recordLengthLong < 0) {
51+
throw new IllegalArgumentException(
52+
s"Invalid RDW length $recordLengthLong. Check 'is_rdw_part_of_record_length' and 'rdw_adjustment'."
53+
)
54+
}
55+
if (isRdwBigEndian && recordLengthLong > 0xFFFFL) {
56+
throw new IllegalArgumentException(
57+
s"RDW length $recordLengthLong exceeds 65535 and cannot be encoded in big-endian mode."
58+
)
59+
}
60+
if (!isRdwBigEndian && recordLengthLong > Int.MaxValue.toLong) {
61+
throw new IllegalArgumentException(
62+
s"RDW length $recordLengthLong exceeds ${Int.MaxValue} and cannot be encoded safely."
63+
)
64+
}
65+
val recordLength = recordLengthLong.toInt
66+
67+
processRDD(df.rdd, cobolSchema.copybook, df.schema, size, recordLength, startOffset, hasRdw, isRdwBigEndian)
68+
}
69+
}
70+
71+
object NestedRecordCombiner {
72+
def getFieldDefinition(field: Primitive): String = {
73+
val pic = field.dataType.originalPic.getOrElse(field.dataType.pic)
74+
75+
val usage = field.dataType match {
76+
case dt: Integral => dt.compact.map(_.toString).getOrElse("USAGE IS DISPLAY")
77+
case dt: Decimal => dt.compact.map(_.toString).getOrElse("USAGE IS DISPLAY")
78+
case _ => ""
79+
}
80+
81+
s"$pic $usage".trim
82+
}
83+
84+
def constructWriterAst(copybook: Copybook, schema: StructType): GroupField = {
85+
buildGroupField(getAst(copybook), schema, row => row)
86+
}
87+
88+
def processRDD(rdd: RDD[Row], copybook: Copybook, schema: StructType, recordSize: Int, recordLengthHeader: Int, startOffset: Int, hasRdw: Boolean, isRdwBigEndian: Boolean): RDD[Array[Byte]] = {
89+
val writerAst = constructWriterAst(copybook, schema)
90+
91+
rdd.mapPartitions { rows =>
92+
rows.map { row =>
93+
val ar = new Array[Byte](recordSize)
94+
95+
if (hasRdw) {
96+
if (isRdwBigEndian) {
97+
ar(0) = ((recordLengthHeader >> 8) & 0xFF).toByte
98+
ar(1) = (recordLengthHeader & 0xFF).toByte
99+
// The last two bytes are reserved and defined by IBM as binary zeros on all platforms.
100+
ar(2) = 0
101+
ar(3) = 0
102+
} else {
103+
ar(0) = (recordLengthHeader & 0xFF).toByte
104+
ar(1) = ((recordLengthHeader >> 8) & 0xFF).toByte
105+
// This is non-standard. But so are little-endian RDW headers.
106+
// As an advantage, it has no effect for small records but adds support for big records (> 64KB).
107+
ar(2) = ((recordLengthHeader >> 16) & 0xFF).toByte
108+
ar(3) = ((recordLengthHeader >> 24) & 0xFF).toByte
109+
}
110+
}
111+
112+
writeToBytes(writerAst, row, ar, startOffset)
113+
114+
ar
115+
}
116+
}
117+
}
118+
119+
def getAst(copybook: Copybook): Group = {
120+
val rootAst = copybook.ast
121+
122+
if (rootAst.children.length == 1 && rootAst.children.head.isInstanceOf[Group]) {
123+
rootAst.children.head.asInstanceOf[Group]
124+
} else {
125+
rootAst
126+
}
127+
}
128+
129+
/**
130+
* Recursively walks the copybook group and the Spark StructType in lockstep, producing
131+
* [[WriterAst]] nodes whose getters extract the correct value from a [[org.apache.spark.sql.Row]].
132+
*
133+
* @param group A copybook Group node whose children will be processed.
134+
* @param schema The Spark StructType that corresponds to `group`.
135+
* @param getter A function that, given the "outer" Row, returns the Row that belongs to this group.
136+
* @return A [[GroupField]] covering all non-filler, non-redefines children found in both
137+
* the copybook and the Spark schema.
138+
*/
139+
private def buildGroupField(group: Group, schema: StructType, getter: GroupGetter): GroupField = {
140+
val children = group.children.withFilter { stmt =>
141+
stmt.redefines.isEmpty
142+
}.flatMap {
143+
case s if s.isFiller => Some(Filler(s.binaryProperties.actualSize))
144+
case p: Primitive => buildPrimitiveNode(p, schema)
145+
case g: Group => buildGroupNode(g, schema)
146+
}
147+
GroupField(children.toSeq, group, getter)
148+
}
149+
150+
/**
151+
* Builds a [[WriterAst]] node for a primitive copybook field, using the field's index in the
152+
* supplied Spark schema to create a getter function.
153+
*
154+
* Returns `None` when the field is absent from the schema (e.g. filtered out during reading).
155+
*/
156+
private def buildPrimitiveNode(p: Primitive, schema: StructType): Option[WriterAst] = {
157+
val fieldName = p.name
158+
val fieldIndexOpt = schema.fields.zipWithIndex.find { case (field, _) =>
159+
field.name.equalsIgnoreCase(fieldName)
160+
}.map(_._2)
161+
162+
fieldIndexOpt.map { idx =>
163+
if (p.occurs.isDefined) {
164+
// Array of primitives
165+
PrimitiveArray(p, row => row.getAs[mutable.WrappedArray[AnyRef]](idx))
166+
} else {
167+
PrimitiveField(p, row => row.get(idx))
168+
}
169+
}
170+
}
171+
172+
/**
173+
* Builds a [[WriterAst]] node for a group copybook field. For groups with OCCURS the getter
174+
* extracts an array; for plain groups it extracts the nested Row. In both cases the children
175+
* are built by recursing into the nested Spark StructType.
176+
*
177+
* Returns `None` when the field is absent from the schema.
178+
*/
179+
private def buildGroupNode(g: Group, schema: StructType): Option[WriterAst] = {
180+
val fieldName = g.name
181+
val fieldIndexOpt = schema.fields.zipWithIndex.find { case (field, _) =>
182+
field.name.equalsIgnoreCase(fieldName)
183+
}.map(_._2)
184+
185+
fieldIndexOpt.flatMap { idx =>
186+
if (g.occurs.isDefined) {
187+
// Array of structs – the element type must be a StructType
188+
schema(idx).dataType match {
189+
case ArrayType(elementType: StructType, _) =>
190+
val childAst = buildGroupField(g, elementType, row => row)
191+
Some(GroupArray(childAst, g, row => row.getAs[mutable.WrappedArray[AnyRef]](idx)))
192+
case other =>
193+
throw new IllegalArgumentException(
194+
s"Expected ArrayType(StructType) for group field '${g.name}' with OCCURS, but got $other")
195+
}
196+
} else {
197+
// Nested struct
198+
schema(idx).dataType match {
199+
case nestedSchema: StructType =>
200+
val childGetter: GroupGetter = row => row.getAs[Row](idx)
201+
val childAst = buildGroupField(g, nestedSchema, childGetter)
202+
Some(GroupField(childAst.children, g, childGetter))
203+
case other =>
204+
throw new IllegalArgumentException(
205+
s"Expected StructType for group field '${g.name}', but got $other")
206+
}
207+
}
208+
}
209+
}
210+
211+
/**
212+
* Recursively walks `ast` and writes every primitive value from `row` into `ar`.
213+
*
214+
* For plain (non-array) fields the `configuredStartOffset` is forwarded directly to
215+
* [[Copybook.setPrimitiveField]], which adds it to `field.binaryProperties.offset`.
216+
*
217+
* For array fields (both primitive and group-of-primitives) each element is written
218+
* using the `fieldStartOffsetOverride` parameter so the exact byte position can be
219+
* supplied. The row array may contain fewer elements than the copybook allows — any
220+
* missing tail elements are silently skipped, leaving those bytes as zeroes.
221+
*
222+
* @param ast The [[WriterAst]] node to process.
223+
* @param row The Spark [[Row]] from which values are read.
224+
* @param ar The target byte array (record buffer).
225+
* @param currentOffset RDW prefix length (0 for fixed-length records, 4 for variable).
226+
*/
227+
private def writeToBytes(ast: WriterAst, row: Row, ar: Array[Byte], currentOffset: Int): Int = {
228+
ast match {
229+
// ── Filler ──────────────────────────────────────────────────────
230+
case Filler(size) => size
231+
232+
// ── Plain primitive ──────────────────────────────────────────────────────
233+
case PrimitiveField(cobolField, getter) =>
234+
val value = getter(row)
235+
if (value != null) {
236+
Copybook.setPrimitiveField(cobolField, ar, value, 0, currentOffset)
237+
}
238+
cobolField.binaryProperties.actualSize
239+
240+
// ── Plain nested group ───────────────────────────────────────────────────
241+
case GroupField(children, cobolField, getter) =>
242+
val nestedRow = getter(row)
243+
if (nestedRow != null) {
244+
var writtenBytes = 0
245+
children.foreach(child =>
246+
writtenBytes += writeToBytes(child, nestedRow, ar, currentOffset + writtenBytes)
247+
)
248+
}
249+
cobolField.binaryProperties.actualSize
250+
251+
// ── Array of primitives (OCCURS on a primitive field) ───────────────────
252+
case PrimitiveArray(cobolField, arrayGetter) =>
253+
val arr = arrayGetter(row)
254+
if (arr != null) {
255+
val maxElements = cobolField.arrayMaxSize // copybook upper bound
256+
val elementSize = cobolField.binaryProperties.dataSize
257+
val baseOffset = currentOffset
258+
val elementsToWrite = math.min(arr.length, maxElements)
259+
260+
var i = 0
261+
while (i < elementsToWrite) {
262+
val value = arr(i)
263+
if (value != null) {
264+
val elementOffset = baseOffset + i * elementSize
265+
// fieldStartOffsetOverride is the absolute position; pass it so
266+
// setPrimitiveField does not add binaryProperties.offset on top again.
267+
Copybook.setPrimitiveField(cobolField, ar, value, fieldStartOffsetOverride = elementOffset )
268+
}
269+
i += 1
270+
}
271+
}
272+
cobolField.binaryProperties.actualSize
273+
274+
// ── Array of groups (OCCURS on a group field) ───────────────────────────
275+
case GroupArray(groupField: GroupField, cobolField, arrayGetter) =>
276+
val arr = arrayGetter(row)
277+
if (arr != null) {
278+
val maxElements = cobolField.arrayMaxSize
279+
// Single-element size: actualSize spans all elements, so divide by maxElements.
280+
val elementSize = cobolField.binaryProperties.dataSize
281+
val baseOffset = currentOffset
282+
val elementsToWrite = math.min(arr.length, maxElements)
283+
284+
var i = 0
285+
while (i < elementsToWrite) {
286+
val elementRow = arr(i).asInstanceOf[Row]
287+
if (elementRow != null) {
288+
// Build an adjusted element offset so that each child's base offset
289+
// (which is relative to the group's base) lands at the correct position in ar.
290+
val elementStartOffset = baseOffset + i * elementSize
291+
writeToBytes(groupField, elementRow, ar, elementStartOffset)
292+
}
293+
i += 1
294+
}
295+
}
296+
cobolField.binaryProperties.actualSize
297+
}
298+
}
299+
}

spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/RecordCombinerSelector.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ object RecordCombinerSelector {
3232
* @return A `RecordCombiner` implementation suitable for combining records based on the given schema and parameters.
3333
*/
3434
def selectCombiner(cobolSchema: CobolSchema, readerParameters: ReaderParameters): RecordCombiner = {
35-
new BasicRecordCombiner
35+
new NestedRecordCombiner
3636
}
3737

3838
}

0 commit comments

Comments
 (0)