Skip to content

Commit 11a0cbb

Browse files
committed
dict writer to expose codes ptype
Signed-off-by: Onur Satici <[email protected]>
1 parent 9686835 commit 11a0cbb

File tree

4 files changed

+56
-20
lines changed

4 files changed

+56
-20
lines changed

vortex-array/src/builders/dict/bytes.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use vortex_buffer::BitBufferMut;
99
use vortex_buffer::BufferMut;
1010
use vortex_buffer::ByteBufferMut;
1111
use vortex_dtype::DType;
12+
use vortex_dtype::PType;
1213
use vortex_dtype::UnsignedPType;
1314
use vortex_error::VortexExpect;
1415
use vortex_error::vortex_panic;
@@ -195,6 +196,10 @@ impl<Code: UnsignedPType> DictEncoder for BytesDictBuilder<Code> {
195196
.into_array()
196197
}
197198
}
199+
200+
fn codes_ptype(&self) -> PType {
201+
Code::PTYPE
202+
}
198203
}
199204

200205
#[cfg(test)]

vortex-array/src/builders/dict/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use bytes::bytes_dict_builder;
55
use primitive::primitive_dict_builder;
6+
use vortex_dtype::PType;
67
use vortex_dtype::match_each_native_ptype;
78
use vortex_error::VortexResult;
89
use vortex_error::vortex_bail;
@@ -37,6 +38,9 @@ pub trait DictEncoder: Send {
3738

3839
/// Clear the encoder state to make it ready for a new round of decoding.
3940
fn reset(&mut self) -> ArrayRef;
41+
42+
/// Returns the PType of the codes this encoder produces.
43+
fn codes_ptype(&self) -> PType;
4044
}
4145

4246
pub fn dict_encoder(array: &dyn Array, constraints: &DictConstraints) -> Box<dyn DictEncoder> {

vortex-array/src/builders/dict/primitive.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use vortex_buffer::BitBufferMut;
99
use vortex_buffer::BufferMut;
1010
use vortex_dtype::NativePType;
1111
use vortex_dtype::Nullability;
12+
use vortex_dtype::PType;
1213
use vortex_dtype::UnsignedPType;
1314
use vortex_error::vortex_panic;
1415
use vortex_utils::aliases::hash_map::Entry;
@@ -145,6 +146,10 @@ where
145146
)
146147
.into_array()
147148
}
149+
150+
fn codes_ptype(&self) -> PType {
151+
Code::PTYPE
152+
}
148153
}
149154

150155
#[cfg(test)]

vortex-layout/src/layouts/dict/writer.rs

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ use vortex_array::builders::dict::DictEncoder;
2727
use vortex_array::builders::dict::dict_encoder;
2828
use vortex_btrblocks::BtrBlocksCompressor;
2929
use vortex_dtype::DType;
30+
use vortex_dtype::Nullability;
3031
use vortex_dtype::PType;
3132
use vortex_error::VortexError;
33+
use vortex_error::VortexExpect;
3234
use vortex_error::VortexResult;
3335
use vortex_error::vortex_err;
3436
use vortex_io::kanal_ext::KanalExt;
@@ -61,9 +63,9 @@ pub struct DictLayoutConstraints {
6163
/// Maximum dictionary length. Limited to `u16` because dictionaries with more than 64k unique
6264
/// values provide diminishing compression returns given typical chunk sizes (~8k elements).
6365
///
64-
/// The codes dtype is chosen dynamically based on the actual dictionary size:
65-
/// - [`PType::U8`] when the dictionary has at most 255 entries
66-
/// - [`PType::U16`] when the dictionary has more than 255 entries
66+
/// The codes dtype is determined upfront from this constraint:
67+
/// - [`PType::U8`] when max_len <= 255
68+
/// - [`PType::U16`] when max_len > 255
6769
pub max_len: u16,
6870
}
6971

