-
Notifications
You must be signed in to change notification settings - Fork 59
Open
Labels
type:supportFurther information is requestedFurther information is requested
Description
I am trying to do an on-fly-augmentation using the dataset.
train_dataset = (grain.MapDataset.source(source)
.map(augment())
.shuffle(seed=412)
.repeat(num_epochs=config.n_epochs)
.batch(batch_size=config.batch_size, drop_remainder=True)
)
train_dataset = grain.experimental.ThreadPrefetchIterDataset(train_dataset, prefetch_buffer_size=64)
train_dataset = train_dataset.map(jax.device_put)The augment() requires random seeds. But the train_dataset will be prefeched in parallel threads, and I am not sure how to pass the random seed. Any suggestions?
ps. I am also not sure that the shuffle was done correctly. Please correct me if I'm wrong.
Thanks a lot!
Metadata
Metadata
Assignees
Labels
type:supportFurther information is requestedFurther information is requested