
#include <algorithm>
#include <array>
#include <charconv>
#include <chrono>
#include <cstdint>
#include <fstream>
#include <memory>
#include <optional>
#include <ranges>
#include <string>
#include <string_view>
#include <variant>
#include <vector>
// I still hate iostream.
#include <iostream>

#include <llvm/ADT/BitVector.h>
#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/DenseSet.h>
#include <llvm/ADT/PostOrderIterator.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringMap.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/PatternMatch.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Support/Allocator.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Transforms/Utils/BasicBlockUtils.h>
#include <llvm/Transforms/Utils/Local.h>

#define UNLIKELY(x) (__builtin_expect((x), 0))

namespace {

struct SrcLoc {
  std::size_t start, end;

  SrcLoc operator|(const SrcLoc &other) const {
    return SrcLoc{std::min(start, other.start), std::max(end, other.end)};
  }
};

struct Error {
  SrcLoc loc;
  std::string desc;
};

struct Token {
  enum Kind {
    // clang-format off
    T_EOF, ERROR,

    // Operators
    COMMA, SEMICOLON, AT, LPAREN, RPAREN, LCURLY, RCURLY, LBRACKET, RBRACKET,
    LANGLE, RANGLE, EQ, EQEQ, LE, GE, SL, SR, INV, XOR, PLUS, PLUSEQ, MINUS,
    MINUSEQ, TIMES, TIMESEQ, DIV, DIVEQ, REM, REMEQ, NOT, NOTEQ, AND, ANDAND,
    OR, OROR,

    // Generic things..
    NUMBER, IDENT,

    // Keywords
    AUTO, REGISTER, IF, ELSE, WHILE, RETURN,

    MAX
    // clang-format on
  };

  Kind kind;
  SrcLoc loc;
};

class Tokenizer {
  const std::string_view prog;
  std::size_t pos;
  std::vector<std::size_t> lineBreaks;

public:
  Tokenizer(std::string_view prog) : prog(prog), pos(0) {}

private:
  enum { W = 1, L, I, N, S, Q, C };
  // Lookup table to classify characters. 0=invalid, W=whitespace, L=line break,
  // I=identifier start, N=number, S=special symbol, Q=quote, C=maybe comment
  static constexpr uint8_t CHAR_LUT[1 << CHAR_BIT] = {
      // clang-format off
      0,0,0,0,0,0,0,0,0,W,L,0,0,W,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
      W,S,Q,S,S,S,S,Q,S,S,S,S,S,S,S,C,N,N,N,N,N,N,N,N,N,N,S,S,S,S,S,S,
      S,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,S,S,S,S,I,
      S,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,S,S,S,S,0,
      // clang-format on
  };

  Token tokenizeIdentifier() {
    std::size_t start = pos;
    for (; pos < prog.size(); pos++) {
      auto lv = CHAR_LUT[(unsigned char)prog[pos]];
      if (lv != I && lv != N)
        break;
    }
    std::string_view val = prog.substr(start, pos - start);
    auto kind = Token::IDENT;
    // There are only a few keywords => the inefficient approach is good enough.
    if (val == "auto")
      kind = Token::AUTO;
    else if (val == "register")
      kind = Token::REGISTER;
    else if (val == "if")
      kind = Token::IF;
    else if (val == "else")
      kind = Token::ELSE;
    else if (val == "while")
      kind = Token::WHILE;
    else if (val == "return")
      kind = Token::RETURN;
    return Token{kind, SrcLoc{start, pos}};
  }

  Token tokenizeNumber() {
    std::size_t start = pos;
    while (pos < prog.size() && CHAR_LUT[(unsigned char)prog[pos]] == N)
      pos++;
    return Token{Token::NUMBER, SrcLoc{start, pos}};
  }

  Token tokenizeOperator() {
    auto kind = Token::ERROR;
    auto next = pos + 1 < prog.size() ? prog[pos + 1] : 0;
    std::size_t start = pos;
    std::size_t len = 1;
    switch (prog[pos]) {
    case ',': kind = Token::COMMA; break;
    case ';': kind = Token::SEMICOLON; break;
    case '@': kind = Token::AT; break;
    case '(': kind = Token::LPAREN; break;
    case ')': kind = Token::RPAREN; break;
    case '[': kind = Token::LBRACKET; break;
    case ']': kind = Token::RBRACKET; break;
    case '{': kind = Token::LCURLY; break;
    case '}': kind = Token::RCURLY; break;
    case '<':
      kind = next == '='   ? (len++, Token::LE)
             : next == '<' ? (len++, Token::SL)
                           : Token::LANGLE;
      break;
    case '>':
      kind = next == '='   ? (len++, Token::GE)
             : next == '>' ? (len++, Token::SR)
                           : Token::RANGLE;
      break;
    case '=': kind = next == '=' ? (len++, Token::EQEQ) : Token::EQ; break;
    case '!': kind = next == '=' ? (len++, Token::NOTEQ) : Token::NOT; break;
    case '+': kind = next == '=' ? (len++, Token::PLUSEQ) : Token::PLUS; break;
    case '-':
      kind = next == '=' ? (len++, Token::MINUSEQ) : Token::MINUS;
      break;
    case '*':
      kind = next == '=' ? (len++, Token::TIMESEQ) : Token::TIMES;
      break;
    case '/': kind = next == '=' ? (len++, Token::DIVEQ) : Token::DIV; break;
    case '%': kind = next == '=' ? (len++, Token::REMEQ) : Token::REM; break;
    case '~': kind = Token::INV; break;
    case '^': kind = Token::XOR; break;
    case '&': kind = next == '&' ? (len++, Token::ANDAND) : Token::AND; break;
    case '|': kind = next == '|' ? (len++, Token::OROR) : Token::OR; break;
    }
    pos += len;
    return Token{kind, {start, pos}};
  }

public:
  Token next() {
    while (pos < prog.size()) {
      unsigned char c = prog[pos];
      switch (CHAR_LUT[c]) {
      case W: pos++; continue;
      case L: lineBreaks.push_back(pos++); continue;
      case I: return tokenizeIdentifier();
      case N: return tokenizeNumber();
      case S: return tokenizeOperator();
      case C: // slash, might be a comment
        if (pos + 1 < prog.size() && prog[pos + 1] == '/')
          for (pos += 2; pos < prog.size() && prog[pos] != '\n'; pos++) {
          }
        else
          return tokenizeOperator();
        continue;
      default:
        return Token{Token::ERROR, {pos, pos + 1}};
        // Note: quotation marks ' and " (Q) are caught here, too.
      }
    }
    return Token{Token::T_EOF, {pos, pos}};
  }

  /// Find line and column of a character offset. Runtime O(log n).
  std::pair<std::size_t, std::size_t> locate(std::size_t pos) const {
    auto lineIt = std::lower_bound(lineBreaks.begin(), lineBreaks.end(), pos);
    std::size_t line = (lineIt - lineBreaks.begin());
    return std::make_pair(line, line ? pos - *(lineIt - 1) - 1 : pos);
  }

  /// Extract a line from the input. Runtime O(1) if tokenized, otw. O(n).
  std::string_view line(std::size_t line) const {
    std::size_t start = line ? lineBreaks[line - 1] + 1 : 0;
    if (line < lineBreaks.size())
      return prog.substr(start, lineBreaks[line] - start);
    // We not necessarily have tokenized the rest of the string.
    std::string_view lineStr = prog.substr(start);
    return lineStr.substr(0, lineStr.find('\n')); // no \n => whole line
  }
};

struct IdentDesc {
  uint64_t value;

  IdentDesc(uint64_t iid, bool addressable = false)
      : value(iid << 1 | addressable) {}
  bool addressable() const { return value & 1; }
  uint64_t id() const { return value >> 1; }
};

class ASTNode {
public:
  enum Kind {
    BLOCK,   // [statements...]
    DECLREG, // [INITIALIZER] ident
    DECLVAR, // [INITIALIZER] ident
    IF,      // [expression, then, else?]
    WHILE,   // [expression, body]
    NUMBER,  // numVal
    IDENT,   // ident
    CALL,    // [arguments...] fnId
    NEG,
    NOT,
    INV,
    ADDROF, // [exp]
    ADD,
    SUB,
    MUL,
    DIV,
    REM,
    SL,
    SR,
    LT,
    GT,
    LE,
    GE,
    EQ,
    NOTEQ,
    AND,
    OR,
    XOR,
    ANDAND,
    OROR,
    ASSIGN,    // [exp, exp],
    SUBSCRIPT, // [exp, exp],
    RETURN,    // [exp?]
    ERROR,     // error node, only for invalid inputs
  };

  const Kind kind;
  const SrcLoc loc;
  ASTNode *child;
  ASTNode *sibling;
  union {
    uint64_t granularity;
    uint64_t numVal;
    uint64_t fnId;
    IdentDesc ident;
  };

  ASTNode(Kind kind, SrcLoc loc, ASTNode *child)
      : kind(kind), loc(loc), child(child), sibling(nullptr) {}

  void printSexpr() const {
    std::string_view name = "???";
    switch (kind) {
    case IDENT: std::cout << "v" << ident.id(); return;
    case NUMBER:
      std::cout << numVal;
      return;

      // regular cases
    case BLOCK: name = "block"; break;
    case DECLREG: name = "register"; break;
    case DECLVAR: name = "auto"; break;
    case IF: name = "if"; break;
    case WHILE: name = "while"; break;
    case CALL: name = "call"; break;
    case NEG: name = "u-"; break;
    case NOT: name = "!"; break;
    case INV: name = "~"; break;
    case ADDROF: name = "&"; break;
    case ADD: name = "+"; break;
    case SUB: name = "-"; break;
    case MUL: name = "*"; break;
    case DIV: name = "/"; break;
    case REM: name = "%"; break;
    case SL: name = "<<"; break;
    case SR: name = ">>"; break;
    case LT: name = "<"; break;
    case GT: name = ">"; break;
    case LE: name = "<="; break;
    case GE: name = ">="; break;
    case EQ: name = "=="; break;
    case NOTEQ: name = "!="; break;
    case AND: name = "&"; break;
    case OR: name = "|"; break;
    case XOR: name = "^"; break;
    case ANDAND: name = "&&"; break;
    case OROR: name = "||"; break;
    case ASSIGN: name = "="; break;
    case SUBSCRIPT: name = "[]"; break;
    case RETURN: name = "return"; break;
    case ERROR: name = "ERROR"; break;
    }
    std::cout << "(" << name;
    if (kind == CALL)
      std::cout << " f" << fnId;
    if (kind == DECLVAR || kind == DECLREG)
      std::cout << " v" << ident.id();
    for (const ASTNode *cur = child; cur; cur = cur->sibling) {
      std::cout << " ";
      cur->printSexpr();
    }
    std::cout << ")";
  }
};

struct Function {
  /// Name
  std::string_view name;
  /// Number of parameters, identifiers 0..<numParams
  const unsigned numParams;
  /// Maximum identifier ID
  unsigned maxIid = 0;

  /// AST, or null if the function is just declared
  const ASTNode *ast = nullptr;

  Function(std::string_view name, unsigned numParams)
      : name(name), numParams(numParams) {}
};

struct Program {
  /// ASTNode allocator. Use a bump ptr allocator to avoid calling malloc/free
  /// for every created node.
  llvm::SpecificBumpPtrAllocator<ASTNode> nodeAllocator;

