Documentations for classes and method for SPAGHETTI. For installing details, see README.md. For a tutorial on sample usage, see the directory at ./tutorials/
Set up the dataset to train SPAGHETTI
The dataset class for the SPAGHETTI model training. This class is inherited from torch.utils.data.Dataset
- args:
path_1: list of strings, the paths to the images in domain 1path_2: list of strings, the paths to the images in domain 2transform_1: the transformation for domain 1 images, intorchvision.transforms.v2transform_2: the transformation for domain 2 images, intorchvision.transforms.v2num_sample: int, optional, the number of images to sample from each domain
Modules for training SPAGHETTI
The function to automatically handle all SPAGHETTI training using pytorch_lightning.
- args:
train_loader: the PyTorch Dataloader for the training datasetval_loader: the PyTorch Dataloader for the validation datasetbatch_size: int, the batch size for the model, default 1weights: list of floats, the weights for the loss functions in the order of GAN loss, cycle loss, identity loss, and SSIM loss. Default [1.0, 10.0, 5.0, 10.0]lr: float, the learning rate for the model, default 0.0002save_dir: str, the directory to save the model checkpoints and logs. Default current directoryepochs: int, the number of epochs to train the model, default 100name: str, the name of the model for the logger, default "my_spaghetti"num_nodes: int, the number of nodes to train the model, default 1ngpus_per_node: int, the number of GPUs per node, default "auto" to use all the available GPUs
- returns: None
- size effects:
Run the training scripts and save model checkpoints, loss logs, and sampling images to
save_dir.
Modules for performing inferences with SPAGHETTI
The main class housing the architecture of SPAGHETTI for inference.
- args:
model_path: str, the path to the model checkpoint
- returns:
The SPAGHETTI inference model, with the following methods:
pre_processingMethod to pre-process the image for SPAGHETTI transformation. This is a static method.- args:
imgs: list[PIL.Imageortorch.Tensorornumpy.ndarray], the images to perform the pre-processingtransform: None, "default", or callable oftorchvision.transform.v2, the transformation to perform on the images. If None, no transformation is performed. If "default", the default transformation (converting to ```torch.Tensor`` with range [0,1], normalize with mean and std of 0.5, and resize to (256,256)) is performed.
- return:
list[
torch.Tensor], the images after the pre-processing. Each image will have the shape of [C, H, W]
- args:
inferenceMethod to translate the images using the SPAGHETTI model initialized with the model checkpoint.- args:
img: list[torch.Tensor] or DataLoader, the image(s) to perform the inference. Images have to preprocessed first using thepre_processingmethod. For larger dataset, it is strongly recommended to use DataLoader to allow the inference using a smaller memory.names: list[str], the names of the images to be savedsave_path:strorNone. Ifstr, images will be saved to the path to after the transformation. IfNone, transalted images will only be returned but not saved
- return:
list[
torch.Tensor], the images after the SPAGHETTI transformaton whensave_pathisNone, otherwise ```None``. - side effects:
If
save_pathis notNone, save the images to the specified path.
- args: