Skip to content

random and augmentation in datasets #1102

@pcasl

Description

@pcasl

I am trying to do an on-fly-augmentation using the dataset.

train_dataset = (grain.MapDataset.source(source)
                     .map(augment())
                     .shuffle(seed=412)
                     .repeat(num_epochs=config.n_epochs)
                     .batch(batch_size=config.batch_size, drop_remainder=True)
                    )
train_dataset = grain.experimental.ThreadPrefetchIterDataset(train_dataset, prefetch_buffer_size=64)
train_dataset = train_dataset.map(jax.device_put)

The augment() requires random seeds. But the train_dataset will be prefeched in parallel threads, and I am not sure how to pass the random seed. Any suggestions?

ps. I am also not sure that the shuffle was done correctly. Please correct me if I'm wrong.

Thanks a lot!

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:supportFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions