parquet_flask/aws/aws_s3.py (95 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from io import BytesIO from parquet_flask.aws.aws_cred import AwsCred from parquet_flask.utils.file_utils import FileUtils LOGGER = logging.getLogger(__name__) class AwsS3(AwsCred): def __init__(self): super().__init__() self.__valid_s3_schemas = ['s3://', 's3a://', 's3s://'] self.__s3_client = self.get_client('s3') self.__target_bucket = None self.__target_key = None def get_s3_stream(self): return self.__s3_client.get_object(Bucket=self.__target_bucket, Key=self.__target_key)['Body'] def read_small_txt_file(self): """ convenient method to read small text files stored in S3 :param bucket: bucket name :param key: S3 key :return: text file contents """ bytestream = BytesIO(self.get_s3_stream().read()) # get the bytes stream of zipped file return bytestream.read().decode('UTF-8') def get_s3_obj_size(self): # get head of the s3 file s3_obj_head = self.__s3_client.head_object( Bucket=self.__target_bucket, Key=self.__target_key, ) # get the object size s3_obj_size = int(s3_obj_head['ResponseMetadata']['HTTPHeaders']['content-length']) if s3_obj_size is None: # no object size found. something went wrong. return -1 return s3_obj_size def __get_all_s3_files_under(self, bucket, prefix, with_versions=False): list_method_name = 'list_object_versions' if with_versions is True else 'list_objects_v2' page_key = 'Versions' if with_versions is True else 'Contents' paginator = self.__s3_client.get_paginator(list_method_name) operation_parameters = { 'Bucket': bucket, 'Prefix': prefix } page_iterator = paginator.paginate(**operation_parameters) for eachPage in page_iterator: if page_key not in eachPage: continue for fileObj in eachPage[page_key]: yield fileObj def get_child_s3_files(self, bucket, prefix, additional_checks=lambda x: True, with_versions=False): for fileObj in self.__get_all_s3_files_under(bucket, prefix, with_versions=with_versions): if additional_checks(fileObj): yield fileObj['Key'], fileObj['Size'] def set_s3_url(self, s3_url): LOGGER.debug(f'setting s3_url: {s3_url}') self.__target_bucket, self.__target_key = self.split_s3_url(s3_url) LOGGER.debug(f'props: {self.__target_bucket}, {self.__target_key}') return self def split_s3_url(self, s3_url): s3_schema = [k for k in self.__valid_s3_schemas if s3_url.startswith(k)] if len(s3_schema) != 1: raise ValueError('invalid s3 url: {}'.format(s3_url)) s3_schema_length = len(s3_schema[0]) split_index = s3_url[s3_schema_length:].find('/') bucket = s3_url[s3_schema_length: split_index+s3_schema_length] key = s3_url[(split_index + s3_schema_length + 1):] return bucket, key def __tag_existing_obj(self, other_tags={}): if len(other_tags) == 0: return tags = { 'TagSet': [] } for key, val in other_tags.items(): tags['TagSet'].append({ 'Key': key, 'Value': str(val) }) self.__s3_client.put_object_tagging(Bucket=self.__target_bucket, Key=self.__target_key, Tagging=tags) return def add_tags_to_obj(self, other_tags={}): """ retrieve existing tags first and append new tags to them :param bucket: string :param s3_key: string :param other_tags: dict :return: bool """ if len(other_tags) == 0: return False response = self.__s3_client.get_object_tagging(Bucket=self.__target_bucket, Key=self.__target_key) if 'TagSet' not in response: return False all_tags = {k['Key']: k['Value'] for k in response['TagSet']} for k, v in other_tags.items(): all_tags[k] = v pass self.__tag_existing_obj(all_tags) return True def download(self, local_dir, file_name=None): if not FileUtils.dir_exist(local_dir): raise ValueError('missing directory') if file_name is None: LOGGER.debug(f'setting the downloading filename from target_key: {self.__target_key}') file_name = os.path.basename(self.__target_key) local_file_path = os.path.join(local_dir, file_name) LOGGER.debug(f'downloading to local_file_path: {local_file_path}') self.__s3_client.download_file(self.__target_bucket, self.__target_key, local_file_path) LOGGER.debug(f'file downloaded') return local_file_path