tools/tftp_tester.py (212 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import sys import socket import time import hashlib import struct import argparse import traceback from enum import Enum class Spinner: positions = ["-", "\\", "|", "/"] def __init__(self): self.cur = 0 def spin(self): self.cur = (self.cur + 1) % 4 return self.positions[self.cur] def show(self): print("\r{0}".format(self.spin()), end="") sys.stdout.flush() """ Helper classes, functions and data structures """ class TftpException(Exception): pass class TFTP(Enum): RRQ = 1 DATA = 3 ACK = 4 ERROR = 5 OACK = 6 def str0(v): """Returns a null terminated byte array""" if type(v) is not str: raise Exception("Only strings") b = bytearray(v, encoding="ascii") b.append(0) return b def as2bytes(i): if isinstance(i, TFTP): i = i.value return struct.pack(">H", i) def get_packet_type(pkt): return TFTP(int.from_bytes(pkt[0:2], byteorder="big")) def get_packet_num(pkt): return int.from_bytes(pkt[2:4], byteorder="big") def get_packet_data(pkt): return pkt[4:] class TftpTester(object): def __init__( self, server, port, timeout, retries, filename, blksize, failsend, failreceive, verbose, ): self.server = server self.port = int(port) self.filename = filename self.blksize = int(blksize) self.output = bytearray() self.hash = hashlib.md5() self.timeout = int(timeout) self.retries = int(retries) self.failsend = [int(i) for i in failsend] self.failreceive = [int(i) for i in failreceive] self.verbose = verbose self.spinner = Spinner() self.is_closed = True def gen_RRQ(self): """Initial RRQ packet and the expected response type OACK""" b = bytearray(as2bytes(TFTP.RRQ)) b.extend(str0(self.filename)) b.extend(str0("octet")) b.extend(str0("tsize")) b.extend(str0("0")) b.extend(str0("blksize")) b.extend(str0(str(self.blksize))) return b def gen_ACK(self, num): """ACK packet {num} and the expected response of type DATA""" b = bytearray(as2bytes(TFTP.ACK)) b.extend(as2bytes(num)) return b def gen_ERROR(self, message): """generate ERROR packet""" b = bytearray(as2bytes(TFTP.ERROR)) b.extend(as2bytes(0)) b.extend(str0(message)) return b def set_socket(self): self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) self.sock.settimeout(self.timeout) self.is_closed = False def send(self, packet): self.sock.sendto(packet, (self.server, self.port)) def send_and_expect(self, packet, expect, cur): retries = 0 while retries < self.retries: begin = time.time() if cur in self.failsend: # pretend we sent a packet which was lost self.failsend.remove(cur) else: self.send(packet) try: answer, sender_addr = self.sock.recvfrom(self.blksize + 4) self.port = sender_addr[1] if self.verbose: self.spinner.show() num = get_packet_num(answer) # is this the last packet? is_last = ( get_packet_type(answer) == TFTP.DATA and len(get_packet_data(answer)) < self.actual_blksize ) # replace -1 with the actual packet number in failreceive # this allows us to use the same construction for all packets if is_last and (-1 in self.failreceive): self.failreceive[self.failreceive.index(-1)] = num # pretend we didn't receive any message if num in self.failreceive: self.failreceive.remove(num) delta = time.time() - begin time.sleep(self.timeout - delta) raise socket.timeout() # if it's the next DATA or an OACK, we're good if get_packet_type(answer) == expect: if (expect == TFTP.DATA and num == cur + 1) or ( expect == TFTP.OACK ): break elif get_packet_type(answer) == TFTP.ERROR: raise TftpException(answer[4:-1].decode("ascii")) else: print("\nUnexpected packet received. Ignoring") except socket.timeout: retries += 1 return answer def loop(self): finished = False current = 0 data = self.send_and_expect(self.gen_RRQ(), TFTP.OACK, current) oack = data.decode("ascii").split("\x00") self.actual_blksize = int(oack[4]) while not finished: resp = self.send_and_expect(self.gen_ACK(current), TFTP.DATA, current) num = get_packet_num(resp) data = get_packet_data(resp) if num > current: current = num self.hash.update(data) if len(data) < self.actual_blksize: finished = True # pretend the last ack was lost in transit while -1 in self.failsend: self.failsend.remove(-1) time.sleep(self.timeout) self.sock.sendto(self.gen_ACK(current), (self.server, self.port)) print("\rFinished") def close(self): if not self.is_closed: self.sock.close() self.is_closed = True print(f"md5: {self.hash.hexdigest()}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Simple utility to test fbtftp server") parser.add_argument( "--server", default="::1", help="server IP address " "(default: ::1)" ) parser.add_argument( "--port", default=69, help="server tftp port " "(default: udp/69)" ) parser.add_argument( "--timeout", default=5, help="timeout interval in seconds " "(default: 5)" ) parser.add_argument( "--retries", default=5, help="number of retries " "(default: 5)" ) parser.add_argument("--filename", required=True, help="remote file name") parser.add_argument( "--blksize", default=1228, help="block size in bytes " "(default: 1228)" ) parser.add_argument( "--failreceive", default=[], help="list of packets which " "will be ignored", nargs="+", ) parser.add_argument( "--failsend", default=[], help="list of packets which " "will not be sent", nargs="+", ) parser.add_argument("--verbose", "-v", action="count", help="display a spinner") args = parser.parse_args(sys.argv[1:]) verbose = bool(args.verbose) t = TftpTester( server=args.server, port=args.port, filename=args.filename, blksize=args.blksize, timeout=args.timeout, retries=args.retries, failsend=args.failsend, failreceive=args.failreceive, verbose=verbose, ) try: t.set_socket() t.loop() except Exception as ex: t.send(t.gen_ERROR("system error")) if t.verbose: traceback.print_tb(ex) else: print(f"Error: {ex}") except KeyboardInterrupt: t.send(t.gen_ERROR("aborted by user request")) print("Aborted") finally: t.close()