|
| 1 | +/* |
| 2 | + * Copyright 2018 ABSA Group Limited |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package za.co.absa.cobrix.spark.cobol.writer |
| 18 | + |
| 19 | +import org.apache.spark.rdd.RDD |
| 20 | +import org.apache.spark.sql.types.{ArrayType, StructType} |
| 21 | +import org.apache.spark.sql.{DataFrame, Row} |
| 22 | +import za.co.absa.cobrix.cobol.parser.Copybook |
| 23 | +import za.co.absa.cobrix.cobol.parser.ast.datatype.{Decimal, Integral} |
| 24 | +import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive} |
| 25 | +import za.co.absa.cobrix.cobol.parser.recordformats.RecordFormat |
| 26 | +import za.co.absa.cobrix.cobol.reader.parameters.ReaderParameters |
| 27 | +import za.co.absa.cobrix.cobol.reader.schema.CobolSchema |
| 28 | +import za.co.absa.cobrix.spark.cobol.writer.WriterAst._ |
| 29 | + |
| 30 | +import scala.collection.mutable |
| 31 | + |
| 32 | +class NestedRecordCombiner extends RecordCombiner { |
| 33 | + import NestedRecordCombiner._ |
| 34 | + |
| 35 | + override def combine(df: DataFrame, cobolSchema: CobolSchema, readerParameters: ReaderParameters): RDD[Array[Byte]] = { |
| 36 | + val hasRdw = readerParameters.recordFormat == RecordFormat.VariableLength |
| 37 | + val isRdwBigEndian = readerParameters.isRdwBigEndian |
| 38 | + val adjustment1 = if (readerParameters.isRdwPartRecLength) 4 else 0 |
| 39 | + val adjustment2 = readerParameters.rdwAdjustment |
| 40 | + |
| 41 | + val size = if (hasRdw) { |
| 42 | + cobolSchema.getRecordSize + 4 |
| 43 | + } else { |
| 44 | + cobolSchema.getRecordSize |
| 45 | + } |
| 46 | + |
| 47 | + val startOffset = if (hasRdw) 4 else 0 |
| 48 | + |
| 49 | + val recordLengthLong = cobolSchema.getRecordSize.toLong + adjustment1.toLong + adjustment2.toLong |
| 50 | + if (recordLengthLong < 0) { |
| 51 | + throw new IllegalArgumentException( |
| 52 | + s"Invalid RDW length $recordLengthLong. Check 'is_rdw_part_of_record_length' and 'rdw_adjustment'." |
| 53 | + ) |
| 54 | + } |
| 55 | + if (isRdwBigEndian && recordLengthLong > 0xFFFFL) { |
| 56 | + throw new IllegalArgumentException( |
| 57 | + s"RDW length $recordLengthLong exceeds 65535 and cannot be encoded in big-endian mode." |
| 58 | + ) |
| 59 | + } |
| 60 | + if (!isRdwBigEndian && recordLengthLong > Int.MaxValue.toLong) { |
| 61 | + throw new IllegalArgumentException( |
| 62 | + s"RDW length $recordLengthLong exceeds ${Int.MaxValue} and cannot be encoded safely." |
| 63 | + ) |
| 64 | + } |
| 65 | + val recordLength = recordLengthLong.toInt |
| 66 | + |
| 67 | + processRDD(df.rdd, cobolSchema.copybook, df.schema, size, recordLength, startOffset, hasRdw, isRdwBigEndian) |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +object NestedRecordCombiner { |
| 72 | + def getFieldDefinition(field: Primitive): String = { |
| 73 | + val pic = field.dataType.originalPic.getOrElse(field.dataType.pic) |
| 74 | + |
| 75 | + val usage = field.dataType match { |
| 76 | + case dt: Integral => dt.compact.map(_.toString).getOrElse("USAGE IS DISPLAY") |
| 77 | + case dt: Decimal => dt.compact.map(_.toString).getOrElse("USAGE IS DISPLAY") |
| 78 | + case _ => "" |
| 79 | + } |
| 80 | + |
| 81 | + s"$pic $usage".trim |
| 82 | + } |
| 83 | + |
| 84 | + def constructWriterAst(copybook: Copybook, schema: StructType): GroupField = { |
| 85 | + buildGroupField(getAst(copybook), schema, row => row) |
| 86 | + } |
| 87 | + |
| 88 | + def processRDD(rdd: RDD[Row], copybook: Copybook, schema: StructType, recordSize: Int, recordLengthHeader: Int, startOffset: Int, hasRdw: Boolean, isRdwBigEndian: Boolean): RDD[Array[Byte]] = { |
| 89 | + val writerAst = constructWriterAst(copybook, schema) |
| 90 | + |
| 91 | + rdd.mapPartitions { rows => |
| 92 | + rows.map { row => |
| 93 | + val ar = new Array[Byte](recordSize) |
| 94 | + |
| 95 | + if (hasRdw) { |
| 96 | + if (isRdwBigEndian) { |
| 97 | + ar(0) = ((recordLengthHeader >> 8) & 0xFF).toByte |
| 98 | + ar(1) = (recordLengthHeader & 0xFF).toByte |
| 99 | + // The last two bytes are reserved and defined by IBM as binary zeros on all platforms. |
| 100 | + ar(2) = 0 |
| 101 | + ar(3) = 0 |
| 102 | + } else { |
| 103 | + ar(0) = (recordLengthHeader & 0xFF).toByte |
| 104 | + ar(1) = ((recordLengthHeader >> 8) & 0xFF).toByte |
| 105 | + // This is non-standard. But so are little-endian RDW headers. |
| 106 | + // As an advantage, it has no effect for small records but adds support for big records (> 64KB). |
| 107 | + ar(2) = ((recordLengthHeader >> 16) & 0xFF).toByte |
| 108 | + ar(3) = ((recordLengthHeader >> 24) & 0xFF).toByte |
| 109 | + } |
| 110 | + } |
| 111 | + |
| 112 | + writeToBytes(writerAst, row, ar, startOffset) |
| 113 | + |
| 114 | + ar |
| 115 | + } |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + def getAst(copybook: Copybook): Group = { |
| 120 | + val rootAst = copybook.ast |
| 121 | + |
| 122 | + if (rootAst.children.length == 1 && rootAst.children.head.isInstanceOf[Group]) { |
| 123 | + rootAst.children.head.asInstanceOf[Group] |
| 124 | + } else { |
| 125 | + rootAst |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + /** |
| 130 | + * Recursively walks the copybook group and the Spark StructType in lockstep, producing |
| 131 | + * [[WriterAst]] nodes whose getters extract the correct value from a [[org.apache.spark.sql.Row]]. |
| 132 | + * |
| 133 | + * @param group A copybook Group node whose children will be processed. |
| 134 | + * @param schema The Spark StructType that corresponds to `group`. |
| 135 | + * @param getter A function that, given the "outer" Row, returns the Row that belongs to this group. |
| 136 | + * @return A [[GroupField]] covering all non-filler, non-redefines children found in both |
| 137 | + * the copybook and the Spark schema. |
| 138 | + */ |
| 139 | + private def buildGroupField(group: Group, schema: StructType, getter: GroupGetter): GroupField = { |
| 140 | + val children = group.children.withFilter { stmt => |
| 141 | + stmt.redefines.isEmpty |
| 142 | + }.flatMap { |
| 143 | + case s if s.isFiller => Some(Filler(s.binaryProperties.actualSize)) |
| 144 | + case p: Primitive => buildPrimitiveNode(p, schema) |
| 145 | + case g: Group => buildGroupNode(g, schema) |
| 146 | + } |
| 147 | + GroupField(children.toSeq, group, getter) |
| 148 | + } |
| 149 | + |
| 150 | + /** |
| 151 | + * Builds a [[WriterAst]] node for a primitive copybook field, using the field's index in the |
| 152 | + * supplied Spark schema to create a getter function. |
| 153 | + * |
| 154 | + * Returns `None` when the field is absent from the schema (e.g. filtered out during reading). |
| 155 | + */ |
| 156 | + private def buildPrimitiveNode(p: Primitive, schema: StructType): Option[WriterAst] = { |
| 157 | + val fieldName = p.name |
| 158 | + val fieldIndexOpt = schema.fields.zipWithIndex.find { case (field, _) => |
| 159 | + field.name.equalsIgnoreCase(fieldName) |
| 160 | + }.map(_._2) |
| 161 | + |
| 162 | + fieldIndexOpt.map { idx => |
| 163 | + if (p.occurs.isDefined) { |
| 164 | + // Array of primitives |
| 165 | + PrimitiveArray(p, row => row.getAs[mutable.WrappedArray[AnyRef]](idx)) |
| 166 | + } else { |
| 167 | + PrimitiveField(p, row => row.get(idx)) |
| 168 | + } |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + /** |
| 173 | + * Builds a [[WriterAst]] node for a group copybook field. For groups with OCCURS the getter |
| 174 | + * extracts an array; for plain groups it extracts the nested Row. In both cases the children |
| 175 | + * are built by recursing into the nested Spark StructType. |
| 176 | + * |
| 177 | + * Returns `None` when the field is absent from the schema. |
| 178 | + */ |
| 179 | + private def buildGroupNode(g: Group, schema: StructType): Option[WriterAst] = { |
| 180 | + val fieldName = g.name |
| 181 | + val fieldIndexOpt = schema.fields.zipWithIndex.find { case (field, _) => |
| 182 | + field.name.equalsIgnoreCase(fieldName) |
| 183 | + }.map(_._2) |
| 184 | + |
| 185 | + fieldIndexOpt.flatMap { idx => |
| 186 | + if (g.occurs.isDefined) { |
| 187 | + // Array of structs – the element type must be a StructType |
| 188 | + schema(idx).dataType match { |
| 189 | + case ArrayType(elementType: StructType, _) => |
| 190 | + val childAst = buildGroupField(g, elementType, row => row) |
| 191 | + Some(GroupArray(childAst, g, row => row.getAs[mutable.WrappedArray[AnyRef]](idx))) |
| 192 | + case other => |
| 193 | + throw new IllegalArgumentException( |
| 194 | + s"Expected ArrayType(StructType) for group field '${g.name}' with OCCURS, but got $other") |
| 195 | + } |
| 196 | + } else { |
| 197 | + // Nested struct |
| 198 | + schema(idx).dataType match { |
| 199 | + case nestedSchema: StructType => |
| 200 | + val childGetter: GroupGetter = row => row.getAs[Row](idx) |
| 201 | + val childAst = buildGroupField(g, nestedSchema, childGetter) |
| 202 | + Some(GroupField(childAst.children, g, childGetter)) |
| 203 | + case other => |
| 204 | + throw new IllegalArgumentException( |
| 205 | + s"Expected StructType for group field '${g.name}', but got $other") |
| 206 | + } |
| 207 | + } |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + /** |
| 212 | + * Recursively walks `ast` and writes every primitive value from `row` into `ar`. |
| 213 | + * |
| 214 | + * For plain (non-array) fields the `configuredStartOffset` is forwarded directly to |
| 215 | + * [[Copybook.setPrimitiveField]], which adds it to `field.binaryProperties.offset`. |
| 216 | + * |
| 217 | + * For array fields (both primitive and group-of-primitives) each element is written |
| 218 | + * using the `fieldStartOffsetOverride` parameter so the exact byte position can be |
| 219 | + * supplied. The row array may contain fewer elements than the copybook allows — any |
| 220 | + * missing tail elements are silently skipped, leaving those bytes as zeroes. |
| 221 | + * |
| 222 | + * @param ast The [[WriterAst]] node to process. |
| 223 | + * @param row The Spark [[Row]] from which values are read. |
| 224 | + * @param ar The target byte array (record buffer). |
| 225 | + * @param currentOffset RDW prefix length (0 for fixed-length records, 4 for variable). |
| 226 | + */ |
| 227 | + private def writeToBytes(ast: WriterAst, row: Row, ar: Array[Byte], currentOffset: Int): Int = { |
| 228 | + ast match { |
| 229 | + // ── Filler ────────────────────────────────────────────────────── |
| 230 | + case Filler(size) => size |
| 231 | + |
| 232 | + // ── Plain primitive ────────────────────────────────────────────────────── |
| 233 | + case PrimitiveField(cobolField, getter) => |
| 234 | + val value = getter(row) |
| 235 | + if (value != null) { |
| 236 | + Copybook.setPrimitiveField(cobolField, ar, value, 0, currentOffset) |
| 237 | + } |
| 238 | + cobolField.binaryProperties.actualSize |
| 239 | + |
| 240 | + // ── Plain nested group ─────────────────────────────────────────────────── |
| 241 | + case GroupField(children, cobolField, getter) => |
| 242 | + val nestedRow = getter(row) |
| 243 | + if (nestedRow != null) { |
| 244 | + var writtenBytes = 0 |
| 245 | + children.foreach(child => |
| 246 | + writtenBytes += writeToBytes(child, nestedRow, ar, currentOffset + writtenBytes) |
| 247 | + ) |
| 248 | + } |
| 249 | + cobolField.binaryProperties.actualSize |
| 250 | + |
| 251 | + // ── Array of primitives (OCCURS on a primitive field) ─────────────────── |
| 252 | + case PrimitiveArray(cobolField, arrayGetter) => |
| 253 | + val arr = arrayGetter(row) |
| 254 | + if (arr != null) { |
| 255 | + val maxElements = cobolField.arrayMaxSize // copybook upper bound |
| 256 | + val elementSize = cobolField.binaryProperties.dataSize |
| 257 | + val baseOffset = currentOffset |
| 258 | + val elementsToWrite = math.min(arr.length, maxElements) |
| 259 | + |
| 260 | + var i = 0 |
| 261 | + while (i < elementsToWrite) { |
| 262 | + val value = arr(i) |
| 263 | + if (value != null) { |
| 264 | + val elementOffset = baseOffset + i * elementSize |
| 265 | + // fieldStartOffsetOverride is the absolute position; pass it so |
| 266 | + // setPrimitiveField does not add binaryProperties.offset on top again. |
| 267 | + Copybook.setPrimitiveField(cobolField, ar, value, fieldStartOffsetOverride = elementOffset ) |
| 268 | + } |
| 269 | + i += 1 |
| 270 | + } |
| 271 | + } |
| 272 | + cobolField.binaryProperties.actualSize |
| 273 | + |
| 274 | + // ── Array of groups (OCCURS on a group field) ─────────────────────────── |
| 275 | + case GroupArray(groupField: GroupField, cobolField, arrayGetter) => |
| 276 | + val arr = arrayGetter(row) |
| 277 | + if (arr != null) { |
| 278 | + val maxElements = cobolField.arrayMaxSize |
| 279 | + // Single-element size: actualSize spans all elements, so divide by maxElements. |
| 280 | + val elementSize = cobolField.binaryProperties.dataSize |
| 281 | + val baseOffset = currentOffset |
| 282 | + val elementsToWrite = math.min(arr.length, maxElements) |
| 283 | + |
| 284 | + var i = 0 |
| 285 | + while (i < elementsToWrite) { |
| 286 | + val elementRow = arr(i).asInstanceOf[Row] |
| 287 | + if (elementRow != null) { |
| 288 | + // Build an adjusted element offset so that each child's base offset |
| 289 | + // (which is relative to the group's base) lands at the correct position in ar. |
| 290 | + val elementStartOffset = baseOffset + i * elementSize |
| 291 | + writeToBytes(groupField, elementRow, ar, elementStartOffset) |
| 292 | + } |
| 293 | + i += 1 |
| 294 | + } |
| 295 | + } |
| 296 | + cobolField.binaryProperties.actualSize |
| 297 | + } |
| 298 | + } |
| 299 | +} |
0 commit comments