integration/create_server.go (174 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. * */ package integration import ( "archive/tar" "compress/gzip" "fmt" "io" "log" "net/http" "os" "os/exec" "path" "path/filepath" "strings" "time" "github.com/facebookincubator/zk/flw" ) const zkFormatURL = "https://archive.apache.org/dist/zookeeper/zookeeper-%s/apache-zookeeper-%s-bin.tar.gz" const defaultArchiveName = "server.tar.gz" const defaultConfigName = "zk.cfg" const maxRetries = 10 // ZKServer represents a configurable Zookeeper server, mainly used for integration testing. type ZKServer struct { Version string Config *ServerConfig cmd *exec.Cmd } // NewZKServer creates a new ZKServer instance using the given version and config map. // Contents of the config struct are written to a file which will later be used by the ZK binary. func NewZKServer(version string, config *ServerConfig) (*ZKServer, error) { file, err := os.Create(defaultConfigName) if err != nil { return nil, err } defer file.Close() if err = config.Marshall(file); err != nil { return nil, fmt.Errorf("error writing config to file: %v", err) } return &ZKServer{ Version: version, Config: config, }, nil } // Run downloads the specified Zookeeper version from Apache's website, // extracts it and runs it using the config specified by configPath. // The server runs in the background until Shutdown is called. func (server *ZKServer) Run() error { zkURL := fmt.Sprintf(zkFormatURL, server.Version, server.Version) workdir, err := os.Getwd() if err != nil { return fmt.Errorf("error getting working directory: %s", err.Error()) } archivePath := filepath.Join(workdir, defaultArchiveName) if _, err := os.Stat(archivePath); os.IsNotExist(err) { err = downloadToFile(zkURL, archivePath) if err != nil { return fmt.Errorf("error downloading file: %s", err) } log.Printf("successfully downloaded archive %s\n", defaultArchiveName) } dirName := "apache-zookeeper-" + server.Version + "-bin" if _, err := os.Stat(filepath.Join(workdir, dirName)); os.IsNotExist(err) { _, err := extractTarGz(archivePath) if err != nil { return fmt.Errorf("error extracting file: %s", err) } } serverScriptPath := filepath.Join(workdir, dirName, "bin/zkServer.sh") err = os.Chmod(serverScriptPath, 0777) if err != nil { return fmt.Errorf("error changing server script permissions: %s", err) } server.cmd = exec.Command(serverScriptPath, "start-foreground", filepath.Join(workdir, defaultConfigName)) server.cmd.Stdout = os.Stdout server.cmd.Stderr = os.Stderr err = server.cmd.Start() if err != nil { return fmt.Errorf("error executing server command: %s", err) } if err = waitForStart(fmt.Sprintf(":%d", server.Config.ClientPort), maxRetries, time.Second); err != nil { return err } return nil } // Shutdown kills the underlying process of a ZKServer instance. func (server *ZKServer) Shutdown() error { log.Printf("Shutdown() called, killing server process") return server.cmd.Process.Kill() } func downloadToFile(sourceURL, filepath string) error { out, err := os.Create(filepath) if err != nil { return err } defer out.Close() response, err := http.Get(sourceURL) if err != nil { return err } defer response.Body.Close() if response.StatusCode != http.StatusOK { return fmt.Errorf("wrong status code while downloading from %s: %d", sourceURL, response.StatusCode) } _, err = io.Copy(out, response.Body) if err != nil { return err } return nil } func extractTarGz(src string) (string, error) { var rootPath string isRootDir := true file, err := os.Open(src) if err != nil { return "", fmt.Errorf("error opening archive file: %v", err.Error()) } defer file.Close() reader, err := gzip.NewReader(file) if err != nil { return "", fmt.Errorf("error creating gzip reader: %v", err.Error()) } tarReader := tar.NewReader(reader) for { header, err := tarReader.Next() if err == io.EOF { break } if err != nil { return "", fmt.Errorf("error reading from tar archive: %s", err.Error()) } switch header.Typeflag { case tar.TypeDir: if isRootDir { rootPath = getRootFromPath(header.Name) isRootDir = false } if err := os.MkdirAll(header.Name, os.ModePerm); err != nil { return "", fmt.Errorf("mkdir failed: %s", err.Error()) } case tar.TypeReg: if err = ensureBaseDir(header.Name); err != nil { return "", fmt.Errorf("error creating file from tar header: %s", err.Error()) } outFile, err := os.Create(header.Name) if err != nil { return "", fmt.Errorf("error creating file from tar header: %s", err.Error()) } if _, err = io.Copy(outFile, tarReader); err != nil { return "", fmt.Errorf("error copying file from archive: %s", err.Error()) } outFile.Close() default: return "", fmt.Errorf("unknown header type detected while extracting from archive %s", src) } } return rootPath, nil } func getRootFromPath(filepath string) string { idx := strings.Index(filepath, "/") if idx == -1 { return filepath } return filepath[:idx] } func ensureBaseDir(filepath string) error { baseDir := path.Dir(filepath) info, err := os.Stat(baseDir) if err == nil && info.IsDir() { return nil } return os.MkdirAll(baseDir, 0755) } // waitForStart blocks until the server from the specified address is up, returns error after max retries otherwise func waitForStart(address string, maxRetry int, interval time.Duration) error { client := &flw.Client{Timeout: time.Second} for i := 0; i < maxRetry; i++ { if _, err := client.Srvr(address); err == nil { return nil } time.Sleep(interval) } return fmt.Errorf("unable to verify health of servers") }