Source code for cerebras.modelzoo.data_preparation.vision.unet.resize_skm_tea_files

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import glob
import json
import os

import h5py
import torch
import torch.nn as nn
from tqdm import tqdm

# isort: off
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
# isort: on

from cerebras.modelzoo.common.utils.utils import check_and_create_output_dirs


def _resize_and_save(
    input_img_path,
    output_img_path,
    width,
    height,
    depth,
):
    resize_transform = nn.Upsample((width, height, depth))
    with h5py.File(input_img_path, "r") as f:
        echo1 = torch.Tensor(f["echo1"][()])
        echo2 = torch.Tensor(f["echo2"][()])
        seg = torch.Tensor(f["seg"][()])

        img_w, img_h, img_d, num_classes = seg.shape
        # view needed for volumetic interpolation
        echo1_resized = resize_transform(echo1.view(1, 1, img_w, img_h, img_d))
        # convert to np.int16 to be consistent with original dataset format
        echo1_resized = (
            echo1_resized.view(width, height, depth).to(torch.int16).numpy()
        )
        echo2_resized = resize_transform(echo2.view(1, 1, img_w, img_h, img_d))
        echo2_resized = (
            echo2_resized.view(width, height, depth).to(torch.int16).numpy()
        )
        # permuting to move one hot labels to first channels
        seg_resized = resize_transform(
            seg.permute(3, 0, 1, 2).view(1, num_classes, img_w, img_h, img_d)
        )
        seg_resized = (
            seg_resized.permute(0, 2, 3, 4, 1)
            .view(width, height, depth, num_classes)
            .to(torch.bool)
            .numpy()
        )

    with h5py.File(output_img_path, "w") as w:
        w.create_dataset(
            "echo1", data=echo1_resized, shape=(width, height, depth)
        )

        w.create_dataset(
            "echo2", data=echo2_resized, shape=(width, height, depth)
        )

        w.create_dataset(
            "seg",
            data=seg_resized,
            shape=(width, height, depth, num_classes),
        )


[docs]def get_parser_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--input_dir", type=str, required=True, help="directory containing original dataset.", ) parser.add_argument( "--output_dir", type=str, default=None, help="directory where dataset will be stored.", ) parser.add_argument( "--width", type=int, required=True, help="width of output volume", ) parser.add_argument( "--height", type=int, required=True, help="height of output volume", ) parser.add_argument( "--depth", type=int, required=True, help="depth of output volume", ) return parser
if __name__ == "__main__": args = get_parser_args().parse_args() width, height, depth = args.width, args.height, args.depth input_json_dir = os.path.join( args.input_dir, "v1-release/annotations/v1.0.0" ) input_hdf5_dir = os.path.join(args.input_dir, "v1-release/image_files") output_json_dir = os.path.join( args.output_dir, "v1-release/annotations/v1.0.0" ) output_hdf5_dir = os.path.join(args.output_dir, "v1-release/image_files") check_and_create_output_dirs(output_json_dir, filetype="json") check_and_create_output_dirs(output_hdf5_dir, filetype="h5") print(f"\nStarting writing data to {args.output_dir}.") for file_name in tqdm(glob.glob1(input_hdf5_dir, "*.h5")): input_img_path = os.path.join(input_hdf5_dir, file_name) _resize_and_save( input_img_path, f"{output_hdf5_dir}/{file_name}", width, height, depth, ) # save new volume dimensions to dataset dict splits = ["train", "val", "test"] for split in splits: with open(f"{input_json_dir}/{split}.json") as f: dataset_dict = json.load(f) for i in range(len(dataset_dict["images"])): dataset_dict["images"][i]["matrix_shape"] = [ width, height, depth, ] with open(f"{output_json_dir}/{split}.json", "w") as f: json.dump(dataset_dict, f, indent=4)