backend/bms_app/services/operations/wave.py (169 lines of code) (raw):
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from datetime import datetime
from marshmallow import ValidationError
from bms_app.models import SourceDB, Config, OperationType, Wave, db
from bms_app.schema import DMSConfigSchema
from bms_app.services.ansible import AnsibleConfigService
from bms_app.services.control_node import (
DeployControlNodeService, RollbackConrolNodeService
)
from bms_app.services.status_handlers.operation import (
DeploymentOperationStatusHandler
)
from bms_app.services.dms import DMS
from bms_app import settings
from google.cloud import secretmanager
from .base import BaseOperation
from .db_mappings import get_wave_db_mappings_objects
logger = logging.getLogger(__name__)
secrets_client = secretmanager.SecretManagerServiceClient()
class BaseWaveOperation(BaseOperation):
"""Base class to run wave operations."""
OPERATION_TYPE = None
CONTROL_NODE_CLS = None
FILTER_DEPLOYABLE_DBS_ONLY = None
GCS_CONFIG_DIR_TMPL = 'wave_{wave_id}_{operation_id}_{dt}'
def __init_subclass__(cls, **kwargs):
"""Check if all required variables are set in subclasses."""
super().__init_subclass__(**kwargs)
required_attrs = (
'OPERATION_TYPE', 'CONTROL_NODE_CLS', 'FILTER_DEPLOYABLE_DBS_ONLY',
)
missing_attrs = [
attr for attr in required_attrs if getattr(cls, attr) is None
]
if missing_attrs:
raise NotImplementedError(
f'{cls.__name__} missing required class variables: {missing_attrs}'
)
def run(self, wave_id, db_ids=None):
wave = Wave.query.with_for_update().get(wave_id)
self._validate_wave_status(wave)
db_mappings_objects = get_wave_db_mappings_objects(
wave_id=wave_id,
db_ids=db_ids,
filter_deployable=self.FILTER_DEPLOYABLE_DBS_ONLY
)
self._validate_db_mappings_objects(db_mappings_objects)
self._set_wave_status_running(wave)
operation = self._create_operation_model(wave_id)
self._log(operation, db_mappings_objects)
self._create_operation_details_models(operation, db_mappings_objects)
try:
self._start_pre_deployment(
wave,
operation,
db_mappings_objects
)
except Exception as e:
print(e)
logger.exception(
'error starting %s %s', operation.operation_type.value, wave_id
)
self._handle_pre_deployment_failure(operation)
self._post_operation_action(db_mappings_objects)
db.session.commit()
def _get_dms_config(self, source_db: SourceDB) -> DMSConfigSchema:
config: Config = db.session.query(Config).filter_by(db_id=source_db.id).one()
schema = DMSConfigSchema()
dms_config: DMSConfigSchema = schema.load(config.dms_config_values)
if dms_config['password_secret_id']:
req = secretmanager.AccessSecretVersionRequest(
name=secrets_client.secret_version_path(settings.GCP_PROJECT_NAME, dms_config['password_secret_id'], "latest")
)
res = secrets_client.access_secret_version(request=req)
dms_config['password'] = res.payload.data.decode('UTF-8')
return dms_config
def _start_dms_pre_deployment_local(self, wave, operation, dms_mappings):
print('dms deployment started')
dms = DMS(project_id=settings.GCP_PROJECT_NAME, region="us-central1")
# TODO: do this async using operations callbacks
for mapping in dms_mappings:
config = self._get_dms_config(mapping.db)
source_conn_name = f'waverunner-source-{mapping.db.db_name}'
dest_conn_name = f'waverunner-target-for-{mapping.db.db_name}'
job_name = f'waverunner-{mapping.db.db_name}'
job_display_name = f'Waverunner job for {mapping.db.db_name}'
print(f'Source DB config: {config}')
def start_job(result):
logger.info('starting job...')
dms.start_migration_job(job_name)
def create_job(result):
logger.info('creating job...')
dms.create_migration_job(
name=job_name,
display_name=job_display_name,
source_conn=source_conn_name,
destination_conn=dest_conn_name
).add_done_callback(start_job)
def create_dest_connection(result):
logger.info('creating destination connection...')
dms.create_destination_connection_profile(
name=dest_conn_name,
source_conn_name=source_conn_name
).add_done_callback(create_job)
logger.info('creating source connection...')
dms.create_source_connection_profile(
name=source_conn_name,
host=mapping.db.server,
port=config['port'],
username=config['username'],
password=config['password']
).add_done_callback(create_dest_connection)
def _start_pre_deployment(self, wave, operation, db_mappings_objects):
dms_mappings = list(filter(lambda mapping: mapping.is_dms, db_mappings_objects))
bms_mappings = list(filter(lambda mapping: not mapping.is_dms, db_mappings_objects))
print(f'dms: {dms_mappings}')
print(f'bms: {bms_mappings}')
self._start_dms_pre_deployment_local(wave, operation, dms_mappings)
if not bms_mappings:
return
"""Generate ansible configs and start control node."""
gcs_config_dir = self._get_gcs_config_dir(
wave_id=operation.wave_id,
operation_id=operation.id
)
# generate and upload ansible config files
AnsibleConfigService(
bms_mappings,
gcs_config_dir
).run()
# run control node
self.CONTROL_NODE_CLS.run(
project=wave.project,
operation=operation,
gcs_config_dir=gcs_config_dir,
wave=wave,
total_targets=self._count_total_targets(bms_mappings),
)
@staticmethod
def _validate_wave_status(wave):
if wave.is_running:
raise ValidationError({'wave': ['is already running']})
@staticmethod
def _set_wave_status_running(wave):
wave.is_running = True
db.session.add(wave)
db.session.commit()
@staticmethod
def _handle_pre_deployment_failure(operation):
DeploymentOperationStatusHandler(
operation,
completed_at=datetime.now()
).terminate()
@staticmethod
def _post_operation_action(db_mappings_objects):
pass
class DeploymentService(BaseWaveOperation):
FILTER_DEPLOYABLE_DBS_ONLY = True
OPERATION_TYPE = OperationType.DEPLOYMENT
CONTROL_NODE_CLS = DeployControlNodeService
class RollbackService(BaseWaveOperation):
FILTER_DEPLOYABLE_DBS_ONLY = False
OPERATION_TYPE = OperationType.ROLLBACK
CONTROL_NODE_CLS = RollbackConrolNodeService
def run(self, wave_id, db_ids):
if not db_ids:
raise ValidationError(
{'db_ids': ['is required for rollback']}
)
super().run(wave_id, db_ids)
@staticmethod
def _post_operation_action(db_mappings_objects):
for obj in db_mappings_objects:
if obj.db.restore_config:
db.session.delete(obj.db.restore_config)