Source code for cerebras.modelzoo.common.registry

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

'''
This is registry for the cerebras modelzoo
'''
import importlib
import os
import pathlib
from pathlib import Path

import cerebras.modelzoo as modelzoo


[docs]class Registry: mapping = { "model": {}, "datasetprocessor": {}, "lr_scheduler": {}, "loss": {}, "dataset": {}, "paths": {}, "config": {}, } _modules_imported = False @classmethod def _import_modules_for_registry( cls, directory_path: str, import_files_regex: str ): """Importing all classes from the files mentioned in the directory path and in import_files. If no files are specified, all python files from that directory will be imported.""" modelzoo_path = os.path.dirname(os.path.realpath(modelzoo.__file__)) for file in Path(directory_path).rglob(import_files_regex): filename = pathlib.Path(file).name module_path = "cerebras.modelzoo.{}".format( os.path.relpath(file, modelzoo_path).replace(os.path.sep, '.')[ :-3 ] ) # Import the module dynamically try: importlib.import_module(module_path, package=__name__) except Exception as ex: raise Exception("Registry Import Failure: {}".format(ex)) @classmethod def _import_modules(cls): if cls._modules_imported: return for path in cls.mapping["paths"]["model_path"]: cls._import_modules_for_registry( path, import_files_regex="**/model.py", ) for path in cls.mapping["paths"]["loss_path"]: cls._import_modules_for_registry(path, import_files_regex="**/*.py") for path in cls.mapping["paths"]["datasetprocessor_path"]: cls._import_modules_for_registry( path, import_files_regex="**/*Processor*.py", ) for path in cls.mapping["paths"]["model_path"]: cls._import_modules_for_registry( path, import_files_regex="**/config.py", ) cls._modules_imported = True
[docs] @classmethod def register_model(cls, model_name, datasetprocessor=[], dataset=[]): """ This method is added to register models """ def wrap(model_cls): if not isinstance(model_name, list): names = [model_name] else: names = model_name for name in names: if name in cls.mapping["model"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["model"][name] ) ) cls.mapping["model"][name] = dict() cls.mapping["model"][name]["class"] = model_cls cls.mapping["model"][name]["run"] = cls.register_run_path(name) cls.mapping["model"][name][ "datasetprocessor" ] = datasetprocessor cls.mapping["model"][name]["dataset"] = dataset return model_cls return wrap
[docs] @classmethod def register_datasetprocessor(cls, name): """ This method is added to register datasetprocessor """ def wrap(datasetprocessor_cls): if name in cls.mapping["datasetprocessor"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["datasetprocessor"][name] ) ) cls.mapping["datasetprocessor"][name] = datasetprocessor_cls return datasetprocessor_cls return wrap
[docs] @classmethod def register_loss(cls, name): """ This method is added to register loss """ def wrap(loss_cls): if name in cls.mapping["loss"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["loss"][name] ) ) cls.mapping["loss"][name] = loss_cls return loss_cls return wrap
[docs] @classmethod def register_lr_scheduler(cls, name): """ This method is added to register lr_schedular """ def wrap(lr_scheduler_cls): if name in cls.mapping["lr_scheduler"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["lr_scheduler"][name] ) ) cls.mapping["lr_scheduler"][name] = lr_scheduler_cls return lr_scheduler_cls return wrap
[docs] @classmethod def register_dataset(cls, name): """ This method is added to register dataset """ def wrap(dataset_cls): if name in cls.mapping["dataset"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["dataset"][name] ) ) cls.mapping["dataset"][name] = dataset_cls return dataset_cls return wrap
[docs] @classmethod def register_paths(cls, kind, path): """ This method is register paths useful for the user """ if kind in cls.mapping["paths"]: cls.mapping["paths"][kind].append(path) else: cls.mapping["paths"].setdefault(kind, [path])
[docs] @classmethod def register_config(cls, name): """ This method is added to register config classes """ def wrap(model_cls): cls.mapping["config"][name] = model_cls return model_cls return wrap
@classmethod def get_path(cls, kind, name): if kind in cls.mapping["paths"]: for path in cls.mapping["paths"][kind]: if os.path.isdir(os.path.join(path, name)): return os.path.join(path, name) return None else: raise ValueError("{} not initialised in registry".format(kind))
[docs] @classmethod def register_run_path(cls, name): """ Look for run path for the model """ return cls.get_path("model_path", name)
[docs] @classmethod def unregister(cls, region, name): """ This method is added to unregister region can be ['model', 'loss', 'lr_scheduler', 'datasetprocessor', 'dataset'] """ if cls.mapping.get('region') is None: raise KeyError("Undefined {}".format(region)) return cls.mapping[region].pop(name, None)
@classmethod def list_models(cls): cls._import_modules() return sorted(cls.mapping["model"].keys()) @classmethod def list_loss(cls): cls._import_modules() return sorted(cls.mapping["loss"].keys()) @classmethod def list_datasetprocessor(cls, model_name=None): cls._import_modules() if model_name is None: return sorted(cls.mapping["datasetprocessor"].keys()) if model_name in cls.mapping["model"]: for dl in cls.mapping["model"][model_name]["datasetprocessor"]: if not (dl in cls.mapping["datasetprocessor"]): raise ValueError( "{} datasetprocessor is not registered".format(dl) ) return cls.mapping["model"][model_name]["datasetprocessor"] else: raise ValueError("{} model is not registered".format(model_name)) @classmethod def list_lr_scheduler(cls): cls._import_modules() return sorted(cls.mapping["lr_scheduler"].keys()) @classmethod def list_dataset(cls, model_name=None): cls._import_modules() if model_name is None: return sorted(cls.mapping["dataset"].keys()) if model_name in cls.mapping["model"]: for ds in cls.mapping["model"][model_name]["dataset"]: if not (ds in cls.mapping["datset"]): raise ValueError("{} dataset is not registered".format(ds)) return cls.mapping["model"][model_name]["dataset"] else: raise ValueError("{} model is not registered".format(model_name)) @classmethod def get_model_class(cls, name): cls._import_modules() if name in cls.mapping["model"]: return cls.mapping["model"][name]["class"] return ValueError("{} model is not registered".format(name)) @classmethod def get_config_class(cls, name): cls._import_modules() if name in cls.mapping["config"]: return cls.mapping["config"][name] return None @classmethod def get_loss_class(cls, name): cls._import_modules() return cls.mapping["loss"].get(name, None) @classmethod def get_run_path(cls, name): if name in cls.mapping["model"]: return cls.mapping["model"][name]["run"] else: return cls.get_path("model_path", name)
registry = Registry()