def load()

in src/graph_notebook/magics/graph_magic.py [0:0]


    def load(self, line='', local_ns: dict = None):
        if self.client.is_analytics_domain():
            load_type = ANALYTICS_LOAD_TYPES[0]
            load_formats = VALID_INCREMENTAL_FORMATS
        else:
            load_type = DB_LOAD_TYPES[0]
            load_formats = VALID_BULK_FORMATS

        # TODO: change widgets to let any arbitrary inputs be added by users
        parser = argparse.ArgumentParser()
        parser.add_argument('-s', '--source', default='s3://')
        try:
            parser.add_argument('-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn)
        except AttributeError:
            print(f"Missing required configuration option 'load_from_s3_arn'. Please ensure that you have provided a "
                  "valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.")
            return
        parser.add_argument('-f', '--format', choices=load_formats, default=FORMAT_CSV)
        parser.add_argument('-p', '--parallelism', choices=PARALLELISM_OPTIONS, default=PARALLELISM_HIGH)
        try:
            parser.add_argument('-r', '--region', default=self.graph_notebook_config.aws_region)
        except AttributeError:
            print("Missing required configuration option 'aws_region'. Please ensure that you have provided a "
                  "valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.")
            return
        parser.add_argument('--no-fail-on-error', action='store_true', default=False)
        parser.add_argument('--update-single-cardinality', action='store_true', default=True)
        parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
        parser.add_argument('--run', action='store_true', default=False)
        parser.add_argument('-m', '--mode', choices=LOAD_JOB_MODES, default=MODE_AUTO)
        parser.add_argument('-q', '--queue-request', action='store_true', default=False)
        parser.add_argument('-d', '--dependencies', action='append', default=[])
        parser.add_argument('-e', '--no-edge-ids', action='store_true', default=False)
        parser.add_argument('-c', '--concurrency', type=int, default=1)
        parser.add_argument('--named-graph-uri', type=str, default=DEFAULT_NAMEDGRAPH_URI,
                            help='The default graph for all RDF formats when no graph is specified. '
                                 'Default is http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph.')
        parser.add_argument('--base-uri', type=str, default=DEFAULT_BASE_URI,
                            help='The base URI for RDF/XML and Turtle formats. '
                                 'Default is http://aws.amazon.com/neptune/default')
        parser.add_argument('--allow-empty-strings', action='store_true', default=False,
                            help='Load empty strings found in node and edge property values.')
        parser.add_argument('-n', '--nopoll', action='store_true', default=False)

        args = parser.parse_args(line.split())
        button = widgets.Button(description="Submit")
        output = widgets.Output()
        widget_width = '25%'
        label_width = '16%'

        source = widgets.Text(
            value=args.source,
            placeholder='Type something',
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        arn = widgets.Text(
            value=args.loader_arn,
            placeholder='Type something',
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        source_format = widgets.Dropdown(
            options=load_formats,
            value=args.format,
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        ids_hbox_visibility = 'none'
        gremlin_parser_options_hbox_visibility = 'none'
        named_graph_hbox_visibility = 'none'
        base_uri_hbox_visibility = 'none'
        concurrency_hbox_visibility = 'none'

        if load_type == 'incremental':
            concurrency_hbox_visibility = 'flex'
        else:
            if source_format.value.lower() == FORMAT_CSV:
                gremlin_parser_options_hbox_visibility = 'flex'
            elif source_format.value.lower() == FORMAT_OPENCYPHER:
                ids_hbox_visibility = 'flex'
            elif source_format.value.lower() in RDF_LOAD_FORMATS:
                named_graph_hbox_visibility = 'flex'
                if source_format.value.lower() in BASE_URI_FORMATS:
                    base_uri_hbox_visibility = 'flex'

        region_box = widgets.Text(
            value=args.region,
            placeholder=args.region,
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        fail_on_error = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(not args.no_fail_on_error).upper(),
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        parallelism = widgets.Dropdown(
            options=PARALLELISM_OPTIONS,
            value=args.parallelism,
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        allow_empty_strings = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(args.allow_empty_strings).upper(),
            disabled=False,
            layout=widgets.Layout(display=gremlin_parser_options_hbox_visibility,
                                  width=widget_width)
        )

        named_graph_uri = widgets.Text(
            value=args.named_graph_uri,
            placeholder='http://named-graph-string',
            disabled=False,
            layout=widgets.Layout(display=named_graph_hbox_visibility,
                                  width=widget_width)
        )

        base_uri = widgets.Text(
            value=args.base_uri,
            placeholder='http://base-uri-string',
            disabled=False,
            layout=widgets.Layout(display=base_uri_hbox_visibility,
                                  width=widget_width)
        )

        update_single_cardinality = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(args.update_single_cardinality).upper(),
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        mode = widgets.Dropdown(
            options=LOAD_JOB_MODES,
            value=args.mode,
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        user_provided_edge_ids = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(not args.no_edge_ids).upper(),
            disabled=False,
            layout=widgets.Layout(display=ids_hbox_visibility,
                                  width=widget_width)
        )

        queue_request = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(args.queue_request).upper(),
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        dependencies = widgets.Textarea(
            value="\n".join(args.dependencies),
            placeholder='load_A_id\nload_B_id',
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        concurrency = widgets.BoundedIntText(
            value=str(args.concurrency),
            placeholder=1,
            min=1,
            max=2 ** 16,
            disabled=False,
            layout=widgets.Layout(display=concurrency_hbox_visibility,
                                  width=widget_width)
        )

        poll_status = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(not args.nopoll).upper(),
            disabled=False,
            layout=widgets.Layout(width=widget_width)
        )

        # Create a series of HBox containers that will hold the widgets and labels
        # that make up the %load form. Some of the labels and widgets are created
        # in two parts to support the validation steps that come later. In the case
        # of validation errors this allows additional text to easily be added to an
        # HBox describing the issue.
        source_hbox_label = widgets.Label('Source:',
                                          layout=widgets.Layout(width=label_width,
                                                                display="flex",
                                                                justify_content="flex-end"))

        source_hbox = widgets.HBox([source_hbox_label, source])

        format_hbox_label = widgets.Label('Format:',
                                          layout=widgets.Layout(width=label_width,
                                                                display="flex",
                                                                justify_content="flex-end"))

        source_format_hbox = widgets.HBox([format_hbox_label, source_format])

        region_hbox = widgets.HBox([widgets.Label('Region:',
                                                  layout=widgets.Layout(width=label_width,
                                                                        display="flex",
                                                                        justify_content="flex-end")),
                                    region_box])

        arn_hbox_label = widgets.Label('Load ARN:',
                                       layout=widgets.Layout(width=label_width,
                                                             display="flex",
                                                             justify_content="flex-end"))

        arn_hbox = widgets.HBox([arn_hbox_label, arn])

        mode_hbox = widgets.HBox([widgets.Label('Mode:',
                                                layout=widgets.Layout(width=label_width,
                                                                      display="flex",
                                                                      justify_content="flex-end")),
                                  mode])

        fail_hbox = widgets.HBox([widgets.Label('Fail on Error:',
                                                layout=widgets.Layout(width=label_width,
                                                                      display="flex",
                                                                      justify_content="flex-end")),
                                  fail_on_error])

        parallelism_hbox = widgets.HBox([widgets.Label('Parallelism:',
                                                       layout=widgets.Layout(width=label_width,
                                                                             display="flex",
                                                                             justify_content="flex-end")),
                                         parallelism])

        allow_empty_strings_hbox_label = widgets.Label('Allow Empty Strings:',
                                                       layout=widgets.Layout(width=label_width,
                                                                             display=gremlin_parser_options_hbox_visibility,
                                                                             justify_content="flex-end"))

        allow_empty_strings_hbox = widgets.HBox([allow_empty_strings_hbox_label, allow_empty_strings])

        named_graph_uri_hbox_label = widgets.Label('Named Graph URI:',
                                                   layout=widgets.Layout(width=label_width,
                                                                         display=named_graph_hbox_visibility,
                                                                         justify_content="flex-end"))

        named_graph_uri_hbox = widgets.HBox([named_graph_uri_hbox_label, named_graph_uri])

        base_uri_hbox_label = widgets.Label('Base URI:',
                                            layout=widgets.Layout(width=label_width,
                                                                  display=base_uri_hbox_visibility,
                                                                  justify_content="flex-end"))

        base_uri_hbox = widgets.HBox([base_uri_hbox_label, base_uri])

        cardinality_hbox = widgets.HBox([widgets.Label('Update Single Cardinality:',
                                                       layout=widgets.Layout(width=label_width,
                                                                             display="flex",
                                                                             justify_content="flex-end")),
                                         update_single_cardinality])

        queue_hbox = widgets.HBox([widgets.Label('Queue Request:',
                                                 layout=widgets.Layout(width=label_width,
                                                                       display="flex", justify_content="flex-end")),
                                   queue_request])

        dep_hbox_label = widgets.Label('Dependencies:',
                                       layout=widgets.Layout(width=label_width,
                                                             display="flex", justify_content="flex-end"))

        dep_hbox = widgets.HBox([dep_hbox_label, dependencies])

        ids_hbox_label = widgets.Label('User Provided Edge Ids:',
                                       layout=widgets.Layout(width=label_width,
                                                             display=ids_hbox_visibility,
                                                             justify_content="flex-end"))

        ids_hbox = widgets.HBox([ids_hbox_label, user_provided_edge_ids])

        concurrency_hbox_label = widgets.Label('Concurrency:',
                                               layout=widgets.Layout(width=label_width,
                                                                     display=concurrency_hbox_visibility,
                                                                     justify_content="flex-end"))

        concurrency_hbox = widgets.HBox([concurrency_hbox_label, concurrency])

        poll_status_label = widgets.Label('Poll Load Status:',
                                          layout=widgets.Layout(width=label_width,
                                                                display="flex",
                                                                justify_content="flex-end"))

        poll_status_hbox = widgets.HBox([poll_status_label, poll_status])

        def update_edge_ids_options(change):
            if change.new.lower() == FORMAT_OPENCYPHER:
                ids_hbox_visibility = 'flex'
            else:
                ids_hbox_visibility = 'none'
                user_provided_edge_ids.value = 'TRUE'
            user_provided_edge_ids.layout.display = ids_hbox_visibility
            ids_hbox_label.layout.display = ids_hbox_visibility

        def update_parserconfig_options(change):
            if change.new.lower() == FORMAT_CSV:
                gremlin_parser_options_hbox_visibility = 'flex'
                named_graph_hbox_visibility_hbox_visibility = 'none'
                base_uri_hbox_visibility = 'none'
                named_graph_uri.value = ''
                base_uri.value = ''
            elif change.new.lower() == FORMAT_OPENCYPHER:
                gremlin_parser_options_hbox_visibility = 'none'
                allow_empty_strings.value = 'FALSE'
                named_graph_hbox_visibility_hbox_visibility = 'none'
                base_uri_hbox_visibility = 'none'
                named_graph_uri.value = ''
                base_uri.value = ''
            else:
                gremlin_parser_options_hbox_visibility = 'none'
                allow_empty_strings.value = 'FALSE'
                named_graph_hbox_visibility_hbox_visibility = 'flex'
                named_graph_uri.value = DEFAULT_NAMEDGRAPH_URI
                if change.new.lower() in BASE_URI_FORMATS:
                    base_uri_hbox_visibility = 'flex'
                    base_uri.value = DEFAULT_BASE_URI
                else:
                    base_uri_hbox_visibility = 'none'
                    base_uri.value = ''

            allow_empty_strings.layout.display = gremlin_parser_options_hbox_visibility
            allow_empty_strings_hbox_label.layout.display = gremlin_parser_options_hbox_visibility
            named_graph_uri.layout.display = named_graph_hbox_visibility_hbox_visibility
            named_graph_uri_hbox_label.layout.display = named_graph_hbox_visibility_hbox_visibility
            base_uri.layout.display = base_uri_hbox_visibility
            base_uri_hbox_label.layout.display = base_uri_hbox_visibility

        source_format.observe(update_edge_ids_options, names='value')
        source_format.observe(update_parserconfig_options, names='value')

        basic_load_boxes = [source_hbox, source_format_hbox, region_hbox, fail_hbox]
        # load arguments for Analytics incremental load
        incremental_load_boxes = [concurrency_hbox]
        # load arguments for Neptune bulk load
        bulk_load_boxes = [arn_hbox, mode_hbox, parallelism_hbox, cardinality_hbox,
                           queue_hbox, dep_hbox, ids_hbox, allow_empty_strings_hbox,
                           named_graph_uri_hbox, base_uri_hbox, poll_status_hbox]
        submit_load_boxes = [button, output]

        if load_type == 'incremental':
            display_boxes = basic_load_boxes + incremental_load_boxes + submit_load_boxes
        else:
            display_boxes = basic_load_boxes + bulk_load_boxes + submit_load_boxes

        display(*display_boxes)

        def on_button_clicked(b):
            source_hbox.children = (source_hbox_label, source,)
            arn_hbox.children = (arn_hbox_label, arn,)
            source_format_hbox.children = (format_hbox_label, source_format,)
            allow_empty_strings.children = (allow_empty_strings_hbox_label, allow_empty_strings,)
            named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,)
            base_uri_hbox.children = (base_uri_hbox_label, base_uri,)
            dep_hbox.children = (dep_hbox_label, dependencies,)
            concurrency_hbox.children = (concurrency_hbox_label, concurrency,)

            validated = True
            validation_label_style = DescriptionStyle(color='red')
            if not (source.value.startswith('s3://') and len(source.value) > 7) and not source.value.startswith('/'):
                validated = False
                source_validation_label = widgets.HTML(
                    '<p style="color:red;">Source must be an s3 bucket or file path</p>')
                source_validation_label.style = validation_label_style
                source_hbox.children += (source_validation_label,)

            if source_format.value == '':
                validated = False
                source_format_validation_label = widgets.HTML('<p style="color:red;">Format cannot be blank.</p>')
                source_format_hbox.children += (source_format_validation_label,)

            if load_type == 'bulk':
                if not arn.value.startswith('arn:aws') and source.value.startswith(
                        "s3://"):  # only do this validation if we are using an s3 bucket.
                    validated = False
                    arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
                    arn_hbox.children += (arn_validation_label,)
                dependencies_list = list(filter(None, dependencies.value.split('\n')))
                if not len(dependencies_list) < 64:
                    validated = False
                    dep_validation_label = widgets.HTML(
                        '<p style="color:red;">A maximum of 64 jobs may be queued at once.</p>')
                    dep_hbox.children += (dep_validation_label,)

            if not validated:
                return

            # replace any env variables in source.value with their values, can use $foo or ${foo}.
            # Particularly useful for ${AWS_REGION}
            source_exp = os.path.expandvars(
                source.value)
            logger.info(f'using source_exp: {source_exp}')
            try:
                kwargs = {
                    'failOnError': fail_on_error.value,
                    'region': region_box.value
                }

                if load_type == 'incremental':
                    incremental_load_kwargs = {
                        'source': source.value,
                        'format': source_format.value,
                        'concurrency': concurrency.value
                    }
                    kwargs.update(incremental_load_kwargs)
                else:
                    bulk_load_kwargs = {
                        'mode': mode.value,
                        'parallelism': parallelism.value,
                        'updateSingleCardinalityProperties': update_single_cardinality.value,
                        'queueRequest': queue_request.value,
                        'parserConfiguration': {}
                    }

                    if dependencies:
                        bulk_load_kwargs['dependencies'] = dependencies_list

                    if source_format.value.lower() == FORMAT_OPENCYPHER:
                        bulk_load_kwargs['userProvidedEdgeIds'] = user_provided_edge_ids.value
                    elif source_format.value.lower() == FORMAT_CSV:
                        if allow_empty_strings.value == 'TRUE':
                            bulk_load_kwargs['parserConfiguration']['allowEmptyStrings'] = True
                    elif source_format.value.lower() in RDF_LOAD_FORMATS:
                        if named_graph_uri.value:
                            bulk_load_kwargs['parserConfiguration']['namedGraphUri'] = named_graph_uri.value
                        if base_uri.value and source_format.value.lower() in BASE_URI_FORMATS:
                            bulk_load_kwargs['parserConfiguration']['baseUri'] = base_uri.value

                    kwargs.update(bulk_load_kwargs)

                source_hbox.close()
                source_format_hbox.close()
                region_hbox.close()
                arn_hbox.close()
                mode_hbox.close()
                fail_hbox.close()
                parallelism_hbox.close()
                cardinality_hbox.close()
                queue_hbox.close()
                dep_hbox.close()
                poll_status_hbox.close()
                ids_hbox.close()
                allow_empty_strings_hbox.close()
                named_graph_uri_hbox.close()
                base_uri_hbox.close()
                concurrency_hbox.close()
                button.close()

                load_submit_status_output = widgets.Output()
                load_submit_hbox = widgets.HBox([load_submit_status_output])
                with output:
                    display(load_submit_hbox)
                with load_submit_status_output:
                    print(f"{load_type.capitalize()} load request submitted, waiting for response...")
                    display_html(HTML(loading_wheel_html))
                try:
                    if load_type == 'incremental':
                        load_oc_params = '{'
                        for param, value in kwargs.items():
                            value_substr = str(value) if (isinstance(value, int) or param == 'failOnError') \
                                else '"' + value + '"'
                            next_param = param + ': ' + value_substr
                            load_oc_params += next_param
                            if param == 'concurrency':
                                if source_format.value == FORMAT_NTRIPLE:
                                    load_oc_params += ', blankNodeHandling: "convertToIri"'
                                load_oc_params += '}'
                            else:
                                load_oc_params += ', '
                        load_oc_query = f"CALL neptune.load({load_oc_params})"
                        oc_load = self.client.opencypher_http(load_oc_query)
                    else:
                        if source.value.startswith("s3://"):
                            load_res = self.client.load(str(source_exp), source_format.value, arn.value, **kwargs)
                        else:
                            load_res = self.client.load(str(source_exp), source_format.value, **kwargs)
                except Exception as e:
                    load_submit_status_output.clear_output()
                    with output:
                        print(f"Failed to submit {load_type.capitalize()} load request.")
                        logger.error(e)
                    return
                load_submit_status_output.clear_output()

                if load_type == 'incremental':
                    oc_load.raise_for_status()
                    load_result = oc_load.json()
                    store_to_ns(args.store_to, load_result, local_ns)
                    with output:
                        print("Load completed.\n")
                        print(json.dumps(load_result, indent=2))
                else:
                    try:
                        load_res.raise_for_status()
                    except Exception as e:
                        # Ignore failure to retrieve status, we handle missing status below.
                        pass
                    load_result = load_res.json()
                    store_to_ns(args.store_to, load_result, local_ns)

                    if 'status' not in load_result or load_result['status'] != '200 OK':
                        with output:
                            print('Something went wrong.')
                            logger.error(load_result)
                        return

                    if poll_status.value == 'FALSE':
                        start_msg_label = widgets.Label(f'Load started successfully!')
                        polling_msg_label = widgets.Label(
                            f'You can run "%load_status {load_result["payload"]["loadId"]}" '
                            f'in another cell to check the current status of your bulk load.')
                        start_msg_hbox = widgets.HBox([start_msg_label])
                        polling_msg_hbox = widgets.HBox([polling_msg_label])
                        vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox])
                        with output:
                            display(vbox)
                    else:
                        poll_interval = 5
                        load_id_label = widgets.Label(f'Load ID: {load_result["payload"]["loadId"]}')
                        interval_output = widgets.Output()
                        job_status_output = widgets.Output()
                        load_id_hbox = widgets.HBox([load_id_label])
                        status_hbox = widgets.HBox([interval_output])
                        vbox = widgets.VBox([load_id_hbox, status_hbox, job_status_output])
                        with output:
                            display(vbox)

                        last_poll_time = time.time()
                        new_interval = True
                        while True:
                            time_elapsed = int(time.time() - last_poll_time)
                            time_remaining = poll_interval - time_elapsed
                            interval_output.clear_output()
                            if time_elapsed > poll_interval:
                                with interval_output:
                                    print('checking status...')
                                job_status_output.clear_output()
                                with job_status_output:
                                    display_html(HTML(loading_wheel_html))
                                new_interval = True
                                try:
                                    load_status_res = self.client.load_status(load_result['payload']['loadId'])
                                    load_status_res.raise_for_status()
                                    interval_check_response = load_status_res.json()
                                except Exception as e:
                                    logger.error(e)
                                    with job_status_output:
                                        print('Something went wrong updating job status. Ending.')
                                        return
                                job_status_output.clear_output()
                                with job_status_output:
                                    # parse status & execution_time differently for Analytics and NeptuneDB
                                    overall_status = \
                                        interval_check_response["payload"][
                                            "status"] if self.client.is_analytics_domain() \
                                            else interval_check_response["payload"]["overallStatus"]["status"]
                                    total_time_spent = \
                                        interval_check_response["payload"][
                                            "timeElapsedSeconds"] if self.client.is_analytics_domain() \
                                            else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
                                    print(f'Overall Status: {overall_status}')
                                    if overall_status in FINAL_LOAD_STATUSES:
                                        execution_time = total_time_spent
                                        if execution_time == 0:
                                            execution_time_statement = '<1 second'
                                        elif execution_time > 59:
                                            execution_time_statement = str(datetime.timedelta(seconds=execution_time))
                                        else:
                                            execution_time_statement = f'{execution_time} seconds'
                                        print('Total execution time: ' + execution_time_statement)
                                        interval_output.close()
                                        print('Done.')
                                        return
                                last_poll_time = time.time()
                            else:
                                if new_interval:
                                    with job_status_output:
                                        display_html(HTML(loading_wheel_html))
                                    new_interval = False
                                with interval_output:
                                    print(f'checking status in {time_remaining} seconds')
                            time.sleep(1)

            except HTTPError as httpEx:
                output.clear_output()
                with output:
                    print(httpEx.response.content.decode('utf-8'))

        button.on_click(on_button_clicked)
        if args.run:
            on_button_clicked(None)