Skip to content
Snippets Groups Projects
Select Git revision
  • 64b0b7947c4724c41913833e4170cabfa65fe797
  • main default protected
2 results

day_17_max_n.cpp

Blame
  • Max Nilsson's avatar
    Max Nilsson authored
    da2659dc
    History
    day_17_max_n.cpp 4.62 KiB
    #include <bits/stdc++.h>
    
    #ifdef LOCAL
    #include "./cpp-dump/cpp-dump.hpp"
    #define pr(x) cpp_dump(x)
    #endif
    
    using namespace std;
    
    #define Mod(x,y) (((x)%(y)+(y))%(y))
    #define rep(i, a, b) for(ll (i) = (a); (i) < (b); ++(i))
    #define all(x) begin(x), end(x)
    #define pb push_back
    #define gcd __gcd 
    #define sz(x) (ll)(x.size())
    
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<ll, ll> pii;
    typedef vector<ll> vi;
    typedef vector<pii> vii;
    
    const bool debug = false;
    
    vector<ll> reg = {0, 0, 0};
    vector<ll> output, A_bin;
    
    ll pointer = 0;
    
    ll get_operand(ll x) {
        if (x >= 0 && x <= 3) return x;
        if (x >= 4 && x <= 6) return reg[x-4];
        assert(false);
        return 0;
    }
    
    ll instruction(ll opcode, ll x) {
        ll res = LLONG_MAX;
        bool ok = true;
        ll combo = get_operand(x);
        if (opcode == 0) {
            ll num = reg[0];
            ll den = 1;
            rep(i, 0, combo) {
                den *= 2;
                if (den > num) break;
            }
            reg[0] = num/den;
        } else if (opcode == 1) {
            reg[1] = reg[1]^x; 
        } else if (opcode == 2) {
            reg[1] = Mod(combo, 8);
        } else if (opcode == 3) {
            if (reg[0] != 0) {
                ok = false;
                pointer = x;
            }
        } else if (opcode == 4) {
            reg[1] = reg[1]^reg[2];
        } else if (opcode == 5) {
            res = Mod(combo, 8);
            output.pb(res);
        } else if (opcode == 6) {
            ll num = reg[0];
            ll den = 1;
            rep(i, 0, combo) {
                den *= 2;
                if (den > num) break;
            }
            reg[1] = num/den;
        } else if (opcode == 7) {
            ll num = reg[0];
            ll den = 1;
            rep(i, 0, combo) {
                den *= 2;
                if (den > num) break;
            }
            reg[2] = num/den;
        } 
    
        if (ok) pointer+=2;
        return res;
    }
    
    bool same(vector<ll> instructions, ll A) {
        reg[0] = reg[1] = reg[2] = 0;
        reg[0] = A;
        output.clear();
        pointer = 0;
        int ind = 0;
        while(pointer < instructions.size()) {
            if (debug) {
                pr(instructions[pointer]);
                pr(instructions[pointer+1]);
            }
            ll res = instruction(instructions[pointer], instructions[pointer+1]);
            if (debug) pr(reg);
    
            if (instructions.size() < output.size()) return false;
            if (res < LLONG_MAX) {
                if (instructions[ind] != res) return false;
                ind++;
            }
        }
    
        if (instructions.size() != output.size()) return false;
        rep(i, 0, sz(instructions)) if (instructions[i] != output[i]) return false;
    
        return true;
    }
    
    vi generate(int a_3, int a_2, int a_1) {
        vector<int> a3s, a2s, a1s;
        if (a_3 == -1) a3s = {0, 1};
        else a3s = {a_3};
        if (a_2 == -1) a2s = {0, 1};
        else a2s = {a_2};
        if (a_1 == -1) a1s = {0, 1};
        else a1s = {a_1};
        
        vi res;
        for (int a3 : a3s) for (int a2 : a2s) for (int a1 : a1s) res.pb(4*a3 + 2*a2 + a1);
        return res;
    }
    
    vi to_bin(int x) {
        vi res = {(x>>2)&1, (x>>1)&1, x&1};
        return res;
    }
    
    bool valid(int ptr, int x) {
        rep(i, 0, 3) if (A_bin[ptr-i] != -1 && A_bin[ptr-i] != to_bin(x)[i]) return false;
        return true;
    }
    
    ll get_ans() {
        int i = 127;
        ll ans = 0;
        while(A_bin[i] != 1) {
            i--;
        }
        while(i >= 0) {
            ans = 2*ans + (A_bin[i] == 1 ? 1 : 0);
            i--;
        }
        return ans;
    }
    
    vector<ll> instructions;
    bool pos(int instruction_ind) {
        if (instruction_ind == sz(instructions)) {
            return same(instructions, get_ans());
        }
        bool res = false;
        int ptr = 3*instruction_ind+2;
        int a_3 = A_bin[ptr], a_2 = A_bin[ptr-1], a_1 = A_bin[ptr-2];
        int v0 = instructions[instruction_ind];
        vi minus_ones_pre;
        rep(i, 0, 3) if (A_bin[ptr-i] == -1) minus_ones_pre.pb(i);
    
        for (int a : generate(a_3, a_2, a_1)) {
            A_bin[ptr] = to_bin(a)[0];
            A_bin[ptr-1] = to_bin(a)[1];
            A_bin[ptr-2] = to_bin(a)[2];
    
            int X = v0^(a^6);
    
            int new_ptr = ptr+(a^3);
    
            if (valid(new_ptr, X)) {
                vi minus_ones;
                rep(i, 0, 3) if (A_bin[new_ptr-i] == -1) minus_ones.pb(i);
                rep(i, 0, 3) if (A_bin[new_ptr-i] = to_bin(X)[i]);
    
                if (pos(instruction_ind+1)) return true;
                for (int i : minus_ones) A_bin[new_ptr-i] = -1;
            }
        }
        for (int i : minus_ones_pre) A_bin[ptr-i] = -1;
        return false;
    }
    
    void solve() {
        cin >> reg[0] >> reg[1] >> reg[2];
        ll x;
        while(cin >> x) instructions.pb(x);
    
        rep(i, 0, 128) A_bin.pb(-1);
        bool ok = pos(0);
        ll ans = get_ans();
        pr(ans);
    }
    
    
    int main() {
        ios::sync_with_stdio(0);cin.tie(0);
        cout << setprecision(15) << fixed;
    
    #ifdef LOCAL
        freopen("input.txt", "r", stdin);
    #endif
    
        solve();
    }