# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json

import matplotlib.pyplot as plt


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", type=str, help="Directory containing the memory usage data")
    parser.add_argument(
        "--memory_threshold",
        type=int,
        default=0,
        help="Memory threshold to filter data that is below this value (only filters 1st `--filter_partition` of the points which should roughtly correspond to the model loading)",
    )
    parser.add_argument(
        "--filter_partition",
        type=float,
        default=1 / 3,
        help="Partition to drop data from that are below the memory threshold",
    )
    return parser.parse_args()


def filter_data(data, memory_threshold, filter_partition, key):
    timestamps = data["timestamps"]
    memory = data[key]

    mid_point = int(len(timestamps) * filter_partition)
    filtered_times = []
    filtered_memory = []
    for i, (t, m) in enumerate(zip(timestamps, memory)):
        if i < mid_point and m < memory_threshold:
            continue
        filtered_times.append(t)
        filtered_memory.append(m)
    return filtered_times, filtered_memory


def compare_memory_usage(data, labels, memory_threshold, filter_partition):
    plt.style.use("seaborn-v0_8")
    colors = ["#2ecc71", "#e74c3c", "#3498db", "#f1c40f"]

    fig1, ax1 = plt.subplots(figsize=(15, 5))
    for data_item, label, color in zip(data, labels, colors):
        timestamps, allocated = filter_data(data_item, memory_threshold, filter_partition, "allocated_memory")
        ax1.plot(timestamps, allocated, label=label, color=color, linewidth=2)

    ax1.set_xlabel("Time (s)", fontsize=12)
    ax1.set_ylabel("Allocated Memory (GB)", fontsize=12)
    ax1.set_title("Allocated Memory Usage Over Time", fontsize=14, pad=15)
    ax1.grid(True, linestyle="--", alpha=0.7)
    ax1.legend(frameon=True, fancybox=True, shadow=True, fontsize=10)
    ax1.spines["top"].set_visible(False)
    ax1.spines["right"].set_visible(False)
    plt.tight_layout()

    fig2, ax2 = plt.subplots(figsize=(15, 5))
    for data_item, label, color in zip(data, labels, colors):
        timestamps, reserved = filter_data(data_item, memory_threshold, filter_partition, "reserved_memory")
        ax2.plot(timestamps, reserved, label=label, color=color, linewidth=2)

    ax2.set_xlabel("Time (s)", fontsize=12)
    ax2.set_ylabel("Reserved Memory (GB)", fontsize=12)
    ax2.set_title("Reserved Memory Usage Over Time", fontsize=14, pad=15)
    ax2.grid(True, linestyle="--", alpha=0.7)
    ax2.legend(frameon=True, fancybox=True, shadow=True, fontsize=10)
    ax2.spines["top"].set_visible(False)
    ax2.spines["right"].set_visible(False)
    plt.tight_layout()

    return fig1, fig2


if __name__ == "__main__":
    args = parse_args()
    DIR = args.dir
    with open(f"{DIR}/torch_optimizer_before_fsdp_not_fixed_memory_usage.json") as f:
        optimizer_before_fsdp_not_fixed = json.load(f)

    with open(f"{DIR}/torch_optimizer_after_fsdp_memory_usage.json") as f:
        optimizer_after_fsdp = json.load(f)

    with open(f"{DIR}/torch_optimizer_before_fsdp_fixed_memory_usage.json") as f:
        optimizer_before_fsdp_fixed = json.load(f)

    with open(f"{DIR}/accelerate_memory_usage.json") as f:
        accelerate = json.load(f)

    data = [optimizer_before_fsdp_not_fixed, optimizer_before_fsdp_fixed, optimizer_after_fsdp, accelerate]
    labels = [
        "Optimizer Before FSDP (w/o fix)",
        "Optimizer Before FSDP (w/ fix)",
        "Optimizer After FSDP",
        "Accelerate",
    ]

    fig1, fig2 = compare_memory_usage(data, labels, args.memory_threshold, args.filter_partition)
    fig1.savefig(f"{DIR}/allocated_memory.png")
    fig2.savefig(f"{DIR}/reserved_memory.png")
