Skip to content

Commit 26c16f2

Browse files
authored
feat(snowflake): transpile CORR with NaN-->NULL (#6619)
* Transpile CORR with NaN-->NULL * Cleanup and address PR feedback
1 parent 68c5e72 commit 26c16f2

File tree

5 files changed

+129
-3
lines changed

5 files changed

+129
-3
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,21 @@ def _regr_val_sql(
871871
)
872872

873873

874+
def _maybe_corr_null_to_false(
875+
expression: t.Union[exp.Filter, exp.Window, exp.Corr],
876+
) -> t.Optional[t.Union[exp.Filter, exp.Window, exp.Corr]]:
877+
expr = expression.copy()
878+
corr = expr
879+
while isinstance(corr, (exp.Window, exp.Filter)):
880+
corr = corr.this
881+
882+
if not isinstance(corr, exp.Corr) or not corr.args.get("null_on_zero_variance"):
883+
return None
884+
885+
corr.set("null_on_zero_variance", False)
886+
return expr
887+
888+
874889
def _date_from_parts_sql(self, expression: exp.DateFromParts) -> str:
875890
"""
876891
Snowflake's DATE_FROM_PARTS allows out-of-range values for the month and day input.
@@ -1348,6 +1363,7 @@ class Generator(generator.Generator):
13481363
exp.BitwiseOrAgg: _bitwise_agg_sql,
13491364
exp.BitwiseXorAgg: _bitwise_agg_sql,
13501365
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
1366+
exp.Corr: lambda self, e: self._corr_sql(e),
13511367
exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"),
13521368
exp.CurrentTime: lambda *_: "CURRENT_TIME",
13531369
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
@@ -2542,3 +2558,37 @@ def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str:
25422558
result_sql = f"~{self.sql(expression, 'this')}"
25432559

25442560
return _gen_with_cast_to_blob(self, expression, result_sql)
2561+
2562+
def window_sql(self, expression: exp.Window) -> str:
2563+
this = expression.this
2564+
if isinstance(this, exp.Corr) or (
2565+
isinstance(this, exp.Filter) and isinstance(this.this, exp.Corr)
2566+
):
2567+
return self._corr_sql(expression)
2568+
2569+
return super().window_sql(expression)
2570+
2571+
def filter_sql(self, expression: exp.Filter) -> str:
2572+
if isinstance(expression.this, exp.Corr):
2573+
return self._corr_sql(expression)
2574+
2575+
return super().filter_sql(expression)
2576+
2577+
def _corr_sql(
2578+
self,
2579+
expression: t.Union[exp.Filter, exp.Window, exp.Corr],
2580+
) -> str:
2581+
if isinstance(expression, exp.Corr) and not expression.args.get(
2582+
"null_on_zero_variance"
2583+
):
2584+
return self.func("CORR", expression.this, expression.expression)
2585+
2586+
corr_expr = _maybe_corr_null_to_false(expression)
2587+
if corr_expr is None:
2588+
if isinstance(expression, exp.Window):
2589+
return super().window_sql(expression)
2590+
if isinstance(expression, exp.Filter):
2591+
return super().filter_sql(expression)
2592+
corr_expr = expression # make mypy happy
2593+
2594+
return self.sql(exp.case().when(exp.IsNan(this=corr_expr), exp.null()).else_(corr_expr))

sqlglot/dialects/snowflake.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,11 @@ class Parser(parser.Parser):
769769
"BIT_XORAGG": exp.BitwiseXorAgg.from_arg_list,
770770
"BITMAP_OR_AGG": exp.BitmapOrAgg.from_arg_list,
771771
"BOOLXOR": _build_bitwise(exp.Xor, "BOOLXOR"),
772+
"CORR": lambda args: exp.Corr(
773+
this=seq_get(args, 0),
774+
expression=seq_get(args, 1),
775+
null_on_zero_variance=True,
776+
),
772777
"DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
773778
"DATEFROMPARTS": _build_date_from_parts,
774779
"DATE_FROM_PARTS": _build_date_from_parts,

sqlglot/expressions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8300,8 +8300,11 @@ class Upper(Func):
83008300
_sql_names = ["UPPER", "UCASE"]
83018301

83028302

8303-
class Corr(AggFunc):
8304-
arg_types = {"this": True, "expression": True}
8303+
class Corr(Binary, AggFunc):
8304+
# Correlation divides by variance(column). If a column has 0 variance, the denominator
8305+
# is 0 - some dialects return NaN (DuckDB) while others return NULL (Snowflake).
8306+
# `null_on_zero_variance` is set to True at parse time for dialects that return NULL.
8307+
arg_types = {"this": True, "expression": True, "null_on_zero_variance": False}
83058308

83068309

83078310
# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/CUME_DIST.html

tests/dialects/test_postgres.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,36 @@ def test_variance(self):
13981398
},
13991399
)
14001400

1401+
def test_corr(self):
1402+
self.validate_all(
1403+
"SELECT CORR(a, b)",
1404+
write={
1405+
"duckdb": "SELECT CORR(a, b)",
1406+
"postgres": "SELECT CORR(a, b)",
1407+
},
1408+
)
1409+
self.validate_all(
1410+
"SELECT CORR(a, b) OVER (PARTITION BY c)",
1411+
write={
1412+
"duckdb": "SELECT CORR(a, b) OVER (PARTITION BY c)",
1413+
"postgres": "SELECT CORR(a, b) OVER (PARTITION BY c)",
1414+
},
1415+
)
1416+
self.validate_all(
1417+
"SELECT CORR(a, b) FILTER(WHERE c > 0)",
1418+
write={
1419+
"duckdb": "SELECT CORR(a, b) FILTER(WHERE c > 0)",
1420+
"postgres": "SELECT CORR(a, b) FILTER(WHERE c > 0)",
1421+
},
1422+
)
1423+
self.validate_all(
1424+
"SELECT CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)",
1425+
write={
1426+
"duckdb": "SELECT CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)",
1427+
"postgres": "SELECT CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)",
1428+
},
1429+
)
1430+
14011431
def test_regexp_binary(self):
14021432
"""See https://github.com/tobymao/sqlglot/pull/2404 for details."""
14031433
self.assertIsInstance(self.parse_one("'thomas' ~ '.*thomas.*'"), exp.Binary)

tests/dialects/test_snowflake.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,6 @@ def test_snowflake(self):
11701170
},
11711171
)
11721172
for func in (
1173-
"CORR",
11741173
"COVAR_POP",
11751174
"COVAR_SAMP",
11761175
):
@@ -4559,6 +4558,45 @@ def test_ceil(self):
45594558
},
45604559
)
45614560