  /// All declared and defined functions
  std::vector<Function> funcs;
};

template <class T> class Scopes {
  struct IdentEntry {
    unsigned nest;
    unsigned gen;
    T payload;
  };

  unsigned currentNest = 0;
  // LLVM's data structures are much better than libstdc++ (or libc++).
  llvm::SmallVector<unsigned> gens;
  llvm::StringMap<llvm::SmallVector<IdentEntry>> map;

public:
  Scopes() : gens(1), map() {}

  void nest() {
    currentNest++;
    if (gens.size() <= currentNest)
      gens.push_back(0);
  }
  void unnest() { gens[currentNest--]++; }
  bool tryDeclare(std::string_view name, T payload) {
    auto &entry = lookupImpl(name);
    if (!entry.empty() && entry.back().nest == currentNest)
      return false;
    entry.emplace_back(currentNest, gens[currentNest], payload);
    return true;
  }
  T *lookup(std::string_view name) {
    auto &entry = lookupImpl(name);
    return entry.empty() ? nullptr : &entry.back().payload;
  }

private:
  decltype(map)::mapped_type &lookupImpl(std::string_view name) {
    auto &entry = map[name];
    while (!entry.empty() && gens[entry.back().nest] != entry.back().gen)
      entry.pop_back();
    return entry;
  }
};

class Parser {
  Program program;

  std::string_view prog;
  Tokenizer t;
  std::optional<Token> peekedToken;
  std::vector<Error> errors;

  uint64_t nextIid = 0;
  Scopes<IdentDesc> scope;

  // Map of function name to function id
  llvm::StringMap<std::size_t> funcMap;

public:
  Parser(std::string_view prog) : prog(prog), t(prog) {}

private:
  Token next() {
    if (peekedToken) {
      Token res = peekedToken.value();
      peekedToken = std::nullopt;
      return res;
    }

    return t.next();
  }
  Token peek() {
    if (!peekedToken)
      peekedToken = next();
    return peekedToken.value();
  }
  bool eof() { return peek().kind == Token::T_EOF; }
  Token expectNext(Token::Kind kind) {
    Token nextToken = next();
    if (UNLIKELY(nextToken.kind != kind)) {
      errors.emplace_back(nextToken.loc, "unexpected token");
      do {
        nextToken = next();
      } while (nextToken.kind != kind && nextToken.kind != Token::T_EOF);
    }
    return nextToken;
  }

  std::string_view tokenValue(Token tok) {
    return prog.substr(tok.loc.start, tok.loc.end - tok.loc.start);
  }

  ASTNode *createNode(ASTNode::Kind kind, SrcLoc loc,
                      ASTNode *children = nullptr) {
    return new (program.nodeAllocator.Allocate()) ASTNode(kind, loc, children);
  }

  /// Declare function with number of parameters and return function id.
  unsigned declareFunc(SrcLoc loc, std::string_view name, std::size_t params) {
    auto [it, inserted] = funcMap.try_emplace(name, program.funcs.size());
    if (inserted)
      program.funcs.emplace_back(Function(name, params));
    else if (program.funcs[it->second].numParams != params)
      errors.emplace_back(loc, "function redeclared with parameter mismatch");
    return it->second;
  }

  ASTNode *parsePrimaryExpression() {
    Token tok = next();
    switch (tok.kind) {
      ASTNode::Kind nodeKind;

    case Token::LPAREN: {
      auto exp = parseExpression();
      (void)expectNext(Token::RPAREN);
      return exp;
    }
    case Token::IDENT: {
      std::string_view identName = tokenValue(tok);
      if (peek().kind != Token::LPAREN) {
        auto node = createNode(ASTNode::IDENT, tok.loc);
        if (IdentDesc *var = scope.lookup(identName))
          node->ident = *var;
        else
          errors.emplace_back(tok.loc, "undeclared variable");
        return node;
      }

      // function call
      (void)next();
      ASTNode *args = nullptr;
      ASTNode **lastArgPtr = &args;
      std::size_t argCount = 0;
      while (!eof() && peek().kind != Token::RPAREN) {
        if (args)
          expectNext(Token::COMMA);
        *lastArgPtr = parseExpression();
        lastArgPtr = &(*lastArgPtr)->sibling;
        argCount++;
      }
      Token rparen = expectNext(Token::RPAREN);

      unsigned fnId = declareFunc(tok.loc | rparen.loc, identName, argCount);
      auto node = createNode(ASTNode::CALL, tok.loc | rparen.loc, args);
      node->fnId = fnId;
      return node;
    }
    case Token::NUMBER: {
      auto node = createNode(ASTNode::NUMBER, tok.loc);
      auto str = tokenValue(tok);
      auto convres =
          std::from_chars(str.data(), str.data() + str.size(), node->numVal);
      (void)convres; // TODO: overflow handling
      return node;
    }
    case Token::MINUS: nodeKind = ASTNode::NEG; goto unaryCommon;
    case Token::NOT: nodeKind = ASTNode::NOT; goto unaryCommon;
    case Token::INV: nodeKind = ASTNode::INV; goto unaryCommon;
    case Token::AND:
      nodeKind = ASTNode::ADDROF;
      goto unaryCommon;
    unaryCommon: {
      auto exp = parseExpression(13);
      auto loc = tok.loc | exp->loc;
      if (nodeKind == ASTNode::ADDROF) {
        if (exp->kind != ASTNode::IDENT && exp->kind != ASTNode::SUBSCRIPT)
          errors.emplace_back(loc, "cannot take address of non-lvalue");
        if (exp->kind == ASTNode::IDENT && !exp->ident.addressable())
          errors.emplace_back(loc, "cannot take address of register");
      }
      return createNode(nodeKind, loc, exp);
    }

    default:
      errors.emplace_back(tok.loc, "invalid token for expression");
      // Just try to keep going...
      return createNode(ASTNode::ERROR, tok.loc);
    }
  }

  static constexpr std::array<std::tuple<ASTNode::Kind, uint8_t, bool>,
                              Token::MAX>
  getOperators() {
    std::array<std::tuple<ASTNode::Kind, uint8_t, bool>, Token::MAX> res;
    res[Token::LBRACKET] = {ASTNode::SUBSCRIPT, 14, false};
    res[Token::TIMES] = {ASTNode::MUL, 12, false};
    res[Token::DIV] = {ASTNode::DIV, 12, false};
    res[Token::REM] = {ASTNode::REM, 12, false};
    res[Token::PLUS] = {ASTNode::ADD, 11, false};
    res[Token::MINUS] = {ASTNode::SUB, 11, false};
    res[Token::SL] = {ASTNode::SL, 10, false};
    res[Token::SR] = {ASTNode::SR, 10, false};
    res[Token::LANGLE] = {ASTNode::LT, 9, false};
    res[Token::RANGLE] = {ASTNode::GT, 9, false};
    res[Token::LE] = {ASTNode::LE, 9, false};
    res[Token::GE] = {ASTNode::GE, 9, false};
    res[Token::EQEQ] = {ASTNode::EQ, 8, false};
    res[Token::NOTEQ] = {ASTNode::NOTEQ, 8, false};
    res[Token::AND] = {ASTNode::AND, 7, false};
    res[Token::XOR] = {ASTNode::XOR, 6, false};
    res[Token::OR] = {ASTNode::OR, 5, false};
    res[Token::ANDAND] = {ASTNode::ANDAND, 4, false};
    res[Token::OROR] = {ASTNode::OROR, 3, false};
    res[Token::EQ] = {ASTNode::ASSIGN, 2, true};
    return res;
  }

  ASTNode *parseExpression(int minPrec = 0) {
    ASTNode *lhs = parsePrimaryExpression();
    while (true) {
      static constexpr auto OPERATORS = getOperators();

      Token op = peek();
      auto [nodeKind, prec, rassoc] = OPERATORS[op.kind];
      if (!prec || prec < minPrec)
        break;
      (void)next();

      if (nodeKind == ASTNode::SUBSCRIPT) {
        lhs->sibling = parseExpression();
        int granularity = 0;
        if (peek().kind == Token::AT) {
          (void)next();
          Token granTok = expectNext(Token::NUMBER);
          std::string_view granStr = tokenValue(granTok);
          if (granStr == "1")
            granularity = 1;
          else if (granStr == "2")
            granularity = 2;
          else if (granStr == "4")
            granularity = 4;
          else if (granStr == "8")
            granularity = 8;
          else
            errors.emplace_back(granTok.loc, "invalid granularity");
        }
        Token rbracket = expectNext(Token::RBRACKET);
        auto loc = lhs->loc | rbracket.loc;
        lhs = createNode(nodeKind, loc, lhs);
        lhs->granularity = granularity;
        continue;
      }

      lhs->sibling = parseExpression(prec + !rassoc);
      if (nodeKind == ASTNode::ASSIGN) {
        if (lhs->kind != ASTNode::IDENT && lhs->kind != ASTNode::SUBSCRIPT)
          errors.emplace_back(lhs->loc, "lhs of assignment must be lvalue");
      }

      auto loc = lhs->loc | lhs->sibling->loc;
      lhs = createNode(nodeKind, loc, lhs);
    }
    return lhs;
  }

  ASTNode *parseStatement() {
    switch (peek().kind) {
    case Token::LCURLY: return parseBlock();
    case Token::WHILE: {
      Token start = next();
      (void)expectNext(Token::LPAREN);
      auto exp = parseExpression();
      (void)expectNext(Token::RPAREN);
      exp->sibling = parseStatement();
      SrcLoc loc = start.loc | exp->sibling->loc;
      return createNode(ASTNode::WHILE, loc, exp);
    }
    case Token::IF: {
      SrcLoc startLoc = next().loc;
      (void)expectNext(Token::LPAREN);
      ASTNode *cond = parseExpression();
      (void)expectNext(Token::RPAREN);
      cond->sibling = parseStatement();
      SrcLoc endLoc = cond->sibling->loc;
      if (peek().kind == Token::ELSE) {
        next();
        cond->sibling->sibling = parseStatement();
        endLoc = cond->sibling->sibling->loc;
      }
      return createNode(ASTNode::IF, startLoc | endLoc, cond);
    }
    case Token::RETURN: {
      SrcLoc startLoc = next().loc;
      ASTNode *child = nullptr;
      if (peek().kind != Token::SEMICOLON)
        child = parseExpression();
      SrcLoc semiLoc = expectNext(Token::SEMICOLON).loc;
      return createNode(ASTNode::RETURN, startLoc | semiLoc, child);
    }
    default: {
      ASTNode *node = parseExpression();
      (void)expectNext(Token::SEMICOLON);
      return node;
    }
    }
  }

  ASTNode *parseDeclStatement() {
    switch (peek().kind) {
    case Token::AUTO:
    case Token::REGISTER: {
      Token storage = next();
      Token nameTok = expectNext(Token::IDENT);
      (void)expectNext(Token::EQ);
      ASTNode *exp = parseExpression();
      Token semi = expectNext(Token::SEMICOLON);

      std::string_view name = tokenValue(nameTok);
      bool addressable = storage.kind != Token::REGISTER;

      IdentDesc desc{nextIid++, addressable};
      if (!scope.tryDeclare(name, desc))
        errors.emplace_back(storage.loc | nameTok.loc,
                            "redundant variable declaration");

      auto nodeKind = addressable ? ASTNode::DECLVAR : ASTNode::DECLREG;
      auto node = createNode(nodeKind, storage.loc | semi.loc, exp);
      node->ident = desc;
      return node;
    }
    default: return parseStatement();
    }
  }

