# Combinations of lists (or maybe how itertools.product works)!?

So here's a little challenge I need a generator for testing that produces all combinations of inputs from a number of lists. For example:

In [1]:
a,b,c = [1,2], [1,2,3], [1,2]

So how do we go about that?

Python's `itertools.combinations` can't help us because it won't allow multiple lists as inputs.
We need `itertools.product` instead:

In [2]:
import itertools

In [3]:
for combination in itertools.product(*[a,b,c]):
 print(combination)

(1, 1, 1)
(1, 1, 2)
(1, 2, 1)
(1, 2, 2)
(1, 3, 1)
(1, 3, 2)
(2, 1, 1)
(2, 1, 2)
(2, 2, 1)
(2, 2, 2)
(2, 3, 1)
(2, 3, 2)


So how does it work?

The simplest implementation I can come up with is to create a list of keys, and increment them step by step.

Then, when the key reaches it's maximum index, we reset the values up to it.

In [4]:
def product(scales):
 keys = [0 for _ in scales]
 counter = 1
 for sub_scale in scales:
 counter *= len(sub_scale)

 for c in range(counter):
 v = [sub_scale[ix] for ix, sub_scale in zip(keys, scales)]
 yield v

 for pointer, sub_scale in enumerate(scales):
 if keys[pointer] + 1 == len(sub_scale):
 keys[pointer] = 0
 else:
 keys[pointer] += 1
 break

In [5]:
for combination in product([a,b,c]):
 print(combination)

[1, 1, 1]
[2, 1, 1]
[1, 2, 1]
[2, 2, 1]
[1, 3, 1]
[2, 3, 1]
[1, 1, 2]
[2, 1, 2]
[1, 2, 2]
[2, 2, 2]
[1, 3, 2]
[2, 3, 2]


With this approach I can also pick out the nth combination, simply be recalculating the indices.

In [6]:
def nth_combination(n, scales):
 counter = 1
 for sub_scale in scales:
 counter *= len(sub_scale)
 if not 0 < n and n <= counter:
 raise ValueError(f"{n} > counter")
 values = []
 multiplier = 1

 for scale_no, sub_scale in enumerate(scales):
 ix = (n % (len(sub_scale) * multiplier)) // multiplier
 multiplier *= len(sub_scale)
 values.append(sub_scale[ix])

 return tuple(values)

In [7]:
a, b, c, d = [1, 2], [1, 2, 3], [4, 5], [6, 7, 8, 9]

expected_result = list(itertools.product(*[a,b,c,d]))

all_nth_combinations = [nth_combination(n, [a,b,c,d]) for n in range(1, (2*3*2*4)+1)]
all_nth_combinations.sort()

for a,b in zip(expected_result, all_nth_combinations):
 sign = "==" if a==b else "!="
 print(a, sign ,b)

(1, 1, 4, 6) == (1, 1, 4, 6)
(1, 1, 4, 7) == (1, 1, 4, 7)
(1, 1, 4, 8) == (1, 1, 4, 8)
(1, 1, 4, 9) == (1, 1, 4, 9)
(1, 1, 5, 6) == (1, 1, 5, 6)
(1, 1, 5, 7) == (1, 1, 5, 7)
(1, 1, 5, 8) == (1, 1, 5, 8)
(1, 1, 5, 9) == (1, 1, 5, 9)
(1, 2, 4, 6) == (1, 2, 4, 6)
(1, 2, 4, 7) == (1, 2, 4, 7)
(1, 2, 4, 8) == (1, 2, 4, 8)
(1, 2, 4, 9) == (1, 2, 4, 9)
(1, 2, 5, 6) == (1, 2, 5, 6)
(1, 2, 5, 7) == (1, 2, 5, 7)
(1, 2, 5, 8) == (1, 2, 5, 8)
(1, 2, 5, 9) == (1, 2, 5, 9)
(1, 3, 4, 6) == (1, 3, 4, 6)
(1, 3, 4, 7) == (1, 3, 4, 7)
(1, 3, 4, 8) == (1, 3, 4, 8)
(1, 3, 4, 9) == (1, 3, 4, 9)
(1, 3, 5, 6) == (1, 3, 5, 6)
(1, 3, 5, 7) == (1, 3, 5, 7)
(1, 3, 5, 8) == (1, 3, 5, 8)
(1, 3, 5, 9) == (1, 3, 5, 9)
(2, 1, 4, 6) == (2, 1, 4, 6)
(2, 1, 4, 7) == (2, 1, 4, 7)
(2, 1, 4, 8) == (2, 1, 4, 8)
(2, 1, 4, 9) == (2, 1, 4, 9)
(2, 1, 5, 6) == (2, 1, 5, 6)
(2, 1, 5, 7) == (2, 1, 5, 7)
(2, 1, 5, 8) == (2, 1, 5, 8)
(2, 1, 5, 9) == (2, 1, 5, 9)
(2, 2, 4, 6) == (2, 2, 4, 6)
(2, 2, 4, 7) == (2, 2, 4, 7)
(2, 2, 4, 8) =