backend/bms_app/services/operations/wave.py (169 lines of code) (raw):

# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from datetime import datetime from marshmallow import ValidationError from bms_app.models import SourceDB, Config, OperationType, Wave, db from bms_app.schema import DMSConfigSchema from bms_app.services.ansible import AnsibleConfigService from bms_app.services.control_node import ( DeployControlNodeService, RollbackConrolNodeService ) from bms_app.services.status_handlers.operation import ( DeploymentOperationStatusHandler ) from bms_app.services.dms import DMS from bms_app import settings from google.cloud import secretmanager from .base import BaseOperation from .db_mappings import get_wave_db_mappings_objects logger = logging.getLogger(__name__) secrets_client = secretmanager.SecretManagerServiceClient() class BaseWaveOperation(BaseOperation): """Base class to run wave operations.""" OPERATION_TYPE = None CONTROL_NODE_CLS = None FILTER_DEPLOYABLE_DBS_ONLY = None GCS_CONFIG_DIR_TMPL = 'wave_{wave_id}_{operation_id}_{dt}' def __init_subclass__(cls, **kwargs): """Check if all required variables are set in subclasses.""" super().__init_subclass__(**kwargs) required_attrs = ( 'OPERATION_TYPE', 'CONTROL_NODE_CLS', 'FILTER_DEPLOYABLE_DBS_ONLY', ) missing_attrs = [ attr for attr in required_attrs if getattr(cls, attr) is None ] if missing_attrs: raise NotImplementedError( f'{cls.__name__} missing required class variables: {missing_attrs}' ) def run(self, wave_id, db_ids=None): wave = Wave.query.with_for_update().get(wave_id) self._validate_wave_status(wave) db_mappings_objects = get_wave_db_mappings_objects( wave_id=wave_id, db_ids=db_ids, filter_deployable=self.FILTER_DEPLOYABLE_DBS_ONLY ) self._validate_db_mappings_objects(db_mappings_objects) self._set_wave_status_running(wave) operation = self._create_operation_model(wave_id) self._log(operation, db_mappings_objects) self._create_operation_details_models(operation, db_mappings_objects) try: self._start_pre_deployment( wave, operation, db_mappings_objects ) except Exception as e: print(e) logger.exception( 'error starting %s %s', operation.operation_type.value, wave_id ) self._handle_pre_deployment_failure(operation) self._post_operation_action(db_mappings_objects) db.session.commit() def _get_dms_config(self, source_db: SourceDB) -> DMSConfigSchema: config: Config = db.session.query(Config).filter_by(db_id=source_db.id).one() schema = DMSConfigSchema() dms_config: DMSConfigSchema = schema.load(config.dms_config_values) if dms_config['password_secret_id']: req = secretmanager.AccessSecretVersionRequest( name=secrets_client.secret_version_path(settings.GCP_PROJECT_NAME, dms_config['password_secret_id'], "latest") ) res = secrets_client.access_secret_version(request=req) dms_config['password'] = res.payload.data.decode('UTF-8') return dms_config def _start_dms_pre_deployment_local(self, wave, operation, dms_mappings): print('dms deployment started') dms = DMS(project_id=settings.GCP_PROJECT_NAME, region="us-central1") # TODO: do this async using operations callbacks for mapping in dms_mappings: config = self._get_dms_config(mapping.db) source_conn_name = f'waverunner-source-{mapping.db.db_name}' dest_conn_name = f'waverunner-target-for-{mapping.db.db_name}' job_name = f'waverunner-{mapping.db.db_name}' job_display_name = f'Waverunner job for {mapping.db.db_name}' print(f'Source DB config: {config}') def start_job(result): logger.info('starting job...') dms.start_migration_job(job_name) def create_job(result): logger.info('creating job...') dms.create_migration_job( name=job_name, display_name=job_display_name, source_conn=source_conn_name, destination_conn=dest_conn_name ).add_done_callback(start_job) def create_dest_connection(result): logger.info('creating destination connection...') dms.create_destination_connection_profile( name=dest_conn_name, source_conn_name=source_conn_name ).add_done_callback(create_job) logger.info('creating source connection...') dms.create_source_connection_profile( name=source_conn_name, host=mapping.db.server, port=config['port'], username=config['username'], password=config['password'] ).add_done_callback(create_dest_connection) def _start_pre_deployment(self, wave, operation, db_mappings_objects): dms_mappings = list(filter(lambda mapping: mapping.is_dms, db_mappings_objects)) bms_mappings = list(filter(lambda mapping: not mapping.is_dms, db_mappings_objects)) print(f'dms: {dms_mappings}') print(f'bms: {bms_mappings}') self._start_dms_pre_deployment_local(wave, operation, dms_mappings) if not bms_mappings: return """Generate ansible configs and start control node.""" gcs_config_dir = self._get_gcs_config_dir( wave_id=operation.wave_id, operation_id=operation.id ) # generate and upload ansible config files AnsibleConfigService( bms_mappings, gcs_config_dir ).run() # run control node self.CONTROL_NODE_CLS.run( project=wave.project, operation=operation, gcs_config_dir=gcs_config_dir, wave=wave, total_targets=self._count_total_targets(bms_mappings), ) @staticmethod def _validate_wave_status(wave): if wave.is_running: raise ValidationError({'wave': ['is already running']}) @staticmethod def _set_wave_status_running(wave): wave.is_running = True db.session.add(wave) db.session.commit() @staticmethod def _handle_pre_deployment_failure(operation): DeploymentOperationStatusHandler( operation, completed_at=datetime.now() ).terminate() @staticmethod def _post_operation_action(db_mappings_objects): pass class DeploymentService(BaseWaveOperation): FILTER_DEPLOYABLE_DBS_ONLY = True OPERATION_TYPE = OperationType.DEPLOYMENT CONTROL_NODE_CLS = DeployControlNodeService class RollbackService(BaseWaveOperation): FILTER_DEPLOYABLE_DBS_ONLY = False OPERATION_TYPE = OperationType.ROLLBACK CONTROL_NODE_CLS = RollbackConrolNodeService def run(self, wave_id, db_ids): if not db_ids: raise ValidationError( {'db_ids': ['is required for rollback']} ) super().run(wave_id, db_ids) @staticmethod def _post_operation_action(db_mappings_objects): for obj in db_mappings_objects: if obj.db.restore_config: db.session.delete(obj.db.restore_config)