  ASTNode *parseBlock(bool nestScope = true) {
    if (nestScope)
      scope.nest();
    SrcLoc lcurlyLoc = expectNext(Token::LCURLY).loc;
    ASTNode *stmts{nullptr};
    ASTNode **lastStmtPtr = &stmts;
    while (!eof() && peek().kind != Token::RCURLY) {
      *lastStmtPtr = parseDeclStatement();
      lastStmtPtr = &(*lastStmtPtr)->sibling;
    }
    SrcLoc rcurlyLoc = expectNext(Token::RCURLY).loc;
    if (nestScope)
      scope.unnest();
    return createNode(ASTNode::BLOCK, lcurlyLoc | rcurlyLoc, stmts);
  }

  void parseFunction() {
    scope.nest();

    // identifier ids are unique within a function.
    nextIid = 0;

    Token nameTok = expectNext(Token::IDENT);
    std::string_view name = tokenValue(nameTok);

    (void)expectNext(Token::LPAREN);
    while (!eof() && peek().kind != Token::RPAREN) {
      if (nextIid)
        (void)expectNext(Token::COMMA);

      Token ident = expectNext(Token::IDENT);
      IdentDesc desc{nextIid++, /*addressable=*/false};
      if (!scope.tryDeclare(tokenValue(ident), desc))
        errors.emplace_back(ident.loc, "redundant parameter declaration");
    }
    Token rparen = expectNext(Token::RPAREN);

    unsigned fnid = declareFunc(nameTok.loc | rparen.loc, name, nextIid);

    // Don't nest scope for the block, parameters belong to the innermost scope.
    program.funcs[fnid].ast = parseBlock(/*nestScope=*/false);
    program.funcs[fnid].maxIid = nextIid;

    scope.unnest();
  }

public:
  std::optional<Program> parseProgram() {
    while (!eof())
      parseFunction();
    if (errors.empty())
      return std::move(program);
    for (const Error &error : errors) {
      auto [lineStart, colStart] = t.locate(error.loc.start);
      std::string_view line = t.line(lineStart);
      // Only show a single line, if lineEnd is different, assume line.size().
      auto [lineEnd, colEnd] = t.locate(error.loc.end);
      colEnd = lineStart == lineEnd ? colEnd : line.size();

      std::size_t markLen = colEnd != colStart ? colEnd - colStart - 1 : 0;
      std::cerr << "input:" << (lineStart + 1) << ":" << (colStart + 1) << ": "
                << error.desc << "\n";
      std::cerr << line << "\n";
      std::cerr << std::string(colStart, ' ') << "^"
                << std::string(markLen, '~') << "\n";
    }
    return std::nullopt;
  }
};

std::string readFile(std::string_view path) {
  auto stream = std::ifstream(path.data(), std::ios::in);
  stream.seekg(0, std::ios::end);
  auto size = stream.tellg();
  stream.seekg(0, std::ios::beg);
  stream.clear();
  if (size != -1) {
    std::vector<char> data(size);
    stream.read(&data[0], size);
    return std::string(&data[0], size);
  }

  std::string res;
  std::vector<char> buf(0x2000);
  do {
    stream.read(&buf[0], buf.size());
    res.append(&buf[0], 0, stream.gcount());
  } while (stream.gcount());
  return res;
}

class LLVMIRGen {
  llvm::LLVMContext &ctx;
  llvm::Module *mod;
  llvm::ArrayRef<llvm::Function *> fns;
  llvm::Function *fn;
  llvm::IRBuilder<> irb;

  using VarBlockMap = llvm::DenseMap<llvm::BasicBlock *, llvm::Value *>;
  using VarDesc = std::variant<std::monostate, VarBlockMap, llvm::AllocaInst *>;
  llvm::SmallVector<VarDesc, 0> varMap;

  llvm::DenseSet<llvm::BasicBlock *> unsealedBlocks;
  using IncompletePhi = std::pair<uint64_t, llvm::PHINode *>;
  llvm::DenseMap<llvm::BasicBlock *, llvm::SmallVector<IncompletePhi>>
      incompletePhis;

  LLVMIRGen(llvm::ArrayRef<llvm::Function *> fns, llvm::Function *fn)
      : ctx(fn->getContext()), mod(fn->getParent()), fns(fns), fn(fn),
        irb(ctx) {}

  void addPhiOperands(uint64_t iid, llvm::PHINode *phi) {
    for (llvm::BasicBlock *pred : llvm::predecessors(phi->getParent()))
      phi->addIncoming(readVar(iid, pred), pred);
  }

  void writeVar(uint64_t iid, llvm::BasicBlock *block, llvm::Value *val) {
    if (auto aip = std::get_if<llvm::AllocaInst *>(&varMap[iid])) {
      irb.CreateStore(val, *aip);
    } else {
      auto &bm = std::get<VarBlockMap>(varMap[iid]);
      bm[block] = val;
    }
  }

  llvm::Value *readVar(uint64_t iid, llvm::BasicBlock *block) {
    if (auto aip = std::get_if<llvm::AllocaInst *>(&varMap[iid]))
      return irb.CreateLoad(irb.getInt64Ty(), *aip);

    auto &bm = std::get<VarBlockMap>(varMap[iid]);
    auto bit = bm.find(block);
    if (bit != bm.end())
      return bit->second;

    bool sealed = !unsealedBlocks.contains(block);
    auto *pred = block->getUniquePredecessor();
    if (pred && sealed) {
      llvm::Value *predVal = readVar(iid, pred);
      bm[block] = predVal;
      return predVal;
    }

    llvm::IRBuilder<> phiIrb(block, block->begin());
    llvm::PHINode *phi = phiIrb.CreatePHI(irb.getInt64Ty(), 2);
    bm[block] = phi;
    if (sealed)
      addPhiOperands(iid, phi);
    else
      incompletePhis[block].emplace_back(iid, phi);
    return phi;
  }

  std::pair<llvm::Type *, llvm::Value *> getSubscriptAddr(const ASTNode &node) {
    llvm::Type *ty =
        irb.getIntNTy(node.granularity ? node.granularity * 8 : 64);
    llvm::Value *base = genValue(*node.child);
    llvm::Value *idx = genValue(*node.child->sibling);
    base = irb.CreateIntToPtr(base, llvm::PointerType::get(ctx, 0));
    if (auto *cstIdx = llvm::dyn_cast<llvm::ConstantInt>(idx);
        cstIdx && cstIdx->isZero())
      return {ty, base};
    return {ty, irb.CreateGEP(ty, base, {idx})};
  }

  void changeBlock(llvm::BasicBlock *bb) {
    // Reorder blocks so that the LLVM-IR follows the program order. This is not
    // required (or even beneficial other than manual inspection of the code).
    bb->moveAfter(irb.GetInsertBlock());
    irb.SetInsertPoint(bb);
  }

  /// Generate a conditional branch based on the expression in the ASTNode.
  void genCondBr(const ASTNode &node, llvm::BasicBlock *thenBB,
                 llvm::BasicBlock *elseBB) {
    if (node.kind == ASTNode::ANDAND) {
      auto secondBB = llvm::BasicBlock::Create(ctx, "", fn);
      genCondBr(*node.child, secondBB, elseBB);
      changeBlock(secondBB);
      genCondBr(*node.child->sibling, thenBB, elseBB);
    } else if (node.kind == ASTNode::OROR) {
      auto secondBB = llvm::BasicBlock::Create(ctx, "", fn);
      genCondBr(*node.child, thenBB, secondBB);
      changeBlock(secondBB);
      genCondBr(*node.child->sibling, thenBB, elseBB);
    } else {
      llvm::Value *cond = genValueAny(node);
      if (cond->getType() != irb.getInt1Ty())
        cond = irb.CreateIsNotNull(cond);
      irb.CreateCondBr(cond, thenBB, elseBB);
    }
  }

