def persist_tensor_via_oss()

in odps/mars_extension/legacy/oss.py [0:0]


def persist_tensor_via_oss(odps, *args, **kwargs):
    from mars.session import Session
    from .tensor.datastore import write_coo

    session = kwargs.pop("session", Session.default_or_local())
    oss_endpoint = kwargs.pop("oss_endpoint")
    oss_access_id = kwargs.pop("oss_access_id")
    oss_access_key = kwargs.pop("oss_access_key")
    oss_bucket_name = kwargs.pop("oss_bucket_name")
    oss_path = kwargs.pop("oss_path")

    oss_prefix = "oss://%s/" % oss_bucket_name
    if oss_path.startswith(oss_prefix):
        oss_path = oss_path[len(oss_prefix) :]

    oss_opts = dict(
        endpoint=oss_endpoint,
        bucket_name=oss_bucket_name,
        access_id=oss_access_id,
        secret_access_key=oss_access_key,
    )

    tensor, table_name, dim_columns, value_column = args
    oss_dir = "oss://%s" % oss_path
    _clean_oss_object(oss_path, **oss_opts)

    t_type = None
    partitions = None

    # submit tensor to mars cluster
    tensors = []
    if isinstance(tensor, dict):
        for p, t in tensor.items():
            if t_type is None:
                t_type = t.dtype
            p_spec = PartitionSpec(p)
            if partitions is None:
                partitions = p_spec.keys
            else:
                if set(partitions) != set(p_spec.keys):
                    raise TypeError("all tensors partitions name must be the same.")

            if t.ndim > len(dim_columns):
                raise TypeError("tensor dimensions cannot more than dim_columns length")

            # write shape to oss
            shape_path = "%s/meta/%s/shape" % (oss_dir, p.replace(",", "/"))
            _write_shape_to_oss(t.shape, shape_path, **oss_opts)

            # write data to oss
            data_path = "%s/data/%s" % (oss_dir, p.replace(",", "/"))
            writer_tensor = write_coo(
                t, data_path, dim_columns, value_column, global_index=True, **oss_opts
            )
            tensors.append(writer_tensor)

        session.run(tensors)
    else:
        shape_path = oss_dir + "/meta/shape"
        _write_shape_to_oss(tensor.shape, shape_path, **oss_opts)

        t_type = tensor.dtype
        data_path = oss_dir + "/data"
        writer_tensor = write_coo(
            tensor, data_path, dim_columns, value_column, global_index=True, **oss_opts
        )
        session.run(writer_tensor)

    # persist to odps table
    ext_table_name = "mars_persist_ext_%s" % str(uuid.uuid4()).replace("-", "_")
    column_types = ["bigint"] * len(dim_columns) + [np_to_odps_types[t_type]]
    ext_column_types = ["bigint"] * (2 * len(dim_columns)) + [np_to_odps_types[t_type]]
    column_names = dim_columns + [value_column]
    ext_column_names = list(chain(*([c, "global_" + c] for c in dim_columns))) + [
        value_column
    ]
    if partitions:
        if isinstance(partitions, six.string_types):
            partitions = [partitions]
        target_schema = TableSchema.from_lists(
            column_names, column_types, partitions, ["string"] * len(partitions)
        )
        ext_schema = TableSchema.from_lists(
            ext_column_names, ext_column_types, partitions, ["string"] * len(partitions)
        )
    else:
        target_schema = TableSchema.from_lists(column_names, column_types)
        ext_schema = TableSchema.from_lists(ext_column_names, ext_column_types)

    ext_table = odps.create_table(
        ext_table_name,
        ext_schema,
        external_stored_as="PARQUET",
        location="oss://%s:%s@%s/%s/%s/data"
        % (
            oss_opts["access_id"],
            oss_opts["secret_access_key"],
            oss_opts["endpoint"].split("://")[1],
            oss_opts["bucket_name"],
            oss_path,
        ),
    )
    if partitions:
        for partition in tensor.keys():
            ext_table.create_partition(partition)
    odps.create_table(table_name, target_schema, if_not_exists=True)
    ext_df = ext_table.to_df()
    fields = [
        ext_df["global_" + f].rename(f) for f in target_schema.names[:-1]
    ] + target_schema.names[-1:]
    if partitions:
        fields = fields + partitions
        ext_df[fields].persist(table_name, partitions=partitions)
    else:
        ext_df[fields].persist(table_name)