in src/graph_notebook/magics/graph_magic.py [0:0]
def load(self, line='', local_ns: dict = None):
if self.client.is_analytics_domain():
load_type = ANALYTICS_LOAD_TYPES[0]
load_formats = VALID_INCREMENTAL_FORMATS
else:
load_type = DB_LOAD_TYPES[0]
load_formats = VALID_BULK_FORMATS
# TODO: change widgets to let any arbitrary inputs be added by users
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--source', default='s3://')
try:
parser.add_argument('-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn)
except AttributeError:
print(f"Missing required configuration option 'load_from_s3_arn'. Please ensure that you have provided a "
"valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.")
return
parser.add_argument('-f', '--format', choices=load_formats, default=FORMAT_CSV)
parser.add_argument('-p', '--parallelism', choices=PARALLELISM_OPTIONS, default=PARALLELISM_HIGH)
try:
parser.add_argument('-r', '--region', default=self.graph_notebook_config.aws_region)
except AttributeError:
print("Missing required configuration option 'aws_region'. Please ensure that you have provided a "
"valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.")
return
parser.add_argument('--no-fail-on-error', action='store_true', default=False)
parser.add_argument('--update-single-cardinality', action='store_true', default=True)
parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
parser.add_argument('--run', action='store_true', default=False)
parser.add_argument('-m', '--mode', choices=LOAD_JOB_MODES, default=MODE_AUTO)
parser.add_argument('-q', '--queue-request', action='store_true', default=False)
parser.add_argument('-d', '--dependencies', action='append', default=[])
parser.add_argument('-e', '--no-edge-ids', action='store_true', default=False)
parser.add_argument('-c', '--concurrency', type=int, default=1)
parser.add_argument('--named-graph-uri', type=str, default=DEFAULT_NAMEDGRAPH_URI,
help='The default graph for all RDF formats when no graph is specified. '
'Default is http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph.')
parser.add_argument('--base-uri', type=str, default=DEFAULT_BASE_URI,
help='The base URI for RDF/XML and Turtle formats. '
'Default is http://aws.amazon.com/neptune/default')
parser.add_argument('--allow-empty-strings', action='store_true', default=False,
help='Load empty strings found in node and edge property values.')
parser.add_argument('-n', '--nopoll', action='store_true', default=False)
args = parser.parse_args(line.split())
button = widgets.Button(description="Submit")
output = widgets.Output()
widget_width = '25%'
label_width = '16%'
source = widgets.Text(
value=args.source,
placeholder='Type something',
disabled=False,
layout=widgets.Layout(width=widget_width)
)
arn = widgets.Text(
value=args.loader_arn,
placeholder='Type something',
disabled=False,
layout=widgets.Layout(width=widget_width)
)
source_format = widgets.Dropdown(
options=load_formats,
value=args.format,
disabled=False,
layout=widgets.Layout(width=widget_width)
)
ids_hbox_visibility = 'none'
gremlin_parser_options_hbox_visibility = 'none'
named_graph_hbox_visibility = 'none'
base_uri_hbox_visibility = 'none'
concurrency_hbox_visibility = 'none'
if load_type == 'incremental':
concurrency_hbox_visibility = 'flex'
else:
if source_format.value.lower() == FORMAT_CSV:
gremlin_parser_options_hbox_visibility = 'flex'
elif source_format.value.lower() == FORMAT_OPENCYPHER:
ids_hbox_visibility = 'flex'
elif source_format.value.lower() in RDF_LOAD_FORMATS:
named_graph_hbox_visibility = 'flex'
if source_format.value.lower() in BASE_URI_FORMATS:
base_uri_hbox_visibility = 'flex'
region_box = widgets.Text(
value=args.region,
placeholder=args.region,
disabled=False,
layout=widgets.Layout(width=widget_width)
)
fail_on_error = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(not args.no_fail_on_error).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
parallelism = widgets.Dropdown(
options=PARALLELISM_OPTIONS,
value=args.parallelism,
disabled=False,
layout=widgets.Layout(width=widget_width)
)
allow_empty_strings = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(args.allow_empty_strings).upper(),
disabled=False,
layout=widgets.Layout(display=gremlin_parser_options_hbox_visibility,
width=widget_width)
)
named_graph_uri = widgets.Text(
value=args.named_graph_uri,
placeholder='http://named-graph-string',
disabled=False,
layout=widgets.Layout(display=named_graph_hbox_visibility,
width=widget_width)
)
base_uri = widgets.Text(
value=args.base_uri,
placeholder='http://base-uri-string',
disabled=False,
layout=widgets.Layout(display=base_uri_hbox_visibility,
width=widget_width)
)
update_single_cardinality = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(args.update_single_cardinality).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
mode = widgets.Dropdown(
options=LOAD_JOB_MODES,
value=args.mode,
disabled=False,
layout=widgets.Layout(width=widget_width)
)
user_provided_edge_ids = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(not args.no_edge_ids).upper(),
disabled=False,
layout=widgets.Layout(display=ids_hbox_visibility,
width=widget_width)
)
queue_request = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(args.queue_request).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
dependencies = widgets.Textarea(
value="\n".join(args.dependencies),
placeholder='load_A_id\nload_B_id',
disabled=False,
layout=widgets.Layout(width=widget_width)
)
concurrency = widgets.BoundedIntText(
value=str(args.concurrency),
placeholder=1,
min=1,
max=2 ** 16,
disabled=False,
layout=widgets.Layout(display=concurrency_hbox_visibility,
width=widget_width)
)
poll_status = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(not args.nopoll).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
# Create a series of HBox containers that will hold the widgets and labels
# that make up the %load form. Some of the labels and widgets are created
# in two parts to support the validation steps that come later. In the case
# of validation errors this allows additional text to easily be added to an
# HBox describing the issue.
source_hbox_label = widgets.Label('Source:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
source_hbox = widgets.HBox([source_hbox_label, source])
format_hbox_label = widgets.Label('Format:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
source_format_hbox = widgets.HBox([format_hbox_label, source_format])
region_hbox = widgets.HBox([widgets.Label('Region:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
region_box])
arn_hbox_label = widgets.Label('Load ARN:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
arn_hbox = widgets.HBox([arn_hbox_label, arn])
mode_hbox = widgets.HBox([widgets.Label('Mode:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
mode])
fail_hbox = widgets.HBox([widgets.Label('Fail on Error:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
fail_on_error])
parallelism_hbox = widgets.HBox([widgets.Label('Parallelism:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
parallelism])
allow_empty_strings_hbox_label = widgets.Label('Allow Empty Strings:',
layout=widgets.Layout(width=label_width,
display=gremlin_parser_options_hbox_visibility,
justify_content="flex-end"))
allow_empty_strings_hbox = widgets.HBox([allow_empty_strings_hbox_label, allow_empty_strings])
named_graph_uri_hbox_label = widgets.Label('Named Graph URI:',
layout=widgets.Layout(width=label_width,
display=named_graph_hbox_visibility,
justify_content="flex-end"))
named_graph_uri_hbox = widgets.HBox([named_graph_uri_hbox_label, named_graph_uri])
base_uri_hbox_label = widgets.Label('Base URI:',
layout=widgets.Layout(width=label_width,
display=base_uri_hbox_visibility,
justify_content="flex-end"))
base_uri_hbox = widgets.HBox([base_uri_hbox_label, base_uri])
cardinality_hbox = widgets.HBox([widgets.Label('Update Single Cardinality:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
update_single_cardinality])
queue_hbox = widgets.HBox([widgets.Label('Queue Request:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
queue_request])
dep_hbox_label = widgets.Label('Dependencies:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end"))
dep_hbox = widgets.HBox([dep_hbox_label, dependencies])
ids_hbox_label = widgets.Label('User Provided Edge Ids:',
layout=widgets.Layout(width=label_width,
display=ids_hbox_visibility,
justify_content="flex-end"))
ids_hbox = widgets.HBox([ids_hbox_label, user_provided_edge_ids])
concurrency_hbox_label = widgets.Label('Concurrency:',
layout=widgets.Layout(width=label_width,
display=concurrency_hbox_visibility,
justify_content="flex-end"))
concurrency_hbox = widgets.HBox([concurrency_hbox_label, concurrency])
poll_status_label = widgets.Label('Poll Load Status:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
poll_status_hbox = widgets.HBox([poll_status_label, poll_status])
def update_edge_ids_options(change):
if change.new.lower() == FORMAT_OPENCYPHER:
ids_hbox_visibility = 'flex'
else:
ids_hbox_visibility = 'none'
user_provided_edge_ids.value = 'TRUE'
user_provided_edge_ids.layout.display = ids_hbox_visibility
ids_hbox_label.layout.display = ids_hbox_visibility
def update_parserconfig_options(change):
if change.new.lower() == FORMAT_CSV:
gremlin_parser_options_hbox_visibility = 'flex'
named_graph_hbox_visibility_hbox_visibility = 'none'
base_uri_hbox_visibility = 'none'
named_graph_uri.value = ''
base_uri.value = ''
elif change.new.lower() == FORMAT_OPENCYPHER:
gremlin_parser_options_hbox_visibility = 'none'
allow_empty_strings.value = 'FALSE'
named_graph_hbox_visibility_hbox_visibility = 'none'
base_uri_hbox_visibility = 'none'
named_graph_uri.value = ''
base_uri.value = ''
else:
gremlin_parser_options_hbox_visibility = 'none'
allow_empty_strings.value = 'FALSE'
named_graph_hbox_visibility_hbox_visibility = 'flex'
named_graph_uri.value = DEFAULT_NAMEDGRAPH_URI
if change.new.lower() in BASE_URI_FORMATS:
base_uri_hbox_visibility = 'flex'
base_uri.value = DEFAULT_BASE_URI
else:
base_uri_hbox_visibility = 'none'
base_uri.value = ''
allow_empty_strings.layout.display = gremlin_parser_options_hbox_visibility
allow_empty_strings_hbox_label.layout.display = gremlin_parser_options_hbox_visibility
named_graph_uri.layout.display = named_graph_hbox_visibility_hbox_visibility
named_graph_uri_hbox_label.layout.display = named_graph_hbox_visibility_hbox_visibility
base_uri.layout.display = base_uri_hbox_visibility
base_uri_hbox_label.layout.display = base_uri_hbox_visibility
source_format.observe(update_edge_ids_options, names='value')
source_format.observe(update_parserconfig_options, names='value')
basic_load_boxes = [source_hbox, source_format_hbox, region_hbox, fail_hbox]
# load arguments for Analytics incremental load
incremental_load_boxes = [concurrency_hbox]
# load arguments for Neptune bulk load
bulk_load_boxes = [arn_hbox, mode_hbox, parallelism_hbox, cardinality_hbox,
queue_hbox, dep_hbox, ids_hbox, allow_empty_strings_hbox,
named_graph_uri_hbox, base_uri_hbox, poll_status_hbox]
submit_load_boxes = [button, output]
if load_type == 'incremental':
display_boxes = basic_load_boxes + incremental_load_boxes + submit_load_boxes
else:
display_boxes = basic_load_boxes + bulk_load_boxes + submit_load_boxes
display(*display_boxes)
def on_button_clicked(b):
source_hbox.children = (source_hbox_label, source,)
arn_hbox.children = (arn_hbox_label, arn,)
source_format_hbox.children = (format_hbox_label, source_format,)
allow_empty_strings.children = (allow_empty_strings_hbox_label, allow_empty_strings,)
named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,)
base_uri_hbox.children = (base_uri_hbox_label, base_uri,)
dep_hbox.children = (dep_hbox_label, dependencies,)
concurrency_hbox.children = (concurrency_hbox_label, concurrency,)
validated = True
validation_label_style = DescriptionStyle(color='red')
if not (source.value.startswith('s3://') and len(source.value) > 7) and not source.value.startswith('/'):
validated = False
source_validation_label = widgets.HTML(
'<p style="color:red;">Source must be an s3 bucket or file path</p>')
source_validation_label.style = validation_label_style
source_hbox.children += (source_validation_label,)
if source_format.value == '':
validated = False
source_format_validation_label = widgets.HTML('<p style="color:red;">Format cannot be blank.</p>')
source_format_hbox.children += (source_format_validation_label,)
if load_type == 'bulk':
if not arn.value.startswith('arn:aws') and source.value.startswith(
"s3://"): # only do this validation if we are using an s3 bucket.
validated = False
arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
arn_hbox.children += (arn_validation_label,)
dependencies_list = list(filter(None, dependencies.value.split('\n')))
if not len(dependencies_list) < 64:
validated = False
dep_validation_label = widgets.HTML(
'<p style="color:red;">A maximum of 64 jobs may be queued at once.</p>')
dep_hbox.children += (dep_validation_label,)
if not validated:
return
# replace any env variables in source.value with their values, can use $foo or ${foo}.
# Particularly useful for ${AWS_REGION}
source_exp = os.path.expandvars(
source.value)
logger.info(f'using source_exp: {source_exp}')
try:
kwargs = {
'failOnError': fail_on_error.value,
'region': region_box.value
}
if load_type == 'incremental':
incremental_load_kwargs = {
'source': source.value,
'format': source_format.value,
'concurrency': concurrency.value
}
kwargs.update(incremental_load_kwargs)
else:
bulk_load_kwargs = {
'mode': mode.value,
'parallelism': parallelism.value,
'updateSingleCardinalityProperties': update_single_cardinality.value,
'queueRequest': queue_request.value,
'parserConfiguration': {}
}
if dependencies:
bulk_load_kwargs['dependencies'] = dependencies_list
if source_format.value.lower() == FORMAT_OPENCYPHER:
bulk_load_kwargs['userProvidedEdgeIds'] = user_provided_edge_ids.value
elif source_format.value.lower() == FORMAT_CSV:
if allow_empty_strings.value == 'TRUE':
bulk_load_kwargs['parserConfiguration']['allowEmptyStrings'] = True
elif source_format.value.lower() in RDF_LOAD_FORMATS:
if named_graph_uri.value:
bulk_load_kwargs['parserConfiguration']['namedGraphUri'] = named_graph_uri.value
if base_uri.value and source_format.value.lower() in BASE_URI_FORMATS:
bulk_load_kwargs['parserConfiguration']['baseUri'] = base_uri.value
kwargs.update(bulk_load_kwargs)
source_hbox.close()
source_format_hbox.close()
region_hbox.close()
arn_hbox.close()
mode_hbox.close()
fail_hbox.close()
parallelism_hbox.close()
cardinality_hbox.close()
queue_hbox.close()
dep_hbox.close()
poll_status_hbox.close()
ids_hbox.close()
allow_empty_strings_hbox.close()
named_graph_uri_hbox.close()
base_uri_hbox.close()
concurrency_hbox.close()
button.close()
load_submit_status_output = widgets.Output()
load_submit_hbox = widgets.HBox([load_submit_status_output])
with output:
display(load_submit_hbox)
with load_submit_status_output:
print(f"{load_type.capitalize()} load request submitted, waiting for response...")
display_html(HTML(loading_wheel_html))
try:
if load_type == 'incremental':
load_oc_params = '{'
for param, value in kwargs.items():
value_substr = str(value) if (isinstance(value, int) or param == 'failOnError') \
else '"' + value + '"'
next_param = param + ': ' + value_substr
load_oc_params += next_param
if param == 'concurrency':
if source_format.value == FORMAT_NTRIPLE:
load_oc_params += ', blankNodeHandling: "convertToIri"'
load_oc_params += '}'
else:
load_oc_params += ', '
load_oc_query = f"CALL neptune.load({load_oc_params})"
oc_load = self.client.opencypher_http(load_oc_query)
else:
if source.value.startswith("s3://"):
load_res = self.client.load(str(source_exp), source_format.value, arn.value, **kwargs)
else:
load_res = self.client.load(str(source_exp), source_format.value, **kwargs)
except Exception as e:
load_submit_status_output.clear_output()
with output:
print(f"Failed to submit {load_type.capitalize()} load request.")
logger.error(e)
return
load_submit_status_output.clear_output()
if load_type == 'incremental':
oc_load.raise_for_status()
load_result = oc_load.json()
store_to_ns(args.store_to, load_result, local_ns)
with output:
print("Load completed.\n")
print(json.dumps(load_result, indent=2))
else:
try:
load_res.raise_for_status()
except Exception as e:
# Ignore failure to retrieve status, we handle missing status below.
pass
load_result = load_res.json()
store_to_ns(args.store_to, load_result, local_ns)
if 'status' not in load_result or load_result['status'] != '200 OK':
with output:
print('Something went wrong.')
logger.error(load_result)
return
if poll_status.value == 'FALSE':
start_msg_label = widgets.Label(f'Load started successfully!')
polling_msg_label = widgets.Label(
f'You can run "%load_status {load_result["payload"]["loadId"]}" '
f'in another cell to check the current status of your bulk load.')
start_msg_hbox = widgets.HBox([start_msg_label])
polling_msg_hbox = widgets.HBox([polling_msg_label])
vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox])
with output:
display(vbox)
else:
poll_interval = 5
load_id_label = widgets.Label(f'Load ID: {load_result["payload"]["loadId"]}')
interval_output = widgets.Output()
job_status_output = widgets.Output()
load_id_hbox = widgets.HBox([load_id_label])
status_hbox = widgets.HBox([interval_output])
vbox = widgets.VBox([load_id_hbox, status_hbox, job_status_output])
with output:
display(vbox)
last_poll_time = time.time()
new_interval = True
while True:
time_elapsed = int(time.time() - last_poll_time)
time_remaining = poll_interval - time_elapsed
interval_output.clear_output()
if time_elapsed > poll_interval:
with interval_output:
print('checking status...')
job_status_output.clear_output()
with job_status_output:
display_html(HTML(loading_wheel_html))
new_interval = True
try:
load_status_res = self.client.load_status(load_result['payload']['loadId'])
load_status_res.raise_for_status()
interval_check_response = load_status_res.json()
except Exception as e:
logger.error(e)
with job_status_output:
print('Something went wrong updating job status. Ending.')
return
job_status_output.clear_output()
with job_status_output:
# parse status & execution_time differently for Analytics and NeptuneDB
overall_status = \
interval_check_response["payload"][
"status"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["status"]
total_time_spent = \
interval_check_response["payload"][
"timeElapsedSeconds"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
print(f'Overall Status: {overall_status}')
if overall_status in FINAL_LOAD_STATUSES:
execution_time = total_time_spent
if execution_time == 0:
execution_time_statement = '<1 second'
elif execution_time > 59:
execution_time_statement = str(datetime.timedelta(seconds=execution_time))
else:
execution_time_statement = f'{execution_time} seconds'
print('Total execution time: ' + execution_time_statement)
interval_output.close()
print('Done.')
return
last_poll_time = time.time()
else:
if new_interval:
with job_status_output:
display_html(HTML(loading_wheel_html))
new_interval = False
with interval_output:
print(f'checking status in {time_remaining} seconds')
time.sleep(1)
except HTTPError as httpEx:
output.clear_output()
with output:
print(httpEx.response.content.decode('utf-8'))
button.on_click(on_button_clicked)
if args.run:
on_button_clicked(None)