  /// Generate LLVM-IR code for an ASTNode and return the result value, or null.
  /// Result is either i64 for integer operations or i1 for logical operations.
  llvm::Value *genValueAny(const ASTNode &node) {
    if (irb.GetInsertBlock()->getTerminator())
      return llvm::UndefValue::get(irb.getInt64Ty());

    switch (node.kind) {
      llvm::Value *val;
      llvm::Instruction::BinaryOps binOp;
      llvm::CmpInst::Predicate cmpPred;

    case ASTNode::IDENT: return readVar(node.ident.id(), irb.GetInsertBlock());
    case ASTNode::NUMBER: return irb.getInt64(node.numVal);
    case ASTNode::ASSIGN:
      val = genValue(*node.child->sibling);
      if (node.child->kind == ASTNode::IDENT) {
        writeVar(node.child->ident.id(), irb.GetInsertBlock(), val);
      } else if (node.child->kind == ASTNode::SUBSCRIPT) {
        auto [ty, ptr] = getSubscriptAddr(*node.child);
        irb.CreateStore(irb.CreateSExtOrTrunc(val, ty), ptr);
      }
      return val;
    case ASTNode::SUBSCRIPT: {
      auto [ty, ptr] = getSubscriptAddr(node);
      return irb.CreateSExtOrTrunc(irb.CreateLoad(ty, ptr), irb.getInt64Ty());
    }
    case ASTNode::ADDROF: {
      if (node.child->kind == ASTNode::IDENT) {
        auto ptr = std::get<llvm::AllocaInst *>(varMap[node.child->ident.id()]);
        return irb.CreatePtrToInt(ptr, irb.getInt64Ty());
      } else if (node.child->kind == ASTNode::SUBSCRIPT) {
        auto [_, ptr] = getSubscriptAddr(*node.child);
        return irb.CreatePtrToInt(ptr, irb.getInt64Ty());
      }
      break;
    }
    case ASTNode::NEG: return irb.CreateNeg(genValue(*node.child));
    case ASTNode::INV: return irb.CreateNot(genValue(*node.child));
    case ASTNode::NOT:
      if (auto *v = genValueAny(*node.child); v->getType()->isIntegerTy(1))
        return irb.CreateNot(v); // i1
      else
        return irb.CreateIsNull(v); // i1

    case ASTNode::ADD: binOp = llvm::Instruction::Add; goto binOpCommon;
    case ASTNode::SUB: binOp = llvm::Instruction::Sub; goto binOpCommon;
    case ASTNode::MUL: binOp = llvm::Instruction::Mul; goto binOpCommon;
    case ASTNode::DIV: binOp = llvm::Instruction::SDiv; goto binOpCommon;
    case ASTNode::REM: binOp = llvm::Instruction::SRem; goto binOpCommon;
    case ASTNode::SL: binOp = llvm::Instruction::Shl; goto binOpCommon;
    case ASTNode::SR: binOp = llvm::Instruction::AShr; goto binOpCommon;
    case ASTNode::AND: binOp = llvm::Instruction::And; goto binOpCommon;
    case ASTNode::OR: binOp = llvm::Instruction::Or; goto binOpCommon;
    case ASTNode::XOR:
      binOp = llvm::Instruction::Xor;
      goto binOpCommon;
    binOpCommon:
      return irb.CreateBinOp(binOp, genValue(*node.child),
                             genValue(*node.child->sibling));

    case ASTNode::EQ: cmpPred = llvm::CmpInst::ICMP_EQ; goto cmpCommon;
    case ASTNode::NOTEQ: cmpPred = llvm::CmpInst::ICMP_NE; goto cmpCommon;
    case ASTNode::LT: cmpPred = llvm::CmpInst::ICMP_SLT; goto cmpCommon;
    case ASTNode::GT: cmpPred = llvm::CmpInst::ICMP_SGT; goto cmpCommon;
    case ASTNode::LE: cmpPred = llvm::CmpInst::ICMP_SLE; goto cmpCommon;
    case ASTNode::GE:
      cmpPred = llvm::CmpInst::ICMP_SGE;
      goto cmpCommon;
    cmpCommon:
      return irb.CreateICmp(cmpPred, genValue(*node.child),
                            genValue(*node.child->sibling)); // i1

    case ASTNode::OROR:
    case ASTNode::ANDAND: {
      auto thenBB = llvm::BasicBlock::Create(ctx, "", fn);
      auto elseBB = llvm::BasicBlock::Create(ctx, "", fn);
      genCondBr(node, thenBB, elseBB);

      auto contBB = llvm::BasicBlock::Create(ctx, "", fn);
      changeBlock(thenBB);
      irb.CreateBr(contBB);
      changeBlock(elseBB);
      irb.CreateBr(contBB);
      changeBlock(contBB);
      llvm::PHINode *phi = irb.CreatePHI(irb.getInt64Ty(), 2);
      phi->addIncoming(irb.getInt64(1), thenBB);
      phi->addIncoming(irb.getInt64(0), elseBB);
      return phi;
    }
    case ASTNode::CALL: {
      llvm::SmallVector<llvm::Value *, 4> args;
      for (const ASTNode *arg = node.child; arg; arg = arg->sibling)
        args.push_back(genValue(*arg));
      return irb.CreateCall(fns[node.fnId]->getFunctionType(), fns[node.fnId],
                            args);
    }

    // Statements
    case ASTNode::BLOCK:
      for (const ASTNode *child = node.child; child; child = child->sibling)
        (void)genValueAny(*child);
      return nullptr;
    case ASTNode::DECLREG:
      val = genValue(*node.child);
      varMap[node.ident.id()] = VarBlockMap{{irb.GetInsertBlock(), val}};
      return nullptr;
    case ASTNode::DECLVAR: {
      llvm::BasicBlock *entryBB = &fn->getEntryBlock();
      llvm::IRBuilder<> allocaIrb(entryBB, entryBB->begin());
      llvm::AllocaInst *alloca = allocaIrb.CreateAlloca(irb.getInt64Ty());
      varMap[node.ident.id()] = alloca;
      irb.CreateStore(genValue(*node.child), alloca);
      return nullptr;
    }
    case ASTNode::IF: {
      auto thenBB = llvm::BasicBlock::Create(ctx, "", fn);
      llvm::BasicBlock *elseBB = nullptr;
      if (node.child->sibling->sibling)
        elseBB = llvm::BasicBlock::Create(ctx, "", fn);
      auto contBB = llvm::BasicBlock::Create(ctx, "", fn);

      genCondBr(*node.child, thenBB, elseBB ? elseBB : contBB);

      changeBlock(thenBB);
      (void)genValueAny(*node.child->sibling);
      if (!irb.GetInsertBlock()->getTerminator())
        irb.CreateBr(contBB);
      if (elseBB) {
        changeBlock(elseBB);
        (void)genValueAny(*node.child->sibling->sibling);
        if (!irb.GetInsertBlock()->getTerminator())
          irb.CreateBr(contBB);
      }
      // In case both branches interrupt control flow...
      if (contBB->hasNPredecessors(0))
        contBB->eraseFromParent();
      else
        changeBlock(contBB);
      return nullptr;
    }
    case ASTNode::WHILE: {
      auto headerBB = llvm::BasicBlock::Create(ctx, "", fn);
      auto bodyBB = llvm::BasicBlock::Create(ctx, "", fn);
      auto contBB = llvm::BasicBlock::Create(ctx, "", fn);

      // Header is missing the predecessor from the loop body.
      unsealedBlocks.insert(headerBB);

      irb.CreateBr(headerBB);
      changeBlock(headerBB);
      genCondBr(*node.child, bodyBB, contBB);
      changeBlock(bodyBB);
      (void)genValueAny(*node.child->sibling);
      if (!irb.GetInsertBlock()->getTerminator())
        irb.CreateBr(headerBB);

      // Seal headerBB
      auto ipit = incompletePhis.find(headerBB);
      if (ipit != incompletePhis.end()) {
        for (const auto &[iid, phi] : ipit->second)
          addPhiOperands(iid, phi);
        incompletePhis.erase(ipit);
      }
      unsealedBlocks.erase(headerBB);

      changeBlock(contBB);
      return nullptr;
    }
    case ASTNode::RETURN:
      if (node.child)
        irb.CreateRet(genValue(*node.child));
      else
        irb.CreateRet(llvm::UndefValue::get(irb.getInt64Ty()));
      return nullptr;
    case ASTNode::ERROR:
      assert(false && "error node in valid AST!");
      return nullptr;
    }
    return nullptr;
  }

  /// Generate value for ASTNode as 64-bit integer
  llvm::Value *genValue(const ASTNode &node) {
    llvm::Value *v = genValueAny(node);
    if (v->getType() == irb.getInt64Ty())
      return v;
    return irb.CreateZExt(v, irb.getInt64Ty());
  }

