Skip to content

Commit 5fe4e11

Browse files
author
Yevgeni Litvin
committed
Parallelize encoding of a single row
Fields that contain data that is not natively supported by Parqyet format, such as numpy arrays, are serialized into byte arrays. Images maybe compressed using png or jpeg compression. Serializing fields on a thread pool speeds up this process in some cases (e.g. a row contains multiple images). This PR adds a pool executor argument to `dict_to_spark_row` enabling user to pass a pool executor that would be used for parallelizing this serialization. If no pool executor is specified, the encoding/serialization is performed on the caller thread.
1 parent 83a02df commit 5fe4e11

File tree

2 files changed

+75
-15
lines changed

2 files changed

+75
-15
lines changed

petastorm/tests/test_unischema.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525

2626
from petastorm.codecs import ScalarCodec, NdarrayCodec
2727
from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row, \
28-
insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch
28+
insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch, encode_row
29+
30+
from concurrent.futures import ThreadPoolExecutor
2931

3032
try:
3133
from unittest import mock
@@ -107,6 +109,28 @@ def test_as_spark_schema_unspecified_codec_type_unknown_scalar_type_raises():
107109
TestSchema.as_spark_schema()
108110

109111

112+
@pytest.mark.parametrize("pool_executor", [None, ThreadPoolExecutor(2)])
113+
def test_encode_row(pool_executor):
114+
"""Test various validations done on data types when converting a dictionary to a spark row"""
115+
TestSchema = Unischema('TestSchema', [
116+
UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
117+
UnischemaField('int8_matrix', np.int8, (2, 2), NdarrayCodec(), False),
118+
])
119+
120+
row = {'string_field': 'abc', 'int8_matrix': np.asarray([[1, 2], [3, 4]], dtype=np.int8)}
121+
encoded_row = encode_row(TestSchema, row, pool_executor)
122+
assert set(row.keys()) == set(encoded_row)
123+
assert isinstance(encoded_row['int8_matrix'], bytearray)
124+
125+
extra_field_row = {'string_field': 'abc', 'int8_matrix': [[1, 2], [3, 4]], 'bogus': 'value'}
126+
with pytest.raises(ValueError, match='.*not found.*bogus.*'):
127+
encode_row(TestSchema, extra_field_row, pool_executor)
128+
129+
extra_field_row = {'string_field': 'abc'}
130+
with pytest.raises(ValueError, match='int8_matrix is not found'):
131+
encode_row(TestSchema, extra_field_row, pool_executor)
132+
133+
110134
def test_dict_to_spark_row_field_validation_scalar_types():
111135
"""Test various validations done on data types when converting a dictionary to a spark row"""
112136
TestSchema = Unischema('TestSchema', [

petastorm/unischema.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def from_arrow_schema(cls, parquet_dataset, omit_unsupported_fields=False):
353353
return Unischema('inferred_schema', unischema_fields)
354354

355355

356-
def dict_to_spark_row(unischema, row_dict):
356+
def dict_to_spark_row(unischema, row_dict, pool_executor=None):
357357
"""Converts a single row into a spark Row object.
358358
359359
Verifies that the data confirms with unischema definition types and encodes the data using the codec specified
@@ -363,44 +363,80 @@ def dict_to_spark_row(unischema, row_dict):
363363
364364
:param unischema: an instance of Unischema object
365365
:param row_dict: a dictionary where the keys match name of fields in the unischema.
366+
:param pool_executor: if not None, encoding of row fields will be performed using the pool_executor
366367
:return: a single pyspark.Row object
367368
"""
368369

369370
# Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
370371
# (currently works only with make_batch_reader)
371372
import pyspark
373+
encoded_dict = encode_row(unischema, row_dict, pool_executor)
374+
375+
field_list = list(unischema.fields.keys())
376+
# generate a value list which match the schema column order.
377+
value_list = [encoded_dict[name] for name in field_list]
378+
# create a row by value list
379+
row = pyspark.Row(*value_list)
380+
# set row fields
381+
row.__fields__ = field_list
382+
383+
return row
384+
385+
386+
def encode_row(unischema, row_dict, pool_executor=None):
387+
"""Verifies that the data confirms with unischema definition types and encodes the data using the codec specified
388+
by the unischema.
389+
390+
:param unischema: an instance of Unischema object
391+
:param row_dict: a dictionary where the keys match name of fields in the unischema.
392+
:param pool_executor: if not None, encoding of row fields will be performed using the pool_executor
393+
:return: a dictionary of encoded fields
394+
"""
395+
396+
# Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
397+
# (currently works only with make_batch_reader)
372398

373399
assert isinstance(unischema, Unischema)
374400
# Add null fields. Be careful not to mutate the input dictionary - that would be an unexpected side effect
375401
copy_row_dict = copy.copy(row_dict)
376402
insert_explicit_nulls(unischema, copy_row_dict)
377403

378-
if set(copy_row_dict.keys()) != set(unischema.fields.keys()):
379-
raise ValueError('Dictionary fields \n{}\n do not match schema fields \n{}'.format(
380-
'\n'.join(sorted(copy_row_dict.keys())), '\n'.join(unischema.fields.keys())))
404+
input_field_names = set(copy_row_dict.keys())
405+
unischema_field_names = set(unischema.fields.keys())
406+
407+
unknown_field_names = input_field_names - unischema_field_names
381408

382-
encoded_dict = {}
409+
if unknown_field_names:
410+
raise ValueError('Following fields of row_dict are not found in '
411+
'unischema: {}'.format(', '.join(sorted(unknown_field_names))))
412+
413+
encoded_dict = dict()
414+
futures_dict = dict()
383415
for field_name, value in copy_row_dict.items():
384416
schema_field = unischema.fields[field_name]
385417
if value is None:
386418
if not schema_field.nullable:
387419
raise ValueError('Field {} is not "nullable", but got passes a None value')
388420
if schema_field.codec:
389-
encoded_dict[field_name] = schema_field.codec.encode(schema_field, value) if value is not None else None
421+
if value is None:
422+
encoded_dict[field_name] = None
423+
else:
424+
if pool_executor:
425+
futures_dict[field_name] = pool_executor.submit(
426+
lambda _schema_field, _value: _schema_field.codec.encode(_schema_field, _value),
427+
schema_field, value)
428+
else:
429+
encoded_dict[field_name] = schema_field.codec.encode(schema_field, value)
390430
else:
391431
if isinstance(value, (np.generic,)):
392432
encoded_dict[field_name] = value.tolist()
393433
else:
394434
encoded_dict[field_name] = value
395435

396-
field_list = list(unischema.fields.keys())
397-
# generate a value list which match the schema column order.
398-
value_list = [encoded_dict[name] for name in field_list]
399-
# create a row by value list
400-
row = pyspark.Row(*value_list)
401-
# set row fields
402-
row.__fields__ = field_list
403-
return row
436+
for k, v in futures_dict.items():
437+
encoded_dict[k] = v.result()
438+
439+
return encoded_dict
404440

405441

406442
def insert_explicit_nulls(unischema, row_dict):

0 commit comments

Comments
 (0)