def main()

in scripts/cost_estimate.py [0:0]


def main():
    lm = DryrunLM()

    task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc"
    values = []
    for taskname in task_list.split(","):
        lm.tokencost = 0
        evaluator.simple_evaluate(
            lm=lm,
            task_dict={taskname: tasks.get_task(taskname)()},
            num_fewshot=0,
            limit=None,
            bootstrap_iters=10,
        )

        print(taskname, lm.tokencost)
        values.append(
            [
                taskname,
                lm.tokencost,
                lm.tokencost / 1000 * 0.0008,
                lm.tokencost / 1000 * 0.0012,
                lm.tokencost / 1000 * 0.006,
                lm.tokencost / 1000 * 0.06,
            ]
        )
    from pytablewriter import MarkdownTableWriter

    writer = MarkdownTableWriter()
    writer.headers = ["Task", "Tokens", "Ada", "Babbage", "Curie", "Davinci"]

    values.sort(key=lambda x: -x[1])
    totcost = sum([x[1] for x in values])
    values.append(
        [
            "**Total**",
            totcost,
            totcost / 1000 * 0.0008,
            totcost / 1000 * 0.0012,
            totcost / 1000 * 0.006,
            totcost / 1000 * 0.06,
        ]
    )

    writer.value_matrix = values

    print(writer.dumps())