in api/controllers/datasets.py [0:0]
def create(credentials, tid, name):
ensure_owner_or_admin(tid, credentials["id"])
if not bool(re.fullmatch("[a-zA-Z0-9-]{1,62}", name)):
bottle.abort(
400,
"Invalid dataset name - must only contain alphanumeric characters "
+ "or '-' and must be shorter than 63 characters",
)
dataset_upload = bottle.request.files.get("file")
tm = TaskModel()
task = tm.get(tid)
delta_dataset_uploads = []
delta_metric_types = [
config["type"]
for config in util.json_decode(task.annotation_config_json)["delta_metrics"]
]
for delta_metric_type in delta_metric_types:
delta_dataset_uploads.append(
(bottle.request.files.get(delta_metric_type), delta_metric_type)
)
uploads = [(dataset_upload, None)] + delta_dataset_uploads
parsed_uploads = []
# Ensure correct format
for upload, perturb_prefix in uploads:
try:
parsed_upload = [
util.json_decode(line)
for line in upload.file.read().decode("utf-8").splitlines()
]
except Exception as ex:
logger.exception(ex)
bottle.abort(400, "Could not parse dataset file. Is it a utf-8 jsonl?")
for io in parsed_upload:
try:
assert "uid" in io, "'uid' must be present for every example"
assert (
"tags" in io
), "there must be a field called 'tags' on every line of the jsonl"
assert isinstance(
io["tags"], list
), "'tags' must be a list on every line of the jsonl"
if perturb_prefix is not None:
assert "input_id" in io, (
"'input_id' must be present for every example for"
+ " perturbed dataset uploads"
)
except Exception as ex:
bottle.abort(400, str(ex))
verified, message = task.verify_annotation(
io, mode=AnnotationVerifierMode.dataset_upload
)
if not verified:
bottle.abort(400, message)
parsed_uploads.append((parsed_upload, perturb_prefix))
# Upload to s3
for parsed_upload, perturb_prefix in parsed_uploads:
try:
s3_client = boto3.client(
"s3",
aws_access_key_id=config["eval_aws_access_key_id"],
aws_secret_access_key=config["eval_aws_secret_access_key"],
region_name=config["eval_aws_region"],
)
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp:
for datum in parsed_upload:
tmp.write(util.json_encode(task.convert_to_model_io(datum)) + "\n")
tmp.close()
response = s3_client.upload_file(
tmp.name,
task.s3_bucket,
get_data_s3_path(task.task_code, name + ".jsonl", perturb_prefix),
)
os.remove(tmp.name)
if response:
logger.info(response)
except Exception as ex:
logger.exception(f"Failed to load {name} to S3 due to {ex}.")
bottle.abort(400, "Issue loading dataset to S3")
# Create an entry in the db for the dataset, or skip if one already exists.
d = DatasetModel()
updated_existing_dataset = False
if not d.getByName(name): # avoid id increment for unsuccessful creation
if d.create(
name=name,
task_id=tid,
rid=0,
access_type=AccessTypeEnum.hidden,
longdesc=None,
source_url=None,
):
logger.info(f"Registered {name} in datasets db.")
else:
updated_existing_dataset = True
# Evaluate all models
eval_config = {
"aws_access_key_id": config["eval_aws_access_key_id"],
"aws_secret_access_key": config["eval_aws_secret_access_key"],
"aws_region": config["eval_aws_region"],
"evaluation_sqs_queue": config["evaluation_sqs_queue"],
}
send_eval_request(
model_id="*",
dataset_name=name,
config=eval_config,
eval_server_id=task.eval_server_id,
logger=logger,
reload_datasets=True,
)
return util.json_encode(
{"success": "ok", "updated_existing_dataset": updated_existing_dataset}
)