import sys
import io
import os
import re
import json
import random
import tarfile
import numpy as np
from functools import partial
import torch.distributed as dist
from PIL import Image

import webdataset as wds
from webdataset import ResampledShards, DataPipeline, tarfile_to_samples
from webdataset.filters import pipelinefilter
from webdataset.tariterators import url_opener, group_by_keys
from webdataset.handlers import reraise_exception
from webdataset.gopen import gopen_schemes, gopen

def pytorch_worker_info(group=None):  # sourcery skip: use-contextlib-suppress
    """Return node and worker info for PyTorch and some distributed environments."""
    rank = 0
    world_size = 1
    worker = 0
    num_workers = 1
    try:
        import torch.distributed

        if torch.distributed.is_available() and torch.distributed.is_initialized():
            group = group or torch.distributed.group.WORLD
            rank = torch.distributed.get_rank(group=group)
            world_size = torch.distributed.get_world_size(group=group)
    except ModuleNotFoundError:
        pass
    try:
        import torch.utils.data

        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            worker = worker_info.id
            num_workers = worker_info.num_workers
    except ModuleNotFoundError:
        pass

    return rank, world_size, worker, num_workers


def pytorch_worker_seed(group=None):
    """Compute a distinct, deterministic RNG seed for each worker and node."""
    rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
    return rank * 1000 + worker

def worker_seed_sat(group=None, seed=0):
    return pytorch_worker_seed(group=group) + seed * 23

class ConfiguredResampledShards(ResampledShards):
    def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
        from sat.helpers import print_rank0
        try:
            from megatron.core.parallel_state import get_data_parallel_group
            group = get_data_parallel_group()
            print_rank0("Using megatron data parallel group.")
        except:
            from sat.mpu import get_data_parallel_group
            try:
                group = get_data_parallel_group()
                print_rank0("Using sat data parallel group.")
            except AssertionError:
                group = None
                print_rank0("No data parallel group is specified!")
        worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)
        super().__init__(urls, nshards, worker_seed_sat_this, deterministic)

class SimpleDistributedWebDataset(DataPipeline):
    def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
        # set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle
        try:
            from sat.mpu import get_model_parallel_world_size
            if get_model_parallel_world_size() > 1:
                shuffle_buffer = 1
        except Exception:
            pass
        super().__init__(
            ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly
            tarfile_to_samples(),
            wds.shuffle(shuffle_buffer),
            process_fn
        )

def tar_file_iterator_with_meta(fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None,handler=reraise_exception, meta_stream=None):
    """Iterate over tar file, yielding filename, content pairs for the given tar stream.

    :param fileobj: byte stream suitable for tarfile
    :param meta_names: key of different items in meta file
    :param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")

    """
    stream = tarfile.open(fileobj=fileobj, mode="r|*")
    data_dir, filename = fileobj.name.rsplit('/', 1)
    meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}}
    if meta_stream is None:
        meta_file_name = filename.split('.')[0] + '.meta.jsonl'
        meta_path = os.path.join(data_dir, meta_file_name)
        if os.path.exists(meta_path):
            meta_stream = open(meta_path, 'r')
    else:
        meta_file_name = meta_stream.name
    
    if meta_stream is not None:
        for lineno, line in enumerate(meta_stream):
            meta_list = []
            try:
                meta_list.append(json.loads(line))
            except Exception as exn:
                from sat.helpers import print_rank0
                print_rank0(f'Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}', level='DEBUG')
                continue
            for item in meta_list:
                if not item['key'] in meta_data:
                    meta_data[item['key']] = {}
                for meta_name in meta_names:
                    if meta_name in item:
                        meta_data[item['key']][meta_name] = item[meta_name]
        meta_stream.close()
    
    try:
        for tarinfo in stream:
            fname = tarinfo.name
            try:
                if not tarinfo.isreg():
                    continue
                if fname is None:
                    continue
                if (
                    "/" not in fname
                    and fname.startswith("__")
                    and fname.endswith("__")
                ):
                    # skipping metadata for now
                    continue
                if skip_meta is not None and re.match(skip_meta, fname):
                    continue
                if fname.endswith('.txt') and suffix is not None:
                    data = (stream.extractfile(tarinfo).read().decode() + suffix).encode()
                else:
                    data = stream.extractfile(tarinfo).read()
                result = dict(fname=fname, data=data)
                yield result
                
                if fname.endswith('.id'):
                    fid = fname.split('.')[0]
                    meta_data_fid = meta_data.get(fid, {})
                    for meta_name in meta_names:
                        meta_fname = fid + '.' + meta_name
                        meta = meta_data_fid.get(meta_name, None)
                        yield dict(fname=meta_fname, data=meta)
                stream.members = []
            except Exception as exn:
                if hasattr(exn, "args") and len(exn.args) > 0:
                    exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
                if handler(exn):
                    continue
                else:
                    break
    except Exception as exn:
        print(exn)
    del stream
    
def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
    """Expand a stream of open tar files into a stream of tar file contents.

    This returns an iterator over (filename, file_contents).
    """
    for source in data:
        url = source["url"]
        try:
            assert isinstance(source, dict)
            assert "stream" in source
            for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source['meta_stream']):
                assert (
                    isinstance(sample, dict) and "data" in sample and "fname" in sample
                )
                sample["__url__"] = url
                yield sample
        except Exception as exn:
            exn.args = exn.args + (source.get("stream"), source.get("url"))
            if handler(exn):
                continue
            else:
                break

