Source code for modelzoo.vision.pytorch.dit.display_images

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
from pathlib import Path

import torch
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm

"""
Usage:

python modelzoo/vision/pytorch/dit/display_images.py \
--folder_path=<path to folder from running `sample_generator_dit.py> \
--file_ext=.png --nrow=8 --image_size=100

"""


[docs]def display_images(folder_path, file_ext=".png", nrow=4, image_size=100): """ Display images in a folder in grid Args: folder_path (str): Path to folder containing images with extension `file_ext` file_ext (str): File extension of images in `folder_path`. Used for glob nrow (int): Number of images to display in each row of the grid image_size (int): Resize image to this size before displaying on grid. """ image_paths = sorted(Path(folder_path).glob("*" + file_ext)) if len(image_paths) == 0: raise ValueError( f"Folder {folder_path} does not contain any {file_ext} images, cannot create grid." ) image_tensors = [] transform = transforms.Compose( [ transforms.Resize(size=(image_size, image_size)), transforms.PILToTensor(), transforms.Lambda(lambda x: x.to(torch.float32)), ] ) fname = os.path.basename(folder_path) grid_path = os.path.join(folder_path, f"{fname}_grid.png") if os.path.exists(grid_path): logging.info(f"Image grid already exists, nothing to do, exiting.") else: for im in tqdm(image_paths): img = Image.open(im) image_tensors.append(transform(img)) save_image( torch.stack(image_tensors, dim=0), grid_path, nrow=nrow, normalize=True, )
[docs]def get_parser_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--folder_path", type=str, required=True) parser.add_argument( "--file_ext", type=str, default=".png", ) parser.add_argument( "--nrow", type=int, default=4, help="Number of images per row", ) parser.add_argument( "--image_size", type=int, default=100, help="Resize image before displaying", ) args = parser.parse_args() return args
if __name__ == "__main__": args = get_parser_args() display_images(args.folder_path, args.file_ext, args.nrow, args.image_size)