modules/agent-framework/airavata-agent/agent.go (371 lines of code) (raw):

package main import ( protos "airavata-agent/protos" "bufio" "bytes" "context" "encoding/json" "fmt" "io" "log" "net" "net/http" "os" "os/exec" "strings" "golang.org/x/crypto/ssh" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) func main() { args := os.Args[1:] serverUrl := args[0] agentId := args[1] grpcStreamChannel := make(chan struct{}) kernelChannel := make(chan struct{}) conn, err := grpc.NewClient(serverUrl, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { log.Fatalf("did not connect: %v", err) } defer conn.Close() c := protos.NewAgentCommunicationServiceClient(conn) stream, err := c.CreateMessageBus(context.Background()) if err != nil { log.Fatalf("Error creating stream: %v", err) } log.Printf("Trying to connect to %s with agent id %s", serverUrl, agentId) if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_AgentPing{AgentPing: &protos.AgentPing{AgentId: agentId}}}); err != nil { log.Fatalf("Failed to connect to the server: %v", err) } else { log.Printf("Connected to the server...") } go func() { log.Printf("Starting jupyter kernel") cmd := exec.Command("micromamba", "run", "-n", "base", "python", "/opt/jupyter/kernel.py") //cmd := exec.Command("jupyter/venv/bin/python", "jupyter/kernel.py") stdout, err := cmd.StdoutPipe() if err != nil { fmt.Println("[agent.go] Error creating StdoutPipe:", err) return } // Get stderr pipe stderr, err := cmd.StderrPipe() if err != nil { fmt.Println("[agent.go] Error creating StderrPipe:", err) return } log.Printf("[agent.go] Starting command for execution") // Start the command if err := cmd.Start(); err != nil { fmt.Println("[agent.go] Error starting command:", err) return } // Create channels to read from stdout and stderr stdoutScanner := bufio.NewScanner(stdout) stderrScanner := bufio.NewScanner(stderr) // Stream stdout go func() { for stdoutScanner.Scan() { fmt.Printf("[agent.go] stdout: %s\n", stdoutScanner.Text()) } }() // Stream stderr go func() { for stderrScanner.Scan() { fmt.Printf("[agent.go] stderr: %s\n", stderrScanner.Text()) } }() // Wait for the command to finish if err := cmd.Wait(); err != nil { fmt.Println("[agent.go] Error waiting for command:", err) return } fmt.Println("[agent.go] Command finished") }() go func() { for { in, err := stream.Recv() if err == io.EOF { close(grpcStreamChannel) return } if err != nil { log.Fatalf("[agent.go] Failed to receive a message : %v", err) } log.Printf("[agent.go] Received message %s", in.Message) switch x := in.GetMessage().(type) { case *protos.ServerMessage_PythonExecutionRequest: log.Printf("[agent.go] Recived a python execution request") executionId := x.PythonExecutionRequest.ExecutionId sessionId := x.PythonExecutionRequest.SessionId code := x.PythonExecutionRequest.Code workingDir := x.PythonExecutionRequest.WorkingDir libraries := x.PythonExecutionRequest.Libraries log.Printf("[agent.go] Execution id %s", executionId) log.Printf("[agent.go] Session id %s", sessionId) log.Printf("[agent.go] Code %s", code) log.Printf("[agent.go] Working Dir %s", workingDir) log.Printf("[agent.go] Libraries %s", libraries) go func() { // setup the venv venvCmd := fmt.Sprintf(` agentId="%s" pkgs="%s" if [ ! -f "/tmp/$agentId/venv" ]; then mkdir -p /tmp/$agentId python3 -m venv /tmp/$agentId/venv fi source /tmp/$agentId/venv/bin/activate python3 -m pip install $pkgs `, agentId, strings.Join(libraries, " ")) log.Println("[agent.go] venv setup:", venvCmd) venvExc := exec.Command("bash", "-c", venvCmd) venvOut, venvErr := venvExc.CombinedOutput() if venvErr != nil { fmt.Println("[agent.go] venv setup: ERR", venvErr) return } venvStdout := string(venvOut) fmt.Println("[agent.go] venv setup:", venvStdout) // execute the python code pyCmd := fmt.Sprintf(` workingDir="%s"; agentId="%s"; cd $workingDir; source /tmp/$agentId/venv/bin/activate; python3 <<EOF %s EOF`, workingDir, agentId, code) log.Println("[agent.go] python code:", pyCmd) pyExc := exec.Command("bash", "-c", pyCmd) pyOut, pyErr := pyExc.CombinedOutput() if pyErr != nil { fmt.Println("[agent.go] python code: ERR", pyErr) } // send the result back to the server pyStdout := string(pyOut) if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_PythonExecutionResponse{ PythonExecutionResponse: &protos.PythonExecutionResponse{ SessionId: sessionId, ExecutionId: executionId, ResponseString: pyStdout}}}); err != nil { log.Printf("[agent.go] Failed to send execution result to server: %v", err) } else { log.Printf("[agent.go] Sent execution result to the server: %v", pyStdout) } }() case *protos.ServerMessage_CommandExecutionRequest: log.Printf("[agent.go] Recived a command execution request") executionId := x.CommandExecutionRequest.ExecutionId execArgs := x.CommandExecutionRequest.Arguments log.Printf("[agent.go] Execution id %s", executionId) cmd := exec.Command(execArgs[0], execArgs[1:]...) log.Printf("[agent.go] Completed execution with the id %s", executionId) output, err := cmd.CombinedOutput() // combined output of stdout and stderr if err != nil { log.Printf("[agent.go] command execution failed: %s", err) } outputString := string(output) log.Printf("[agent.go] Execution output is %s", outputString) if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_CommandExecutionResponse{ CommandExecutionResponse: &protos.CommandExecutionResponse{ExecutionId: executionId, ResponseString: outputString}}}); err != nil { log.Printf("[agent.go] Failed to send execution result to server: %v", err) } case *protos.ServerMessage_JupyterExecutionRequest: log.Printf("[agent.go] Recived a jupyter execution request") executionId := x.JupyterExecutionRequest.ExecutionId sessionId := x.JupyterExecutionRequest.SessionId code := x.JupyterExecutionRequest.Code log.Printf("[agent.go] Execution ID: %s, Session ID: %s, Code: %s", executionId, sessionId, code) url := "http://127.0.0.1:15000/start" client := &http.Client{} req, err := http.NewRequest("GET", url, nil) if err != nil { log.Printf("[agent.go] Failed to create the request start jupyter kernel: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } log.Printf("[agent.go] Sending the jupyter kernel start request to server...") resp, err := client.Do(req) if err != nil { log.Printf("[agent.go] Failed to send the request start jupyter kernel: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } log.Printf("[agent.go] Successfully sent the jupyter kernel start request to server") defer func() { err := resp.Body.Close() if err != nil { log.Printf("[agent.go] Failed to close the response body for kernel start: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } }() body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("[agent.go] Failed to read response for start jupyter kernel: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } log.Printf("[agent.go] Starting to marshal execution request JSON data...") url = "http://127.0.0.1:15000/execute" data := map[string]string{ "code": code, "executionId": executionId, } jsonData, err := json.Marshal(data) if err != nil { log.Fatalf("[agent.go] Failed to marshal JSON: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } log.Printf("[agent.go] Successful marshaling the JSON data") req, err = http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { log.Printf("[agent.go] Failed to create the request run jupyter kernel: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } req.Header.Set("Content-Type", "application/json") client = &http.Client{} resp, err = client.Do(req) if err != nil { log.Printf("[agent.go] Failed to send the request run jupyter kernel: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } defer func() { log.Printf("[agent.go] Closing the response...") err := resp.Body.Close() if err != nil { log.Printf("[agent.go] Failed to close the response body for kernel execution: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } }() log.Printf("[agent.go] Sending the jupyter execution " + executionId + "result to server...") body, err = io.ReadAll(resp.Body) if err != nil { log.Printf("[agent.go] Failed to read response for run jupyter kernel: %v", err) jupyterResponse := "Failed while running the cell in remote. Please retry" if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } return } jupyterResponse := string(body) log.Println("[agent.go] Jupyter execution " + executionId + "response: " + jupyterResponse) if err := stream.Send(&protos.AgentMessage{Message: &protos.AgentMessage_JupyterExecutionResponse{ JupyterExecutionResponse: &protos.JupyterExecutionResponse{ExecutionId: executionId, ResponseString: jupyterResponse, SessionId: sessionId}}}); err != nil { log.Printf("[agent.go] Failed to send jupyter execution result to server: %v", err) } case *protos.ServerMessage_TunnelCreationRequest: log.Printf("[agent.go] Received a tunnel creation request") host := x.TunnelCreationRequest.DestinationHost destPort := x.TunnelCreationRequest.DestinationPort srcPort := x.TunnelCreationRequest.SourcePort keyPath := x.TunnelCreationRequest.SshKeyPath sshUser := x.TunnelCreationRequest.SshUserName log.Printf("[agent.go] Tunnel details - Host: %s, DestPort: %s, SrcPort: %s, KeyPath: %s, SSH User: %s", host, destPort, srcPort, keyPath, sshUser) openRemoteTunnel(host, destPort, srcPort, sshUser, keyPath) } } }() <-grpcStreamChannel <-kernelChannel if err := stream.CloseSend(); err != nil { log.Fatalf("failed to close the stream: %v", err) } } func openRemoteTunnel(remoteHost string, remotePort string, localPort string, sshUser string, sshKeyFile string) { log.Printf("Opening remote SSH tunnel - Remote Host: %s, Remote Port: %s, Local Port: %s", remoteHost, remotePort, localPort) // SSH server details sshHost := remoteHost + ":22" //sshPassword := "your_ssh_password" // Remote and local ports localHost := "localhost" key, err := os.ReadFile(sshKeyFile) if err != nil { log.Fatalf("unable to read private key: %v", err) } // Create the Signer for this private key. signer, err := ssh.ParsePrivateKey(key) if err != nil { log.Fatalf("unable to parse private key: %v", err) } // Create SSH client configuration sshConfig := &ssh.ClientConfig{ User: sshUser, Auth: []ssh.AuthMethod{ //ssh.Password(sshPassword), ssh.PublicKeys(signer), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Replace with proper host key verification for production } log.Println("Connecting to SSH server...") // Connect to the SSH server sshConn, err := ssh.Dial("tcp", sshHost, sshConfig) if err != nil { log.Fatalf("Failed to dial SSH: %s", err) } defer sshConn.Close() log.Println("SSH connection established.") // Listen on the remote port remoteListener, err := sshConn.Listen("tcp", fmt.Sprintf("0.0.0.0:%s", remotePort)) if err != nil { log.Fatalf("Failed to listen on remote port %s: %s", remotePort, err) } defer remoteListener.Close() log.Printf("Reverse SSH tunnel established. Listening on remote port %s", remotePort) for { remoteConn, err := remoteListener.Accept() if err != nil { log.Printf("Failed to accept remote connection: %s", err) continue } go handleConnection(remoteConn, localHost, localPort) } } func handleConnection(remoteConn net.Conn, localHost, localPort string) { log.Printf("Handling connection to local host %s:%s", localHost, localPort) defer remoteConn.Close() // Connect to the local host localConn, err := net.Dial("tcp", net.JoinHostPort(localHost, localPort)) if err != nil { log.Printf("Failed to connect to local host %s:%s: %s", localHost, localPort, err) return } defer localConn.Close() // Create channels to signal when copying is done done := make(chan struct{}) log.Println("Starting data transfer between remote and local connections...") // Start copying data between remote and local connections go copyConn(remoteConn, localConn, done) go copyConn(localConn, remoteConn, done) // Wait for both copy operations to complete <-done <-done log.Println("Data transfer completed.") } func copyConn(writer, reader net.Conn, done chan struct{}) { defer func() { // Signal that copying is done done <- struct{}{} }() _, err := io.Copy(writer, reader) if err != nil { if err != io.EOF { log.Printf("Data copy error: %s", err) } } }