odps/ipython/magics.py (274 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from .. import ODPS, options
from .. import types as odps_types
from ..compat import StringIO, six
from ..df import DataFrame, Delay, NullScalar, Scalar
from ..df.backends.frame import ResultFrame
from ..df.backends.odpssql.types import odps_schema_to_df_schema, odps_type_to_df_type
from ..inter import enter, list_rooms, setup, teardown
from ..models import TableSchema
from ..ui.common import html_notify
from ..ui.progress import (
create_instance_group,
fetch_instance_group,
reload_instance_status,
)
from ..utils import (
init_progress_ui,
replace_sql_parameters,
split_backquoted,
strip_backquotes,
)
logger = logging.getLogger(__name__)
try:
import numpy as np
np_int_types = map(np.dtype, [np.int_, np.int8, np.int16, np.int32, np.int64])
np_float_types = map(np.dtype, [np.float_, np.float16, np.float32, np.float64])
np_to_odps_types = dict(
[(t, odps_types.bigint) for t in np_int_types]
+ [(t, odps_types.double) for t in np_float_types]
)
except ImportError:
pass
try:
from IPython.core.magic import Magics, line_cell_magic, line_magic, magics_class
except ImportError:
# skipped for ci
ODPSSql = None
pass
else:
@magics_class
class ODPSSql(Magics):
_odps = None
def _set_odps(self):
if self._odps is not None:
return
if options.account is not None and options.default_project is not None:
self._odps = ODPS._from_account(
options.account,
options.default_project,
endpoint=options.endpoint,
tunnel_endpoint=options.tunnel.endpoint,
)
else:
self._odps = enter().odps
@line_magic("enter")
def enter(self, line):
room = line.strip()
if room:
r = enter(room)
self._odps = r.odps
else:
r = enter()
self._odps = r.odps
if "o" not in self.shell.user_ns:
self.shell.user_ns["o"] = self._odps
self.shell.user_ns["odps"] = self._odps
return r
@line_magic("setup")
def setup(self, line):
args = line.strip().split()
name, args = args[0], args[1:]
setup(*args, room=name)
html_notify("Setup succeeded")
@line_magic("teardown")
def teardown(self, line):
name = line.strip()
teardown(name)
html_notify("Teardown succeeded")
@line_magic("list_rooms")
def list_rooms(self, line):
return list_rooms()
@line_magic("stores")
def list_stores(self, line):
line = line.strip()
if line:
room = enter(line)
else:
room = enter()
return room.display()
@staticmethod
def _get_task_percent(task_progress):
if len(task_progress.stages) > 0:
all_percent = sum(
(float(stage.terminated_workers) / stage.total_workers)
for stage in task_progress.stages
if stage.total_workers > 0
)
return all_percent / len(task_progress.stages)
else:
return 0
def _to_stdout(cls, msg):
print(msg)
@line_magic("set")
def set_hint(self, line):
if "=" not in line:
raise ValueError("Hint for sql is not allowed")
key, val = line.strip().strip(";").split("=", 1)
key, val = key.strip(), val.strip()
settings = options.sql.settings
if settings is None:
options.sql.settings = {key: val}
else:
options.sql.settings[key] = val
@line_cell_magic("sql")
def execute(self, line, cell=""):
self._set_odps()
content = line + "\n" + cell
content = content.strip()
sql = None
hints = dict()
splits = content.split(";")
for s in splits:
stripped = s.strip()
if stripped.lower().startswith("set "):
hint = stripped.split(" ", 1)[1]
k, v = hint.split("=", 1)
k, v = k.strip(), v.strip()
hints[k] = v
elif len(stripped) == 0:
continue
else:
if sql is None:
sql = s
else:
sql = "%s;%s" % (sql, s)
# replace user defined parameters
sql = replace_sql_parameters(sql, self.shell.user_ns)
if sql:
progress_ui = init_progress_ui()
group_id = create_instance_group("SQL Query")
progress_ui.add_keys(group_id)
instance = self._odps.run_sql(sql, hints=hints)
if logger.getEffectiveLevel() <= logging.INFO:
logger.info(
"Instance ID: %s\n Log view: %s",
instance.id,
instance.get_logview_address(),
)
reload_instance_status(self._odps, group_id, instance.id)
progress_ui.status("Executing")
percent = 0
while not instance.is_terminated(retry=True):
last_percent = percent
reload_instance_status(self._odps, group_id, instance.id)
inst_progress = fetch_instance_group(group_id).instances.get(
instance.id
)
if inst_progress is not None and len(inst_progress.tasks) > 0:
percent = sum(
self._get_task_percent(task)
for task in six.itervalues(inst_progress.tasks)
) / len(inst_progress.tasks)
else:
percent = 0
percent = min(1, max(percent, last_percent))
progress_ui.update(percent)
progress_ui.update_group()
time.sleep(1)
instance.wait_for_success()
progress_ui.update(1)
try:
with instance.open_reader() as reader:
try:
import pandas as pd
try:
from pandas.io.parsers import (
ParserError as CParserError,
)
except ImportError:
pass
try:
from pandas.parser import CParserError # noqa
except ImportError:
CParserError = ValueError # noqa
if not hasattr(reader, "raw"):
res = ResultFrame(
[rec.values for rec in reader],
schema=odps_schema_to_df_schema(reader._schema),
)
else:
try:
res = pd.read_csv(StringIO(reader.raw))
if len(res.values) > 0:
schema = DataFrame(res).schema
else:
cols = res.columns.tolist()
schema = odps_schema_to_df_schema(
TableSchema.from_lists(
cols, ["string" for _ in cols]
)
)
res = ResultFrame(res.values, schema=schema)
except (ValueError, CParserError):
res = reader.raw
except (ImportError, ValueError):
if not hasattr(reader, "raw"):
res = ResultFrame(
[rec.values for rec in reader],
schema=odps_schema_to_df_schema(reader._schema),
)
else:
try:
columns = [
odps_types.Column(
name=col.name,
typo=odps_type_to_df_type(col.type),
)
for col in reader._columns
]
res = ResultFrame(list(reader), columns=columns)
except TypeError:
res = reader.raw
html_notify("SQL execution succeeded")
return res
finally:
progress_ui.close()
@line_magic("persist")
def persist(self, line):
try:
import pandas as pd
has_pandas = True
except (ImportError, ValueError):
has_pandas = False
self._set_odps()
line = line.strip().strip(";")
frame_name, table_name = line.split(None, 1)
if "." in table_name:
parts = split_backquoted(table_name, ".")
if len(parts) == 3:
project_name, schema_name, table_name = parts
else:
project_name, table_name = parts
schema_name = None
table_name = strip_backquotes(table_name)
else:
project_name = schema_name = None
frame = self.shell.user_ns[frame_name]
if self._odps.exist_table(
table_name, project=project_name, schema=schema_name
):
raise TypeError("%s already exists" % table_name)
if isinstance(frame, DataFrame):
frame.persist(
name=table_name,
project=project_name,
schema=schema_name,
notify=False,
)
elif has_pandas and isinstance(frame, pd.DataFrame):
frame = DataFrame(frame)
frame.persist(
name=table_name,
project=project_name,
schema=schema_name,
notify=False,
)
html_notify("Persist succeeded")
def load_ipython_extension(ipython):
ipython.register_magics(ODPSSql)
# Do global import when load extension
ipython.user_ns["DataFrame"] = DataFrame
ipython.user_ns["Scalar"] = Scalar
ipython.user_ns["NullScalar"] = NullScalar
ipython.user_ns["options"] = options
ipython.user_ns["TableSchema"] = TableSchema
ipython.user_ns["Delay"] = Delay