Source code for seaquest.runner
from importlib import import_module
import os
import pathlib
import pkgutil
import sys
from seaquest.utils.loggus import init_logger
from seaquest.utils.validate import _parse_runner_args
logger = init_logger(__name__ if __name__ != "__main__" else pathlib.Path(__file__).stem, level="debug")
[docs]
def _load_model_dir(model_dir):
"""Load model directory
Parameters
----------
model_dir: str
The name of the model directory
Returns
-------
namespace: str
The name of the namespace where the model can be loaded from
"""
logger.info("Attempting to load model directory ...")
try:
namespace = import_module(model_dir)
except ModuleNotFoundError as e:
raise ValueError(f"Model directory '{model_dir}' or one of its imports not found or is not a module!") from e
logger.info("Model directory successfully loaded!")
return namespace
[docs]
def _load_model_class(namespace, model_name):
"""Loads the model class that instantiates a model object
Parameters
----------
namespace: str
The namespace that holds the model
model_name: str
The name of the model
Returns
-------
model: class
The model class that instantiates a model object
"""
logger.info("Attempting to load model class ...")
subpackages = [name for _, name, _ in list(pkgutil.iter_modules(
namespace.__path__,
namespace.__name__ + "."
))]
# find the model and load it
found_model = False
for subp in subpackages:
try:
imp_subp = import_module(subp)
model = getattr(imp_subp, model_name)
found_model = True
except ModuleNotFoundError as e:
pass
except AttributeError as e:
pass
if found_model is False:
raise ModuleNotFoundError(f"Module '{model_name}' not found in directory {namespace.__name__}")
logger.info(f"Successfully imported model '{model_name}' from directory '{namespace.__name__}'")
return model
[docs]
def _load_model_function(model, fun):
callable = getattr(model, fun, None)
if callable is None:
raise ValueError(f"Function '{fun}' not found in model '{model.__name__}'")
return callable
[docs]
def main(args: dict) -> None:
"""Dynamically imports and runs a specified function from a model class.
Parameters:
-----------
model_name: str
Name of the model to import.
model_dir: str
Directory where the model is located.
data_dir: str
Directory containing the data for training or inference.
output_dir: str
Directory where the output should be stored.
fun:
str Function to execute from the model class ('train' or 'infer').
Returns:
--------
None
Raises:
-------
ValueError: If the model or function cannot be found.
"""
output_dir = pathlib.Path(args["output_dir"]) # .resolve()
model_dir = pathlib.Path(pathlib.Path.cwd() / args["md_dir"])
data_file = pathlib.Path(args["data_file"])
sys.path.insert(0, str(model_dir.parent)) # needs str
# TODO: refactor this?
namespace = _load_model_dir(model_dir=model_dir.name)
model_class = _load_model_class(namespace=namespace, model_name=args["model_name"])
# change dir to model_dir so that all relative imports of the model work
os.chdir(os.path.dirname(namespace.__file__))
if "model_keyword_args" in args:
model = model_class(output_dir=output_dir, data_file=data_file, **args["model_keyword_args"])
else:
model = model_class(output_dir=output_dir, data_file=data_file)
# run the experiments
logger.info("Starting function '{fun}' from model '{mn}'".format(fun=args["model_fun"], mn=args["model_name"]))
if args["model_fun"] == "train":
model.train()
elif args["model_fun"] == "infer":
model.infer()
else:
raise ValueError("Provided function should be either test or infer (was {fun})".format(fun=args["model_fun"]))
logger.info("Function '{fun}' from model '{mn}' completed successfully".format(fun=args["model_fun"], mn=args["model_name"]))
if __name__ == "__main__":
args = _parse_runner_args(sys.argv)
main(args)