ja/4_evaluation/project/create_dataset.py (47 lines of code) (raw):
import argparse
import argilla as rg
from datasets import Dataset
################################################################################
# スクリプトのパラメータ
################################################################################
parser = argparse.ArgumentParser(
description="注釈付きArgillaデータからHugging Faceデータセットを作成します。"
)
parser.add_argument(
"--argilla_api_key",
type=str,
default="argilla.apikey",
help="ArgillaのAPIキー",
)
parser.add_argument(
"--argilla_api_url",
type=str,
default="http://localhost:6900",
help="ArgillaのAPI URL",
)
parser.add_argument(
"--dataset_path",
type=str,
default="exam_questions",
help="Argillaデータセットのパス",
)
parser.add_argument(
"--dataset_repo_id",
type=str,
default="burtenshaw/exam_questions",
help="Hugging FaceデータセットリポジトリID",
)
args = parser.parse_args()
################################################################################
# Argillaクライアントを初期化し、データセットを読み込む
################################################################################
client = rg.Argilla(api_key=args.argilla_api_key, api_url=args.argilla_api_url)
dataset = client.datasets(args.dataset_path)
################################################################################
# Argillaレコードを処理
################################################################################
dataset_rows = []
for record in dataset.records(with_suggestions=True, with_responses=True):
row = record.fields
if len(record.responses) == 0:
answer = record.suggestions["correct_answer"].value
row["correct_answer"] = answer
else:
for response in record.responses:
if response.question_name == "correct_answer":
row["correct_answer"] = response.value
dataset_rows.append(row)
################################################################################
# Hugging Faceデータセットを作成し、Hubにプッシュ
################################################################################
hf_dataset = Dataset.from_list(dataset_rows)
hf_dataset.push_to_hub(repo_id=args.dataset_repo_id)
print(f"データセットが{args.dataset_repo_id}に正常にプッシュされました")