# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import random
from PIL import Image
# isort: off
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
# isort: on
def _resize_and_save(
input_img_path, output_img_path, width, height, is_lbl_transform
):
if is_lbl_transform:
resample = Image.NEAREST
else:
resample = Image.BICUBIC
img = Image.open(input_img_path)
img_resized = img.resize((width, height), resample=resample)
img_resized.save(output_img_path)
img.close()
def _center_crop_and_save(
input_img_path, output_img_path, width, height, is_lbl_transform
):
img = Image.open(input_img_path)
img_w, img_h = img.size
left = (img_w - width) // 2
top = (img_h - height) // 2
right = (img_w + width) // 2
bottom = (img_h + height) // 2
img_cropped = img.crop((left, top, right, bottom))
img_cropped.save(output_img_path)
img.close()
[docs]def get_parser_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="input image file dir",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="directory for output images",
)
parser.add_argument(
"--width",
type=int,
required=True,
help="width of output image",
)
parser.add_argument(
"--height",
type=int,
required=True,
help="height of output image",
)
parser.add_argument(
"--transform",
type=str,
required=True,
choices=["resize", "center-crop"],
help="transform to apply to the input image, choose from [resize, center-crop]",
)
return parser
if __name__ == "__main__":
args = get_parser_args().parse_args()
input_dir = args.input_dir
width = args.width
height = args.height
transform_op = args.transform
if args.output_dir is None:
output_dir = f"/cb/datasets/cv/scratch/demo/inria_aerial_{width}_{height}_{transform_op}/src_files/AerialImageDataset"
else:
output_dir = args.output_dir
os.makedirs(output_dir)
for split in ["train", "val", "test"]:
img_names = sorted(os.listdir(os.path.join(input_dir, split, "images")))
if split != "test":
apply_transform_to_lbl = True
else:
apply_transform_to_lbl = False
apply_transform_to_files(
img_names,
input_dir,
output_dir,
width,
height,
split,
transform_op=transform_op,
apply_transform_to_lbl=apply_transform_to_lbl,
)
# spot check
out_img_names = sorted(
os.listdir(os.path.join(output_dir, split, "images"))
)
test_imgs = random.sample(out_img_names, 5)
for out_img in test_imgs:
print(f"-- testing {out_img} -- ")
test_img = os.path.join(output_dir, split, "images", out_img)
im = Image.open(test_img)
assert im.width == width, "Width mismatch"
assert im.height == height, "Height mismatch"
im.close()
if apply_transform_to_lbl:
test_lbl = os.path.join(output_dir, split, "gt", out_img)
im = Image.open(test_lbl)
assert im.width == width, "Width mismatch"
assert im.height == height, "Height mismatch"
im.close()