backend/time-series-forecasting/services/forecast_job_service.py (47 lines of code) (raw):

# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import datetime from typing import Any, Dict import logging from models import completed_forecast_job, dataset from models.forecast_job_request import ForecastJobRequest from training_methods import training_method logger = logging.getLogger(__name__) class ForecastJobService: """ This service handles model training, evaluation and prediction. """ def __init__( self, training_registry: Dict[str, training_method.TrainingMethod] ) -> None: """_summary_ Args: training_registry (Dict[str, training_method.TrainingMethod]): _description_ """ super().__init__() # TODO: Register training methods self._training_registry = training_registry def run( self, request: ForecastJobRequest ) -> completed_forecast_job.CompletedForecastJob: """Run model training, evaluation and prediction for a given `training_method_name`. Waits for completion. Args: training_method_name (str): The training method name as defined in the training registry. start_time (datetime.datetime): Start time of job. dataset (dataset.Dataset): The dataset used for training. model_parameters (Dict[str, Any]): The parameters for training. prediction_parameters (Dict[str, Any]): The paramters for prediction. Raises: ValueError: Any error that happens during training, evaluation or prediction. Returns: forecast_job_result.ForecastJobResult: The results containing the URIs for each step. """ training_method = self._training_registry.get(request.training_method_id) # Start training output = completed_forecast_job.CompletedForecastJob( end_time=datetime.datetime.now(datetime.timezone.utc), # Placeholder time request=request, model_uri=None, error_message=None, ) try: if training_method is None: raise ValueError( f"Training method '{request.training_method_id}' is not supported" ) # Train model output.model_uri = training_method.train( dataset=request.dataset, model_parameters=request.model_parameters, prediction_parameters=request.prediction_parameters, ) # Run evaluation output.evaluation_uri = training_method.evaluate(model=output.model_uri) # Run prediction output.prediction_uri = training_method.predict( dataset=request.dataset, model=output.model_uri, model_parameters=request.model_parameters, prediction_parameters=request.prediction_parameters, ) except Exception as exception: logger.error(str(exception)) output.error_message = str(exception) finally: output.end_time = datetime.datetime.now(datetime.timezone.utc) return output