def augment_config_from_db()

in mephisto/tools/scripts.py [0:0]


def augment_config_from_db(script_cfg: DictConfig, db: "MephistoDB") -> DictConfig:
    """
    Check the database for validity of the incoming MephistoConfig, ensure
    that the config has all the necessary fields set.
    """
    cfg = script_cfg.mephisto
    requester_name = cfg.provider.get("requester_name", None)
    provider_type = cfg.provider.get("_provider_type", None)
    architect_type = cfg.architect.get("_architect_type", None)

    if requester_name is None:
        if provider_type is None:
            print("No requester specified, defaulting to mock")
            provider_type = "mock"
        if provider_type == "mock":
            req = get_mock_requester(db)
            requester_name = req.requester_name
        else:
            reqs = db.find_requesters(provider_type=provider_type)
            # TODO (#93) proper logging
            if len(reqs) == 0:
                print(
                    f"No requesters found for provider type {provider_type}, please "
                    f"register one. You can register with `mephisto register {provider_type}`, "
                    f"or `python mephisto/client/cli.py register {provider_type}` if you haven't "
                    "installed Mephisto using poetry."
                )
                exit(1)
            elif len(reqs) == 1:
                req = reqs[0]
                requester_name = req.requester_name
                print(
                    f"Found one `{provider_type}` requester to launch with: {requester_name}"
                )
            else:
                req = reqs[-1]
                requester_name = req.requester_name
                print(
                    f"Found many `{provider_type}` requesters to launch with, "
                    f"choosing the most recent: {requester_name}"
                )
    else:
        # Ensure provided requester exists
        reqs = db.find_requesters(requester_name=requester_name)
        if len(reqs) == 0:
            print(
                f"No requesters found under name {requester_name}, "
                "have you registered with `mephisto register`?"
            )
            exit(1)
        provider_type = reqs[0].provider_type

    if provider_type in ["mturk"]:
        input(
            f"This task is going to launch live on {provider_type}, press enter to continue: "
        )
    if provider_type in ["mturk_sandbox", "mturk"] and architect_type not in [
        "heroku",
        "ec2",
    ]:
        input(
            f"This task is going to launch live on {provider_type}, but your "
            f"provided architect is {architect_type}, are you sure you "
            "want to do this? : "
        )

    cfg.provider.requester_name = requester_name
    cfg.provider._provider_type = provider_type
    return script_cfg