def language_predict()

in cookbook-efforts/dpo-orpo-preference/aya_dpo_gen.py [0:0]


def language_predict(inputs: StepInput) -> StepOutput:
    """
    A step to predict the language of the generated text.
    Sometimes models fail to generate text in the desired language.
    This step helps to identify such cases using an external language prediction model.
    """
    for input in inputs:
        try:
            cleaned_input = input["generation"].replace("\n", " ")
            resp = InferenceClient("laurievb/OpenLID").text_classification(
                cleaned_input
            )
            top_prediction = resp[0]  # top prediction is the first element in the list
            input["predicted_generation_language"] = top_prediction.label
            input["predicted_generation_language_score"] = min(
                1.0, top_prediction.score
            )  # ensure score is between 0 and 1
        except Exception as e:
            print(e)
            input["predicted_generation_language"] = "error"
            input["predicted_generation_language_score"] = 0.0
    yield inputs