def url_opener(
    data,
    handler,
    **kw,
):
    """Open URLs and yield a stream of url+stream pairs.

    Args:
        data: iterator over dict(url=...)
        handler: exception handler.
        kw: keyword arguments for gopen.gopen.

    Yields:
        a stream of url+stream pairs.
    """
    for sample in data:
        assert isinstance(sample, dict), sample
        assert "url" in sample
        url = sample["url"]
        try:
            stream = gopen(url, **kw)
            if hasattr(stream, 'meta_stream'):
                meta_stream = stream.meta_stream
                del stream.meta_stream
            else:
                meta_stream = None
            sample.update(stream=stream, meta_stream=meta_stream)
            yield sample
        except Exception as exn:
            exn.args = exn.args + (url,)
            if handler(exn):
                continue
            else:
                break
            
def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception):
    streams = url_opener(src, handler=handler)
    files = tar_file_expander_with_meta(streams, meta_names, handler)
    samples = group_by_keys(files, handler=handler)
    return samples  
        
class MetaDistributedWebDataset(DataPipeline):
    '''WebDataset with meta information files
    Extra Format:
        in webdataset (tar), for each sample there is a '.id'; 
        for each tar file, there is a '.meta.jsonl' file with the same name;
        The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'.
    '''
    def __init__(self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None):
        # os.environ['WDS_SHOW_SEED'] = '1'
        if include_dirs is not None: # /webdatasets/A,/webdatasets/C
            other_paths = []
            include_dirs = include_dirs.split(',')
            for include_dir in include_dirs:
                if '*' in include_dir:
                    include_dir, n = include_dir.split('*')
                    n = int(n)
                else:
                    n = 1
                for cur_dir, dirs, files in os.walk(include_dir):
                    for f in files:
                        if f.endswith('tar') and os.path.getsize(os.path.join(cur_dir,f)) > 0:
                            # other_paths.append(os.path.join(cur_dir,f))
                            other_paths.extend([os.path.join(cur_dir,f)]*n)
            # print(f'Adding dataset paths {",".join(other_paths)}')
            from braceexpand import braceexpand
            if len(path) > 0: # not "" 
                path = list(braceexpand(path)) + other_paths
            else:
                path = other_paths
        
        tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names)
        tarfile_to_samples = pipelinefilter(tarfile_samples)

        # if model parallel, shuffle_buffer should be 1 to disable shuffling
        try:
            from sat.mpu import get_model_parallel_world_size
            if get_model_parallel_world_size() > 1:
                shuffle_buffer = 1
        except Exception:
            pass
        
        super().__init__(
            ConfiguredResampledShards(path, seed, nshards=nshards),
            tarfile_to_samples(),
            wds.shuffle(shuffle_buffer),
            process_fn
        )

# rclone support
from webdataset.gopen import Pipe
def gopen_rclone(url, mode="rb", bufsize=1024*1024*32):
    """Open a URL with `curl`.

    :param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured.
    :param mode: file mode
    :param bufsize: buffer size
    """
    url = url.replace("rclone://", "")
    if mode[0] == "r":
        cmd = f"rclone cat '{url}'"
        return Pipe(
            cmd,
            mode=mode,
            shell=True,
            bufsize=bufsize,
            ignore_status=[141, 23],
        )  # skipcq: BAN-B604
    elif mode[0] == "w":
        cmd = f"rclone cp - '{url}'"
        return Pipe(
            cmd,
            mode=mode,
            shell=True,
            bufsize=bufsize,
            ignore_status=[141, 26],
        )  # skipcq: BAN-B604
    else:
        raise ValueError(f"{mode}: unknown mode")

def gopen_boto3(url, mode="rb", bufsize=8192*2):
    """Open a URL with boto3 API.

    :param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured.
    :param mode: file mode
    :param bufsize: buffer size
    """
    import boto3
    # boto3.set_stream_logger('botocore', level='DEBUG')
    if url.startswith("boto3://"):
        url = url.replace("boto3://", "")
        need_meta = False
    else:
        url = url.replace("metaboto3://", "")
        need_meta = True
    endpoint_url = os.environ.get("S3_ENDPOINT_URL", None)
    access_key = os.environ.get("S3_ACCESS_KEY_ID", None)
    secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None)
    
    if mode[0] == "r":
        s3_client = boto3.client('s3',
            endpoint_url=endpoint_url,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key
            )
        bucket, key = url.split('/', 1)

        if need_meta:
            # download a meta json
            meta_file_key = key.split('.')[0] + '.meta.jsonl'
            meta_stream = io.BytesIO()
            s3_client.download_fileobj(bucket, meta_file_key, meta_stream)
            meta_stream.seek(0)
            meta_stream.name = meta_file_key
        else:
            meta_stream = None

        # data tar stream
        response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional
        response['Body'].name = key # actually not used
        response['Body'].meta_stream = meta_stream
        return response['Body']
    else:
        raise ValueError(f"{mode}: unknown mode")

gopen_schemes['rclone'] = gopen_rclone
gopen_schemes['boto3'] = gopen_boto3
gopen_schemes['metaboto3'] = gopen_boto3


