import datetime
import time

import torch


def format_secs_to_time(time_delta):
    """converts seconds delta as generated by 2 snapshots of time.time() into hh:mm:ss"""
    return str(datetime.timedelta(seconds=time_delta)).split(".")[0]


def format_secs_to_sec_fractions(time_delta):
    """converts seconds delta as generated by 2 snapshots of time.time() into ss.msec"""
    return f"{time_delta:0.2f}"


def format_secs_to_sec(time_delta):
    """converts seconds delta as generated by 2 snapshots of time.time() into ss"""
    return int(time_delta)


class Timer(object):
    def __init__(self):
        self.start_time = None
        self.last_time = None

    def start(self):
        self.start_time = time.time()
        self.last_time = self.start_time

    def delta(self):
        """delta since last delta call"""
        now = time.time()
        delta = now - self.last_time
        self.last_time = now
        return delta

    def elapsed(self):
        return time.time() - self.start_time

    def stop(self):
        elapsed = self.elapsed()
        self.start_time = None
        self.last_time = None
        return elapsed


class DeviceAgnosticTimer(object):
    def __init__(self):
        self.is_cuda_available = torch.cuda.is_available()

    def start(self):
        if self.is_cuda_available:
            self.start_event = torch.cuda.Event(enable_timing=True)
            self.end_event = torch.cuda.Event(enable_timing=True)
            self.start_event.record()
        else:
            self.start_event = time.time()

    def stop(self):
        if self.is_cuda_available:
            self.end_event.record()
            torch.cuda.synchronize()
            diff = self.start_event.elapsed_time(self.end_event) / 1000
        else:
            self.end_event = time.time()
            diff = self.end_event - self.start_event
        return diff
