vsock_sample/py/vsock-sample.py (79 lines of code) (raw):
#!/usr/local/bin/env python3
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import socket
import sys
class VsockStream:
"""Client"""
def __init__(self, conn_tmo=5):
self.conn_tmo = conn_tmo
def connect(self, endpoint):
"""Connect to the remote endpoint"""
self.sock = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
self.sock.settimeout(self.conn_tmo)
self.sock.connect(endpoint)
def send_data(self, data):
"""Send data to a remote endpoint"""
self.sock.sendall(data)
def recv_data(self):
"""Receive data from a remote endpoint"""
while True:
data = self.sock.recv(1024).decode()
if not data:
break
print(data, end='', flush=True)
print()
def disconnect(self):
"""Close the client socket"""
self.sock.close()
def client_handler(args):
client = VsockStream()
endpoint = (args.cid, args.port)
client.connect(endpoint)
msg = 'Hello, world!'
client.send_data(msg.encode())
client.disconnect()
class VsockListener:
"""Server"""
def __init__(self, conn_backlog=128):
self.conn_backlog = conn_backlog
def bind(self, port):
"""Bind and listen for connections on the specified port"""
self.sock = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
self.sock.bind((socket.VMADDR_CID_ANY, port))
self.sock.listen(self.conn_backlog)
def recv_data(self):
"""Receive data from a remote endpoint"""
while True:
(from_client, (remote_cid, remote_port)) = self.sock.accept()
# Read 1024 bytes at a time
while True:
try:
data = from_client.recv(1024).decode()
except socket.error:
break
if not data:
break
print(data, end='', flush=True)
print()
from_client.close()
def send_data(self, data):
"""Send data to a renote endpoint"""
while True:
(to_client, (remote_cid, remote_port)) = self.sock.accept()
to_client.sendall(data)
to_client.close()
def server_handler(args):
server = VsockListener()
server.bind(args.port)
server.recv_data()
def main():
parser = argparse.ArgumentParser(prog='vsock-sample')
parser.add_argument("--version", action="version",
help="Prints version information.",
version='%(prog)s 0.1.0')
subparsers = parser.add_subparsers(title="options")
client_parser = subparsers.add_parser("client", description="Client",
help="Connect to a given cid and port.")
client_parser.add_argument("cid", type=int, help="The remote endpoint CID.")
client_parser.add_argument("port", type=int, help="The remote endpoint port.")
client_parser.set_defaults(func=client_handler)
server_parser = subparsers.add_parser("server", description="Server",
help="Listen on a given port.")
server_parser.add_argument("port", type=int, help="The local port to listen on.")
server_parser.set_defaults(func=server_handler)
if len(sys.argv) < 2:
parser.print_usage()
sys.exit(1)
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()