  void genFunction(const Function &func) {
    llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(ctx, "", fn);
    irb.SetInsertPoint(entryBB);
    varMap.resize(func.maxIid);
    for (unsigned i = 0; i < func.numParams; i++)
      varMap[i] = VarBlockMap{{entryBB, fn->getArg(i)}};

    (void)genValueAny(*func.ast);

    if (!irb.GetInsertBlock()->getTerminator())
      irb.CreateRet(llvm::UndefValue::get(irb.getInt64Ty()));

    varMap.clear();
  }

public:
  static std::unique_ptr<llvm::Module> genIR(llvm::LLVMContext &ctx,
                                             const Program &prog) {
    auto modUP = std::make_unique<llvm::Module>("mod", ctx);
    llvm::Module *mod = modUP.get();

    llvm::Type *i64 = llvm::Type::getInt64Ty(ctx);
    auto linkage = llvm::GlobalValue::ExternalLinkage;

    llvm::SmallVector<llvm::Function *> irFuncs;
    irFuncs.reserve(prog.funcs.size());
    llvm::SmallVector<llvm::Type *, 4> argTys;
    for (const Function &func : prog.funcs) {
      argTys.resize(func.numParams, i64);
      auto fnTy = llvm::FunctionType::get(i64, argTys, false);
      irFuncs.push_back(llvm::Function::Create(fnTy, linkage, func.name, mod));
    }
    for (unsigned i = 0; i < prog.funcs.size(); i++)
      if (prog.funcs[i].ast)
        LLVMIRGen{irFuncs, irFuncs[i]}.genFunction(prog.funcs[i]);

    return modUP;
  }
};

// TARGET_OP(name, retinfo, opinfo, clobbers)
//   retinfo: 0=void; 1=reg; 2=reg-fixed (bits 7:4 indicate register);
//            3=reg-tied (must be identical to first op), 4=terminator.
//   optypes: 8-bit units, elements: 0=none/end; 1=reg; 2=reg-fixed (bits 7:4
//            indicate register); 3=reg-opt (can be unset, for memory operands);
//            4=imm32; 5=imm64; 6=imm-scale (1/2/4/8, for memory operands);
//            7=basic block, 8=call target.
//   clobbers: bit set of registers that are clobbered.
//
// Notes:
// - XOR64zero is special; while it does have operands in reality, they are not
//   read. We therefore model this as instruction that doesn't have inputs.
// - Memory operands are always 0x04030603 (base, scale, idx, off).
#define TARGET_OPS                                                             \
  TARGET_OP(PseudoPhi, 0x01, 0x0105, 0)           /* Phi block:r64 */          \
  TARGET_OP(PseudoRegArg, 0x01, 0x05, 0)          /* Register argument i */    \
  TARGET_OP(PseudoSP, 0x42, 0, 0)                 /* Stack pointer */          \
  TARGET_OP(PseudoFP, 0x52, 0, 0)                 /* Frame pointer */          \
  TARGET_OP(RET, 0x04, 0x02, 0)                   /* RET (rax) */              \
  TARGET_OP(JMP, 0x04, 0x07, 0)                   /* JMP target */             \
  TARGET_OP(JCC, 0x04, 0x070705, 0)               /* Jcc cond, then, else */   \
  TARGET_OP(CALL, 0x02, 0x92821222627208, 0x0fc7) /* CALL target, r64... */    \
  TARGET_OP(SETCC8r, 0x01, 0x05, 0)               /* SETCC r8, cond */         \
  TARGET_OP(LEA64rm, 0x01, 0x04030603, 0)         /* LEA r64, mem */           \
  TARGET_OP(MOV64rr, 0x01, 0x01, 0)               /* MOV r64, r64 */           \
  TARGET_OP(MOV64ri, 0x01, 0x05, 0)               /* MOV r64, imm */           \
  TARGET_OP(MOV64rm, 0x01, 0x04030603, 0)         /* MOV r64, mem */           \
  TARGET_OP(MOV64mr, 0x00, 0x0104030603, 0)       /* MOV mem, r64 */           \
  TARGET_OP(MOV64mi, 0x00, 0x0404030603, 0)       /* MOV mem, imm */           \
  TARGET_OP(MOV32rr, 0x01, 0x01, 0)               /* MOV r32, r32 */           \
  TARGET_OP(MOV32ri, 0x01, 0x05, 0)               /* MOV r32, imm */           \
  TARGET_OP(MOV32rm, 0x01, 0x04030603, 0)         /* MOV r32, mem */           \
  TARGET_OP(MOV32mr, 0x00, 0x0104030603, 0)       /* MOV mem, r32 */           \
  TARGET_OP(MOV32mi, 0x00, 0x0404030603, 0)       /* MOV mem, imm */           \
  TARGET_OP(MOV16mr, 0x00, 0x0104030603, 0)       /* MOV mem, r16 */           \
  TARGET_OP(MOV16mi, 0x00, 0x0404030603, 0)       /* MOV mem, imm */           \
  TARGET_OP(MOV8mr, 0x00, 0x0104030603, 0)        /* MOV mem, r8 */            \
  TARGET_OP(MOV8mi, 0x00, 0x0404030603, 0)        /* MOV mem, imm */           \
  TARGET_OP(MOVZXB32rr, 0x01, 0x01, 0)            /* MOVZX r32, r8 */          \
  TARGET_OP(MOVZXB32rm, 0x01, 0x04030603, 0)      /* MOVZX r32, mem8 */        \
  TARGET_OP(MOVZXW32rr, 0x01, 0x01, 0)            /* MOVZX r32, r16 */         \
  TARGET_OP(MOVZXW32rm, 0x01, 0x04030603, 0)      /* MOVZX r32, mem16 */       \
  TARGET_OP(MOVSXB64rr, 0x01, 0x01, 0)            /* MOVSX r64, r8 */          \
  TARGET_OP(MOVSXB64rm, 0x01, 0x04030603, 0)      /* MOVSX r64, mem8 */        \
  TARGET_OP(MOVSXW64rr, 0x01, 0x01, 0)            /* MOVSX r64, r16 */         \
  TARGET_OP(MOVSXW64rm, 0x01, 0x04030603, 0)      /* MOVSX r64, mem16 */       \
  TARGET_OP(MOVSXD64rr, 0x01, 0x01, 0)            /* MOVSX r64, r32 */         \
  TARGET_OP(MOVSXD64rm, 0x01, 0x04030603, 0)      /* MOVSX r64, mem32 */       \
  TARGET_OP(NEG64r, 0x03, 0x01, 0)                /* NEG r64 */                \
  TARGET_OP(NOT64r, 0x03, 0x01, 0)                /* NOT r64 */                \
  TARGET_OP(ADD64rr, 0x03, 0x0101, 0)             /* ADD r64, r64 */           \
  TARGET_OP(ADD64ri, 0x03, 0x0401, 0)             /* ADD r64, imm */           \
  TARGET_OP(ADD64rm, 0x03, 0x0403060301, 0)       /* ADD r64, mem */           \
  TARGET_OP(ADD64mr, 0x00, 0x0104030603, 0)       /* ADD mem, r64 */           \
  TARGET_OP(ADD64mi, 0x00, 0x0404030603, 0)       /* ADD mem, imm */           \
  TARGET_OP(SUB64rr, 0x03, 0x0101, 0)             /* SUB r64, r64 */           \
  TARGET_OP(SUB64ri, 0x03, 0x0401, 0)             /* SUB r64, imm */           \
  TARGET_OP(SUB64rm, 0x03, 0x0403060301, 0)       /* SUB r64, mem */           \
  TARGET_OP(SUB64mr, 0x00, 0x0104030603, 0)       /* SUB mem, r64 */           \
  TARGET_OP(SUB64mi, 0x00, 0x0404030603, 0)       /* SUB mem, imm */           \
  TARGET_OP(XOR64rr, 0x03, 0x0101, 0)             /* XOR r64, r64 */           \
  TARGET_OP(XOR64ri, 0x03, 0x0401, 0)             /* XOR r64, imm */           \
  TARGET_OP(XOR64rm, 0x03, 0x0403060301, 0)       /* XOR r64, mem */           \
  TARGET_OP(XOR64mr, 0x00, 0x0104030603, 0)       /* XOR mem, r64 */           \
  TARGET_OP(XOR64mi, 0x00, 0x0404030603, 0)       /* XOR mem, imm */           \
  TARGET_OP(XOR64zero, 0x01, 0, 0)                /* XOR reg = zero */         \
  TARGET_OP(AND64rr, 0x03, 0x0101, 0)             /* AND r64, r64 */           \
  TARGET_OP(AND64ri, 0x03, 0x0401, 0)             /* AND r64, imm */           \
  TARGET_OP(AND64rm, 0x03, 0x0403060301, 0)       /* AND r64, mem */           \
  TARGET_OP(AND64mr, 0x00, 0x0104030603, 0)       /* AND mem, r64 */           \
  TARGET_OP(AND64mi, 0x00, 0x0404030603, 0)       /* AND mem, imm */           \
  TARGET_OP(OR64rr, 0x03, 0x0101, 0)              /* OR r64, r64 */            \
  TARGET_OP(OR64ri, 0x03, 0x0401, 0)              /* OR r64, imm */            \
  TARGET_OP(OR64rm, 0x03, 0x0403060301, 0)        /* OR r64, mem */            \
  TARGET_OP(OR64mr, 0x00, 0x0104030603, 0)        /* OR mem, r64 */            \
  TARGET_OP(OR64mi, 0x00, 0x0404030603, 0)        /* OR mem, imm */            \
  TARGET_OP(CMP64rr, 0x00, 0x0101, 0)             /* CMP r64, r64 */           \
  TARGET_OP(CMP64ri, 0x00, 0x0401, 0)             /* CMP r64, imm */           \
  TARGET_OP(CMP64rm, 0x00, 0x0403060301, 0)       /* CMP r64, mem */           \
  TARGET_OP(CMP64mr, 0x00, 0x0104030603, 0)       /* CMP mem, r64 */           \
  TARGET_OP(CMP64mi, 0x00, 0x0404030603, 0)       /* CMP mem, imm */           \
  TARGET_OP(TEST64rr, 0x00, 0x0101, 0)            /* TEST r64, r64 */          \
  TARGET_OP(TEST64ri, 0x00, 0x0401, 0)            /* TEST r64, imm */          \
  TARGET_OP(TEST64rm, 0x00, 0x0403060301, 0)      /* TEST r64, mem */          \
  TARGET_OP(TEST64mr, 0x00, 0x0104030603, 0)      /* TEST mem, r64 */          \
  TARGET_OP(TEST64mi, 0x00, 0x0404030603, 0)      /* TEST mem, imm */          \
  TARGET_OP(IMUL64rr, 0x03, 0x0101, 0)            /* IMUL r64, r64 */          \
  TARGET_OP(IMUL64rri, 0x01, 0x0401, 0)           /* IMUL r64, r64, imm */     \
  TARGET_OP(IMUL64rm, 0x03, 0x0403060301, 0)      /* IMUL r64, mem */          \
  TARGET_OP(IMUL64rmi, 0x01, 0x0404030603, 0)     /* IMUL r64, mem, imm */     \
  TARGET_OP(SHL64ri, 0x03, 0x0501, 0)             /* SHL r64, imm */           \
  TARGET_OP(SHL64rCL, 0x03, 0x1201, 0)            /* SHL r64, cl */            \
  TARGET_OP(SHR64ri, 0x03, 0x0501, 0)             /* SHR r64, imm */           \
  TARGET_OP(SHR64rCL, 0x03, 0x1201, 0)            /* SHR r64, cl */            \
  TARGET_OP(SAR64ri, 0x03, 0x0501, 0)             /* SAR r64, imm */           \
  TARGET_OP(SAR64rCL, 0x03, 0x1201, 0)            /* SAR r64, cl */            \
  TARGET_OP(CQO, 0x22, 0x02, 0)                   /* CQO (rdx, rax) */         \
  TARGET_OP(IDIV64r_quot, 0x02, 0x220201, 0x0004) /* IDIV r64 (quotient) */    \
  TARGET_OP(IDIV64r_rem, 0x22, 0x220201, 0x0001)  /* IDIV r64 (remainder) */

#define TARGET_OP(opname, ...) opname,
enum class MOpc { TARGET_OPS };
#undef TARGET_OP

struct MOpcInfo {
  uint8_t ret;
  uint16_t clobbers;
  uint64_t args;
};
#define TARGET_OP(opname, ret, ops, clobbers) {ret, clobbers, ops},
static constexpr MOpcInfo OpInfos[] = {TARGET_OPS};
#undef TARGET_OP

struct MReg {
  uint32_t id;
  MReg() : id(-uint32_t{1}) {}
  explicit MReg(uint32_t id) : id(id) {}

  bool isValid() const { return id != -uint32_t{1}; }
  bool operator==(const MReg &other) const { return id == other.id; }
};
struct MOp {
  using RegOp = std::pair<MReg, bool>; // Register and kill flag.
  std::variant<uint64_t, RegOp, llvm::Function *> v;
  MOp(uint64_t v) : v(v) {}
  MOp(MReg v) : v(RegOp{v, false}) {}
  MOp(llvm::Function *v) : v(v) {}

  bool isVal() const { return std::holds_alternative<uint64_t>(v); }
  uint64_t getVal() const { return std::get<uint64_t>(v); }
  bool isReg() const { return std::holds_alternative<RegOp>(v); }
  MReg getReg() const { return std::get<RegOp>(v).first; }
  bool isValidReg() const { return isReg() && getReg().isValid(); }
  void setKill(bool kill = true) { std::get<RegOp>(v).second = kill; }
  bool isKill() const { return std::get<RegOp>(v).second; }

  void print(llvm::raw_ostream &os, bool skipKill = false) const;
};
void MOp::print(llvm::raw_ostream &os, bool skipKill) const {
  if (const auto *reg = std::get_if<RegOp>(&v))
    os << (!skipKill && reg->second ? " k" : " ") << int32_t(reg->first.id);
  else if (const auto *imm = std::get_if<uint64_t>(&v))
    os << " " << int64_t(*imm);
  else
    os << " " << std::get<llvm::Function *>(v)->getName();
}
struct MInst {
  MOpc opcode;                ///< Instruction opcode.
  MOp def;                    ///< Destination reg, or invalid.
  llvm::SmallVector<MOp> ops; ///< Operands

  void print(llvm::raw_ostream &os) const;
};

void MInst::print(llvm::raw_ostream &os) const {
#define TARGET_OP(opname, ...) #opname,
  static constexpr std::string_view names[] = {TARGET_OPS};
#undef TARGET_OP
  if (opcode == MOpc::PseudoPhi) {
    for (size_t i = 0; i < ops.size(); i += 2) {
      os << "PseudoPhi";
      def.print(os, /*skipKill=*/i + 2 < ops.size());
      ops[i].print(os);
      ops[i + 1].print(os);
      os << "\n";
    }
    return;
  }
  os << names[unsigned(opcode)];
  if (def.isValidReg())
    def.print(os);
  for (const auto &op : ops)
    op.print(os);
  os << "\n";
}
struct MBlock {
  llvm::SmallVector<MBlock *> preds;
  llvm::SmallVector<MBlock *> succs;
  llvm::SmallVector<MInst> insts;
  llvm::DenseSet<uint32_t> liveIn;
  llvm::DenseSet<uint32_t> liveOut;
  unsigned number;

  explicit MBlock(unsigned number) : number(number) {}

  auto successors() const {
    const MInst &term = insts.back();
    unsigned start, end;
    switch (term.opcode) {
    case MOpc::JMP: start = 0, end = 1; break;
    case MOpc::JCC: start = 1, end = 3; break;
    case MOpc::RET: start = 0, end = 0; break;
    default: assert(false && "invalid terminator");
    }
    return std::views::drop(term.ops, start) | std::views::take(end - start) |
           std::views::transform([](const MOp &op) { return op.getVal(); });
  }
};
struct MFunc {
  std::string name;
  llvm::SmallVector<std::unique_ptr<MBlock>> blocks;
  unsigned nextReg = 0;
  uint64_t frameSize = 0;

  MFunc(std::string_view name) : name(name) {}

  MReg allocReg() { return MReg{nextReg++}; }

  MBlock *addBlock() {
    blocks.emplace_back(std::make_unique<MBlock>(blocks.size()));
    return blocks.back().get();
  }

  static void printSet(llvm::raw_ostream &os, std::string_view name,
                       llvm::DenseSet<uint32_t> set) {
    if (!set.empty()) {
      os << name;
      for (uint32_t reg : set)
        os << " " << reg;
      os << '\n';
    }
  }
  void print(llvm::raw_ostream &os) const {
    os << name << ":\n";
    for (const auto &block : blocks) {
      printSet(os, "%livein", block->liveIn);
      for (const auto &inst : block->insts)
        inst.print(os);
      printSet(os, "%liveout", block->liveOut);
    }
  }
};

class ISel {
  llvm::Function *fn;
  MFunc *mfn;
  MBlock *mbb;

