community/front-end/ofe/website/ghpcfe/views/clusters.py (913 lines of code) (raw):
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" clusters.py """
import csv
import json
from asgiref.sync import sync_to_async
from rest_framework import viewsets
from rest_framework.authentication import (
SessionAuthentication,
TokenAuthentication,
)
from rest_framework.permissions import IsAuthenticated
from rest_framework.decorators import action
from rest_framework.response import Response
from django.shortcuts import render, get_object_or_404, redirect
from django.db import transaction
from django.db.models import Q
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.views import redirect_to_login
from django.core.exceptions import ValidationError
from django.http import (
HttpResponse,
HttpResponseRedirect,
JsonResponse,
HttpResponseNotFound,
)
from django.urls import reverse
from django.forms import inlineformset_factory
from django.views import generic
from django.views import View
from django.views.generic.edit import CreateView, UpdateView, DeleteView
from django.contrib import messages
from ..models import (
Application,
Cluster,
Credential,
Job,
Filesystem,
FilesystemExport,
MountPoint,
FilesystemImpl,
Role,
ClusterPartition,
VirtualSubnet,
Task,
User,
)
from ..serializers import ClusterSerializer
from ..forms import ClusterForm, ClusterMountPointForm, ClusterPartitionForm
from ..cluster_manager import cloud_info, c2, utils
from ..cluster_manager.clusterinfo import ClusterInfo
from ..views.asyncview import BackendAsyncView
from .view_utils import TerraformLogFile, GCSFile, StreamingFileView
import logging
import secrets
logger = logging.getLogger(__name__)
class ClusterPartitionDeleteView(LoginRequiredMixin, View):
def post(self, request, *args, **kwargs):
partition_id = kwargs.get('partition_id')
#logger.info(partition_id)
try:
partition = ClusterPartition.objects.get(pk=partition_id)
logger.info(f"Deleting partition point: {partition}, ID: {partition_id}")
partition.delete()
return JsonResponse({'success': True})
except ClusterPartition.DoesNotExist:
return JsonResponse({'success': False, 'error': 'Partition not found.'}, status=404)
except Exception as e:
return JsonResponse({'success': False, 'error': str(e)}, status=500)
class ClusterListView(LoginRequiredMixin, generic.ListView):
"""Custom ListView for Cluster model"""
model = Cluster
template_name = "cluster/list.html"
def get_queryset(self):
qs = super().get_queryset()
if self.request.user.has_admin_role():
return qs
wanted_items = set()
for cluster in qs:
if (
self.request.user in cluster.authorised_users.all()
and cluster.status == "r"
):
wanted_items.add(cluster.pk)
return qs.filter(pk__in=wanted_items)
def get_context_data(self, *args, **kwargs):
loading = 0
for cluster in self.get_queryset():
if cluster.status in ["c", "i", "t"]:
loading = 1
break
admin_view = 0
if self.request.user.has_admin_role():
admin_view = 1
context = super().get_context_data(*args, **kwargs)
context["loading"] = loading
context["admin_view"] = admin_view
context["navtab"] = "cluster"
return context
class ClusterDetailView(LoginRequiredMixin, generic.DetailView):
"""Custom DetailView for Cluster model"""
model = Cluster
template_name = "cluster/detail.html"
def get_context_data(self, **kwargs):
admin_view = 0
if self.request.user.has_admin_role():
admin_view = 1
context = super().get_context_data(**kwargs)
context["navtab"] = "cluster"
context["admin_view"] = admin_view
# Perform extra query to populate instance types data
# context['cluster_instance_types'] = \
# ClusterInstanceType.objects.filter(cluster=self.kwargs['pk'])
return context
class ClusterCreateView(LoginRequiredMixin, CreateView):
"""Custom CreateView for Cluster model"""
def get(self, request, *args, **kwargs):
# Check if there are any credentials available
credentials = Credential.objects.filter(owner=self.request.user)
if credentials.exists():
# Create a new cluster with default values
cluster = Cluster(
cloud_credential=credentials.first(),
name="cluster",
owner=request.user,
status="n",
spackdir="/opt/cluster/spack",
num_login_nodes=1)
cluster.save()
return redirect('backend-create-cluster', pk=cluster.pk)
else:
# Redirect to the credentials creation page with a message
messages.error(self.request, "Please create a credential before creating a cluster.")
return redirect('credentials') # Adjust to your credential creation view name
class ClusterUpdateView(LoginRequiredMixin, UpdateView):
"""Custom UpdateView for Cluster model"""
model = Cluster
template_name = "cluster/update_form.html"
form_class = ClusterForm
def get_mp_formset(self, **kwargs):
def formfield_cb(model_field, **kwargs):
field = model_field.formfield(**kwargs)
cluster = self.object
if model_field.name == "export":
if cluster.shared_fs is None:
# Create and save the shared filesystem, exports, and mount points
shared_fs = Filesystem(
name=f"{cluster.name}-sharedfs",
cloud_credential=cluster.cloud_credential,
cloud_id=cluster.cloud_id,
cloud_state=cluster.cloud_state,
cloud_region=cluster.cloud_region,
cloud_zone=cluster.cloud_zone,
subnet=cluster.subnet,
fstype="n",
impl_type=FilesystemImpl.BUILT_IN,
)
shared_fs.save()
export = FilesystemExport(filesystem=shared_fs, export_name="/opt/cluster")
export.save()
export = FilesystemExport(filesystem=shared_fs, export_name="/home")
export.save()
cluster.shared_fs = shared_fs
cluster.save()
# Create and save mount points
export = cluster.shared_fs.exports.all()[0]
mp = MountPoint(
export=export,
cluster=cluster,
mount_order=0,
mount_options="defaults,nofail,nosuid",
mount_path="/opt/cluster",
)
mp.save()
export = cluster.shared_fs.exports.all()[1]
mp = MountPoint(
export=export,
cluster=cluster,
mount_order=1,
mount_options="defaults,nofail,nosuid",
mount_path="/home",
)
mp.save()
# Continue with the usual logic for handling exports
if cluster.shared_fs is not None:
fsquery = (
Filesystem.objects.exclude(
impl_type=FilesystemImpl.BUILT_IN
)
.filter(cloud_state__in=["m", "i"])
.values_list("pk", flat=True)
)
# Add back our cluster's filesystem
fsystems = list(fsquery) + [cluster.shared_fs.id]
field.queryset = FilesystemExport.objects.filter(
filesystem__in=fsystems
)
return field
# This creates a new class on the fly
FormClass = inlineformset_factory( # pylint: disable=invalid-name
Cluster,
MountPoint,
form=ClusterMountPointForm,
formfield_callback=formfield_cb,
can_delete=True,
extra=0,
)
if self.request.POST:
kwargs["data"] = self.request.POST
return FormClass(instance=self.object, **kwargs)
def get_partition_formset(self, **kwargs):
def formfield_cb(model_field, **kwargs):
field = model_field.formfield(**kwargs)
cluster = self.object
if not cluster.partitions.exists():
logger.info("No partitions exist, creating a default one.")
# Create and save the default partition with hardcoded values
default_partition = ClusterPartition(
name="batch",
machine_type="c2-standard-60",
dynamic_node_count=4,
vCPU_per_node=30,
cluster=cluster # Set the cluster for the partition
)
default_partition.save()
return field
# This creates a new class on the fly
FormClass = inlineformset_factory( # pylint: disable=invalid-name
Cluster,
ClusterPartition,
form=ClusterPartitionForm,
formfield_callback=formfield_cb,
can_delete=True,
extra=0,
)
if self.request.POST:
kwargs["data"] = self.request.POST
return FormClass(instance=self.object, **kwargs)
def get_success_url(self):
logger.info(f"Current cluster state { self.object.cloud_state }")
if self.object.cloud_state == "m":
# Perform live cluster reconfiguration
return reverse("backend-reconfigure-cluster", kwargs={"pk": self.object.pk})
elif self.object.cloud_state == "nm":
# Perform live cluster reconfiguration
return reverse("backend-start-cluster", kwargs={"pk": self.object.pk})
def _get_region_info(self):
if not hasattr(self, "region_info"):
self.region_info = cloud_info.get_region_zone_info(
"GCP", self.get_object().cloud_credential.detail
)
return self.region_info
def get_context_data(self, **kwargs):
"""Perform extra query to populate instance types data"""
context = super().get_context_data(**kwargs)
subnet_regions = {
sn.id: sn.cloud_region
for sn in VirtualSubnet.objects.filter(
cloud_credential=self.get_object().cloud_credential
).all()
}
subnet_regions = {
sn.id: sn.cloud_region
for sn in VirtualSubnet.objects.filter(
cloud_credential=self.get_object().cloud_credential
)
.filter(Q(cloud_state="i") | Q(cloud_state="m"))
.all()
}
context["subnet_regions"] = json.dumps(subnet_regions)
context["object"] = self.object
context["region_info"] = json.dumps(self._get_region_info())
context["navtab"] = "cluster"
context["mountpoints_formset"] = self.get_mp_formset()
context["cluster_partitions_formset"] = self.get_partition_formset()
context["title"] = "Create cluster" if self.object.status == "n" else "Update cluster"
return context
def form_valid(self, form):
logger.info("In form_valid")
context = self.get_context_data()
mountpoints = context["mountpoints_formset"]
partitions = context["cluster_partitions_formset"]
if self.object.status == "n":
# If creating a new cluster generate unique cloud id.
unique_str = secrets.token_hex(4)
self.object.cloud_id = self.object.name + "-" + unique_str
suffix = self.object.cloud_id.split("-")[-1]
self.object.cloud_id = self.object.name + "-" + suffix
self.object.cloud_region = self.object.subnet.cloud_region
machine_info = cloud_info.get_machine_types(
"GCP",
self.object.cloud_credential.detail,
self.object.cloud_region,
self.object.cloud_zone,
)
disk_info = {
x["name"]: x
for x in cloud_info.get_disk_types(
"GCP",
self.object.cloud_credential.detail,
self.object.cloud_region,
self.object.cloud_zone,
)
if x["name"].startswith("pd-")
}
if self.object.status != "n" and self.object.status != "r":
form.add_error(None, "It is not newly created cluster or it is not running yet.")
return self.form_invalid(form)
# Verify Disk Types & Sizes
try:
my_info = disk_info[self.object.controller_disk_type]
if self.object.controller_disk_size < my_info["minSizeGB"]:
form.add_error(
"controller_disk_size",
"Minimum Disk Size for "
f"{self.object.controller_disk_type} is "
f"{my_info['minSizeGB']}"
)
return self.form_invalid(form)
if self.object.controller_disk_size > my_info["maxSizeGB"]:
form.add_error(
"controller_disk_size",
"Maximum Disk Size for "
f"{self.object.controller_disk_type} is "
f"{my_info['maxSizeGB']}"
)
return self.form_invalid(form)
except KeyError:
form.add_error("controller_disk_type", "Invalid Disk Type")
return self.form_invalid(form)
try:
my_info = disk_info[self.object.login_node_disk_type]
if self.object.login_node_disk_size < my_info["minSizeGB"]:
form.add_error(
"login_node_disk_size",
"Minimum Disk Size for "
f"{self.object.login_node_disk_type} is "
f"{my_info['minSizeGB']}"
)
return self.form_invalid(form)
if self.object.login_node_disk_size > my_info["maxSizeGB"]:
form.add_error(
"login_node_disk_size",
"Maximum Disk Size for "
f"{self.object.login_node_disk_type} is "
f"{my_info['maxSizeGB']}"
)
return self.form_invalid(form)
except KeyError:
form.add_error("login_node_disk_type", "Invalid Disk Type")
return self.form_invalid(form)
# Verify formset validity (surprised there's no method to do this)
for formset, formset_name in [
(mountpoints, "mountpoints"),
(partitions, "partitions"),
]:
if not formset.is_valid():
form.add_error(None, f"Error in {formset_name} section")
return self.form_invalid(form)
# Get the existing MountPoint objects associated with the cluster
existing_mount_points = MountPoint.objects.filter(cluster=self.object)
# Iterate through the existing mount points and check if they are in the updated formset
for mount_point in existing_mount_points:
if not any(mount_point_form.instance == mount_point for mount_point_form in mountpoints.forms):
# The mount point is not in the updated formset, so delete it
mount_point_path = mount_point.mount_path
mount_point_id = mount_point.pk
logger.info(f"Deleting mount point: {mount_point_path}, ID: {mount_point_id}")
mount_point.delete()
# Get the existing ClusterPartition objects associated with the cluster
existing_partitions = ClusterPartition.objects.filter(cluster=self.object)
logger.info(f"Processing total {len(partitions.forms)} partition forms.")
logger.info(f"Existing number of partitions is {len(partitions.forms)}.")
for partition in existing_partitions:
#logger.info(f"Checking existing partition: {partition.name}")
found = False
for partition_form in partitions.forms:
#logger.info(f"Checking form for partition: {partition_form.instance.name}")
if partition_form.instance == partition:
found = True
delete_status = partition_form.cleaned_data.get('DELETE', False)
if delete_status:
# Log the intent to delete then delete the partition
logger.info(f"Partition: {partition.name} (ID: {partition.pk}) marked for deletion.")
partition.delete()
else:
logger.info(f"No deletion requested for existing partition: {partition.name}.")
if not found:
# Log if no corresponding form was found for the partition
logger.info(f"No form found for Partition: {partition.name}.")
try:
with transaction.atomic():
# Save the modified Cluster object
self.object.save()
self.object = form.save()
mountpoints.instance = self.object
mountpoints.save()
partitions.instance = self.object
parts = partitions.save()
try:
total_nodes_requested = {}
for part in parts:
part.vCPU_per_node = machine_info[part.machine_type]["vCPU"] // (1 if part.enable_hyperthreads else 2)
cpu_count = machine_info[part.machine_type]["vCPU"]
logger.info(f"{part.machine_type} CPU Count: {cpu_count}")
# Tier1 networking validation
if part.enable_tier1_networking == True:
logger.info("User selected Tier1 networking, checking if nodes in partition are compatible.")
tier_1_supported_prefixes = ["n2-", "n2d-", "c2-", "c2d-", "c3-", "c3d-", "m3-", "z3-"]
is_tier_1_compatible = any(part.machine_type.startswith(prefix) for prefix in tier_1_supported_prefixes)
if not(cpu_count >= 30 and is_tier_1_compatible):
raise ValidationError(f"VM type {part.machine_type} is not compatible with Tier 1 networking.")
# Validate GPU choice
if part.GPU_type:
try:
accel_info = machine_info[part.machine_type]["accelerators"][part.GPU_type]
if (
part.GPU_per_node < accel_info["min_count"]
or part.GPU_per_node > accel_info["max_count"]
):
raise ValidationError(
"Invalid number of GPUs of type " f"{part.GPU_type}"
)
except KeyError as err:
raise ValidationError(f"Invalid GPU type {part.GPU_type}") from err
# Add validation for machine_type and disk_type combinations here
invalid_combinations = [
("c3-", "pd-standard"),
("h3-", "pd-standard"),
("h3-", "pd-ssd"),
]
for machine_prefix, disk_type in invalid_combinations:
if part.machine_type.startswith(machine_prefix) and part.boot_disk_type == disk_type:
logger.info("invalid disk")
raise ValidationError(
f"Invalid combination: machine_type {part.machine_type} cannot be used with disk_type {disk_type}."
)
# Sum the total nodes for each reservation
if part.reservation_name:
if part.reservation_name not in total_nodes_requested:
total_nodes_requested[part.reservation_name] = 0
total_nodes_requested[part.reservation_name] += part.dynamic_node_count + part.static_node_count
# Validate total requested nodes against available nodes
for reservation_name, requested_nodes in total_nodes_requested.items():
reservation = cloud_info.get_vm_reservations(
"GCP",
self.object.cloud_credential.detail,
None,
self.object.cloud_zone
)
matching_reservation = reservation.get(reservation_name)
available_nodes = int(matching_reservation["instanceProperties"].get("availableCount", 0))
if requested_nodes > available_nodes:
raise ValidationError(f"Reservation {reservation_name} does not have enough available nodes."
f"Requested: {requested_nodes}, Available: {available_nodes}"
)
except KeyError as err:
raise ValidationError("Error in Partition - invalid machine type: " f"{part.machine_type}") from err
# Continue with saving the 'parts' if no validation errors were raised
parts = partitions.save()
except ValidationError as ve:
form.add_error(None, ve)
return self.form_invalid(form)
msg = (
"Provisioning a new cluster. This may take up to 15 minutes."
)
if self.object.status == "r":
msg = "Reconfiguring running cluster, this may take few minutes."
messages.success(self.request, msg)
# Be kind... Check filesystems to verify all in the same zone as us.
for mp in self.object.mount_points.exclude(
export__filesystem__impl_type=FilesystemImpl.BUILT_IN
):
if mp.export.filesystem.cloud_zone != self.object.cloud_zone:
messages.warning(
self.request,
"Possibly expensive: Filesystem "
f"{mp.export.filesystem.name} is in a different zone "
f"({mp.export.filesystem.cloud_zone}) than the cluster!",
)
return super().form_valid(form)
class ClusterDeleteView(LoginRequiredMixin, DeleteView):
"""Custom DeleteView for Cluster model"""
model = Cluster
template_name = "cluster/check_delete.html"
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context["navtab"] = "cluster"
return context
def get_success_url(self):
cluster = Cluster.objects.get(pk=self.kwargs["pk"])
messages.success(self.request, f"Cluster {cluster.name} deleted.")
return reverse("clusters")
class ClusterDestroyView(LoginRequiredMixin, generic.DetailView):
"""Custom View to confirm Cluster destroy"""
model = Cluster
template_name = "cluster/check_destroy.html"
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
applications = Application.objects.filter(cluster=context["cluster"].id)
jobs = Job.objects.filter(application__in=applications)
context["applications"] = applications
context["jobs"] = jobs
context["navtab"] = "cluster"
return context
class ClusterCostView(LoginRequiredMixin, generic.DetailView):
"""Custom view for a cluster's cost analysis"""
model = Cluster
template_name = "cluster/cost.html"
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context["navtab"] = "cluster"
cluster_users = []
for user in User.objects.all():
spend = user.total_spend(cluster_id=context["cluster"].id)
if spend > 0:
cluster_users.append(
(
spend,
user.total_jobs(cluster_id=context["cluster"].id),
user,
)
)
cluster_apps = []
for app in Application.objects.filter(cluster=context["cluster"].id):
cluster_apps.append((app.total_spend(), app))
context["users_by_spend"] = sorted(
cluster_users, key=lambda x: x[0], reverse=True
)
context["apps_by_spend"] = sorted(
cluster_apps, key=lambda x: x[0], reverse=True
)
return context
class ClusterLogFileView(LoginRequiredMixin, StreamingFileView):
"""View for cluster provisioning logs"""
bucket = utils.load_config()["server"]["gcs_bucket"]
valid_logs = [
{"title": "Terraform Log", "type": TerraformLogFile, "args": ()},
{
"title": "Startup Log",
"type": GCSFile,
"args": (bucket, "tmp/setup.log"),
},
{
"title": "Ansible Sync Log",
"type": GCSFile,
"args": (bucket, "tmp/ansible.log"),
},
{
"title": "System Log",
"type": GCSFile,
"args": (bucket, "var/log/messages"),
},
{
"title": "Slurm slurmctld.log",
"type": GCSFile,
"args": (bucket, "var/log/slurm/slurmctld.log"),
},
{
"title": "Slurm resume.log",
"type": GCSFile,
"args": (bucket, "var/log/slurm/resume.log"),
},
{
"title": "Slurm suspend.log",
"type": GCSFile,
"args": (bucket, "var/log/slurm/suspend.log"),
},
]
def _create_file_info_object(self, logfile_info, *args, **kwargs):
return logfile_info["type"](*logfile_info["args"], *args, **kwargs)
def get_file_info(self):
logid = self.kwargs.get("logid", -1)
cluster_id = self.kwargs.get("pk")
cluster = get_object_or_404(Cluster, pk=cluster_id)
ci = ClusterInfo(cluster)
tf_dir = ci.get_terraform_dir()
bucket_prefix = f"clusters/{cluster.id}/controller_logs"
entry = self.valid_logs[logid]
if entry["type"] == TerraformLogFile:
extra_args = [tf_dir]
elif entry["type"] == GCSFile:
extra_args = [bucket_prefix]
else:
extra_args = []
return self._create_file_info_object(entry, *extra_args)
class ClusterLogView(LoginRequiredMixin, generic.DetailView):
"""View to display cluster log files"""
model = Cluster
template_name = "cluster/log.html"
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context["log_files"] = [
{"id": n, "title": entry["title"]}
for n, entry in enumerate(ClusterLogFileView.valid_logs)
]
context["navtab"] = "cluster"
return context
class ClusterCostExportView(LoginRequiredMixin, generic.DetailView):
"""Export raw cost data per cluster as CSV"""
model = Cluster
def get(self, request, *args, **kwargs):
response = HttpResponse(content_type="text/csv")
writer = csv.writer(response)
writer.writerow(["Job ID", "User", "Application", "Partition",
"Number of Nodes", "Ranks per Node", "Runtime (sec)",
"Node Price (per hour)", "Job Cost"])
for job in Job.objects.filter(
cluster=self.kwargs["pk"]).values_list("id", "user__username",
"application__name", "partition__name", "number_of_nodes",
"ranks_per_node", "runtime", "node_price", "job_cost"):
writer.writerow(job)
response["Content-Disposition"] = "attachment; filename='report.csv'"
return response
# For APIs
class ClusterViewSet(viewsets.ModelViewSet):
"""Custom ModelViewSet for Cluster model"""
permission_classes = (IsAuthenticated,)
# queryset = Cluster.objects.all().order_by('name')
serializer_class = ClusterSerializer
def get_queryset(self):
# cluster admins can see all the clusters
if Role.CLUSTERADMIN in [x.id for x in self.request.user.roles.all()]:
queryset = Cluster.objects.all().order_by("name")
# ordinary user can only see clusters authorised to use
else:
queryset = Cluster.objects.filter(
authorised_users__id=self.request.user.id
).order_by("name")
return queryset
@action(methods=["get"], detail=True, permission_classes=[IsAuthenticated])
def get_users(self, request, unused_pk):
cluster = self.get_object()
auth_users = cluster.authorised_users.all()
return Response(
[{"username": user.username, "uid": user.id} for user in auth_users]
)
@action(methods=["get"], detail=True, permission_classes=[IsAuthenticated])
def get_instance_limits(self, request, unused_pk):
cluster = self.get_object()
limits = cluster.instance_limits()
return Response(
[
{"instance_name": entry[0].name, "nodes": entry[1]}
for entry in limits
]
)
@action(
methods=["get"],
detail=True,
permission_classes=[IsAuthenticated],
url_path="filesystem.fact",
suffix=".fact",
)
def ansible_filesystem(self, request, unused_pk):
fs_type_translator = {
" ": "none",
"n": "nfs",
"e": "efs",
"l": "lustre",
"b": "beegfs",
}
cluster = self.get_object()
mounts = [
{
"path": mp.mount_path,
"src": mp.mount_source,
"fstype": fs_type_translator[mp.fstype],
"opts": mp.mount_options,
}
for mp in cluster.mount_points.all()
]
return JsonResponse({"mounts": mounts})
class InstancePricingViewSet(viewsets.ViewSet):
"""ModelviewSet providing GCP instance pricing"""
permission_classes = (IsAuthenticated,)
authentication_classes = [SessionAuthentication, TokenAuthentication]
def retrieve(self, request, pk=None):
partition = get_object_or_404(ClusterPartition, pk=pk)
instance_type = partition.machine_type
cluster = partition.cluster
price = cloud_info.get_instance_pricing(
"GCP",
cluster.cloud_credential.detail,
cluster.cloud_region,
cluster.cloud_zone,
instance_type,
(partition.GPU_type, partition.GPU_per_node),
)
return JsonResponse(
{"instance": instance_type, "price": price, "currency": "USD"}
) # TODO: Currency
def list(self, request):
return JsonResponse({})
class InstanceAvailabilityViewSet(viewsets.ViewSet):
"""ModelviewSet providing GCP instance availability across locations"""
permission_classes = (IsAuthenticated,)
authentication_classes = [SessionAuthentication, TokenAuthentication]
def retrieve(self, request, pk=None):
cluster = get_object_or_404(
Cluster, pk=request.query_params.get("cluster", -1)
)
region = request.query_params.get("region", None)
zone = request.query_params.get("zone", None)
try:
region_info = cloud_info.get_region_zone_info(
"GCP", cluster.cloud_credential.detail
)
if zone not in region_info.get(region, []):
return JsonResponse({})
machine_info = cloud_info.get_machine_types(
"GCP", cluster.cloud_credential.detail, region, zone
)
return JsonResponse(machine_info.get(pk, {}))
# Want to fail gracefully here
except Exception: # pylint: disable=broad-except
pass
return JsonResponse({})
def list(self, request):
cluster = get_object_or_404(
Cluster, pk=request.query_params.get("cluster", -1)
)
region = request.query_params.get("region", None)
zone = request.query_params.get("zone", None)
try:
region_info = cloud_info.get_region_zone_info(
"GCP", cluster.cloud_credential.detail
)
if zone not in region_info.get(region, []):
logger.info(
"Unable to retrieve data for zone %s in region %s",
zone,
region,
)
return JsonResponse({})
machine_info = cloud_info.get_machine_types(
"GCP", cluster.cloud_credential.detail, region, zone
)
return JsonResponse({"machine_types": list(machine_info.keys())})
# Can't do a lot about API failures, just log it and move one
except Exception as err: # pylint: disable=broad-except
logger.exception("Exception during cloud API query:", exc_info=err)
pass
return JsonResponse({})
class DiskAvailabilityViewSet(viewsets.ViewSet):
"""API View providing GCP disk availability across locations"""
permission_classes = (IsAuthenticated,)
authentication_classes = [SessionAuthentication, TokenAuthentication]
def list(self, request):
cluster = get_object_or_404(
Cluster, pk=request.query_params.get("cluster", -1)
)
region = request.query_params.get("region", None)
zone = request.query_params.get("zone", None)
try:
region_info = cloud_info.get_region_zone_info(
"GCP", cluster.cloud_credential.detail
)
if zone not in region_info.get(region, []):
logger.info(
"Unable to retrieve data for zone %s in region %s",
zone,
region,
)
return JsonResponse({})
info = cloud_info.get_disk_types(
"GCP", cluster.cloud_credential.detail, region, zone
)
return JsonResponse({"disks": info})
# Can't do a lot about API failures, just log it and move one
except Exception as err: # pylint: disable=broad-except
logger.exception("Exception during cloud API query:", exc_info=err)
pass
return JsonResponse({})
# Other supporting views
class BackendCreateCluster(BackendAsyncView):
"""A view to make async call to create a new cluster"""
@sync_to_async
def get_orm(self, cluster_id):
cluster = Cluster.objects.get(pk=cluster_id)
creds = cluster.cloud_credential.detail
return (cluster, creds)
def cmd(self, unused_task_id, unused_token, cluster, creds):
ci = ClusterInfo(cluster)
ci.prepare(creds)
async def get(self, request, pk):
"""this will invoke the background tasks and return immediately"""
# Mixins don't yet work with Async views
if not await sync_to_async(lambda: request.user.is_authenticated)():
return redirect_to_login(request.get_full_path)
await self.test_user_is_cluster_admin(request.user)
args = await self.get_orm(pk)
await self.create_task("Create Cluster", *args)
return HttpResponseRedirect(
reverse("cluster-update", kwargs={"pk": pk})
)
class BackendReconfigureCluster(BackendAsyncView):
"""View to reconfigure the cluster."""
@sync_to_async
def get_orm(self, cluster_id):
cluster = Cluster.objects.get(pk=cluster_id)
return (cluster,)
def cmd(self, unused_task_id, unused_token, cluster):
ci = ClusterInfo(cluster)
ci.update()
ci.reconfigure_cluster()
async def get(self, request, pk):
"""this will invoke the background tasks and return immediately"""
# Mixins don't yet work with Async views
if not await sync_to_async(lambda: request.user.is_authenticated)():
return redirect_to_login(request.get_full_path)
await self.test_user_is_cluster_admin(request.user)
args = await self.get_orm(pk)
await self.create_task("Live Reconfigure the Cluster", *args)
return HttpResponseRedirect(
reverse("cluster-detail", kwargs={"pk": pk})
)
class BackendStartCluster(BackendAsyncView):
"""A view to make async call to create a new cluster"""
@sync_to_async
def get_orm(self, cluster_id):
cluster = Cluster.objects.get(pk=cluster_id)
creds = cluster.cloud_credential.detail
return (cluster, creds)
def cmd(self, unused_task_id, unused_token, cluster, creds):
ci = ClusterInfo(cluster)
ci.start_cluster(creds)
async def get(self, request, pk):
"""this will invoke the background tasks and return immediately"""
# Mixins don't yet work with Async views
if not await sync_to_async(lambda: request.user.is_authenticated)():
return redirect_to_login(request.get_full_path)
await self.test_user_is_cluster_admin(request.user)
args = await self.get_orm(pk)
await self.create_task("Start Cluster", *args)
return HttpResponseRedirect(
reverse("cluster-detail", kwargs={"pk": pk})
)
class BackendDestroyCluster(BackendAsyncView):
"""A view to make async call to destroy a cluster"""
@sync_to_async
def get_orm(self, cluster_id):
cluster = Cluster.objects.get(pk=cluster_id)
return (cluster,)
def cmd(self, unused_task_id, unused_token, cluster):
ci = ClusterInfo(cluster)
ci.stop_cluster()
async def post(self, request, pk):
"""this will invoke the background tasks and return immediately"""
# Mixins don't yet work with Async views
if not await sync_to_async(lambda: request.user.is_authenticated)():
return redirect_to_login(request.get_full_path)
await self.test_user_is_cluster_admin(request.user)
args = await self.get_orm(pk)
await self.create_task("Destroy Cluster", *args)
return HttpResponseRedirect(
reverse("cluster-detail", kwargs={"pk": pk})
)
class BackendSyncCluster(LoginRequiredMixin, generic.View):
"""Backend handler for cluster syncing"""
def get(self, request, pk, *args, **kwargs):
def response(message):
logger.info("Received SYNC Complete: %s", message)
if message.get("cluster_id") != pk:
logger.error(
"Cluster ID mismatch versus to callback: "
"expected %s, %s",
pk,
message.get("cluster_id"),
)
cluster = Cluster.objects.get(pk=pk)
cluster.status = message.get("status", "r")
cluster.save()
return True
cluster = get_object_or_404(Cluster, pk=pk)
cluster.status = "i"
cluster.save()
c2.send_command(pk, "SYNC", data={}, on_response=response)
return HttpResponseRedirect(
reverse("cluster-detail", kwargs={"pk": pk})
)
class BackendClusterStatus(LoginRequiredMixin, generic.View):
"""Backend handler for cluster syncing"""
def get(self, request, pk, *args, **kwargs):
"""
This handles GET request with parameter pk.
for example: /backend/cluster-status/50
"""
cluster = get_object_or_404(Cluster, pk=pk)
logger.info(f"Current cluster {pk} status: {cluster.status}")
return JsonResponse({'status': cluster.status})
class BackendAuthUserGCP(BackendAsyncView):
"""Backend handler to authorise GCP users on the cluster"""
@sync_to_async
def get_orm(self, cluster_id):
cluster = Cluster.objects.get(pk=cluster_id)
return cluster
def cmd(self, task_id, token, cluster, username):
# from ..cluster_manager.update_cluster import auth_user_gcloud
# auth_user_gcloud(cluster, token, username, task_id)
raise NotImplementedError()
async def get(self, request, pk):
"""this will invoke the background tasks and return immediately"""
# Mixins don't yet work with Async views
if not await sync_to_async(lambda: request.user.is_authenticated)():
return redirect_to_login(request.get_full_path)
await self.test_user_access_to_cluster(request.user, pk)
cluster = await self.get_orm(pk)
record = await self.create_task(
"Auth User GCP", cluster, request.user.username
)
return JsonResponse({"taskid": record.id})
class BackendAuthUserGCP2(LoginRequiredMixin, generic.View):
"""Handler for stage 2 of the GCP user auth process"""
# Process - A "GET" to get started from user's browser
# This will send a C2 command to cluster to start the process
# Cluster will then respond with a URL for user to visit
# We use the 'Task' DB entry to inform client browser of new URL
# Client POSTs back to this class the verify key
# We use C2 to UPDATE to send to cluster
# We get an ACK back from cluster, and delete the Task
# Google side should update browser message to show completion
def get(self, request, pk):
cluster = get_object_or_404(Cluster, pk=pk)
user = request.user
try:
user_uid = user.socialaccount_set.first().uid
except AttributeError:
# User doesn't have a Google SocialAccount.
messages.error(
request,
"You are not signed in with a Google Account. This is required",
)
return HttpResponseRedirect(
reverse("user-gcp-auth", kwargs={"pk": pk})
)
logger.info(
"Beginning User GCS authorization process for %s on %s",
user,
cluster.name,
)
task = Task.objects.create(
owner=user,
title="Auth User GCP",
data={"status": "Contacting Cluster"},
)
task.save()
cluster_id = cluster.id
cluster_name = cluster.name
task_id = task.id
def callback(message):
logger.info(
"GCS Auth Status message received from cluster %s: %s",
cluster_name,
message["status"],
)
task = Task.objects.get(pk=task_id)
task.data.update(message)
task.save()
if "exit_status" in message:
logger.info(
"Final result from cluster %s for user auth to GCS was "
"status code %s",
cluster_name,
message["exit_status"],
)
task.delete()
message_data = {
"login_uid": user_uid,
}
comm_id = c2.send_command(
cluster_id,
"REGISTER_USER_GCS",
on_response=callback,
data=message_data,
)
task.data["comm_id"] = comm_id
task.save()
return JsonResponse({"taskid": task_id})
def post(self, request, pk):
cluster = get_object_or_404(Cluster, pk=pk)
try:
logger.debug("Received POST from browser for GCS Auth.")
task_id = request.POST["task_id"]
task = get_object_or_404(Task, pk=task_id)
comm_id = request.POST["comm_id"]
verify_key = request.POST["verify_key"]
if task.data.get("ackid", None) != comm_id:
logger.error(
"Ack ID mismatch: expected %s, received %s",
task.data.get("ackid", None),
comm_id,
)
return HttpResponseNotFound()
c2.send_update(cluster.id, comm_id, data={"verify_key": verify_key})
except KeyError as ke:
logger.error("Missing POST data", exc_info=ke)
return HttpResponseNotFound()
return JsonResponse({})
class AuthUserGCP(LoginRequiredMixin, generic.View):
"""A view keep interactive watch on adding a user's GCS auth creds"""
def get(self, request, pk):
cluster = Cluster.objects.get(pk=pk)
return render(
request,
"cluster/user_auth_gcp.html",
context={"cluster": cluster, "navtab": "cluster"},
)