@@ -237,7 +239,11 @@ impl LayoutStrategy for DictStrategy {
237239
}
238240

239241
enum DictionaryChunk {
240-
Codes((SequenceId, ArrayRef)),
242+
Codes {
243+
seq_id: SequenceId,
244+
codes: ArrayRef,
245+
codes_ptype: PType,
246+
},
241247
Values((SequenceId, ArrayRef)),
242248
}
243249

@@ -299,26 +305,33 @@ impl DictStreamState {
299305
match self.encoder.take() {
300306
None => match start_encoding(&self.constraints, &remaining) {
301307
EncodingState::Continue((encoder, encoded)) => {
302-
res.push(labeler.codes(encoded));
308+
let ptype = encoder.codes_ptype();
309+
res.push(labeler.codes(encoded, ptype));
303310
self.encoder = Some(encoder);
304311
}
305312
EncodingState::Done((values, encoded, unencoded)) => {
306-
res.push(labeler.codes(encoded));
313+
// Encoder was created and consumed within start_encoding
314+
let ptype = PType::try_from(encoded.dtype())
315+
.vortex_expect("codes should be primitive");
316+
res.push(labeler.codes(encoded, ptype));
307317
res.push(labeler.values(values));
308318
to_be_encoded = Some(unencoded);
309319
}
310320
},
311-
Some(encoder) => match encode_chunk(encoder, &remaining) {
312-
EncodingState::Continue((encoder, encoded)) => {
313-
res.push(labeler.codes(encoded));
314-
self.encoder = Some(encoder);
315-
}
316-
EncodingState::Done((values, encoded, unencoded)) => {
317-
res.push(labeler.codes(encoded));
318-
res.push(labeler.values(values));
319-
to_be_encoded = Some(unencoded);
321+
Some(encoder) => {
322+
let ptype = encoder.codes_ptype();
323+
match encode_chunk(encoder, &remaining) {
324+
EncodingState::Continue((encoder, encoded)) => {
325+
res.push(labeler.codes(encoded, ptype));
326+
self.encoder = Some(encoder);
327+
}
328+
EncodingState::Done((values, encoded, unencoded)) => {
329+
res.push(labeler.codes(encoded, ptype));
330+
res.push(labeler.values(values));
331+
to_be_encoded = Some(unencoded);
332+
}
320333
}
321-
},
334+
}
322335
}
323336
}
324337
res
@@ -342,8 +355,12 @@ impl DictChunkLabeler {
342355
Self { sequence_pointer }
343356
}
344357

345-
fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
346-
DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
358+
fn codes(&mut self, chunk: ArrayRef, ptype: PType) -> DictionaryChunk {
359+
DictionaryChunk::Codes {
360+
seq_id: self.sequence_pointer.advance(),
361+
codes: chunk,
362+
codes_ptype: ptype,
363+
}
347364
}
348365

349366
fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
@@ -398,7 +415,11 @@ impl Stream for DictionaryTransformer {
398415
}
399416

400417
match self.input.poll_next_unpin(cx) {
401-
Poll::Ready(Some(Ok(DictionaryChunk::Codes((seq_id, codes))))) => {
418+
Poll::Ready(Some(Ok(DictionaryChunk::Codes {
419+
seq_id,
420+
codes,
421+
codes_ptype,
422+
}))) => {
402423
if self.active_codes_tx.is_none() {
403424
// Start a new group
404425
let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
@@ -407,7 +428,8 @@ impl Stream for DictionaryTransformer {
407428
self.active_codes_tx = Some(codes_tx.clone());
408429
self.active_values_tx = Some(values_tx);
409430

410-
let codes_dtype = codes.dtype().clone();
431+
// Use passed codes_ptype instead of getting from array
432+
let codes_dtype = DType::Primitive(codes_ptype, Nullability::NonNullable);
411433

412434
// Send first codes.
413435
self.pending_send =

0 commit comments

Comments
 (0)