Skip to content

Commit 7b4e898

Browse files
author
Wei Runpu
committed
clean code
1 parent 2d2ffac commit 7b4e898

File tree

1 file changed

+1
-56
lines changed

1 file changed

+1
-56
lines changed

train_control_lora_flux.py

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -658,61 +658,6 @@ def parse_args(input_args=None):
658658
return args
659659

660660

661-
def get_train_dataset(args, accelerator):
662-
dataset = None
663-
if args.dataset_name is not None:
664-
# Downloading and loading a dataset from the hub.
665-
dataset = load_dataset(
666-
args.dataset_name,
667-
args.dataset_config_name,
668-
cache_dir=args.cache_dir,
669-
)
670-
if args.jsonl_for_train is not None:
671-
# load from json
672-
dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
673-
dataset = dataset.flatten_indices()
674-
# Preprocessing the datasets.
675-
# We need to tokenize inputs and targets.
676-
column_names = dataset["train"].column_names
677-
678-
# 6. Get the column names for input/target.
679-
if args.image_column is None:
680-
image_column = column_names[0]
681-
logger.info(f"image column defaulting to {image_column}")
682-
else:
683-
image_column = args.image_column
684-
if image_column not in column_names:
685-
raise ValueError(
686-
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
687-
)
688-
689-
if args.caption_column is None:
690-
caption_column = column_names[1]
691-
logger.info(f"caption column defaulting to {caption_column}")
692-
else:
693-
caption_column = args.caption_column
694-
if caption_column not in column_names:
695-
raise ValueError(
696-
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
697-
)
698-
699-
if args.conditioning_image_column is None:
700-
conditioning_image_column = column_names[2]
701-
logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
702-
else:
703-
conditioning_image_column = args.conditioning_image_column
704-
if conditioning_image_column not in column_names:
705-
raise ValueError(
706-
f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
707-
)
708-
709-
with accelerator.main_process_first():
710-
train_dataset = dataset["train"].shuffle(seed=args.seed)
711-
if args.max_train_samples is not None:
712-
train_dataset = train_dataset.select(range(args.max_train_samples))
713-
return train_dataset
714-
715-
716661
class TrainRemovalDataset(torch.utils.data.Dataset):
717662
def __init__(self,
718663
data_root,
@@ -779,7 +724,7 @@ def prepare_train_dataset(dataset, accelerator):
779724
)
780725

781726
crop_transform = PairedRandomCrop(size=args.resolution)
782-
727+
783728
def preprocess_train(examples):
784729
images = examples["images"].convert("RGB") if not isinstance(examples["images"], str) else Image.open(examples["images"]).convert("RGB")
785730
backgrounds = examples["background"].convert("RGB") if not isinstance(examples["background"], str) else Image.open(examples["background"]).convert("RGB")

0 commit comments

Comments
 (0)