  llvm::DenseMap<llvm::BasicBlock *, MBlock *> blockMap;
  llvm::DenseMap<llvm::Value *, MReg> valueMap;
  llvm::SmallVector<std::pair<llvm::Value *, MReg>> consts;
  llvm::DenseMap<llvm::PHINode *, llvm::SmallVector<MOp>> phiMap;

  MReg fp;
  MReg sp;
  llvm::DenseMap<llvm::AllocaInst *, int> allocaMap;

public:
  ISel(llvm::Function *fn) : fn(fn) {}

  static uint64_t getCond(llvm::CmpInst::Predicate pred) {
    switch (pred) {
    case llvm::CmpInst::ICMP_EQ: return 0x4;
    case llvm::CmpInst::ICMP_NE: return 0x5;
    case llvm::CmpInst::ICMP_UGE: return 0x3;
    case llvm::CmpInst::ICMP_ULT: return 0x2;
    case llvm::CmpInst::ICMP_UGT: return 0x7;
    case llvm::CmpInst::ICMP_ULE: return 0x6;
    case llvm::CmpInst::ICMP_SGE: return 0xd;
    case llvm::CmpInst::ICMP_SLT: return 0xc;
    case llvm::CmpInst::ICMP_SGT: return 0xf;
    case llvm::CmpInst::ICMP_SLE: return 0xe;
    default: assert(false && "unsupported cmp predicate");
    }
  }

  static bool isImm32(uint64_t c) { return int64_t(c) == int32_t(c); }

  MReg getFP() { return fp.isValid() ? fp : (fp = mfn->allocReg()); }
  MReg getSP() { return sp.isValid() ? sp : (sp = mfn->allocReg()); }

  MReg intoReg(llvm::Value *v) {
    // Treat allocas as constant. This will cause the address to be recomputed
    // at every use that isn't handled in intoAddr.
    if (llvm::isa<llvm::Constant, llvm::AllocaInst>(v)) {
      MReg res = mfn->allocReg();
      consts.emplace_back(v, res);
      return res;
    }
    assert((llvm::isa<llvm::Argument, llvm::Instruction>(v)));
    // Look through no-op trunc, ptrtoint, and inttoptr.
    if (llvm::isa<llvm::TruncInst, llvm::PtrToIntInst, llvm::IntToPtrInst>(v))
      return intoReg(llvm::cast<llvm::Instruction>(v)->getOperand(0));
    auto [it, inserted] = valueMap.try_emplace(v, MReg{});
    if (inserted)
      it->second = mfn->allocReg();
    return it->second;
  }

  void intoAddr(llvm::Value *v, llvm::SmallVectorImpl<MOp> &ops) {
    if (auto *alloca = llvm::dyn_cast<llvm::AllocaInst>(v)) {
      auto [it, inserted] = allocaMap.try_emplace(alloca, 0);
      if (inserted)
        it->second = (mfn->frameSize -= 8);
      ops.append({getFP(), uint64_t{1}, MReg{}, uint64_t(it->second)});
    } else if (auto *gep = llvm::dyn_cast<llvm::GetElementPtrInst>(v)) {
      const llvm::DataLayout &DL = fn->getParent()->getDataLayout();
      uint64_t tysz = DL.getTypeAllocSize(gep->getSourceElementType());
      assert((tysz & (tysz - 1)) == 0 && tysz < 16 && "invalid gep type");
      llvm::Value *off = gep->getOperand(1);
      uint64_t disp = 0;
      if (auto *off_const = llvm::dyn_cast<llvm::ConstantInt>(off)) {
        if (auto off_val = off_const->getSExtValue() * tysz; isImm32(off_val)) {
          disp = off_val;
          off = nullptr;
        }
      }
      ops.append({intoReg(gep->getPointerOperand()), tysz,
                  off ? intoReg(off) : MReg{}, uint64_t(disp)});
    } else {
      ops.append({intoReg(v), uint64_t{1}, MReg{}, uint64_t{0}});
    }
  }

  void addInst(MOpc opcode, MReg dst, llvm::ArrayRef<MOp> ops) {
    mbb->insts.emplace_back(MInst{
        .opcode = opcode,
        .def = dst,
        .ops = {ops.begin(), ops.end()},
    });
  }

  void iselInst(llvm::Instruction &inst);

  void iselICmp(llvm::ICmpInst &icmp) {
    llvm::Use *ops = icmp.op_begin();
    MOp lhs = intoReg(ops[0]);
    if (auto *ci = llvm::dyn_cast<llvm::ConstantInt>(ops[1])) {
      if (int64_t val = ci->getSExtValue(); val == 0)
        addInst(MOpc::TEST64rr, {}, {lhs, lhs});
      else if (int32_t(val) == val)
        addInst(MOpc::CMP64ri, {}, {lhs, uint64_t(val)});
      else
        addInst(MOpc::CMP64rr, {}, {lhs, intoReg(ops[1])});
    } else {
      addInst(MOpc::CMP64rr, {}, {lhs, intoReg(ops[1])});
    }
  }

