cookbook-efforts/dpo-orpo-preference/custom_preference_to_argilla.py (70 lines of code) (raw):
import contextlib
import hashlib
from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing_extensions import override
with contextlib.suppress(ImportError):
import argilla as rg
from distilabel.steps import PreferenceToArgilla, StepInput
if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput, RatingQuestion, TextQuestion
class CustomPreferenceToArgilla(PreferenceToArgilla):
"""
Custom PreferenceToArgilla step that adds metadata properties to the feedback records.
This allows filtering based on metadata properties in the Argilla UI.
"""
metadata_properties: List[Dict[str, Any]]
def load(self) -> None:
super().load()
for metadata_property in self.metadata_properties:
metadata_property_type = metadata_property.pop("type", None)
if metadata_property_type == "float":
metadata_property = rg.FloatMetadataProperty.parse_obj(
metadata_property
)
elif metadata_property_type == "integer":
metadata_property = rg.IntegerMetadataProperty.parse_obj(
metadata_property
)
elif metadata_property_type == "terms":
metadata_property = rg.TermsMetadataProperty.parse_obj(
metadata_property
)
else:
break
self._rg_dataset.add_metadata_property(metadata_property) # type: ignore
def _rating_rationale_pairs(
self,
) -> List[Union["RatingQuestion", "TextQuestion"]]:
questions = super()._rating_rationale_pairs()
questions.append(
rg.TextQuestion( # type: ignore
name="improved_response",
title="How would you improve the response?",
required=False,
)
)
return questions
@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
records = []
for input in inputs:
# Generate the SHA-256 hash of the instruction to use it as the metadata
instruction_id = hashlib.sha256(
input["instruction"].encode("utf-8") # type: ignore
).hexdigest()
generations = {
f"{self._generations}-{idx}": generation
for idx, generation in enumerate(input["generations"]) # type: ignore
}
records.append( # type: ignore
rg.FeedbackRecord( # type: ignore
fields={
"id": instruction_id,
"instruction": input["instruction"], # type: ignore
**generations,
},
suggestions=self._add_suggestions_if_any(input), # type: ignore
metadata={
metadata_property["name"]: input[metadata_property["name"]]
for metadata_property in self.metadata_properties
if metadata_property["name"] in input
},
)
)
self._rg_dataset.add_records(records) # type: ignore
yield inputs