in torchdata/datapipes/iter/util/samplemultiplexer.py [0:0]
def __iter__(self) -> Iterator[T_co]:
pipes_and_weights = [(iter(k), v) for k, v in self.pipes_and_weights]
while len(pipes_and_weights) > 1:
r = self.random.random()
s: float = 0
for it, weight in pipes_and_weights:
s += weight
if r < s:
try:
item = next(it)
yield item
except StopIteration:
# remove the current stream
new_total = 1 - weight
assert new_total > 0
pipes_and_weights = [(k, v / new_total) for k, v in pipes_and_weights if k != it]
break
# only one stream left
for item in pipes_and_weights[0][0]:
yield item