tools/memorystore-cluster-ops-framework/mrc_framework.py (190 lines of code) (raw):

#!/usr/bin/env python # Copyright 2023 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import redis import subprocess import time from google.cloud import logging import os import json import requests def read_config(): """ Read the configuration file. Returns: dict: The configuration dictionary. """ with open('config.json', 'r') as config_file: config = json.load(config_file) return config OUTPUT_LOGS = read_config()['OUTPUT_LOGS'] class redisCluster(redis.cluster.RedisCluster): """ A class to interact with a Redis cluster. Args: host (str): The host of the Redis cluster. port (int): The port of the Redis cluster. password (str): The password of the Redis cluster. """ def __init__(self, host, port, password): """ Initialize the RedisCluster object. Args: host (str): The host of the Redis cluster. port (int): The port of the Redis cluster. password (str): The password of the Redis cluster. """ self.host = host self.port = port self.password = password self.client = redis.cluster.RedisCluster(host=self.host, port=self.port, password=self.password, decode_responses=False) # Connect to Redis redis_client = self.client # Get all nodes in the cluster self.cluster_nodes = redis_client.execute_command('CLUSTER NODES') super().__init__(host=host, port=port, password=password) def getDBSize(self): """ Get the total number of keys in the Redis cluster. Returns: int: The total number of keys in the Redis cluster. """ cluster_nodes = self.cluster_nodes key_count = {} for node in cluster_nodes: # Check if the node is a master if 'master' not in str(cluster_nodes[node]['flags']): continue # Extract the IP and port of the node node_ip, node_port = node.split(':') # Connect to each node and get the key count node_client = redis.Redis(host=node_ip, port=int(node_port), decode_responses=True) key_count[node] = node_client.dbsize() node_client.close() # Print the key count for each node for node, count in key_count.items(): node_ip = node.split(':')[0] node_port = node.split(':')[1] totalDBSize = sum(key_count.values()) return totalDBSize def nrandomkeys(self, n): """ Get a list of `n` random keys from the Redis cluster. Args: n (int): The number of keys to return. Returns: set: A set of `n` random keys from the Redis cluster. """ redis_client = self.client size = min(n, self.getDBSize()) keys = set() counter = 0 while counter <=size : key = redis_client.randomkey() keys.add(key) counter += 1 return keys def delAllKeys(self): """ Delete all keys from the Redis cluster. """ cluster_nodes = self.cluster_nodes for node in cluster_nodes: node_port = node.split(':')[1] node_ip = node.split(':')[0] node_client = redis.cluster.RedisCluster(host=node_ip, port=int(node_port), decode_responses=True) node_client.flushdb() write_log(f"DB successfully flushed", target=OUTPUT_LOGS) def getVal(self, key): """ Get the value of a key from the Redis cluster. Args: key (str): The key to get the value of. Returns: str: The value of the key. """ redis_client = self.client key_type = redis_client.type(key) if key_type == b'string': return redis_client.get(key) elif key_type == b'hash': return redis_client.hgetall(key) elif key_type == b'list': return redis_client.lrange(key, 0, -1) elif key_type == b'set': return redis_client.smembers(key) elif key_type == b'zset': return redis_client.zrange(key, 0, -1) else: write_log(f"Key type not supported", target=OUTPUT_LOGS) # Add more cases as needed return None def backup_cluster(self, cluster_name, gcs_bucket, file_type): """ Backup the Redis cluster to a GCS bucket. Args: cluster_name (str): The name of the Redis cluster. gcs_bucket (str): The name of the GCS bucket to backup the cluster to. """ # Generate timestamp timestamp = time.strftime("%Y%m%d%H%M%S") write_log(f"Exporting Redis data to GCS at {timestamp}",target=OUTPUT_LOGS) prefix = f"{gcs_bucket}/mrc-redis-backups/{cluster_name}" # Construct the output filename output_filename = f"export_{timestamp}.{file_type}" # Construct the path path = f"{prefix}/{output_filename}" write_log(f"File will be placed at {path}", target=OUTPUT_LOGS) # Check if the directory exists in case of local storage. if not os.path.exists(prefix) and not gcs_bucket.startswith("gs://") : # Create the directory os.makedirs(prefix) # Construct the bash command riot_path = read_config()['riot_bin_path'] does_file_exist(riot_path) bash_command = f"{riot_path}/riot -h {self.host} -p {self.port} -c file-export {path}" write_log(f"Executing bash command: {bash_command}", target=OUTPUT_LOGS) webhook_url = read_config()['SLACK_WEBHOOK_URL'] # Run the bash command exec_subprocess(bash_command) write_log(f"Export successful. File uploaded to: {path}", target=OUTPUT_LOGS) send_slack_message( webhook_url=webhook_url, message=f"Backup successful for cluster {cluster_name} on {timestamp}") def restore_cluster(self, restore_file, mode = 'append'): """ Restore the Redis cluster from a GCS backup. Args: restore_file (str): The path to the GCS backup file. """ riot_path = read_config()['riot_bin_path'] does_file_exist(riot_path) does_file_exist(restore_file) if mode == 'append': pass elif mode == 'replace': self.delAllKeys() else: write_log(f"Invalid mode", target=OUTPUT_LOGS) exit(1) bash_command = f"{riot_path}/riot -h {self.host} -p {self.port} -c dump-import {restore_file}" write_log(f"Executing bash command: {bash_command}", target=OUTPUT_LOGS) exec_subprocess(bash_command) write_log(f"Import successful.", target=OUTPUT_LOGS) def exec_subprocess(bash_command): try: result = subprocess.run(bash_command, shell=True, check=True,capture_output = True, text = True) write_log(f"{result.stdout}", target=OUTPUT_LOGS) write_log(f"{result.stderr}", target=OUTPUT_LOGS) except subprocess.CalledProcessError as e: write_log(f"Error: {e.stderr}", target=OUTPUT_LOGS) exit(1) def does_file_exist(file): """ Check if a file exists. Args: file (str): The path to the file. Returns: bool: True if the file exists, False otherwise. """ if os.path.exists(file): return True else: write_log(f"Error: {file} does not exist.",target=OUTPUT_LOGS) exit(1) # Write to the log def write_log(message, target = "console"): """ Write a message to the log. Args: message (str): The message to write to the log. """ if target == "console": print(message) elif target == "cloud-logging": client = logging.Client() logger = client.logger('ms-validation-framework-logs') logger.log_text(message) else: print(message) client = logging.Client() logger = client.logger('ms-validation-framework-logs') logger.log_text(message) def replicate_data(source , target, replication_mode = 'snapshot', verification_mode = ''): """ Replicate data from one Redis cluster to another. Args: source (redisCluster): The source Redis cluster. target (redisCluster): The target Redis cluster. """ # Generate timestamp sourcehost = source.host sourceport = source.port tgthost = target.host tgtport = target.port verificiation_mode = verification_mode replication_mode = replication_mode # Construct the bash command riot_path = read_config()['riot_bin_path'] does_file_exist(riot_path) bash_command = f"{riot_path}/riot -h {sourcehost} -p {sourceport} --cluster replicate --mode={replication_mode} -h {tgthost} -p {tgtport} --cluster {verificiation_mode}" write_log(f"Executing bash command: {bash_command}", target=OUTPUT_LOGS) try: subprocess.run(bash_command, shell=True, check=True) except subprocess.CalledProcessError as e: write_log(f"Error: {e.stderr}", target=OUTPUT_LOGS) exit(1) write_log(f"Replication successful", target=OUTPUT_LOGS) def validateCounts(source, target): """ Validate that the number of keys in two Redis clusters are the same. Args: source (redisCluster): The source Redis cluster. target (redisCluster): The target Redis cluster. """ source_size = source.getDBSize() target_size = target.getDBSize() if source_size == target_size: write_log(f"Source and target DB sizes match: {source_size}. Count validation successful", target=OUTPUT_LOGS) return True else: write_log(f"Source and target DB sizes do not match: {source_size} != {target_size}", target=OUTPUT_LOGS) return False def deepValidate(sampling_factor, src, tgt): """ Deep validate the data in two Redis clusters. Args: sampling_factor (float): The sampling factor to use for the deep validation. src (redisCluster): The source Redis cluster. tgt (redisCluster): The target Redis cluster. """ sample_count = int(round(sampling_factor * src.getDBSize())) # Get the samples to test in a list samples = src.nrandomkeys(sample_count) validationPassed = True for key in samples: key_exists = tgt.exists(key) if key_exists: srcVal = src.getVal(key) tgtVal = tgt.getVal(key) if srcVal != tgtVal: write_log(f"Invalid Value for key '{key}':\n {tgtVal} \n {srcVal}", target=OUTPUT_LOGS) validationPassed = False else: pass else: write_log(f"Key '{key}' is NOT present in Redis.", target=OUTPUT_LOGS) validationPassed = False if validationPassed: write_log(f"Deep validation successful", target=OUTPUT_LOGS) else: write_log(f"Deep validation failed", target=OUTPUT_LOGS) return validationPassed def send_slack_message(webhook_url, message): payload = { "text": message } headers = { "Content-Type": "application/json" } response = requests.post(webhook_url, data=json.dumps(payload), headers=headers) if response.status_code == 200: write_log(f"Message sent successfully!", target=OUTPUT_LOGS) else: write_log(f"Message failed to send to SLACK", target=OUTPUT_LOGS)