summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--2024/17/python/main.py123
1 files changed, 56 insertions, 67 deletions
diff --git a/2024/17/python/main.py b/2024/17/python/main.py
index 7e8ecb1..a441d40 100644
--- a/2024/17/python/main.py
+++ b/2024/17/python/main.py
@@ -1,77 +1,66 @@
from fileinput import input
-from itertools import takewhile
+from itertools import takewhile, zip_longest
inp = map(str.strip, input())
-regs = {r[1][0]: int(r[2]) for r in (line.split() for line in takewhile(bool, inp))}
+regs = [int(r[2]) for r in (line.split() for line in takewhile(bool, inp))]
prog = [int(o) for o in next(inp).split()[1].split(",")]
out = []
-ir = 0
-combo = {
- 0: lambda: 0,
- 1: lambda: 1,
- 2: lambda: 2,
- 3: lambda: 3,
- 4: lambda: regs["A"],
- 5: lambda: regs["B"],
- 6: lambda: regs["C"],
- 7: None,
-}
-
-def _adv(op):
- regs["A"] = regs["A"] >> combo[op]()
-
-
-def _bxl(op):
- regs["B"] ^= op
-
-
-def _bst(op):
- regs["B"] = combo[op]() & 0b111
-
-
-def _jnz(op):
- global ir
- if regs["A"]:
- ir = op
- ir -= 2 # Undo IR
-
-
-def _bxc(_):
- regs["B"] ^= regs["C"]
-
-
-def _out(op):
- out.append(combo[op]() & 0b111)
-
-
-def _bdv(op):
- regs["B"] = regs["A"] >> combo[op]()
-
-
-def _cdv(op):
- regs["C"] = regs["A"] >> combo[op]()
-
-
-instrs = {
- 0: _adv,
- 1: _bxl,
- 2: _bst,
- 3: _jnz,
- 4: _bxc,
- 5: _out,
- 6: _bdv,
- 7: _cdv,
-}
-
-while ir <= len(prog) - 1:
- inst, op = prog[ir], prog[ir + 1]
- instrs[inst](op)
- ir += 2
-
-silver = ",".join(map(str, out))
-gold = 0
+def run(prog, ra, rb, rc):
+ out = []
+ ir = 0
+ while ir < len(prog):
+ instr, op = prog[ir], prog[ir + 1]
+ combo = {4: ra, 5: rb, 6: rc}
+ cop = combo.get(op, op)
+ match instr:
+ case 0:
+ ra = ra >> cop
+ case 1:
+ rb ^= op
+ case 2:
+ rb = cop & 0b111
+ case 3 if ra:
+ ir = op - 2
+ case 4:
+ rb ^= rc
+ case 5:
+ yield cop & 0b111
+ case 6:
+ rb = ra >> cop
+ case 7:
+ rc = ra >> cop
+ ir += 2
+ return out
+
+
+def match(prog, ra):
+ return (
+ got == want
+ for got, want in zip_longest(
+ reversed(list(run(prog, ra, 0, 0))),
+ reversed(prog),
+ )
+ )
+
+
+def find_a(prog):
+ q = list(range(8))
+ while True:
+ curr = q.pop(0)
+ if all(match(prog, curr)):
+ return curr
+
+ best = sum(match(prog, curr))
+ for n in range(8):
+ ra = (curr << 3) + n
+ if sum(match(prog, ra)) > best:
+ q.append(ra)
+
+
+silver = ",".join(map(str, run(prog, *regs)))
+gold = find_a(prog)
print("silver:", silver)
print("gold:", gold)