  MFunc run();
};

void ISel::iselInst(llvm::Instruction &inst) {
  using namespace llvm::PatternMatch;

  MReg dst = valueMap.lookup(&inst);
  if (!dst.isValid()) {
    if (llvm::wouldInstructionBeTriviallyDead(&inst, nullptr))
      return;
    if (!inst.getType()->isVoidTy())
      dst = mfn->allocReg();
  }
  auto ops = inst.op_begin();
  uint64_t c;
  llvm::Value *v;
  switch (inst.getOpcode()) {
  case llvm::Instruction::PHI:
    for (llvm::BasicBlock *pred : llvm::predecessors(inst.getParent())) {
      auto *br = llvm::cast<llvm::BranchInst>(pred->getTerminator());
      assert(br->isUnconditional() && "PHI with critical edge");
    }
    break;
  case llvm::Instruction::Alloca:
  case llvm::Instruction::PtrToInt:
  case llvm::Instruction::IntToPtr:
  case llvm::Instruction::Trunc:
    assert(false && "alloca/ptrtoint/inttoptr/trunc used in MIR??");
    break;
  case llvm::Instruction::Ret:
    if (auto *retval = llvm::cast<llvm::ReturnInst>(inst).getReturnValue();
        retval && !llvm::isa<llvm::UndefValue>(retval))
      addInst(MOpc::RET, {}, {intoReg(retval)});
    else
      addInst(MOpc::RET, {}, {});
    break;
  case llvm::Instruction::Add:
    if (match(&inst, m_c_Add(m_Value(v), m_ConstantInt(c))) && isImm32(c))
      addInst(MOpc::ADD64ri, dst, {intoReg(v), c});
    else
      addInst(MOpc::ADD64rr, dst, {intoReg(ops[0]), intoReg(ops[1])});
    break;
  case llvm::Instruction::Sub:
    if (match(ops[0].get(), m_Zero()))
      addInst(MOpc::NEG64r, dst, {intoReg(ops[1])});
    else if (match(ops[1].get(), m_ConstantInt(c)) && isImm32(c))
      addInst(MOpc::SUB64ri, dst, {intoReg(ops[0]), c});
    else
      addInst(MOpc::SUB64rr, dst, {intoReg(ops[0]), intoReg(ops[1])});
    break;
  case llvm::Instruction::Mul:
    if (match(ops[1].get(), m_ConstantInt(c)) && isImm32(c))
      addInst(MOpc::IMUL64rri, dst, {intoReg(ops[0]), c});
    else
      addInst(MOpc::IMUL64rr, dst, {intoReg(ops[0]), intoReg(ops[1])});
    break;
  case llvm::Instruction::Xor:
    if (match(&inst, m_c_Xor(m_Value(v), m_AllOnes())))
      addInst(MOpc::NOT64r, dst, {intoReg(v)});
    else if (match(&inst, m_c_Xor(m_Value(v), m_ConstantInt(c))) && isImm32(c))
      addInst(MOpc::XOR64ri, dst, {intoReg(v), c});
    else
      addInst(MOpc::XOR64rr, dst, {intoReg(ops[0]), intoReg(ops[1])});
    break;
  case llvm::Instruction::And:
    if (match(&inst, m_c_And(m_Value(v), m_ConstantInt(c))) && isImm32(c))
      addInst(MOpc::AND64ri, dst, {intoReg(v), c});
    else
      addInst(MOpc::AND64rr, dst, {intoReg(ops[0]), intoReg(ops[1])});
    break;
  case llvm::Instruction::Or:
    if (match(&inst, m_c_Or(m_Value(v), m_ConstantInt(c))) && isImm32(c))
      addInst(MOpc::OR64ri, dst, {intoReg(v), c});
    else
      addInst(MOpc::OR64rr, dst, {intoReg(ops[0]), intoReg(ops[1])});
    break;
  case llvm::Instruction::Shl:
    if (match(ops[1].get(), m_ConstantInt(c)))
      addInst(MOpc::SHL64ri, dst, {intoReg(ops[0]), c});
    else
      addInst(MOpc::SHL64rCL, dst, {intoReg(ops[0]), intoReg(ops[1])});
    break;
  case llvm::Instruction::AShr:
    if (match(ops[1].get(), m_ConstantInt(c)))
      addInst(MOpc::SAR64ri, dst, {intoReg(ops[0]), c});
    else
      addInst(MOpc::SAR64rCL, dst, {intoReg(ops[0]), intoReg(ops[1])});
    break;
  case llvm::Instruction::ZExt:
    if (match(ops[0].get(), m_ICmp(m_Value(), m_Value()))) {
      // icmp is setcc, which always sets the full byte.
      addInst(MOpc::MOVZXB32rr, dst, {intoReg(ops[0])});
      break;
    }
    switch (ops[0]->getType()->getIntegerBitWidth()) {
    case 1: addInst(MOpc::AND64ri, dst, {intoReg(ops[0]), uint64_t{1}}); break;
    case 8: addInst(MOpc::MOVZXB32rr, dst, {intoReg(ops[0])}); break;
    case 16: addInst(MOpc::MOVZXW32rr, dst, {intoReg(ops[0])}); break;
    case 32: addInst(MOpc::MOV32rr, dst, {intoReg(ops[0])}); break;
    default: assert(0 && "unhandled ZExt");
    }
    break;
  case llvm::Instruction::SExt:
    // Only fuse adjacent loads.
    if (auto *load = llvm::dyn_cast<llvm::LoadInst>(ops[0]);
        load && load->getNextNode() == &inst) {
      llvm::SmallVector<MOp, 4> mops;
      intoAddr(load->getOperand(0), mops);
      switch (load->getType()->getIntegerBitWidth()) {
      case 32: addInst(MOpc::MOVSXD64rm, dst, mops); break;
      case 16: addInst(MOpc::MOVSXW64rm, dst, mops); break;
      case 8: addInst(MOpc::MOVSXB64rm, dst, mops); break;
      default: assert(false);
      }
      break;
    }
    switch (ops[0]->getType()->getIntegerBitWidth()) {
    case 8: addInst(MOpc::MOVSXB64rr, dst, {intoReg(ops[0])}); break;
    case 16: addInst(MOpc::MOVSXW64rr, dst, {intoReg(ops[0])}); break;
    case 32: addInst(MOpc::MOVSXD64rr, dst, {intoReg(ops[0])}); break;
    default: assert(0 && "unhandled SExt");
    }
    break;
  case llvm::Instruction::ICmp: {
    auto &cmp = llvm::cast<llvm::ICmpInst>(inst);
    addInst(MOpc::SETCC8r, dst, {getCond(cmp.getPredicate())});
    iselICmp(cmp);
    break;
  }
  case llvm::Instruction::Call: {
    auto *call = llvm::cast<llvm::CallInst>(&inst);
    llvm::SmallVector<MOp> mops;
    llvm::SmallVector<std::pair<unsigned, llvm::Value *>> stackArgs;
    mops.push_back(call->getCalledFunction());
    for (const auto &[i, arg] : llvm::enumerate(call->args())) {
      if (i >= 6)
        stackArgs.emplace_back(8 * (i - 6), arg);
      else
        mops.emplace_back(intoReg(arg));
    }
    addInst(MOpc::CALL, dst, mops);
    for (auto [off, v] : llvm::reverse(stackArgs)) {
      if (match(v, m_ConstantInt(c)) && isImm32(c))
        addInst(MOpc::MOV64mi, {}, {getSP(), 1, MReg{}, off, c});
      else
        addInst(MOpc::MOV64mr, {}, {getSP(), 1, MReg{}, off, intoReg(v)});
    }
    break;
  }
  case llvm::Instruction::Load: {
    llvm::SmallVector<MOp, 4> mops;
    intoAddr(ops[0], mops);
    switch (inst.getType()->getIntegerBitWidth()) {
    case 64: addInst(MOpc::MOV64rm, dst, mops); break;
    case 32: addInst(MOpc::MOV32rm, dst, mops); break;
    case 16: addInst(MOpc::MOVZXW32rm, dst, mops); break;
    case 8: addInst(MOpc::MOVZXB32rm, dst, mops); break;
    default: assert(false);
    }
    break;
  }
  case llvm::Instruction::Store: {
    llvm::SmallVector<MOp, 5> mops;
    intoAddr(ops[1], mops);
    if (const auto *ci = llvm::dyn_cast<llvm::ConstantInt>(&*ops[0]);
        ci && int32_t(ci->getSExtValue()) == ci->getSExtValue()) {
      mops.push_back(ci->getSExtValue());
      switch (ops[0]->getType()->getIntegerBitWidth()) {
      case 64: addInst(MOpc::MOV64mi, {}, mops); break;
      case 32: addInst(MOpc::MOV32mi, {}, mops); break;
      case 16: addInst(MOpc::MOV16mi, {}, mops); break;
      case 8: addInst(MOpc::MOV8mi, {}, mops); break;
      default: assert(false);
      }
    } else {
      mops.push_back(intoReg(ops[0]));
      switch (ops[0]->getType()->getIntegerBitWidth()) {
      case 64: addInst(MOpc::MOV64mr, {}, mops); break;
      case 32: addInst(MOpc::MOV32mr, {}, mops); break;
      case 16: addInst(MOpc::MOV16mr, {}, mops); break;
      case 8: addInst(MOpc::MOV8mr, {}, mops); break;
      default: assert(false);
      }
    }
    break;
  }
  case llvm::Instruction::GetElementPtr: {
    llvm::SmallVector<MOp, 4> mops;
    intoAddr(&inst, mops);
    addInst(MOpc::LEA64rm, dst, mops);
    break;
  }
  case llvm::Instruction::Br: {
    auto *br = llvm::dyn_cast<llvm::BranchInst>(&inst);
    if (!br->isConditional()) {
      llvm::BasicBlock *succ = br->getSuccessor(0);
      addInst(MOpc::JMP, {}, {uint64_t{blockMap.lookup(succ)->number}});
      llvm::BasicBlock *bb = inst.getParent();
      for (auto &phi : succ->phis()) {
        auto &phiOps = phiMap[&phi];
        phiOps.push_back(uint64_t{mbb->number});
        phiOps.push_back(intoReg(phi.getIncomingValueForBlock(bb)));
      }
    } else {
      llvm::Value *cond = br->getCondition();
      uint64_t thenBB = blockMap.lookup(br->getSuccessor(0))->number;
      uint64_t elseBB = blockMap.lookup(br->getSuccessor(1))->number;
      if (auto *cmpInst = llvm::dyn_cast<llvm::ICmpInst>(cond)) {
        addInst(MOpc::JCC, {},
                {getCond(cmpInst->getPredicate()), thenBB, elseBB});
        iselICmp(*cmpInst);
      } else {
        addInst(MOpc::JCC, {}, {5, thenBB, elseBB});
        addInst(MOpc::TEST64ri, {}, {intoReg(br->getCondition()), 1});
      }
    }
    break;
  }
  default: llvm::dbgs() << "unhandled: " << inst << "\n"; assert(0);
  }
}

MFunc ISel::run() {
  // We need to this later for regalloc anyway, so we can also do it now on
  // LLVM-IR and save the effort of implementing this on the low-level IR.
  llvm::SplitAllCriticalEdges(*fn);

  MFunc res(fn->getName());
  mfn = &res;

  // Initialize basic blocks.
  for (auto &bb : *fn)
    blockMap[&bb] = res.addBlock();

  // Traverse blocks in post order (bottom up) and instructions inside blocks in
  // reverse order so that uses are seen before the defs. Phis are visited with
  // successors. All used values are inserted into valueMap with their assigned
  // virtual register. If an instruction has no side effects and has no entry
  // in valueMap, we know that it is dead (either because it was dead before or
  // because it was fused into other instructions).
  //
  // We therefore generate code *in reverse order* here. Instructions are
  // reversed at the end below.
  for (auto &bb : llvm::post_order(fn)) {
    mbb = blockMap[bb];
    for (auto &inst : llvm::reverse(*bb)) {
      iselInst(inst);

      // For now, place constants immediately after an instruction.
      for (auto &[v, reg] : consts) {
        if (auto *ci = llvm::dyn_cast<llvm::ConstantInt>(v)) {
          if (uint64_t c = ci->getZExtValue(); c == 0)
            addInst(MOpc::XOR64zero, reg, {});
          else
            addInst(MOpc::MOV64ri, reg, {ci->getZExtValue()});
        } else if (llvm::isa<llvm::AllocaInst>(v)) {
          llvm::SmallVector<MOp, 4> ops;
          intoAddr(v, ops);
          addInst(MOpc::LEA64rm, reg, ops);
        } else {
          llvm::dbgs() << "unhandled const: " << *v << "\n";
          assert(0);
        }
      }
      consts.clear();
    }
  }

  // Add (used) Phi instructions once all basic blocks have been visited.
  for (auto &bb : *fn) {
    mbb = blockMap[&bb];
    for (auto &phi : bb.phis()) {
      MReg dst = valueMap.lookup(&phi);
      if (!dst.isValid())
        continue;
      addInst(MOpc::PseudoPhi, dst, std::move(phiMap[&phi]));
    }
  }

  // Finish entry block: add instructions to get arguments, SP, and FP.
  mbb = res.blocks[0].get();
  for (auto &arg : llvm::reverse(fn->args())) {
    MReg dst = valueMap.lookup(&arg);
    if (dst.isValid()) {
      if (arg.getArgNo() >= 6) {
        addInst(MOpc::MOV64rm, dst,
                {getFP(), 1, MReg{}, 16 + (arg.getArgNo() - 6) * 8});
      } else {
        addInst(MOpc::PseudoRegArg, dst, {arg.getArgNo()});
      }
    }
  }
  (void)getFP(); // We always need a FP for later regalloc.
  if (fp.isValid())
    addInst(MOpc::PseudoFP, fp, {});
  if (sp.isValid())
    addInst(MOpc::PseudoSP, sp, {});

  // Reverse instructions into correct order.
  for (auto &mbb : res.blocks)
    std::reverse(mbb->insts.begin(), mbb->insts.end());
  return res;
}

void computeLiveness(MFunc &mfn) {
  bool changed = true;
  while (changed) {
    changed = false;
    llvm::BitVector visitedBlocks(mfn.blocks.size());
    llvm::SmallVector<std::pair<uint32_t, bool>> stack;
    stack.emplace_back(0, false);
    while (!stack.empty()) {
      auto [blockNum, visit] = stack.pop_back_val();
      if (!visit) {
        if (visitedBlocks[blockNum])
          continue;
        visitedBlocks[blockNum] = true;
        stack.emplace_back(blockNum, true);
        for (uint32_t succ : mfn.blocks[blockNum]->successors())
          stack.emplace_back(succ, false);
        continue;
      }

      MBlock &block = *mfn.blocks[blockNum];
      block.liveOut.clear();
      for (uint32_t succNum : block.successors()) {
        MBlock &succ = *mfn.blocks[succNum];
        block.liveOut.insert(succ.liveIn.begin(), succ.liveIn.end());
        for (MInst &phi : succ.insts) {
          if (phi.opcode != MOpc::PseudoPhi)
            break;
          block.liveOut.erase(phi.def.getReg().id);
          for (unsigned i = 0; i < phi.ops.size(); i += 2) {
            if (phi.ops[i].getVal() == blockNum) {
              block.liveOut.insert(phi.ops[i + 1].getReg().id);
              break;
            }
          }
        }
      }
      size_t oldLiveInSize = block.liveIn.size();
      block.liveIn = block.liveOut;
      for (MInst &inst : llvm::reverse(block.insts)) {
        if (inst.opcode == MOpc::PseudoPhi) {
          inst.def.setKill(!block.liveIn.count(inst.def.getReg().id));
          continue;
        }
        for (MOp &op : llvm::reverse(inst.ops)) {
          if (op.isValidReg()) {
            auto [_, inserted] = block.liveIn.insert(op.getReg().id);
            op.setKill(inserted);
          }
        }
        if (inst.def.isValidReg())
          inst.def.setKill(!block.liveIn.erase(inst.def.getReg().id));
      }
      if (block.liveIn.size() != oldLiveInSize)
        changed = true;
    }
  }
}

void regAlloc(MFunc &mfn) {
  struct VarState {
    uint64_t stackSlot; ///< RBP-relative, 0 implies none.
    MReg reg;           ///< Current physical register.
  };
  llvm::DenseMap<uint32_t, VarState> vars;
  struct RegState {
    MReg value[16]; ///< Map from phys reg to virtual reg.
  };

  // Compute post-order for RPO traversal.
  llvm::SmallVector<uint32_t> pot;
  {
    pot.reserve(mfn.blocks.size());
    llvm::BitVector visitedBlocks(mfn.blocks.size());
    llvm::SmallVector<std::pair<uint32_t, bool>> stack;
    stack.emplace_back(0, false);
    while (!stack.empty()) {
      auto [blockNum, visit] = stack.pop_back_val();
      if (!visit) {
        if (visitedBlocks[blockNum])
          continue;
        visitedBlocks[blockNum] = true;
        stack.emplace_back(blockNum, true);
        for (uint32_t succ : mfn.blocks[blockNum]->successors())
          stack.emplace_back(succ, false);
        continue;
      }
      pot.push_back(blockNum);
    }
  }

  for (uint32_t blockNum : llvm::reverse(pot)) {
    llvm::SmallVector<MInst> newInsts;
    RegState regState;
    unsigned nextReg = 0;

    auto evictPhysReg = [&](MReg physReg) {
      MReg regVal = regState.value[physReg.id];
      if (!regVal.isValid())
        return; // Register unused.
      if (!vars[regVal.id].stackSlot) {
        uint64_t stackSlot = vars[regVal.id].stackSlot = (mfn.frameSize -= 8);
        newInsts.emplace_back(MInst{
            .opcode = MOpc::MOV64mr,
            .def = MReg(),
            .ops = {MOp{5}, uint64_t{1}, MReg{}, stackSlot, physReg},
        });
      }
      vars[regVal.id].reg = MReg{};
      regState.value[physReg.id] = MReg{};
    };

    for (MInst &inst : mfn.blocks[blockNum]->insts) {
      if (inst.opcode == MOpc::PseudoPhi) {
        newInsts.push_back(inst); // keep for now, discard later
        continue;
      }
      uint16_t free = 0xffcf; // Bitmask of usable registers.
      const MOpcInfo &info = OpInfos[unsigned(inst.opcode)];
      for (unsigned i = inst.ops.size(); i-- > 0;) {
        MOp &op = inst.ops[i];
        unsigned argInfo = info.args >> (8 * i) & 0xff;
        if (!op.isValidReg())
          continue;
        MReg dstReg = (argInfo & 0xf) == 2 ? MReg{argInfo >> 4} : MReg{};
        VarState &varInfo = vars[op.getReg().id];
        MReg curPhysReg = varInfo.reg;
        if (!dstReg.isValid())
          dstReg =
              curPhysReg.isValid() ? curPhysReg : MReg(llvm::countr_zero(free));
        // If dstReg is occupied, spill.
        if (curPhysReg.id != dstReg.id)
          evictPhysReg(dstReg);

        if (!curPhysReg.isValid()) {
          newInsts.emplace_back(MInst{
              .opcode = MOpc::MOV64rm,
              .def = dstReg,
              .ops = {MOp{5}, uint64_t{1}, MReg{}, varInfo.stackSlot},
          });
        } else if (dstReg.id != curPhysReg.id) {
          newInsts.emplace_back(MInst{
              .opcode = MOpc::MOV64rr,
              .def = dstReg,
              .ops = {curPhysReg},
          });
        }
        varInfo.reg = dstReg;
        regState.value[dstReg.id] = op.getReg();

        free &= ~uint16_t(1 << dstReg.id);
        op = dstReg;
      }
      // Evict clobbers
      for (unsigned i = 0, msk = info.clobbers; msk;) {
        i += llvm::countr_zero(msk) + 1;
        evictPhysReg(MReg(i - 1));
        msk >>= llvm::countr_zero(msk) + 1;
      }
      MReg defReg;
      if (inst.def.isValidReg()) {
        if ((info.ret & 0xf) == 2)
          defReg = MReg(info.ret >> 4);
        else if (info.ret == 3)
          defReg = inst.ops[0].getReg();
        else if (inst.opcode == MOpc::PseudoRegArg) {
          static constexpr uint8_t argRegs[] = {7, 6, 2, 1, 8, 9};
          defReg = MReg{argRegs[inst.ops[0].getVal()]};
        } else {
          defReg = MReg{nextReg}; // really stupid...
          nextReg += 1;
          if (nextReg == 4)
            nextReg += 2; // skip RSP/RBP.
        }
        evictPhysReg(defReg);
        regState.value[defReg.id] = inst.def.getReg();
        vars[inst.def.getReg().id].reg = defReg;
        inst.def = defReg;
      } else if (info.ret == 4) {
        // Terminator.
        // Across basic blocks, all values live in memory.
        for (auto [idx, reg] : llvm::enumerate(regState.value))
          evictPhysReg(MReg(idx));
        // Move to phis. We only store phis in memory.
        for (uint32_t succNum : mfn.blocks[blockNum]->successors()) {
          struct Phi {
            MReg phi;
            MReg val;
            unsigned readers = 0;
          };
          llvm::SmallVector<Phi> phis;
          llvm::DenseMap<uint32_t, size_t> phiIdxMap;
          for (MInst &phi : mfn.blocks[succNum]->insts) {
            if (phi.opcode != MOpc::PseudoPhi)
              break;
            for (unsigned i = 0; i < phi.ops.size(); i += 2) {
              if (phi.ops[i].getVal() == blockNum) {
                phiIdxMap[phi.def.getReg().id] = phis.size();
                phis.push_back(Phi{phi.def.getReg(), phi.ops[i + 1].getReg()});
                break;
              }
            }
          }
          for (Phi &phi : phis) {
            // Ensure that every phi has a stack slot.
            if (!vars[phi.phi.id].stackSlot)
              vars[phi.phi.id].stackSlot = (mfn.frameSize -= 8);
            auto it = phiIdxMap.find(phi.val.id);
            if (it != phiIdxMap.end())
              phis[it->second].readers += 1;
          }
          llvm::SmallVector<Phi *> queue;
          for (Phi &phi : phis)
            if (!phi.readers) {
              phiIdxMap.erase(phi.phi.id);
              queue.push_back(&phi);
            }
          unsigned assigned = 0;
          MReg curTempVal;
          while (assigned != phis.size()) {
            if (!queue.empty()) {
              Phi *phi = queue.pop_back_val();
              if (phi->val.id == curTempVal.id) {
                // Move from tmp
                newInsts.emplace_back(MInst{
                    .opcode = MOpc::MOV64mr,
                    .def = MReg{},
                    .ops = {MOp{5}, uint64_t{1}, MReg{},
                            vars[phi->phi.id].stackSlot, MReg{15}},
                });
                curTempVal = MReg{};
                phiIdxMap.erase(phi->val.id);
                assigned += 1;
                continue;
              } else {
                MReg curPhysReg = vars[phi->val.id].reg;
                if (!curPhysReg.isValid()) {
                  curPhysReg = MReg{14};
                  newInsts.emplace_back(MInst{
                      .opcode = MOpc::MOV64rm,
                      .def = curPhysReg,
                      .ops = {MOp{5}, uint64_t{1}, MReg{},
                              vars[phi->val.id].stackSlot},
                  });
                }
                newInsts.emplace_back(MInst{
                    .opcode = MOpc::MOV64mr,
                    .def = MReg{},
                    .ops = {MOp{5}, uint64_t{1}, MReg{},
                            vars[phi->phi.id].stackSlot, curPhysReg},
                });
              }
              assigned += 1;
              auto it = phiIdxMap.find(phi->val.id);
              if (it != phiIdxMap.end()) {
                if (!--phis[it->second].readers) {
                  queue.push_back(&phis[it->second]);
                  phiIdxMap.erase(it);
                }
              }
            } else {
              // Only cycles remain, pick an arbitrary phi.
              auto it = phiIdxMap.begin();
              Phi *phi = &phis[it->second];
              assert(phi->readers == 1);
              phi->readers = 0;
              curTempVal = phi->phi;

              MReg curPhysReg = vars[phi->phi.id].reg;
              if (!curPhysReg.isValid()) {
                newInsts.emplace_back(MInst{
                    .opcode = MOpc::MOV64rm,
                    .def = MReg{15},
                    .ops = {MOp{5}, uint64_t{1}, MReg{},
                            vars[phi->phi.id].stackSlot},
                });
              } else {
                newInsts.emplace_back(MInst{
                    .opcode = MOpc::MOV64rr,
                    .def = MReg{15},
                    .ops = {curPhysReg},
                });
              }

              queue.push_back(phi);
            }
          }
        }
      }
      newInsts.push_back(inst);
    }
    mfn.blocks[blockNum]->insts = newInsts;
  }
  for (auto &block : mfn.blocks) {
    unsigned count = 0;
    for (MInst &phi : block->insts) {
      if (phi.opcode != MOpc::PseudoPhi)
        break;
      ++count;
    }
    if (count)
      block->insts.erase(block->insts.begin(), block->insts.begin() + count);
  }
}

} // end anonymous namespace

