template_flow.py (81 lines of code) (raw):
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
import os
from metaflow import (
FlowSpec,
IncludeFile,
Parameter,
card,
current,
step,
environment,
kubernetes,
pypi
)
from metaflow.cards import Markdown
GCS_PROJECT_NAME = "project-name-here"
GCS_BUCKET_NAME = "bucket-name-here"
class TemplateFlow(FlowSpec):
"""
This flow is a template for you to use
for orchestration of your model.
"""
# You can import the contents of files from your file system to use in flows.
# This is meant for small files—in this example, a bit of config.
example_config = IncludeFile("example_config", default="./example_config.json")
# This is an example of a parameter. You can toggle this when you call the flow
# with python template_flow.py run --offline False
offline_wandb = Parameter(
"offline",
help="Do not connect to W&B servers when training",
type=bool,
default=True,
)
# You can uncomment and adjust this decorator when it's time to scale your flow remotely.
# @kubernetes(image="url-to-docker-image:tag", cpu=1)
@card(type="default")
@step
def start(self):
"""
Each flow has a 'start' step.
You can use it for collecting/preprocessing data or other setup tasks.
"""
self.next(self.train)
@card
@environment(
vars={
"WANDB_API_KEY": os.getenv("WANDB_API_KEY"),
"WANDB_ENTITY": os.getenv("WANDB_ENTITY"),
"WANDB_PROJECT": os.getenv("WANDB_PROJECT"),
}
)
@pypi(python="3.10.8",
packages={
"scikit-learn": "1.5.0",
"matplotlib": "3.8.0",
"wandb": "0.18.1",
})
@kubernetes(image="registry.hub.docker.com/dexterp37/mlops-copilot-demo:v3", gpu=1)
@step
def train(self):
"""
In this step you can train your model,
save checkpoints and artifacts,
and deliver data to Weights and Biases
for experiment evaluation
"""
import json
import wandb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
# from mozmlops.cloud_storage_api_client import CloudStorageAPIClient # noqa: F401
# This can help you fetch and upload artifacts to
# GCS. Check out help(CloudStorageAPIClient) for more details.
# It does require the account you're running the flow from to have
# access to Google Cloud Storage.
# storage_client = CloudStorageAPIClient(
# project_name=GCS_PROJECT_NAME, bucket_name=GCS_BUCKET_NAME
# )
config_as_dict = json.loads(self.example_config)
print(f"The config file says: {config_as_dict.get('example_key')}")
if not self.offline_wandb:
tracking_run = wandb.init(project=os.getenv("WANDB_PROJECT"))
wandb_url = tracking_run.get_url()
current.card.append(Markdown("# Weights & Biases"))
current.card.append(
Markdown(f"Your training run is tracked [here]({wandb_url}).")
)
print("All set. Running training.")
# Model training goes here; for example, a LogisticRegression model on the iris dataset.
# Of course, replace this example with YOUR model training code :).
# Load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Initialize the classifier
clf = LogisticRegression(max_iter=300)
# Train the classifier on the training data
clf.fit(X_train, y_train)
# Make predictions on the test data
y_pred = clf.predict(X_test)
prediction_path = os.path.join( # noqa: F841
current.flow_name, current.run_id, "y_predictions.txt"
)
observed_path = os.path.join(current.flow_name, current.run_id, "y_test.txt") # noqa: F841
# Example: How you'd store a checkpoint in the cloud
predictions_for_storage = bytearray(y_pred) # noqa: F841
# storage_client.store(data=predictions_for_storage, storage_path=prediction_path)
observed_values_for_storage = bytearray(y_test) # noqa: F841
# storage_client.store(
# data=observed_values_for_storage, storage_path=observed_path
# )
# Example: How you'd fetch a checkpoint from the cloud
# storage_client.fetch(
# remote_path=prediction_path, local_path="y_predictions.txt"
# )
self.next(self.end)
@step
def end(self):
"""
This is the mandatory 'end' step: it prints some helpful information
to access the model and the used dataset.
"""
print(
f"""
Flow complete.
See artifacts at {GCS_BUCKET_NAME}.
"""
)
if __name__ == "__main__":
TemplateFlow()