Source code for cerebras.modelzoo.common.utils.run.monitor_cs2_run

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import getpass
import logging
import os
import re
import subprocess
import time
from pathlib import Path

import paramiko

logging.basicConfig(
    format='%(asctime)s %(name)s: %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S',
)

PT_CKPT_PATTERN = r"checkpoint_\d+.mdl"
TF_CKPT_PATTERN = r"model.ckpt-\d+"
CKPT_PATTERN = f"({PT_CKPT_PATTERN})|({TF_CKPT_PATTERN})"

# don't copy a checkpoint unless it's been untouched for 2 minutes
CKPT_UNTOUCHED_THRESHOLD = 2 * 60


[docs]def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_dir_colo") parser.add_argument("--remote_host") parser.add_argument("--model_dir_aws") parser.add_argument( "--coarse_checkpoint_steps", type=int, help=( "The frequency with which checkpoints are saved with the " "intension of long term storage and analysis. Often this interval " "is coarser than the frequency with which checkpoints are saved " "for restart purposes, see '--keep_last_n_checkpoints'" ), ) parser.add_argument( "--keep_last_n_checkpoints", type=int, help=( "How many checkpoints to keep on remote for restarts in adddition " "to those kept according to `coarse_checkpoint_steps` for long " "term storage and analysis" ), ) parser.add_argument( "--polling_interval", type=int, default=60 * 5, help="How often to check for new events (in seconds)", ) parser.add_argument( "--analyze_weights", action="store_true", help="Extract summaries from weights after copying to aws", ) args = parser.parse_args() return args
[docs]def exists_remote(remote_host, p): file_exists = subprocess.call(["ssh", remote_host, f"test -f {p}"]) == 0 dir_exists = subprocess.call(["ssh", remote_host, f"test -d {p}"]) == 0 return file_exists or dir_exists
[docs]def ckpt_name_to_step_num(name): if re.fullmatch(PT_CKPT_PATTERN, name): return int(name[len("checkpoint_") : -len(".mdl")]) elif re.fullmatch(TF_CKPT_PATTERN, name): return int(name[len("model.ckpt-") :]) else: raise ValueError( f"attempted to extract step number from invalid checkpoint {name}" )
[docs]def maybe_copy_checkpoint(ckpt, args): ckpt_path = os.path.join(args.model_dir_colo, ckpt) logs_dir = os.path.join(args.model_dir_aws, "logs") step_num = ckpt_name_to_step_num(ckpt) did_something = False modified_time = subprocess.run( ["ssh", args.remote_host, "stat", ckpt_path, "-c", r"%Y"], capture_output=True, text=True, ).stdout modified_time = int(modified_time) # get time from remote machine to remove potential consistency # or time zone issues current_time = subprocess.run( ["ssh", args.remote_host, "date", r"+%s"], capture_output=True, text=True, ).stdout current_time = int(current_time) if current_time - modified_time > CKPT_UNTOUCHED_THRESHOLD: # wait a few minutes before copying checkpoints to avoid # copying partially written files did_something = True log_file_path = os.path.join( logs_dir, f"logs_process_checkpoint_{step_num}.out" ) cmd = [ "cbrun", "--", "sbatch", "-c4", "-o", log_file_path, "launch_checkpoint_copy.sh", args.model_dir_colo, args.model_dir_aws, ckpt, args.remote_host, ] result = subprocess.run(cmd, capture_output=True, text=True) slurm_id = result.stdout.split()[-1] logging.info( f"Launched copy and processing of checkpoint {ckpt} with " f"slurm id {slurm_id}." ) # Queue up weight analysis to run after checkpoint copy framework = "pt" if re.fullmatch(PT_CKPT_PATTERN, ckpt) else "tf" aws_ckpt_path = os.path.join(args.model_dir_aws, ckpt) cmd = [ "cbrun", "--", "sbatch", "-c8", "--open-mode=append", "-o", log_file_path, "-d", f"afterok:{slurm_id}", "write_weight_summaries.py", "--input_path", aws_ckpt_path, "--output_path", aws_ckpt_path + ".wt_summary.txt", "--framework", framework, ] if args.analyze_weights: result = subprocess.run(cmd, capture_output=True, text=True) summaries_slurm_id = result.stdout.split()[-1] logging.info( f"Queued weight analysis of checkpoint {ckpt} with " f"slurm id {summaries_slurm_id} to start after job " f"{slurm_id} finishes successfully." ) return did_something
[docs]def main(): args = parse_args() if not os.path.exists(args.model_dir_aws): Path(args.model_dir_aws).mkdir(parents=True) params_coppied = False params_path = os.path.join( args.model_dir_colo, "train", "params_train.yaml" ) logs_dir = os.path.join(args.model_dir_aws, "logs") if not os.path.exists(logs_dir): os.mkdir(logs_dir) processed_checkpoints = set( ckpt_name_to_step_num(f) for f in os.listdir(args.model_dir_aws) if re.fullmatch(CKPT_PATTERN, f) ) ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.client.AutoAddPolicy) ssh.connect( args.remote_host, username="lab", password=getpass.getpass(f"password for lab@{args.remote_host}: "), ) sftp = ssh.open_sftp() while True: tick = time.time() # copy params files if not params_coppied and exists_remote(args.remote_host, params_path): logging.info(f"Copying params {args.remote_host}:{params_path}") sftp.get( params_path, os.path.join(args.model_dir_aws, "params_train.yaml"), ) params_coppied = True # copy checkpoints all_ckpts = [ f for f in sftp.listdir(args.model_dir_colo) if re.fullmatch(CKPT_PATTERN, f) ] all_ckpts.sort(key=ckpt_name_to_step_num) for i, ckpt in enumerate(reversed(all_ckpts)): step_num = ckpt_name_to_step_num(ckpt) ckpt_path = os.path.join(args.model_dir_colo, ckpt) if step_num in processed_checkpoints: continue elif step_num % args.coarse_checkpoint_steps == 0: success = maybe_copy_checkpoint(ckpt, args) if success: processed_checkpoints.add(step_num) elif ( args.keep_last_n_checkpoints is not None and i >= args.keep_last_n_checkpoints ): logging.info(f"Removing remote checkpoint {ckpt_path}") sftp.remove(ckpt_path) tock = time.time() elapsed = tock - tick time.sleep(max(args.polling_interval - elapsed, 0))
if __name__ == "__main__": main()