|
3 | 3 | use std::str::FromStr; |
4 | 4 |
|
5 | 5 | use ::taos::{sync::*, RawBlock, ResultSet}; |
| 6 | +use bigdecimal::BigDecimal; |
6 | 7 | use chrono_tz::Tz; |
7 | 8 | use pyo3::prelude::*; |
8 | | -use pyo3::types::{PyDict, PyString, PyTuple}; |
| 9 | +use pyo3::types::{PyAny, PyDict, PyString, PyTuple, PyType}; |
9 | 10 | use pyo3::{create_exception, exceptions::PyException}; |
10 | 11 | use taos::taos_query; |
11 | | -use taos::taos_query::common::{SchemalessPrecision, SchemalessProtocol, SmlDataBuilder}; |
| 12 | +use taos::taos_query::common::{ |
| 13 | + views::DecimalView, SchemalessPrecision, SchemalessProtocol, SmlDataBuilder, |
| 14 | +}; |
12 | 15 | use taos::Value::{ |
13 | 16 | BigInt, Bool, Double, Float, Geometry, Int, Json, NChar, Null, SmallInt, Timestamp, TinyInt, |
14 | 17 | UBigInt, UInt, USmallInt, UTinyInt, VarBinary, VarChar, |
@@ -618,6 +621,7 @@ enum PyColumnType { |
618 | 621 | Json, |
619 | 622 | VarBinary, |
620 | 623 | Decimal, |
| 624 | + Decimal64, |
621 | 625 | Blob, |
622 | 626 | MediumBlob, |
623 | 627 | Geometry, |
@@ -935,6 +939,70 @@ fn doubles_to_column(values: Vec<Option<f64>>) -> PyColumnView { |
935 | 939 | } |
936 | 940 | } |
937 | 941 |
|
| 942 | +fn parse_decimal_values(values: Vec<Option<&PyAny>>) -> PyResult<Vec<Option<BigDecimal>>> { |
| 943 | + let mut decimals = Vec::with_capacity(values.len()); |
| 944 | + |
| 945 | + for (index, value) in values.into_iter().enumerate() { |
| 946 | + let decimal = match value { |
| 947 | + Some(value) => { |
| 948 | + let decimal_type: &PyType = value |
| 949 | + .py() |
| 950 | + .import("decimal")? |
| 951 | + .getattr("Decimal")? |
| 952 | + .downcast()?; |
| 953 | + if !value.is_instance(decimal_type)? { |
| 954 | + let type_name = value.get_type().name()?; |
| 955 | + return Err(ProgrammingError::new_err(format!( |
| 956 | + "expected decimal.Decimal or None, got {type_name} at index {index}" |
| 957 | + ))); |
| 958 | + } |
| 959 | + |
| 960 | + let decimal_str = value.str()?.to_string(); |
| 961 | + let is_finite = value.call_method0("is_finite")?.extract::<bool>()?; |
| 962 | + if !is_finite { |
| 963 | + return Err(ProgrammingError::new_err(format!( |
| 964 | + "decimal value must be finite, got '{decimal_str}' at index {index}" |
| 965 | + ))); |
| 966 | + } |
| 967 | + |
| 968 | + let fixed_decimal = value |
| 969 | + .call_method1("__format__", ("f",))? |
| 970 | + .extract::<String>()?; |
| 971 | + let parsed = BigDecimal::from_str(&fixed_decimal).map_err(|_| { |
| 972 | + ProgrammingError::new_err(format!( |
| 973 | + "failed to parse decimal value '{decimal_str}' at index {index}" |
| 974 | + )) |
| 975 | + })?; |
| 976 | + Some(parsed) |
| 977 | + } |
| 978 | + None => None, |
| 979 | + }; |
| 980 | + decimals.push(decimal); |
| 981 | + } |
| 982 | + |
| 983 | + Ok(decimals) |
| 984 | +} |
| 985 | + |
| 986 | +#[pyfunction] |
| 987 | +fn decimal64_to_column(values: Vec<Option<&PyAny>>) -> PyResult<PyColumnView> { |
| 988 | + let decimals = parse_decimal_values(values)?; |
| 989 | + let decimal_view = DecimalView::<i64>::from_decimals(decimals) |
| 990 | + .map_err(|_| ProgrammingError::new_err("decimal64 scale exceeds maximum 18"))?; |
| 991 | + Ok(PyColumnView { |
| 992 | + _inner: ColumnView::Decimal64(decimal_view), |
| 993 | + }) |
| 994 | +} |
| 995 | + |
| 996 | +#[pyfunction] |
| 997 | +fn decimal_to_column(values: Vec<Option<&PyAny>>) -> PyResult<PyColumnView> { |
| 998 | + let decimals = parse_decimal_values(values)?; |
| 999 | + let decimal_view = DecimalView::<i128>::from_decimals(decimals) |
| 1000 | + .map_err(|_| ProgrammingError::new_err("decimal scale exceeds maximum 38"))?; |
| 1001 | + Ok(PyColumnView { |
| 1002 | + _inner: ColumnView::Decimal(decimal_view), |
| 1003 | + }) |
| 1004 | +} |
| 1005 | + |
938 | 1006 | #[pyfunction] |
939 | 1007 | fn varchar_to_column(values: Vec<Option<String>>) -> PyColumnView { |
940 | 1008 | PyColumnView { |
@@ -1073,6 +1141,8 @@ fn _taosws(py: Python<'_>, m: &PyModule) -> PyResult<()> { |
1073 | 1141 | m.add_function(wrap_pyfunction!(unsigned_big_ints_to_column, m)?)?; |
1074 | 1142 | m.add_function(wrap_pyfunction!(floats_to_column, m)?)?; |
1075 | 1143 | m.add_function(wrap_pyfunction!(doubles_to_column, m)?)?; |
| 1144 | + m.add_function(wrap_pyfunction!(decimal64_to_column, m)?)?; |
| 1145 | + m.add_function(wrap_pyfunction!(decimal_to_column, m)?)?; |
1076 | 1146 | m.add_function(wrap_pyfunction!(varchar_to_column, m)?)?; |
1077 | 1147 | m.add_function(wrap_pyfunction!(nchar_to_column, m)?)?; |
1078 | 1148 | m.add_function(wrap_pyfunction!(json_to_column, m)?)?; |
|
0 commit comments