in econml/utilities.py [0:0]
def einsum_sparse(subscripts, *arrs):
"""
Evaluate the Einstein summation convention on the operands.
Using the Einstein summation convention, many common multi-dimensional array operations can be represented
in a simple fashion. This function provides a way to compute such summations.
Parameters
----------
subscripts : str
Specifies the subscripts for summation.
Unlike `np.eisnum` elipses are not supported and the output must be explicitly included
arrs : list of COO arrays
These are the sparse arrays for the operation.
Returns
-------
SparseArray
The sparse array calculated based on the Einstein summation convention.
"""
inputs, outputs = subscripts.split('->')
inputs = inputs.split(',')
outputInds = set(outputs)
allInds = set.union(*[set(i) for i in inputs])
# same number of input definitions as arrays
assert len(inputs) == len(arrs)
# input definitions have same number of dimensions as each array
assert all(arr.ndim == len(input) for (arr, input) in zip(arrs, inputs))
# all result indices are unique
assert len(outputInds) == len(outputs)
# all result indices must match at least one input index
assert outputInds <= allInds
# map indices to all array, axis pairs for that index
indMap = {c: [(n, i) for n in range(len(inputs)) for (i, x) in enumerate(inputs[n]) if x == c] for c in allInds}
for c in indMap:
# each index has the same cardinality wherever it appears
assert len({arrs[n].shape[i] for (n, i) in indMap[c]}) == 1
# State: list of (set of letters, list of (corresponding indices, value))
# Algo: while list contains more than one entry
# take two entries
# sort both lists by intersection of their indices
# merge compatible entries (where intersection of indices is equal - in the resulting list,
# take the union of indices and the product of values), stepping through each list linearly
# TODO: might be faster to break into connected components first
# e.g. for "ab,d,bc->ad", the two components "ab,bc" and "d" are independent,
# so compute their content separately, then take cartesian product
# this would save a few pointless sorts by empty tuples
# TODO: Consider investigating other performance ideas for these cases
# where the dense method beat the sparse method (usually sparse is faster)
# e,facd,c->cfed
# sparse: 0.0335489
# dense: 0.011465999999999997
# gbd,da,egb->da
# sparse: 0.0791625
# dense: 0.007319099999999995
# dcc,d,faedb,c->abe
# sparse: 1.2868097
# dense: 0.44605229999999985
def merge(x1, x2):
(s1, l1), (s2, l2) = x1, x2
keys = {c for c in s1 if c in s2} # intersection of strings
outS = ''.join(set(s1 + s2)) # union of strings
outMap = [(True, s1.index(c)) if c in s1 else (False, s2.index(c)) for c in outS]
def keyGetter(s):
inds = [s.index(c) for c in keys]
return lambda p: tuple(p[0][ind] for ind in inds)
kg1 = keyGetter(s1)
kg2 = keyGetter(s2)
l1.sort(key=kg1)
l2.sort(key=kg2)
i1 = i2 = 0
outL = []
while i1 < len(l1) and i2 < len(l2):
k1, k2 = kg1(l1[i1]), kg2(l2[i2])
if k1 < k2:
i1 += 1
elif k2 < k1:
i2 += 1
else:
j1, j2 = i1, i2
while j1 < len(l1) and kg1(l1[j1]) == k1:
j1 += 1
while j2 < len(l2) and kg2(l2[j2]) == k2:
j2 += 1
for c1, d1 in l1[i1:j1]:
for c2, d2 in l2[i2:j2]:
outL.append((tuple(c1[charIdx] if inFirst else c2[charIdx] for inFirst, charIdx in outMap),
d1 * d2))
i1 = j1
i2 = j2
return outS, outL
# when indices are repeated within an array, pre-filter the coordinates and data
def filter_inds(coords, data, n):
counts = Counter(inputs[n])
repeated = [(c, counts[c]) for c in counts if counts[c] > 1]
if len(repeated) > 0:
mask = np.full(len(data), True)
for (k, v) in repeated:
inds = [i for i in range(len(inputs[n])) if inputs[n][i] == k]
for i in range(1, v):
mask &= (coords[:, inds[0]] == coords[:, inds[i]])
if not all(mask):
return coords[mask, :], data[mask]
return coords, data
xs = [(s, list(zip(c, d)))
for n, (s, arr) in enumerate(zip(inputs, arrs))
for c, d in [filter_inds(arr.coords.T, arr.data, n)]]
# TODO: would using einsum's paths to optimize the order of merging help?
while len(xs) > 1:
xs.append(merge(xs.pop(), xs.pop()))
results = defaultdict(int)
for (s, l) in xs:
coordMap = [s.index(c) for c in outputs]
for (c, d) in l:
results[tuple(c[i] for i in coordMap)] += d
return sp.COO(np.array(list(results.keys())).T if results else
np.empty((len(outputs), 0)),
np.array(list(results.values())),
[arrs[indMap[c][0][0]].shape[indMap[c][0][1]] for c in outputs])