4561+
def test_corr(self):
4562+
self.validate_all(
4563+
"SELECT CORR(a, b)",
4564+
read={
4565+
"snowflake": "SELECT CORR(a, b)",
4566+
"postgres": "SELECT CORR(a, b)",
4567+
},
4568+
write={
4569+
"snowflake": "SELECT CORR(a, b)",
4570+
"postgres": "SELECT CORR(a, b)",
4571+
"duckdb": "SELECT CASE WHEN ISNAN(CORR(a, b)) THEN NULL ELSE CORR(a, b) END",
4572+
},
4573+
)
4574+
self.validate_all(
4575+
"SELECT CORR(a, b) OVER (PARTITION BY c)",
4576+
read={
4577+
"snowflake": "SELECT CORR(a, b) OVER (PARTITION BY c)",
4578+
"postgres": "SELECT CORR(a, b) OVER (PARTITION BY c)",
4579+
},
4580+
write={
4581+
"snowflake": "SELECT CORR(a, b) OVER (PARTITION BY c)",
4582+
"postgres": "SELECT CORR(a, b) OVER (PARTITION BY c)",
4583+
"duckdb": "SELECT CASE WHEN ISNAN(CORR(a, b) OVER (PARTITION BY c)) THEN NULL ELSE CORR(a, b) OVER (PARTITION BY c) END",
4584+
},
4585+
)
4586+
4587+
self.validate_all(
4588+
"SELECT CORR(a, b) FILTER(WHERE c > 0)",
4589+
write={
4590+
"duckdb": "SELECT CASE WHEN ISNAN(CORR(a, b) FILTER(WHERE c > 0)) THEN NULL ELSE CORR(a, b) FILTER(WHERE c > 0) END",
4591+
},
4592+
)
4593+
self.validate_all(
4594+
"SELECT CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)",
4595+
write={
4596+
"duckdb": "SELECT CASE WHEN ISNAN(CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)) THEN NULL ELSE CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d) END",
4597+
},
4598+
)
4599+
45624600
def test_encryption_functions(self):
45634601
# ENCRYPT
45644602
self.validate_identity("ENCRYPT(value, 'passphrase')")

0 commit comments

Comments
 (0)