Source code for seaquest.model_class
from abc import ABC, abstractmethod
import os
import pathlib
[docs]
class OutputContext:
"""Class that handles the context switch for saving output files on the PVC"""
[docs]
def __init__(self, output_dir: pathlib.Path):
self.output_dir = output_dir
[docs]
def __enter__(self):
""""""
self.curr_wd = os.getcwd()
os.chdir(self.output_dir) # .resolve())
return self
[docs]
def __exit__(self, exc_type, exc_value, traceback):
""""""
os.chdir(self.curr_wd)
[docs]
class NautPipelineModel(ABC):
"""Base class to be inherited by all models to ensure minimum required constructor arguments and methods"""
[docs]
def __init__(self, output_dir: pathlib.Path, data_file: pathlib.Path):
super().__init__()
self.output_dir = output_dir
self.data_file = data_file
[docs]
@abstractmethod
def train(self):
pass
[docs]
@abstractmethod
def infer(self):
pass