Skip to content

Commit 2121f0b

Browse files
committed
Added logic to cast fields to CDAP types when using the Storage Read API
1 parent 1953f91 commit 2121f0b

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySQLEngine.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,11 @@ public SQLDatasetProducer getProducer(SQLPullRequest pullRequest, PullCapability
317317

318318
String table = datasets.get(pullRequest.getDatasetName()).getBigQueryTable();
319319

320-
return new BigQuerySparkDatasetProducer(sqlEngineConfig, datasetProject, dataset, table);
320+
return new BigQuerySparkDatasetProducer(sqlEngineConfig,
321+
datasetProject,
322+
dataset,
323+
table,
324+
pullRequest.getDatasetSchema());
321325
}
322326

323327
@Override

src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySparkDatasetProducer.java

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,20 @@
1616

1717
package io.cdap.plugin.gcp.bigquery.sqlengine;
1818

19+
import io.cdap.cdap.api.data.schema.Schema;
1920
import io.cdap.cdap.etl.api.engine.sql.dataset.RecordCollection;
2021
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDataset;
2122
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDatasetDescription;
2223
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDatasetProducer;
2324
import io.cdap.cdap.etl.api.sql.engine.dataset.SparkRecordCollectionImpl;
24-
import io.cdap.plugin.gcp.common.GCPConfig;
2525
import org.apache.spark.SparkContext;
2626
import org.apache.spark.sql.DataFrameReader;
2727
import org.apache.spark.sql.Dataset;
2828
import org.apache.spark.sql.Row;
2929
import org.apache.spark.sql.SparkSession;
30+
import org.apache.spark.sql.types.DataTypes;
31+
import org.slf4j.Logger;
32+
import org.slf4j.LoggerFactory;
3033

3134
import java.io.Serializable;
3235
import java.nio.charset.StandardCharsets;
@@ -39,6 +42,8 @@
3942
public class BigQuerySparkDatasetProducer
4043
implements SQLDatasetProducer, Serializable {
4144

45+
private static final Logger LOG = LoggerFactory.getLogger(BigQuerySparkDatasetProducer.class);
46+
4247
private static final String FORMAT = "bigquery";
4348
private static final String CONFIG_CREDENTIALS_FILE = "credentialsFile";
4449
private static final String CONFIG_CREDENTIALS = "credentials";
@@ -47,15 +52,19 @@ public class BigQuerySparkDatasetProducer
4752
private String project;
4853
private String bqDataset;
4954
private String bqTable;
55+
private Schema schema;
56+
5057

5158
public BigQuerySparkDatasetProducer(BigQuerySQLEngineConfig config,
5259
String project,
5360
String bqDataset,
54-
String bqTable) {
61+
String bqTable,
62+
Schema schema) {
5563
this.config = config;
5664
this.project = project;
5765
this.bqDataset = bqDataset;
5866
this.bqTable = bqTable;
67+
this.schema = schema;
5968
}
6069

6170
@Override
@@ -87,6 +96,7 @@ public RecordCollection produce(SQLDataset sqlDataset) {
8796

8897
// Load path into dataset.
8998
Dataset<Row> ds = bqReader.load(path);
99+
ds = convertFieldTypes(ds);
90100

91101
return new SparkRecordCollectionImpl(ds);
92102
}
@@ -95,4 +105,37 @@ public RecordCollection produce(SQLDataset sqlDataset) {
95105
private String encodeBase64(String serviceAccountJson) {
96106
return Base64.getEncoder().encodeToString(serviceAccountJson.getBytes(StandardCharsets.UTF_8));
97107
}
108+
109+
/**
110+
* Adjust CDAP types for int and float fields.
111+
*
112+
* @param ds input dataframe
113+
* @return dataframe with updated schema.
114+
*/
115+
private Dataset<Row> convertFieldTypes(Dataset<Row> ds) {
116+
for (Schema.Field field : schema.getFields()) {
117+
String fieldName = field.getName();
118+
Schema fieldSchema = field.getSchema();
119+
120+
// For nullable types, check the underlying type.
121+
if (fieldSchema.isNullable()) {
122+
fieldSchema = fieldSchema.getNonNullable();
123+
}
124+
125+
// Handle Int types
126+
if (fieldSchema.getType() == Schema.Type.INT && fieldSchema.getLogicalType() == null) {
127+
LOG.trace("Converting field {} to Integer", fieldName);
128+
ds = ds.withColumn(fieldName, ds.col(fieldName).cast(DataTypes.IntegerType));
129+
}
130+
131+
// Handle float types
132+
if (fieldSchema.getType() == Schema.Type.FLOAT && fieldSchema.getLogicalType() == null) {
133+
LOG.trace("Converting field {} to Float", fieldName);
134+
ds = ds.withColumn(fieldName, ds.col(fieldName).cast(DataTypes.FloatType));
135+
}
136+
}
137+
138+
return ds;
139+
}
140+
98141
}

0 commit comments

Comments
 (0)