azext_edge/edge/util/file_operations.py (101 lines of code) (raw):

# coding=utf-8 # ---------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- import csv import json import yaml import os from pathlib import PurePath from typing import Any, Callable, List, Optional, Union from azure.cli.core.azclierror import FileOperationError, InvalidArgumentValueError from knack.log import get_logger logger = get_logger(__name__) # TODO: unit test def dump_content_to_file( content: List[dict], file_name: str, extension: str, fieldnames: Optional[List[str]] = None, output_dir: Optional[str] = None, replace: bool = False, ) -> PurePath: output_dir = normalize_dir(output_dir) file_path = os.path.join(output_dir, f"{file_name}.{extension}") if os.path.exists(file_path): if not replace: raise FileExistsError(f"File {file_path} already exists. Please choose another file name or add replace.") logger.warning(f"The file {file_path} will be overwritten.") if extension.endswith("csv"): with open(file_path, "w", newline="", encoding="utf-8") as f: if not fieldnames: fieldnames = content[0].keys() writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(content) return file_path # These let you dump to a string before writing to file if extension == "json": content = json.dumps(content, indent=2) elif extension in ["yaml", "yml"]: content = yaml.dump(content) with open(file_path, "w", encoding="utf-8") as f: f.write(content) return file_path def normalize_dir(dir_path: Optional[str] = None) -> PurePath: if not dir_path: dir_path = "." if "~" in dir_path: dir_path = os.path.expanduser(dir_path) dir_path = os.path.abspath(dir_path) dir_pure_path = PurePath(dir_path) if not os.path.exists(str(dir_pure_path)): os.makedirs(dir_pure_path, exist_ok=True) return dir_pure_path def read_file_content(file_path: str, read_as_binary: bool = False) -> Union[bytes, str]: from pathlib import Path logger.debug("Processing %s", file_path) pure_path = Path(os.path.abspath(os.path.expanduser(file_path))) if not pure_path.exists(): raise FileOperationError(f"{file_path} does not exist.") if not pure_path.is_file(): raise FileOperationError(f"{file_path} is not a file.") if read_as_binary: logger.debug("Reading %s as binary", file_path) return pure_path.read_bytes() # Try with 'utf-8-sig' first, so that BOM in WinOS won't cause trouble. for encoding in ["utf-8-sig", "utf-8"]: try: logger.debug("Reading %s as %s", file_path, encoding) return pure_path.read_text(encoding=encoding) except (UnicodeError, UnicodeDecodeError): pass raise FileOperationError(f"Failed to decode file {file_path}.") def deserialize_file_content(file_path: str) -> Any: extension = file_path.split(".")[-1] valid_extension = extension in ["json", "yaml", "yml", "csv"] content = read_file_content(file_path) result = None if not valid_extension or extension == "json": # will always be a list or dict result = _try_loading_as( loader=json.loads, content=content, error_type=json.JSONDecodeError, raise_error=valid_extension ) if (not result and not valid_extension) or extension in ["yaml", "yml"]: # can be list, dict, str, int, bool, none result = _try_loading_as( loader=yaml.safe_load, content=content, error_type=yaml.YAMLError, raise_error=valid_extension ) if (not result and not valid_extension) or extension == "csv": # iterrable object so lets cast to list result = _try_loading_as( loader=csv.DictReader, content=content.splitlines(), error_type=csv.Error, raise_error=valid_extension ) if result is not None or valid_extension: return result raise FileOperationError(f"File contents for {file_path} cannot be read.") def validate_file_extension(file_name: str, expected_exts: List[str]) -> str: ext = os.path.splitext(file_name)[1] lowercased_exts = [ext.lower() for ext in expected_exts] if ext.lower() not in lowercased_exts: exts_text = ", ".join(expected_exts) raise InvalidArgumentValueError( f"Invalid file extension found for {file_name}, only {exts_text} file extensions are supported." ) return ext def _try_loading_as(loader: Callable, content: str, error_type: Exception, raise_error: bool = True) -> Optional[Any]: try: return loader(content) except error_type as e: if raise_error: raise FileOperationError(e)