wstl1/tools/notebook/extensions/wstl/magics/wstl.py (310 lines of code) (raw):
# Copyright 2020 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Magic command wrapper for whistle mapping language."""
import json
import os
from google.cloud import storage
from googleapiclient import discovery
import grpc
from IPython.core import magic_arguments
from IPython.core.magic import cell_magic
from IPython.core.magic import line_magic
from IPython.core.magic import Magics
from IPython.core.magic import magics_class
from IPython.display import JSON
from google.protobuf import json_format
from wstl.magics import _constants
from wstl.magics import _location
from wstl.proto import wstlservice_pb2
from wstl.proto import wstlservice_pb2_grpc
_GRPC_TIMEOUT = os.environ.get("NOTEBOOK_GRPC_TIMEOUT", 10.0)
_DEFAULT_HOST = os.environ.get("NOTEBOOK_GRPC_HOST", "localhost")
_DEFAULT_PORT = os.environ.get("NOTEBOOK_GRPC_PORT", "50051")
@magics_class
class WSTLMagics(Magics):
"""Evaluates whistle mapping language (.wstl) within a cell."""
def __init__(self, shell):
super(WSTLMagics, self).__init__(shell)
self.grpc_target = "{}:{}".format(_DEFAULT_HOST, _DEFAULT_PORT)
@line_magic("wstl-reset")
def wstl_reset(self, line):
"""Cell magic to clear all variables and functions from incremental transformation."""
with grpc.insecure_channel(self.grpc_target) as channel:
stub = wstlservice_pb2_grpc.WhistleServiceStub(channel)
session_id = str(self.shell.history_manager.session_number)
req = wstlservice_pb2.DeleteIncrementalSessionRequest(
session_id=session_id)
try:
resp = stub.DeleteIncrementalSessionRequest(req)
except grpc.RpcError as rpc_error:
return rpc_error
else:
return JSON(json_format.MessageToDict(resp))
@magic_arguments.magic_arguments()
@magic_arguments.argument(
"--input",
type=str,
required=False,
help="""The input. Supports the following prefix notations:
py://<name_of_python_variable>
json://<inline_json_object_or_array> : python inline dict and list
expressions are supported. e.g. json://{"field":"value"} or
json://[{"first":"value"},{"second":"value"}]
file://<path_to_local_file_system> , supports glob wildcard expressions
and will only load .json or .ndjson file extensions. Each json object/list
defined within an ndjson will be a separate input to the mapping.
""")
@magic_arguments.argument(
"--library_config",
type=str,
required=False,
help="""Path to the directory where the library mapping files are located."""
)
@magic_arguments.argument(
"--code_config",
type=str,
required=False,
help="""Path to the directory of FHIR ConceptMaps used for code harmonization."""
)
@magic_arguments.argument(
"--unit_config",
type=str,
required=False,
help="""Path to a unit harmonization file (textproto).""")
@magic_arguments.argument(
"--output",
type=str,
required=False,
help="""Name of python variable to store result. e.g. --output temp_var"""
)
@cell_magic("wstl")
def wstl(self, line, cell):
"""Cell magic to evaluate whistle mapping language from iPython kernel."""
args = magic_arguments.parse_argstring(self.wstl, line)
# TODO (): migrate to secure channel.
with grpc.insecure_channel(self.grpc_target) as channel:
stub = wstlservice_pb2_grpc.WhistleServiceStub(channel)
(incremental_session, err) = _get_or_create_session(stub, self.shell)
if err:
return err
(transform,
err) = _get_incremental_transform(stub, self.shell,
incremental_session.session_id, args,
cell)
if err:
return err
result = _response_to_json(transform)
if args.output:
self.shell.push({args.output: result})
return JSON(result)
@magic_arguments.magic_arguments()
@magic_arguments.argument(
"--version",
choices=["stu3", "r4"],
default="r4",
type=str,
help="""The fhir version to apply to the validation.
The default is r4.""")
@magic_arguments.argument(
"--input",
type=str,
help="""The input. Supports the following prefix notations:
py://<name_of_python_variable>
json://<inline_json_object_or_array> : python inline dict and list
expressions are supported. e.g. json://{"field":"value"} or
json://[{"first":"value"},{"second":"value"}]
file://<path_to_local_file_system> , supports glob wildcard expressions
and will only load .json or .ndjson file extensions. Each json object/list
defined within an ndjson will be a separate input to the validation.""")
@line_magic("fhir_validate")
def fhir_validate(self, line):
"""Line magic to validate json FHIR resource(s) from iPython kernel."""
args = magic_arguments.parse_argstring(self.fhir_validate, line)
with grpc.insecure_channel(self.grpc_target) as channel:
stub = wstlservice_pb2_grpc.WhistleServiceStub(channel)
(resp, err) = _get_validation(stub, self.shell, args.version, args.input)
if err:
return err
return JSON(json_format.MessageToDict(resp))
@magics_class
class LoadHL7Magics(Magics):
"""Loads parsed HL7v2 message from GCS or HL7v2 Store."""
@magic_arguments.magic_arguments()
@magic_arguments.argument(
"--project_id",
type=str,
help="""ID of the GCP project that the HL7v2 Store belongs to.""",
required=True)
@magic_arguments.argument(
"--region",
type=str,
help="""Region of the HL7v2 Store.""",
required=True)
@magic_arguments.argument(
"--dataset_id",
type=str,
required=True,
help="""ID of the dataset that the HL7v2 store belongs to.""")
@magic_arguments.argument(
"--hl7v2_store_id",
type=str,
required=True,
help="""ID of the HL7v2 store to load data from.""")
@magic_arguments.argument(
"--api_version",
type=str,
required=False,
default="v1beta1",
choices=["v1", "v1beta1"],
help="""The version of healthcare api to call. Default to v1.""")
@magic_arguments.argument(
"--filter",
type=str,
required=False,
help="""
filter: string, Restricts messages returned to those matching a filter. Syntax:
https://cloud.google.com/appengine/docs/standard/python/search/query_strings
If the filter string contains white space, it must be surrounded by single quotes.
Fields/functions available for filtering are:
* `message_type`, from the MSH-9.1 field. For example,
`NOT message_type = "ADT"`.
* `send_date` or `sendDate`, the YYYY-MM-DD date the message was sent in
the dataset's time_zone, from the MSH-7 segment. For example,
`send_date < "2017-01-02"`.
* `send_time`, the timestamp when the message was sent, using the
RFC3339 time format for comparisons, from the MSH-7 segment. For example,
`send_time < "2017-01-02T00:00:00-05:00"`.
* `send_facility`, the care center that the message came from, from the
MSH-4 segment. For example, `send_facility = "ABC"`.
* `PatientId(value, type)`, which matches if the message lists a patient
having an ID of the given value and type in the PID-2, PID-3, or PID-4
segments. For example, `PatientId("123456", "MRN")`.
* `labels.x`, a string value of the label with key `x` as set using the
Message.labels
map. For example, `labels."priority"="high"`. The operator `:*` can be used
to assert the existence of a label. For example, `labels."priority":*`.""")
@magic_arguments.argument(
"--dest_file_name",
type=str,
required=False,
help="""
The destination file path to store the loaded data.
If not provided, the result will be directly returned to the IPython kernel.
""",
default="")
@line_magic("load_hl7v2_datastore")
def load_hl7v2_datastore(self, line):
"""Load parsed HL7v2 massage from the HL7v2 Store specified."""
args = magic_arguments.parse_argstring(self.load_hl7v2_datastore, line)
hl7v2_messages = _get_message_from_hl7v2_store(args.api_version,
args.project_id, args.region,
args.dataset_id,
args.hl7v2_store_id,
args.filter)
if args.dest_file_name:
with open(args.dest_file_name, "w") as dest_file:
dest_file.write(hl7v2_messages)
return "The message was written to {} successfully.".format(
args.dest_file_name)
return JSON(hl7v2_messages)
@magic_arguments.magic_arguments()
@magic_arguments.argument(
"--bucket_name",
type=str,
help="""The name of the GCS bucket to load data from.""",
required=True)
@magic_arguments.argument(
"--source_blob_name",
type=str,
help="""The name of the blob to load.""",
required=True)
@magic_arguments.argument(
"--dest_file_name",
type=str,
required=False,
help="""
The destination file path to store the loaded data.
If not provided, the result will be directly returned to the IPython kernel.
""",
default="")
@line_magic("load_hl7v2_gcs")
def load_hl7v2_gcs(self, line):
"""Load and return parsed HL7v2 massage from the blob in a GCS bucket specified."""
args = magic_arguments.parse_argstring(self.load_hl7v2_gcs, line)
dest_file_name = args.dest_file_name
storage_client = storage.Client()
bucket = storage_client.bucket(args.bucket_name)
if not bucket.exists():
raise ValueError(
"The bucket does not exist. Please check the provided bucket name.")
blob = bucket.get_blob(args.source_blob_name)
if not blob:
raise ValueError(
"The blob does not exist. Please check the provided blob name.")
content = blob.download_as_string()
if blob and blob.content_encoding:
content = content.decode(blob.content_encoding)
# check if the returned content is a json
try:
try:
result = json.loads(content)
except TypeError:
result = json.loads(content.decode("UTF-8"))
except json.JSONDecodeError:
print(
"The loaded content is not a valid JSON. Please check the source bucket and blob."
)
raise
if dest_file_name:
with open(dest_file_name, "w") as dest:
dest.write(content)
return "The message was written to {} successfully.".format(
dest_file_name)
return JSON(result)
def _get_message_from_hl7v2_store(api_version, project, region, dataset,
data_store, filter_str):
"""Returns an authorized API client by discovering the Healthcare API and creating a service object using the service account credentials in the GOOGLE_APPLICATION_CREDENTIALS environment variable."""
# TODO(): add paging support for HL7v2 messages.
service_name = "healthcare"
client = discovery.build(service_name, api_version)
hl7v2_messages_parent = "projects/{}/locations/{}/datasets/{}".format(
project, region, dataset)
hl7v2_message_path = "{}/hl7V2Stores/{}".format(hl7v2_messages_parent,
data_store)
if filter_str:
filter_str = filter_str.strip("'")
return (
client.projects().locations().datasets().hl7V2Stores().messages().list(
parent=hl7v2_message_path, view="FULL",
filter=filter_str).execute().get("hl7V2Messages", []))
def _get_or_create_session(stub, shell):
"""Retrieves or creates the incremental transform session.
Args:
stub: gRPC client stub library.
shell: an instance of the iPython shell that invoked the magic command.
Returns:
IncrementalSessionResponse or a grpc.RcpError
"""
session_id = shell.history_manager.session_number
req = wstlservice_pb2.CreateIncrementalSessionRequest(
session_id=str(session_id))
try:
resp = stub.GetOrCreateIncrementalSession(req)
except grpc.RpcError as rpc_error:
return None, rpc_error
else:
return resp, None
def _get_incremental_transform(stub, shell, session_id, wstl_args, cell):
"""Invokes, throughs a gRPC request, an incremental Whistle transform.
Args:
stub: gRPC client stub library.
shell: an instance of the iPython shell that invoked the magic command.
session_id: the incremental transformation session id.
wstl_args: the arguments to the wstl magic command.
cell: the contents of the cell, containing whistle.
Returns:
TransformResponse or a grpc.RpcError
"""
req = wstlservice_pb2.IncrementalTransformRequest()
req.session_id = session_id
req.wstl = cell
if wstl_args.library_config:
library_configs = _location.parse_location(
shell,
wstl_args.library_config,
file_ext=_constants.WSTL_FILE_EXT,
load_contents=False)
if library_configs:
req.library_config.extend(library_configs)
if wstl_args.code_config:
code_configs = _location.parse_location(
shell,
wstl_args.code_config,
file_ext=_constants.JSON_FILE_EXT,
load_contents=False)
if code_configs:
req.code_config.extend(code_configs)
if wstl_args.unit_config:
unit_config = _location.parse_location(
shell,
wstl_args.unit_config,
file_ext=_constants.TEXTPROTO_FILE_EXT,
load_contents=False)
if unit_config:
req.unit_config = unit_config[0]
if wstl_args.input:
inputs = _location.parse_location(
shell,
wstl_args.input,
file_ext=_constants.JSON_FILE_EXT,
load_contents=True)
if inputs:
req.input.extend(inputs)
else:
return None, "no inputs matching arguement {}".format(wstl_args.input)
try:
resp = stub.GetIncrementalTransform(req, timeout=_GRPC_TIMEOUT)
except grpc.RpcError as rpc_error:
return None, rpc_error
else:
return resp, None
def _convert_message_to_json(transform_record):
"""Converts the output or error of TransformedRecords proto to JSON.
Args:
transform_record: a TransformRecords to convert to JSON.
Returns:
The JSON representation of the output or error field.
"""
if transform_record.HasField("output"):
return json.loads(transform_record.output)
elif transform_record.HasField("error"):
return json_format.MessageToDict(transform_record.error)
else:
return json_format.MessageToDict(transform_record)
def _response_to_json(response):
"""Converts each element within a TransformResponse result into JSON.
Args:
response: the TransformResponse from a GetIncrementalTransform request.
Returns:
One or more TransformedRecords contained in the TransformResponse as JSON.
"""
if len(response.results) > 1:
return [_convert_message_to_json(result) for result in response.results]
else:
return _convert_message_to_json(response.results[0])
def _get_validation(stub, shell, version, input_arg):
"""Validates the input JSON resource(s) against the FHIR version.
Args:
stub: gRPC client stub library.
shell: an instance of the iPython shell that invoked the magic command.
version: the FHIR version to be used for validation.
input_arg: the FHIR resource to be validated against the specified version.
Returns:
The ValidationResponse containing the validation results of the resources
or an error.
"""
req = wstlservice_pb2.ValidationRequest(
input=_location.parse_location(shell, input_arg))
if version.lower() == "r4":
req.fhir_version = wstlservice_pb2.ValidationRequest.FhirVersion.R4
elif version.lower() == "stu3":
req.fhir_version = wstlservice_pb2.ValidationRequest.FhirVersion.STU3
else:
return None, ValueError("""FHIR version {} is incorrect or not supported,
{} are supported versions""".format(
version, wstlservice_pb2.ValidationRequest.FhirVersion.keys()))
try:
resp = stub.FhirValidate(req)
except grpc.RpcError as rpc_error:
return None, rpc_error
return resp, None