def load()

in src/graph_notebook/magics/graph_magic.py [0:0]


    def load(self, line='', local_ns: dict = None):
        # TODO: change widgets to let any arbitrary inputs be added by users
        parser = argparse.ArgumentParser()
        parser.add_argument('-s', '--source', default='s3://')
        try:
            parser.add_argument('-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn)
        except AttributeError:
            print(f"Missing required configuration option 'load_from_s3_arn'. Please ensure that you have provided a "
                  "valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.")
            return
        parser.add_argument('-f', '--format', choices=LOADER_FORMAT_CHOICES, default=FORMAT_CSV)
        parser.add_argument('-p', '--parallelism', choices=PARALLELISM_OPTIONS, default=PARALLELISM_HIGH)
        try:
            parser.add_argument('-r', '--region', default=self.graph_notebook_config.aws_region)
        except AttributeError:
            print("Missing required configuration option 'aws_region'. Please ensure that you have provided a "
                  "valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.")
            return
        parser.add_argument('--fail-on-failure', action='store_true', default=False)
        parser.add_argument('--update-single-cardinality', action='store_true', default=True)
        parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
        parser.add_argument('--run', action='store_true', default=False)
        parser.add_argument('-m', '--mode', choices=LOAD_JOB_MODES, default=MODE_AUTO)
        parser.add_argument('-q', '--queue-request', action='store_true', default=False)
        parser.add_argument('-d', '--dependencies', action='append', default=[])
        parser.add_argument('-e', '--no-edge-ids', action='store_true', default=False)
        parser.add_argument('--named-graph-uri', type=str, default=DEFAULT_NAMEDGRAPH_URI,
                            help='The default graph for all RDF formats when no graph is specified. '
                                 'Default is http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph.')
        parser.add_argument('--base-uri', type=str, default=DEFAULT_BASE_URI,
                            help='The base URI for RDF/XML and Turtle formats. '
                                 'Default is http://aws.amazon.com/neptune/default')
        parser.add_argument('--allow-empty-strings', action='store_true', default=False,
                            help='Load empty strings found in node and edge property values.')
        parser.add_argument('-n', '--nopoll', action='store_true', default=False)

        args = parser.parse_args(line.split())
        region = self.graph_notebook_config.aws_region
        button = widgets.Button(description="Submit")
        output = widgets.Output()
        widget_width = '25%'
        label_width = '16%'

        source = widgets.Text(
            value=args.source,
            placeholder='Type something',
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        arn = widgets.Text(
            value=args.loader_arn,
            placeholder='Type something',
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        source_format = widgets.Dropdown(
            options=LOADER_FORMAT_CHOICES,
            value=args.format,
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        ids_hbox_visibility = 'none'
        gremlin_parser_options_hbox_visibility = 'none'
        named_graph_hbox_visibility = 'none'
        base_uri_hbox_visibility = 'none'

        if source_format.value.lower() == FORMAT_CSV:
            gremlin_parser_options_hbox_visibility = 'flex'
        elif source_format.value.lower() == FORMAT_OPENCYPHER:
            ids_hbox_visibility = 'flex'
        elif source_format.value.lower() in RDF_LOAD_FORMATS:
            named_graph_hbox_visibility = 'flex'
            if source_format.value.lower() in BASE_URI_FORMATS:
                base_uri_hbox_visibility = 'flex'

        region_box = widgets.Text(
            value=region,
            placeholder=args.region,
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        fail_on_error = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(args.fail_on_failure).upper(),
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        parallelism = widgets.Dropdown(
            options=PARALLELISM_OPTIONS,
            value=args.parallelism,
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        allow_empty_strings = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(args.allow_empty_strings).upper(),
            disabled=False,
            layout=widgets.Layout(display=gremlin_parser_options_hbox_visibility,
                                  width=widget_width)
        )

        named_graph_uri = widgets.Text(
            value=args.named_graph_uri,
            placeholder='http://named-graph-string',
            disabled=False,
            layout=widgets.Layout(display=named_graph_hbox_visibility,
                                  width=widget_width)
        )

        base_uri = widgets.Text(
            value=args.base_uri,
            placeholder='http://base-uri-string',
            disabled=False,
            layout=widgets.Layout(display=base_uri_hbox_visibility,
                                  width=widget_width)
        )

        update_single_cardinality = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(args.update_single_cardinality).upper(),
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        mode = widgets.Dropdown(
            options=LOAD_JOB_MODES,
            value=args.mode,
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        user_provided_edge_ids = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(not args.no_edge_ids).upper(),
            disabled=False,
            layout=widgets.Layout(display=ids_hbox_visibility,
                                  width=widget_width)
        )

        queue_request = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(args.queue_request).upper(),
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        dependencies = widgets.Textarea(
            value="\n".join(args.dependencies),
            placeholder='load_A_id\nload_B_id',
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        poll_status = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(not args.nopoll).upper(),
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        # Create a series of HBox containers that will hold the widgets and labels
        # that make up the %load form. Some of the labels and widgets are created
        # in two parts to support the validation steps that come later. In the case
        # of validation errors this allows additional text to easily be added to an
        # HBox describing the issue.
        source_hbox_label = widgets.Label('Source:',
                                          layout=widgets.Layout(width=label_width,
                                                                display="flex",
                                                                justify_content="flex-end"))

        source_hbox = widgets.HBox([source_hbox_label, source])

        format_hbox_label = widgets.Label('Format:',
                                          layout=widgets.Layout(width=label_width,
                                                                display="flex",
                                                                justify_content="flex-end"))

        source_format_hbox = widgets.HBox([format_hbox_label, source_format])

        region_hbox = widgets.HBox([widgets.Label('Region:',
                                                  layout=widgets.Layout(width=label_width,
                                                                        display="flex",
                                                                        justify_content="flex-end")),
                                    region_box])

        arn_hbox_label = widgets.Label('Load ARN:',
                                       layout=widgets.Layout(width=label_width,
                                                             display="flex",
                                                             justify_content="flex-end"))

        arn_hbox = widgets.HBox([arn_hbox_label, arn])

        mode_hbox = widgets.HBox([widgets.Label('Mode:',
                                                layout=widgets.Layout(width=label_width,
                                                                      display="flex",
                                                                      justify_content="flex-end")),
                                  mode])

        fail_hbox = widgets.HBox([widgets.Label('Fail on Error:',
                                                layout=widgets.Layout(width=label_width,
                                                                      display="flex",
                                                                      justify_content="flex-end")),
                                  fail_on_error])

        parallelism_hbox = widgets.HBox([widgets.Label('Parallelism:',
                                                       layout=widgets.Layout(width=label_width,
                                                                             display="flex",
                                                                             justify_content="flex-end")),
                                         parallelism])

        allow_empty_strings_hbox_label = widgets.Label('Allow Empty Strings:',
                                                       layout=widgets.Layout(width=label_width,
                                                                             display=gremlin_parser_options_hbox_visibility,
                                                                             justify_content="flex-end"))

        allow_empty_strings_hbox = widgets.HBox([allow_empty_strings_hbox_label, allow_empty_strings])

        named_graph_uri_hbox_label = widgets.Label('Named Graph URI:',
                                                   layout=widgets.Layout(width=label_width,
                                                                         display=named_graph_hbox_visibility,
                                                                         justify_content="flex-end"))

        named_graph_uri_hbox = widgets.HBox([named_graph_uri_hbox_label, named_graph_uri])

        base_uri_hbox_label = widgets.Label('Base URI:',
                                            layout=widgets.Layout(width=label_width,
                                                                  display=base_uri_hbox_visibility,
                                                                  justify_content="flex-end"))

        base_uri_hbox = widgets.HBox([base_uri_hbox_label, base_uri])

        cardinality_hbox = widgets.HBox([widgets.Label('Update Single Cardinality:',
                                                       layout=widgets.Layout(width=label_width,
                                                                             display="flex",
                                                                             justify_content="flex-end")),
                                         update_single_cardinality])

        queue_hbox = widgets.HBox([widgets.Label('Queue Request:',
                                                 layout=widgets.Layout(width=label_width,
                                                                       display="flex", justify_content="flex-end")),
                                   queue_request])

        dep_hbox_label = widgets.Label('Dependencies:',
                                       layout=widgets.Layout(width=label_width,
                                                             display="flex", justify_content="flex-end"))

        dep_hbox = widgets.HBox([dep_hbox_label, dependencies])

        ids_hbox_label = widgets.Label('User Provided Edge Ids:',
                                       layout=widgets.Layout(width=label_width,
                                                             display=ids_hbox_visibility,
                                                             justify_content="flex-end"))

        ids_hbox = widgets.HBox([ids_hbox_label, user_provided_edge_ids])

        poll_status_label = widgets.Label('Poll Load Status:',
                                          layout=widgets.Layout(width=label_width,
                                                                display="flex",
                                                                justify_content="flex-end"))

        poll_status_hbox = widgets.HBox([poll_status_label, poll_status])

        def update_edge_ids_options(change):
            if change.new.lower() == FORMAT_OPENCYPHER:
                ids_hbox_visibility = 'flex'
            else:
                ids_hbox_visibility = 'none'
                user_provided_edge_ids.value = 'TRUE'
            user_provided_edge_ids.layout.display = ids_hbox_visibility
            ids_hbox_label.layout.display = ids_hbox_visibility

        def update_parserconfig_options(change):
            if change.new.lower() == FORMAT_CSV:
                gremlin_parser_options_hbox_visibility = 'flex'
                named_graph_hbox_visibility_hbox_visibility = 'none'
                base_uri_hbox_visibility = 'none'
                named_graph_uri.value = ''
                base_uri.value = ''
            elif change.new.lower() == FORMAT_OPENCYPHER:
                gremlin_parser_options_hbox_visibility = 'none'
                allow_empty_strings.value = 'FALSE'
                named_graph_hbox_visibility_hbox_visibility = 'none'
                base_uri_hbox_visibility = 'none'
                named_graph_uri.value = ''
                base_uri.value = ''
            else:
                gremlin_parser_options_hbox_visibility = 'none'
                allow_empty_strings.value = 'FALSE'
                named_graph_hbox_visibility_hbox_visibility = 'flex'
                named_graph_uri.value = DEFAULT_NAMEDGRAPH_URI
                if change.new.lower() in BASE_URI_FORMATS:
                    base_uri_hbox_visibility = 'flex'
                    base_uri.value = DEFAULT_BASE_URI
                else:
                    base_uri_hbox_visibility = 'none'
                    base_uri.value = ''

            allow_empty_strings.layout.display = gremlin_parser_options_hbox_visibility
            allow_empty_strings_hbox_label.layout.display = gremlin_parser_options_hbox_visibility
            named_graph_uri.layout.display = named_graph_hbox_visibility_hbox_visibility
            named_graph_uri_hbox_label.layout.display = named_graph_hbox_visibility_hbox_visibility
            base_uri.layout.display = base_uri_hbox_visibility
            base_uri_hbox_label.layout.display = base_uri_hbox_visibility

        source_format.observe(update_edge_ids_options, names='value')
        source_format.observe(update_parserconfig_options, names='value')

        display(source_hbox,
                source_format_hbox,
                region_hbox,
                arn_hbox,
                mode_hbox,
                fail_hbox,
                parallelism_hbox,
                cardinality_hbox,
                queue_hbox,
                dep_hbox,
                poll_status_hbox,
                ids_hbox,
                allow_empty_strings_hbox,
                named_graph_uri_hbox,
                base_uri_hbox,
                button,
                output)

        def on_button_clicked(b):
            source_hbox.children = (source_hbox_label, source,)
            arn_hbox.children = (arn_hbox_label, arn,)
            source_format_hbox.children = (format_hbox_label, source_format,)
            allow_empty_strings.children = (allow_empty_strings_hbox_label, allow_empty_strings,)
            named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,)
            base_uri_hbox.children = (base_uri_hbox_label, base_uri,)
            dep_hbox.children = (dep_hbox_label, dependencies,)

            dependencies_list = list(filter(None, dependencies.value.split('\n')))

            validated = True
            validation_label_style = DescriptionStyle(color='red')
            if not (source.value.startswith('s3://') and len(source.value) > 7) and not source.value.startswith('/'):
                validated = False
                source_validation_label = widgets.HTML(
                    '<p style="color:red;">Source must be an s3 bucket or file path</p>')
                source_validation_label.style = validation_label_style
                source_hbox.children += (source_validation_label,)

            if source_format.value == '':
                validated = False
                source_format_validation_label = widgets.HTML('<p style="color:red;">Format cannot be blank.</p>')
                source_format_hbox.children += (source_format_validation_label,)

            if not arn.value.startswith('arn:aws') and source.value.startswith(
                    "s3://"):  # only do this validation if we are using an s3 bucket.
                validated = False
                arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
                arn_hbox.children += (arn_validation_label,)

            if not len(dependencies_list) < 64:
                validated = False
                dep_validation_label = widgets.HTML(
                    '<p style="color:red;">A maximum of 64 jobs may be queued at once.</p>')
                dep_hbox.children += (dep_validation_label,)

            if not validated:
                return

            # replace any env variables in source.value with their values, can use $foo or ${foo}.
            # Particularly useful for ${AWS_REGION}
            source_exp = os.path.expandvars(
                source.value)
            logger.info(f'using source_exp: {source_exp}')
            try:
                kwargs = {
                    'failOnError': fail_on_error.value,
                    'parallelism': parallelism.value,
                    'updateSingleCardinalityProperties': update_single_cardinality.value,
                    'queueRequest': queue_request.value,
                    'region': region,
                    'parserConfiguration': {}
                }

                if dependencies:
                    kwargs['dependencies'] = dependencies_list

                if source_format.value.lower() == FORMAT_OPENCYPHER:
                    kwargs['userProvidedEdgeIds'] = user_provided_edge_ids.value
                elif source_format.value.lower() == FORMAT_CSV:
                    if allow_empty_strings.value == 'TRUE':
                        kwargs['parserConfiguration']['allowEmptyStrings'] = True
                elif source_format.value.lower() in RDF_LOAD_FORMATS:
                    if named_graph_uri.value:
                        kwargs['parserConfiguration']['namedGraphUri'] = named_graph_uri.value
                    if base_uri.value and source_format.value.lower() in BASE_URI_FORMATS:
                        kwargs['parserConfiguration']['baseUri'] = base_uri.value

                if source.value.startswith("s3://"):
                    load_res = self.client.load(str(source_exp), source_format.value, arn.value, **kwargs)
                else:
                    load_res = self.client.load(str(source_exp), source_format.value, **kwargs)
                load_res.raise_for_status()
                load_result = load_res.json()
                store_to_ns(args.store_to, load_result, local_ns)

                source_hbox.close()
                source_format_hbox.close()
                region_hbox.close()
                arn_hbox.close()
                mode_hbox.close()
                fail_hbox.close()
                parallelism_hbox.close()
                cardinality_hbox.close()
                queue_hbox.close()
                dep_hbox.close()
                poll_status_hbox.close()
                ids_hbox.close()
                allow_empty_strings_hbox.close()
                named_graph_uri_hbox.close()
                base_uri_hbox.close()
                button.close()
                output.close()

                if 'status' not in load_result or load_result['status'] != '200 OK':
                    with output:
                        print('Something went wrong.')
                        print(load_result)
                        logger.error(load_result)
                    return

                if poll_status.value == 'FALSE':
                    start_msg_label = widgets.Label(f'Load started successfully!')
                    polling_msg_label = widgets.Label(f'You can run "%load_status {load_result["payload"]["loadId"]}" '
                                                      f'in another cell to check the current status of your bulk load.')
                    start_msg_hbox = widgets.HBox([start_msg_label])
                    polling_msg_hbox = widgets.HBox([polling_msg_label])
                    vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox])
                    display(vbox)
                else:
                    poll_interval = 5
                    load_id_label = widgets.Label(f'Load ID: {load_result["payload"]["loadId"]}')
                    interval_output = widgets.Output()
                    job_status_output = widgets.Output()
                    load_id_hbox = widgets.HBox([load_id_label])
                    status_hbox = widgets.HBox([interval_output])
                    vbox = widgets.VBox([load_id_hbox, status_hbox, job_status_output])
                    display(vbox)

                    last_poll_time = time.time()
                    while True:
                        time_elapsed = int(time.time() - last_poll_time)
                        time_remaining = poll_interval - time_elapsed
                        interval_output.clear_output()
                        if time_elapsed > poll_interval:
                            with interval_output:
                                print('checking status...')
                            job_status_output.clear_output()
                            with job_status_output:
                                display_html(HTML(loading_wheel_html))
                            try:
                                load_status_res = self.client.load_status(load_result['payload']['loadId'])
                                load_status_res.raise_for_status()
                                interval_check_response = load_status_res.json()
                            except Exception as e:
                                logger.error(e)
                                with job_status_output:
                                    print('Something went wrong updating job status. Ending.')
                                    return
                            job_status_output.clear_output()
                            with job_status_output:
                                print(f'Overall Status: {interval_check_response["payload"]["overallStatus"]["status"]}')
                                if interval_check_response["payload"]["overallStatus"]["status"] in FINAL_LOAD_STATUSES:
                                    execution_time = interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
                                    if execution_time == 0:
                                        execution_time_statement = '<1 second'
                                    elif execution_time > 59:
                                        execution_time_statement = str(datetime.timedelta(seconds=execution_time))
                                    else:
                                        execution_time_statement = f'{execution_time} seconds'
                                    print('Total execution time: ' + execution_time_statement)
                                    interval_output.close()
                                    print('Done.')
                                    return
                            last_poll_time = time.time()
                        else:
                            with interval_output:
                                print(f'checking status in {time_remaining} seconds')
                        time.sleep(1)

            except HTTPError as httpEx:
                output.clear_output()
                with output:
                    print(httpEx.response.content.decode('utf-8'))

        button.on_click(on_button_clicked)
        if args.run:
            on_button_clicked(None)