1616
1717package io .cdap .plugin .gcp .bigquery .sqlengine ;
1818
19+ import io .cdap .cdap .api .data .schema .Schema ;
1920import io .cdap .cdap .etl .api .engine .sql .dataset .RecordCollection ;
2021import io .cdap .cdap .etl .api .engine .sql .dataset .SQLDataset ;
2122import io .cdap .cdap .etl .api .engine .sql .dataset .SQLDatasetDescription ;
2223import io .cdap .cdap .etl .api .engine .sql .dataset .SQLDatasetProducer ;
2324import io .cdap .cdap .etl .api .sql .engine .dataset .SparkRecordCollectionImpl ;
24- import io .cdap .plugin .gcp .common .GCPConfig ;
2525import org .apache .spark .SparkContext ;
2626import org .apache .spark .sql .DataFrameReader ;
2727import org .apache .spark .sql .Dataset ;
2828import org .apache .spark .sql .Row ;
2929import org .apache .spark .sql .SparkSession ;
30+ import org .apache .spark .sql .types .DataTypes ;
31+ import org .slf4j .Logger ;
32+ import org .slf4j .LoggerFactory ;
3033
3134import java .io .Serializable ;
3235import java .nio .charset .StandardCharsets ;
3942public class BigQuerySparkDatasetProducer
4043 implements SQLDatasetProducer , Serializable {
4144
45+ private static final Logger LOG = LoggerFactory .getLogger (BigQuerySparkDatasetProducer .class );
46+
4247 private static final String FORMAT = "bigquery" ;
4348 private static final String CONFIG_CREDENTIALS_FILE = "credentialsFile" ;
4449 private static final String CONFIG_CREDENTIALS = "credentials" ;
@@ -47,15 +52,19 @@ public class BigQuerySparkDatasetProducer
4752 private String project ;
4853 private String bqDataset ;
4954 private String bqTable ;
55+ private Schema schema ;
56+
5057
5158 public BigQuerySparkDatasetProducer (BigQuerySQLEngineConfig config ,
5259 String project ,
5360 String bqDataset ,
54- String bqTable ) {
61+ String bqTable ,
62+ Schema schema ) {
5563 this .config = config ;
5664 this .project = project ;
5765 this .bqDataset = bqDataset ;
5866 this .bqTable = bqTable ;
67+ this .schema = schema ;
5968 }
6069
6170 @ Override
@@ -87,6 +96,7 @@ public RecordCollection produce(SQLDataset sqlDataset) {
8796
8897 // Load path into dataset.
8998 Dataset <Row > ds = bqReader .load (path );
99+ ds = convertFieldTypes (ds );
90100
91101 return new SparkRecordCollectionImpl (ds );
92102 }
@@ -95,4 +105,37 @@ public RecordCollection produce(SQLDataset sqlDataset) {
95105 private String encodeBase64 (String serviceAccountJson ) {
96106 return Base64 .getEncoder ().encodeToString (serviceAccountJson .getBytes (StandardCharsets .UTF_8 ));
97107 }
108+
109+ /**
110+ * Adjust CDAP types for int and float fields.
111+ *
112+ * @param ds input dataframe
113+ * @return dataframe with updated schema.
114+ */
115+ private Dataset <Row > convertFieldTypes (Dataset <Row > ds ) {
116+ for (Schema .Field field : schema .getFields ()) {
117+ String fieldName = field .getName ();
118+ Schema fieldSchema = field .getSchema ();
119+
120+ // For nullable types, check the underlying type.
121+ if (fieldSchema .isNullable ()) {
122+ fieldSchema = fieldSchema .getNonNullable ();
123+ }
124+
125+ // Handle Int types
126+ if (fieldSchema .getType () == Schema .Type .INT && fieldSchema .getLogicalType () == null ) {
127+ LOG .trace ("Converting field {} to Integer" , fieldName );
128+ ds = ds .withColumn (fieldName , ds .col (fieldName ).cast (DataTypes .IntegerType ));
129+ }
130+
131+ // Handle float types
132+ if (fieldSchema .getType () == Schema .Type .FLOAT && fieldSchema .getLogicalType () == null ) {
133+ LOG .trace ("Converting field {} to Float" , fieldName );
134+ ds = ds .withColumn (fieldName , ds .col (fieldName ).cast (DataTypes .FloatType ));
135+ }
136+ }
137+
138+ return ds ;
139+ }
140+
98141}
0 commit comments