int main(int argc, char **argv) {
  enum class Mode {
    CHECK,
    AST,
    LLVM,
    ISEL,
    LIVENESS,
    REGALLOC,
  } mode = Mode::AST;
  bool printTimes = false;

  int c;
  while ((c = getopt(argc, argv, "acliLrSt")) != -1) {
    switch (c) {
    case 'c': mode = Mode::CHECK; break;
    case 'a': mode = Mode::AST; break;
    case 'l': mode = Mode::LLVM; break;
    case 'i': mode = Mode::ISEL; break;
    case 'L': mode = Mode::LIVENESS; break;
    case 'r': mode = Mode::REGALLOC; break;
    case 't': printTimes = true; break;
    default:
      fprintf(stderr, "usage: %s <bfprog>\n", argv[0]);
      return EXIT_FAILURE;
    }
  }

  if (optind >= argc)
    return 1;
  auto fc = readFile(argv[optind]);
  auto time_argparse_end = std::chrono::steady_clock::now();

  auto p = Parser{fc}.parseProgram();
  auto time_parse_end = std::chrono::steady_clock::now();
  if (printTimes)
    std::cerr << "parsing: "
              << std::chrono::duration_cast<std::chrono::milliseconds>(
                     time_parse_end - time_argparse_end)
                     .count()
              << "ms\n";

  if (!p)
    return 1;
  if (mode == Mode::CHECK)
    return 0;
  if (mode == Mode::AST) {
    for (const Function &fn : p->funcs) {
      if (fn.ast) {
        fn.ast->printSexpr();
        std::cout << "\n";
      }
    }
    return 0;
  }

  llvm::LLVMContext ctx;
  auto mod = LLVMIRGen::genIR(ctx, *p);
  auto time_irgen_end = std::chrono::steady_clock::now();
  if (printTimes)
    std::cerr << "irgen: "
              << std::chrono::duration_cast<std::chrono::milliseconds>(
                     time_irgen_end - time_parse_end)
                     .count()
              << "ms\n";

  if (mode == Mode::LLVM) {
    if (llvm::verifyModule(*mod, &llvm::errs()))
      return 1;
    mod->print(llvm::outs(), nullptr);
    return 0;
  }

  llvm::SmallVector<MFunc> mfns;
  for (llvm::Function &fn : *mod) {
    if (fn.empty())
      continue;
    mfns.push_back(ISel(&fn).run());
  }
  auto time_isel_end = std::chrono::steady_clock::now();
  if (printTimes)
    std::cerr << "isel: "
              << std::chrono::duration_cast<std::chrono::milliseconds>(
                     time_isel_end - time_irgen_end)
                     .count()
              << "ms\n";

  if (mode == Mode::ISEL) {
    for (MFunc &mfn : mfns)
      mfn.print(llvm::outs());
    return 0;
  }
  if (mode == Mode::LIVENESS) {
    for (MFunc &mfn : mfns)
      computeLiveness(mfn);
    for (MFunc &mfn : mfns)
      mfn.print(llvm::outs());
    return 0;
  }
  if (mode == Mode::REGALLOC) {
    for (MFunc &mfn : mfns)
      regAlloc(mfn);
    for (MFunc &mfn : mfns)
      mfn.print(llvm::outs());
    return 0;
  }
  return 0;
}
