@@ -353,7 +353,7 @@ def from_arrow_schema(cls, parquet_dataset, omit_unsupported_fields=False):
353353 return Unischema ('inferred_schema' , unischema_fields )
354354
355355
356- def dict_to_spark_row (unischema , row_dict ):
356+ def dict_to_spark_row (unischema , row_dict , pool_executor = None ):
357357 """Converts a single row into a spark Row object.
358358
359359 Verifies that the data confirms with unischema definition types and encodes the data using the codec specified
@@ -363,44 +363,80 @@ def dict_to_spark_row(unischema, row_dict):
363363
364364 :param unischema: an instance of Unischema object
365365 :param row_dict: a dictionary where the keys match name of fields in the unischema.
366+ :param pool_executor: if not None, encoding of row fields will be performed using the pool_executor
366367 :return: a single pyspark.Row object
367368 """
368369
369370 # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
370371 # (currently works only with make_batch_reader)
371372 import pyspark
373+ encoded_dict = encode_row (unischema , row_dict , pool_executor )
374+
375+ field_list = list (unischema .fields .keys ())
376+ # generate a value list which match the schema column order.
377+ value_list = [encoded_dict [name ] for name in field_list ]
378+ # create a row by value list
379+ row = pyspark .Row (* value_list )
380+ # set row fields
381+ row .__fields__ = field_list
382+
383+ return row
384+
385+
386+ def encode_row (unischema , row_dict , pool_executor = None ):
387+ """Verifies that the data confirms with unischema definition types and encodes the data using the codec specified
388+ by the unischema.
389+
390+ :param unischema: an instance of Unischema object
391+ :param row_dict: a dictionary where the keys match name of fields in the unischema.
392+ :param pool_executor: if not None, encoding of row fields will be performed using the pool_executor
393+ :return: a dictionary of encoded fields
394+ """
395+
396+ # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
397+ # (currently works only with make_batch_reader)
372398
373399 assert isinstance (unischema , Unischema )
374400 # Add null fields. Be careful not to mutate the input dictionary - that would be an unexpected side effect
375401 copy_row_dict = copy .copy (row_dict )
376402 insert_explicit_nulls (unischema , copy_row_dict )
377403
378- if set (copy_row_dict .keys ()) != set (unischema .fields .keys ()):
379- raise ValueError ('Dictionary fields \n {}\n do not match schema fields \n {}' .format (
380- '\n ' .join (sorted (copy_row_dict .keys ())), '\n ' .join (unischema .fields .keys ())))
404+ input_field_names = set (copy_row_dict .keys ())
405+ unischema_field_names = set (unischema .fields .keys ())
406+
407+ unknown_field_names = input_field_names - unischema_field_names
381408
382- encoded_dict = {}
409+ if unknown_field_names :
410+ raise ValueError ('Following fields of row_dict are not found in '
411+ 'unischema: {}' .format (', ' .join (sorted (unknown_field_names ))))
412+
413+ encoded_dict = dict ()
414+ futures_dict = dict ()
383415 for field_name , value in copy_row_dict .items ():
384416 schema_field = unischema .fields [field_name ]
385417 if value is None :
386418 if not schema_field .nullable :
387419 raise ValueError ('Field {} is not "nullable", but got passes a None value' )
388420 if schema_field .codec :
389- encoded_dict [field_name ] = schema_field .codec .encode (schema_field , value ) if value is not None else None
421+ if value is None :
422+ encoded_dict [field_name ] = None
423+ else :
424+ if pool_executor :
425+ futures_dict [field_name ] = pool_executor .submit (
426+ lambda _schema_field , _value : _schema_field .codec .encode (_schema_field , _value ),
427+ schema_field , value )
428+ else :
429+ encoded_dict [field_name ] = schema_field .codec .encode (schema_field , value )
390430 else :
391431 if isinstance (value , (np .generic ,)):
392432 encoded_dict [field_name ] = value .tolist ()
393433 else :
394434 encoded_dict [field_name ] = value
395435
396- field_list = list (unischema .fields .keys ())
397- # generate a value list which match the schema column order.
398- value_list = [encoded_dict [name ] for name in field_list ]
399- # create a row by value list
400- row = pyspark .Row (* value_list )
401- # set row fields
402- row .__fields__ = field_list
403- return row
436+ for k , v in futures_dict .items ():
437+ encoded_dict [k ] = v .result ()
438+
439+ return encoded_dict
404440
405441
406442def insert_explicit_nulls (unischema , row_dict ):
0 commit comments