in src/graph_notebook/magics/graph_magic.py [0:0]
def load(self, line='', local_ns: dict = None):
# TODO: change widgets to let any arbitrary inputs be added by users
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--source', default='s3://')
try:
parser.add_argument('-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn)
except AttributeError:
print(f"Missing required configuration option 'load_from_s3_arn'. Please ensure that you have provided a "
"valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.")
return
parser.add_argument('-f', '--format', choices=LOADER_FORMAT_CHOICES, default=FORMAT_CSV)
parser.add_argument('-p', '--parallelism', choices=PARALLELISM_OPTIONS, default=PARALLELISM_HIGH)
try:
parser.add_argument('-r', '--region', default=self.graph_notebook_config.aws_region)
except AttributeError:
print("Missing required configuration option 'aws_region'. Please ensure that you have provided a "
"valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.")
return
parser.add_argument('--fail-on-failure', action='store_true', default=False)
parser.add_argument('--update-single-cardinality', action='store_true', default=True)
parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
parser.add_argument('--run', action='store_true', default=False)
parser.add_argument('-m', '--mode', choices=LOAD_JOB_MODES, default=MODE_AUTO)
parser.add_argument('-q', '--queue-request', action='store_true', default=False)
parser.add_argument('-d', '--dependencies', action='append', default=[])
parser.add_argument('-e', '--no-edge-ids', action='store_true', default=False)
parser.add_argument('--named-graph-uri', type=str, default=DEFAULT_NAMEDGRAPH_URI,
help='The default graph for all RDF formats when no graph is specified. '
'Default is http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph.')
parser.add_argument('--base-uri', type=str, default=DEFAULT_BASE_URI,
help='The base URI for RDF/XML and Turtle formats. '
'Default is http://aws.amazon.com/neptune/default')
parser.add_argument('--allow-empty-strings', action='store_true', default=False,
help='Load empty strings found in node and edge property values.')
parser.add_argument('-n', '--nopoll', action='store_true', default=False)
args = parser.parse_args(line.split())
region = self.graph_notebook_config.aws_region
button = widgets.Button(description="Submit")
output = widgets.Output()
widget_width = '25%'
label_width = '16%'
source = widgets.Text(
value=args.source,
placeholder='Type something',
disabled=False,
layout=widgets.Layout(width=widget_width)
)
arn = widgets.Text(
value=args.loader_arn,
placeholder='Type something',
disabled=False,
layout=widgets.Layout(width=widget_width)
)
source_format = widgets.Dropdown(
options=LOADER_FORMAT_CHOICES,
value=args.format,
disabled=False,
layout=widgets.Layout(width=widget_width)
)
ids_hbox_visibility = 'none'
gremlin_parser_options_hbox_visibility = 'none'
named_graph_hbox_visibility = 'none'
base_uri_hbox_visibility = 'none'
if source_format.value.lower() == FORMAT_CSV:
gremlin_parser_options_hbox_visibility = 'flex'
elif source_format.value.lower() == FORMAT_OPENCYPHER:
ids_hbox_visibility = 'flex'
elif source_format.value.lower() in RDF_LOAD_FORMATS:
named_graph_hbox_visibility = 'flex'
if source_format.value.lower() in BASE_URI_FORMATS:
base_uri_hbox_visibility = 'flex'
region_box = widgets.Text(
value=region,
placeholder=args.region,
disabled=False,
layout=widgets.Layout(width=widget_width)
)
fail_on_error = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(args.fail_on_failure).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
parallelism = widgets.Dropdown(
options=PARALLELISM_OPTIONS,
value=args.parallelism,
disabled=False,
layout=widgets.Layout(width=widget_width)
)
allow_empty_strings = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(args.allow_empty_strings).upper(),
disabled=False,
layout=widgets.Layout(display=gremlin_parser_options_hbox_visibility,
width=widget_width)
)
named_graph_uri = widgets.Text(
value=args.named_graph_uri,
placeholder='http://named-graph-string',
disabled=False,
layout=widgets.Layout(display=named_graph_hbox_visibility,
width=widget_width)
)
base_uri = widgets.Text(
value=args.base_uri,
placeholder='http://base-uri-string',
disabled=False,
layout=widgets.Layout(display=base_uri_hbox_visibility,
width=widget_width)
)
update_single_cardinality = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(args.update_single_cardinality).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
mode = widgets.Dropdown(
options=LOAD_JOB_MODES,
value=args.mode,
disabled=False,
layout=widgets.Layout(width=widget_width)
)
user_provided_edge_ids = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(not args.no_edge_ids).upper(),
disabled=False,
layout=widgets.Layout(display=ids_hbox_visibility,
width=widget_width)
)
queue_request = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(args.queue_request).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
dependencies = widgets.Textarea(
value="\n".join(args.dependencies),
placeholder='load_A_id\nload_B_id',
disabled=False,
layout=widgets.Layout(width=widget_width)
)
poll_status = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(not args.nopoll).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
# Create a series of HBox containers that will hold the widgets and labels
# that make up the %load form. Some of the labels and widgets are created
# in two parts to support the validation steps that come later. In the case
# of validation errors this allows additional text to easily be added to an
# HBox describing the issue.
source_hbox_label = widgets.Label('Source:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
source_hbox = widgets.HBox([source_hbox_label, source])
format_hbox_label = widgets.Label('Format:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
source_format_hbox = widgets.HBox([format_hbox_label, source_format])
region_hbox = widgets.HBox([widgets.Label('Region:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
region_box])
arn_hbox_label = widgets.Label('Load ARN:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
arn_hbox = widgets.HBox([arn_hbox_label, arn])
mode_hbox = widgets.HBox([widgets.Label('Mode:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
mode])
fail_hbox = widgets.HBox([widgets.Label('Fail on Error:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
fail_on_error])
parallelism_hbox = widgets.HBox([widgets.Label('Parallelism:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
parallelism])
allow_empty_strings_hbox_label = widgets.Label('Allow Empty Strings:',
layout=widgets.Layout(width=label_width,
display=gremlin_parser_options_hbox_visibility,
justify_content="flex-end"))
allow_empty_strings_hbox = widgets.HBox([allow_empty_strings_hbox_label, allow_empty_strings])
named_graph_uri_hbox_label = widgets.Label('Named Graph URI:',
layout=widgets.Layout(width=label_width,
display=named_graph_hbox_visibility,
justify_content="flex-end"))
named_graph_uri_hbox = widgets.HBox([named_graph_uri_hbox_label, named_graph_uri])
base_uri_hbox_label = widgets.Label('Base URI:',
layout=widgets.Layout(width=label_width,
display=base_uri_hbox_visibility,
justify_content="flex-end"))
base_uri_hbox = widgets.HBox([base_uri_hbox_label, base_uri])
cardinality_hbox = widgets.HBox([widgets.Label('Update Single Cardinality:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
update_single_cardinality])
queue_hbox = widgets.HBox([widgets.Label('Queue Request:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
queue_request])
dep_hbox_label = widgets.Label('Dependencies:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end"))
dep_hbox = widgets.HBox([dep_hbox_label, dependencies])
ids_hbox_label = widgets.Label('User Provided Edge Ids:',
layout=widgets.Layout(width=label_width,
display=ids_hbox_visibility,
justify_content="flex-end"))
ids_hbox = widgets.HBox([ids_hbox_label, user_provided_edge_ids])
poll_status_label = widgets.Label('Poll Load Status:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
poll_status_hbox = widgets.HBox([poll_status_label, poll_status])
def update_edge_ids_options(change):
if change.new.lower() == FORMAT_OPENCYPHER:
ids_hbox_visibility = 'flex'
else:
ids_hbox_visibility = 'none'
user_provided_edge_ids.value = 'TRUE'
user_provided_edge_ids.layout.display = ids_hbox_visibility
ids_hbox_label.layout.display = ids_hbox_visibility
def update_parserconfig_options(change):
if change.new.lower() == FORMAT_CSV:
gremlin_parser_options_hbox_visibility = 'flex'
named_graph_hbox_visibility_hbox_visibility = 'none'
base_uri_hbox_visibility = 'none'
named_graph_uri.value = ''
base_uri.value = ''
elif change.new.lower() == FORMAT_OPENCYPHER:
gremlin_parser_options_hbox_visibility = 'none'
allow_empty_strings.value = 'FALSE'
named_graph_hbox_visibility_hbox_visibility = 'none'
base_uri_hbox_visibility = 'none'
named_graph_uri.value = ''
base_uri.value = ''
else:
gremlin_parser_options_hbox_visibility = 'none'
allow_empty_strings.value = 'FALSE'
named_graph_hbox_visibility_hbox_visibility = 'flex'
named_graph_uri.value = DEFAULT_NAMEDGRAPH_URI
if change.new.lower() in BASE_URI_FORMATS:
base_uri_hbox_visibility = 'flex'
base_uri.value = DEFAULT_BASE_URI
else:
base_uri_hbox_visibility = 'none'
base_uri.value = ''
allow_empty_strings.layout.display = gremlin_parser_options_hbox_visibility
allow_empty_strings_hbox_label.layout.display = gremlin_parser_options_hbox_visibility
named_graph_uri.layout.display = named_graph_hbox_visibility_hbox_visibility
named_graph_uri_hbox_label.layout.display = named_graph_hbox_visibility_hbox_visibility
base_uri.layout.display = base_uri_hbox_visibility
base_uri_hbox_label.layout.display = base_uri_hbox_visibility
source_format.observe(update_edge_ids_options, names='value')
source_format.observe(update_parserconfig_options, names='value')
display(source_hbox,
source_format_hbox,
region_hbox,
arn_hbox,
mode_hbox,
fail_hbox,
parallelism_hbox,
cardinality_hbox,
queue_hbox,
dep_hbox,
poll_status_hbox,
ids_hbox,
allow_empty_strings_hbox,
named_graph_uri_hbox,
base_uri_hbox,
button,
output)
def on_button_clicked(b):
source_hbox.children = (source_hbox_label, source,)
arn_hbox.children = (arn_hbox_label, arn,)
source_format_hbox.children = (format_hbox_label, source_format,)
allow_empty_strings.children = (allow_empty_strings_hbox_label, allow_empty_strings,)
named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,)
base_uri_hbox.children = (base_uri_hbox_label, base_uri,)
dep_hbox.children = (dep_hbox_label, dependencies,)
dependencies_list = list(filter(None, dependencies.value.split('\n')))
validated = True
validation_label_style = DescriptionStyle(color='red')
if not (source.value.startswith('s3://') and len(source.value) > 7) and not source.value.startswith('/'):
validated = False
source_validation_label = widgets.HTML(
'<p style="color:red;">Source must be an s3 bucket or file path</p>')
source_validation_label.style = validation_label_style
source_hbox.children += (source_validation_label,)
if source_format.value == '':
validated = False
source_format_validation_label = widgets.HTML('<p style="color:red;">Format cannot be blank.</p>')
source_format_hbox.children += (source_format_validation_label,)
if not arn.value.startswith('arn:aws') and source.value.startswith(
"s3://"): # only do this validation if we are using an s3 bucket.
validated = False
arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
arn_hbox.children += (arn_validation_label,)
if not len(dependencies_list) < 64:
validated = False
dep_validation_label = widgets.HTML(
'<p style="color:red;">A maximum of 64 jobs may be queued at once.</p>')
dep_hbox.children += (dep_validation_label,)
if not validated:
return
# replace any env variables in source.value with their values, can use $foo or ${foo}.
# Particularly useful for ${AWS_REGION}
source_exp = os.path.expandvars(
source.value)
logger.info(f'using source_exp: {source_exp}')
try:
kwargs = {
'failOnError': fail_on_error.value,
'parallelism': parallelism.value,
'updateSingleCardinalityProperties': update_single_cardinality.value,
'queueRequest': queue_request.value,
'region': region,
'parserConfiguration': {}
}
if dependencies:
kwargs['dependencies'] = dependencies_list
if source_format.value.lower() == FORMAT_OPENCYPHER:
kwargs['userProvidedEdgeIds'] = user_provided_edge_ids.value
elif source_format.value.lower() == FORMAT_CSV:
if allow_empty_strings.value == 'TRUE':
kwargs['parserConfiguration']['allowEmptyStrings'] = True
elif source_format.value.lower() in RDF_LOAD_FORMATS:
if named_graph_uri.value:
kwargs['parserConfiguration']['namedGraphUri'] = named_graph_uri.value
if base_uri.value and source_format.value.lower() in BASE_URI_FORMATS:
kwargs['parserConfiguration']['baseUri'] = base_uri.value
if source.value.startswith("s3://"):
load_res = self.client.load(str(source_exp), source_format.value, arn.value, **kwargs)
else:
load_res = self.client.load(str(source_exp), source_format.value, **kwargs)
load_res.raise_for_status()
load_result = load_res.json()
store_to_ns(args.store_to, load_result, local_ns)
source_hbox.close()
source_format_hbox.close()
region_hbox.close()
arn_hbox.close()
mode_hbox.close()
fail_hbox.close()
parallelism_hbox.close()
cardinality_hbox.close()
queue_hbox.close()
dep_hbox.close()
poll_status_hbox.close()
ids_hbox.close()
allow_empty_strings_hbox.close()
named_graph_uri_hbox.close()
base_uri_hbox.close()
button.close()
output.close()
if 'status' not in load_result or load_result['status'] != '200 OK':
with output:
print('Something went wrong.')
print(load_result)
logger.error(load_result)
return
if poll_status.value == 'FALSE':
start_msg_label = widgets.Label(f'Load started successfully!')
polling_msg_label = widgets.Label(f'You can run "%load_status {load_result["payload"]["loadId"]}" '
f'in another cell to check the current status of your bulk load.')
start_msg_hbox = widgets.HBox([start_msg_label])
polling_msg_hbox = widgets.HBox([polling_msg_label])
vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox])
display(vbox)
else:
poll_interval = 5
load_id_label = widgets.Label(f'Load ID: {load_result["payload"]["loadId"]}')
interval_output = widgets.Output()
job_status_output = widgets.Output()
load_id_hbox = widgets.HBox([load_id_label])
status_hbox = widgets.HBox([interval_output])
vbox = widgets.VBox([load_id_hbox, status_hbox, job_status_output])
display(vbox)
last_poll_time = time.time()
while True:
time_elapsed = int(time.time() - last_poll_time)
time_remaining = poll_interval - time_elapsed
interval_output.clear_output()
if time_elapsed > poll_interval:
with interval_output:
print('checking status...')
job_status_output.clear_output()
with job_status_output:
display_html(HTML(loading_wheel_html))
try:
load_status_res = self.client.load_status(load_result['payload']['loadId'])
load_status_res.raise_for_status()
interval_check_response = load_status_res.json()
except Exception as e:
logger.error(e)
with job_status_output:
print('Something went wrong updating job status. Ending.')
return
job_status_output.clear_output()
with job_status_output:
print(f'Overall Status: {interval_check_response["payload"]["overallStatus"]["status"]}')
if interval_check_response["payload"]["overallStatus"]["status"] in FINAL_LOAD_STATUSES:
execution_time = interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
if execution_time == 0:
execution_time_statement = '<1 second'
elif execution_time > 59:
execution_time_statement = str(datetime.timedelta(seconds=execution_time))
else:
execution_time_statement = f'{execution_time} seconds'
print('Total execution time: ' + execution_time_statement)
interval_output.close()
print('Done.')
return
last_poll_time = time.time()
else:
with interval_output:
print(f'checking status in {time_remaining} seconds')
time.sleep(1)
except HTTPError as httpEx:
output.clear_output()
with output:
print(httpEx.response.content.decode('utf-8'))
button.on_click(on_button_clicked)
if args.run:
on_button_clicked(None)