Skip to content

Commit 7016895

Browse files
committed
feat: add support for decimal and decimal64 types in taosws with corresponding tests
1 parent 6a6c5ae commit 7016895

File tree

4 files changed

+154
-3
lines changed

4 files changed

+154
-3
lines changed

taos-ws-py/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

taos-ws-py/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ crate-type = ["cdylib"]
1111

1212
[dependencies]
1313
anyhow = "1"
14+
bigdecimal = "0.4"
1415
chrono = "0.4"
1516
chrono-tz = "0.10.4"
1617
iana-time-zone = "0.1.63"

taos-ws-py/src/lib.rs

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
use std::str::FromStr;
44

55
use ::taos::{sync::*, RawBlock, ResultSet};
6+
use bigdecimal::BigDecimal;
67
use chrono_tz::Tz;
78
use pyo3::prelude::*;
8-
use pyo3::types::{PyDict, PyString, PyTuple};
9+
use pyo3::types::{PyAny, PyDict, PyString, PyTuple, PyType};
910
use pyo3::{create_exception, exceptions::PyException};
1011
use taos::taos_query;
11-
use taos::taos_query::common::{SchemalessPrecision, SchemalessProtocol, SmlDataBuilder};
12+
use taos::taos_query::common::{
13+
views::DecimalView, SchemalessPrecision, SchemalessProtocol, SmlDataBuilder,
14+
};
1215
use taos::Value::{
1316
BigInt, Bool, Double, Float, Geometry, Int, Json, NChar, Null, SmallInt, Timestamp, TinyInt,
1417
UBigInt, UInt, USmallInt, UTinyInt, VarBinary, VarChar,
@@ -618,6 +621,7 @@ enum PyColumnType {
618621
Json,
619622
VarBinary,
620623
Decimal,
624+
Decimal64,
621625
Blob,
622626
MediumBlob,
623627
Geometry,
@@ -935,6 +939,70 @@ fn doubles_to_column(values: Vec<Option<f64>>) -> PyColumnView {
935939
}
936940
}
937941

942+
fn parse_decimal_values(values: Vec<Option<&PyAny>>) -> PyResult<Vec<Option<BigDecimal>>> {
943+
let mut decimals = Vec::with_capacity(values.len());
944+
945+
for (index, value) in values.into_iter().enumerate() {
946+
let decimal = match value {
947+
Some(value) => {
948+
let decimal_type: &PyType = value
949+
.py()
950+
.import("decimal")?
951+
.getattr("Decimal")?
952+
.downcast()?;
953+
if !value.is_instance(decimal_type)? {
954+
let type_name = value.get_type().name()?;
955+
return Err(ProgrammingError::new_err(format!(
956+
"expected decimal.Decimal or None, got {type_name} at index {index}"
957+
)));
958+
}
959+
960+
let decimal_str = value.str()?.to_string();
961+
let is_finite = value.call_method0("is_finite")?.extract::<bool>()?;
962+
if !is_finite {
963+
return Err(ProgrammingError::new_err(format!(
964+
"decimal value must be finite, got '{decimal_str}' at index {index}"
965+
)));
966+
}
967+
968+
let fixed_decimal = value
969+
.call_method1("__format__", ("f",))?
970+
.extract::<String>()?;
971+
let parsed = BigDecimal::from_str(&fixed_decimal).map_err(|_| {
972+
ProgrammingError::new_err(format!(
973+
"failed to parse decimal value '{decimal_str}' at index {index}"
974+
))
975+
})?;
976+
Some(parsed)
977+
}
978+
None => None,
979+
};
980+
decimals.push(decimal);
981+
}
982+
983+
Ok(decimals)
984+
}
985+
986+
#[pyfunction]
987+
fn decimal64_to_column(values: Vec<Option<&PyAny>>) -> PyResult<PyColumnView> {
988+
let decimals = parse_decimal_values(values)?;
989+
let decimal_view = DecimalView::<i64>::from_decimals(decimals)
990+
.map_err(|_| ProgrammingError::new_err("decimal64 scale exceeds maximum 18"))?;
991+
Ok(PyColumnView {
992+
_inner: ColumnView::Decimal64(decimal_view),
993+
})
994+
}
995+
996+
#[pyfunction]
997+
fn decimal_to_column(values: Vec<Option<&PyAny>>) -> PyResult<PyColumnView> {
998+
let decimals = parse_decimal_values(values)?;
999+
let decimal_view = DecimalView::<i128>::from_decimals(decimals)
1000+
.map_err(|_| ProgrammingError::new_err("decimal scale exceeds maximum 38"))?;
1001+
Ok(PyColumnView {
1002+
_inner: ColumnView::Decimal(decimal_view),
1003+
})
1004+
}
1005+
9381006
#[pyfunction]
9391007
fn varchar_to_column(values: Vec<Option<String>>) -> PyColumnView {
9401008
PyColumnView {
@@ -1073,6 +1141,8 @@ fn _taosws(py: Python<'_>, m: &PyModule) -> PyResult<()> {
10731141
m.add_function(wrap_pyfunction!(unsigned_big_ints_to_column, m)?)?;
10741142
m.add_function(wrap_pyfunction!(floats_to_column, m)?)?;
10751143
m.add_function(wrap_pyfunction!(doubles_to_column, m)?)?;
1144+
m.add_function(wrap_pyfunction!(decimal64_to_column, m)?)?;
1145+
m.add_function(wrap_pyfunction!(decimal_to_column, m)?)?;
10761146
m.add_function(wrap_pyfunction!(varchar_to_column, m)?)?;
10771147
m.add_function(wrap_pyfunction!(nchar_to_column, m)?)?;
10781148
m.add_function(wrap_pyfunction!(json_to_column, m)?)?;

taos-ws-py/tests/test_stmt2.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import time
2-
import taosws
2+
from decimal import Decimal
33

4+
import pytest
5+
import taosws
46

57
url = "taosws://localhost:6041/"
68

@@ -96,3 +98,80 @@ def test_stmt2_stable():
9698
assert rows == 4
9799
stmt2_query(conn, "select * from stb1 where a > ?")
98100
after_test(db_name)
101+
102+
103+
def test_decimal_column_rejects_non_decimal_type():
104+
with pytest.raises(taosws.ProgrammingError, match=r"expected decimal\.Decimal or None, got int at index 0"):
105+
taosws.decimal64_to_column([1])
106+
107+
108+
def test_decimal_column_rejects_non_finite_value():
109+
with pytest.raises(taosws.ProgrammingError, match=r"decimal value must be finite, got 'NaN' at index 0"):
110+
taosws.decimal_to_column([Decimal("NaN")])
111+
112+
113+
def test_decimal64_column_rejects_scale_overflow():
114+
with pytest.raises(taosws.ProgrammingError, match=r"decimal64 scale exceeds maximum 18"):
115+
taosws.decimal64_to_column([Decimal("0.1234567890123456789")])
116+
117+
118+
def test_decimal_column_rejects_scale_overflow():
119+
with pytest.raises(taosws.ProgrammingError, match=r"decimal scale exceeds maximum 38"):
120+
taosws.decimal_to_column([Decimal("0.123456789012345678901234567890123456789")])
121+
122+
123+
def test_stmt2_decimal():
124+
db_name = "test_1774702104"
125+
conn = taosws.connect()
126+
try:
127+
conn.execute(f"drop database if exists {db_name}")
128+
conn.execute(f"create database {db_name}")
129+
conn.execute(f"use {db_name}")
130+
conn.execute("create table t_decimal (ts timestamp, d64 decimal(10,2), d128 decimal(20,10))")
131+
132+
stmt2 = conn.stmt2_statement()
133+
stmt2.prepare("insert into t_decimal values (?, ?, ?)")
134+
param = taosws.stmt2_bind_param_view(
135+
table_name="",
136+
tags=None,
137+
columns=[
138+
taosws.millis_timestamps_to_column([1726803356466, 1726803356467, 1726803356468]),
139+
taosws.decimal64_to_column([Decimal("99.9876"), Decimal("1.0234"), None]),
140+
taosws.decimal_to_column(
141+
[Decimal("1234567890.1234567890"), Decimal("1.23E+5"), Decimal("0.1234567890123")]
142+
),
143+
],
144+
)
145+
stmt2.bind([param])
146+
rows = stmt2.execute()
147+
assert rows == 3
148+
149+
taosws.decimal64_to_column([])
150+
taosws.decimal_to_column([])
151+
152+
try:
153+
stmt2.prepare("select d64, d128 from t_decimal where ts >= ? order by ts")
154+
query_param = taosws.stmt2_bind_param_view(
155+
table_name="",
156+
tags=None,
157+
columns=[taosws.millis_timestamps_to_column([1726803356466])],
158+
)
159+
stmt2.bind([query_param])
160+
stmt2.execute()
161+
data = [row for row in stmt2.result_set()]
162+
except (taosws.QueryError, taosws.OperationalError) as err:
163+
if "[0x073A]" in str(err):
164+
pytest.skip("current ws environment cannot query DECIMAL rows: query memory exhausted")
165+
raise
166+
167+
assert len(data) == 3
168+
assert Decimal(data[0][0]) == Decimal("99.99")
169+
assert Decimal(data[0][1]) == Decimal("1234567890.1234567890")
170+
assert Decimal(data[1][0]) == Decimal("1.02")
171+
assert Decimal(data[1][1]) == Decimal("123000")
172+
assert data[2][0] is None
173+
assert Decimal(data[2][1]) == Decimal("0.1234567890")
174+
175+
finally:
176+
conn.execute(f"drop database if exists {db_name}")
177+
conn.close()

0 commit comments

Comments
 (0)