@@ -658,61 +658,6 @@ def parse_args(input_args=None):
658658 return args
659659
660660
661- def get_train_dataset (args , accelerator ):
662- dataset = None
663- if args .dataset_name is not None :
664- # Downloading and loading a dataset from the hub.
665- dataset = load_dataset (
666- args .dataset_name ,
667- args .dataset_config_name ,
668- cache_dir = args .cache_dir ,
669- )
670- if args .jsonl_for_train is not None :
671- # load from json
672- dataset = load_dataset ("json" , data_files = args .jsonl_for_train , cache_dir = args .cache_dir )
673- dataset = dataset .flatten_indices ()
674- # Preprocessing the datasets.
675- # We need to tokenize inputs and targets.
676- column_names = dataset ["train" ].column_names
677-
678- # 6. Get the column names for input/target.
679- if args .image_column is None :
680- image_column = column_names [0 ]
681- logger .info (f"image column defaulting to { image_column } " )
682- else :
683- image_column = args .image_column
684- if image_column not in column_names :
685- raise ValueError (
686- f"`--image_column` value '{ args .image_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
687- )
688-
689- if args .caption_column is None :
690- caption_column = column_names [1 ]
691- logger .info (f"caption column defaulting to { caption_column } " )
692- else :
693- caption_column = args .caption_column
694- if caption_column not in column_names :
695- raise ValueError (
696- f"`--caption_column` value '{ args .caption_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
697- )
698-
699- if args .conditioning_image_column is None :
700- conditioning_image_column = column_names [2 ]
701- logger .info (f"conditioning image column defaulting to { conditioning_image_column } " )
702- else :
703- conditioning_image_column = args .conditioning_image_column
704- if conditioning_image_column not in column_names :
705- raise ValueError (
706- f"`--conditioning_image_column` value '{ args .conditioning_image_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
707- )
708-
709- with accelerator .main_process_first ():
710- train_dataset = dataset ["train" ].shuffle (seed = args .seed )
711- if args .max_train_samples is not None :
712- train_dataset = train_dataset .select (range (args .max_train_samples ))
713- return train_dataset
714-
715-
716661class TrainRemovalDataset (torch .utils .data .Dataset ):
717662 def __init__ (self ,
718663 data_root ,
@@ -779,7 +724,7 @@ def prepare_train_dataset(dataset, accelerator):
779724 )
780725
781726 crop_transform = PairedRandomCrop (size = args .resolution )
782-
727+
783728 def preprocess_train (examples ):
784729 images = examples ["images" ].convert ("RGB" ) if not isinstance (examples ["images" ], str ) else Image .open (examples ["images" ]).convert ("RGB" )
785730 backgrounds = examples ["background" ].convert ("RGB" ) if not isinstance (examples ["background" ], str ) else Image .open (examples ["background" ]).convert ("RGB" )
0 commit comments