community-content/vertex_model_garden/model_oss/autogluon/train.py (108 lines of code) (raw):
"""AutoGluon training binary. """
import argparse
import json
from typing import Any
from autogluon.tabular import TabularPredictor
import pandas as pd
class BaseConfig:
def to_dict(self) -> dict[str, Any]:
return {
key: value for key, value in self.__dict__.items() if value is not None
}
class DataConfig(BaseConfig):
def __init__(self, train_data_path: Any) -> None:
self.train_data_path = train_data_path
class ProblemConfig(BaseConfig):
def __init__(self, label: Any, problem_type: Any) -> None:
self.label = label
self.problem_type = problem_type
class EvaluationConfig(BaseConfig):
def __init__(self, eval_metric: Any) -> None:
self.eval_metric = eval_metric
class TrainingConfig(BaseConfig):
"""Config for training."""
def __init__(
self,
time_limit: Any,
presets: Any,
hyperparameters: Any,
model_save_path: str,
) -> None:
self.time_limit = time_limit
self.hyperparameters = hyperparameters
self.presets = presets
self.model_save_path = model_save_path
def parse_args() -> (
tuple[DataConfig, ProblemConfig, EvaluationConfig, TrainingConfig]
):
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="AutoGluon Tabular Predictor")
# Add arguments for each config class
parser.add_argument(
"--train_data_path",
type=str,
required=True,
help="Path to the input data CSV file.",
)
parser.add_argument(
"--label", type=str, required=True, help="Target variable column name."
)
parser.add_argument(
"--problem_type",
type=str,
choices=["binary", "multiclass", "regression", "quantile"],
default=None,
help="Problem type.",
)
parser.add_argument(
"--eval_metric", type=str, default=None, help="Evaluation metric to use."
)
# Add arguments for TrainingConfig if needed
parser.add_argument(
"--time_limit",
type=int,
default=None,
help="Time limit in seconds for training.",
)
parser.add_argument(
"--presets",
type=str,
default="medium_quality",
help="Presets used for training ",
)
parser.add_argument(
"--hyperparameters",
type=json.loads,
default=None,
help="Hyperparameter dictionary in JSON format.",
)
parser.add_argument(
"--model_save_path",
type=str,
default=None,
help="Path to save the trained model.",
)
args = parser.parse_args()
data_config = DataConfig(train_data_path=args.train_data_path)
problem_config = ProblemConfig(
label=args.label, problem_type=args.problem_type
)
eval_config = EvaluationConfig(eval_metric=args.eval_metric)
training_config = TrainingConfig(
time_limit=args.time_limit,
presets=args.presets,
hyperparameters=args.hyperparameters,
model_save_path=args.model_save_path,
)
return data_config, problem_config, eval_config, training_config
def main() -> None:
data_config, problem_config, eval_config, training_config = parse_args()
# Load the training data.
data = pd.read_csv(data_config.train_data_path)
# Create a TabularPredictor.
predictor = TabularPredictor(
label=problem_config.label,
eval_metric=eval_config.eval_metric,
path=training_config.model_save_path,
)
# Fit the model
predictor.fit(
data,
presets=training_config.presets,
time_limit=training_config.time_limit,
hyperparameters=training_config.hyperparameters,
)
if __name__ == "__main__":
main()