;; copy and patch style native code generation
;;
;; compile some C code, use relocation records to slice and dice it
;; up.

(require '[clojure.java.io :as io]
         '[clojure.java.shell :as sh])

(import '(java.util Scanner
                    HashMap
                    List)
        '(java.io ByteArrayOutputStream)
        '(java.nio ByteBuffer
                   ByteOrder)
        '(java.lang.foreign Arena
                            Linker
                            FunctionDescriptor
                            ValueLayout
                            MemoryLayout
                            Linker$Option))

(def c-code
  "
#include <stdint.h>

__regcall
void continuation0(int64_t a, int64_t b, int64_t c, int64_t d);

__regcall
void continuation1(int64_t a, int64_t b, int64_t c, int64_t d);

__regcall
void evaluate_add(int64_t a, int64_t b, int64_t c, int64_t d) {
  a = a + b;
  [[clang::musttail]] return continuation0(a,b,c,d);
}

__regcall
void  do_escape3(int64_t a, int64_t b, int64_t c, int64_t d) {
  int64_t * e = (int64_t *)d;
  *e = a;
  [[clang::musttail]] return continuation0(a,b,c,0);
}

void entry(int64_t a, int64_t b, int64_t c, int64_t d) {
  return continuation0(a,b,c,d);
}

")


(spit (io/file "/tmp/stencil.c") c-code)


(sh/sh "clang" "-c" "/tmp/stencil.c" "-fPIC" "-o" "/tmp/stencil.o" "-O2")

(def o (:out (sh/sh "objdump" "-x" "/tmp/stencil.o")))

(def s (Scanner. (:out (sh/sh "objdump" "-x" "/tmp/stencil.o"))))

(defn scan-until [s txt]
  (while (not= txt (.next s))))

(scan-until s ".text")

(.nextLong s 16)
(.nextLong s 16)
(.nextLong s 16)

(def text-offset (.nextLong s 16))

(scan-until s "SYMBOL")

(.next s)

(def symbol-definitions (HashMap.))

(loop []
  (when (.hasNextLong s 16)
    (let [location (.nextLong s 16)
          txt (.next s)]
      (if (.startsWith txt "*")
        (do
          (.next s)
          (let [sym (.next s)]
            (.put symbol-definitions sym [location nil])
            (recur)))
        (let [flag1 txt
              flag2 (.next s)
              flag3 (.next s)
              maybe-size (.nextLong s 16)]
          (let [sym (.next s)]
            (.put symbol-definitions (.replaceAll sym "__regcall3__" "") [location (when (= flag2 "F") maybe-size)])
            (recur)))))))

(while (not= "RELOCATION RECORDS FOR [.text]:" (.nextLine s)))

(.nextLine s)

(def relocation-locations (HashMap.))

(while (.hasNextLong s 16)
  (let [location (.nextLong s 16)
        _ (.next s)
        sym (.next s)]
    (.put relocation-locations location sym)))


(def stencils
  (transduce
   (comp (map (fn [[sym [location size]]] {:sym sym :location location :size size}))
         (filter :size)
         (mapcat (fn [m] (for [[loc sym] relocation-locations] (assoc m :rloc loc :rsym sym))))
         (filter (fn [m]
                   (and (> (:rloc m) (:location m))
                        (> (+ (:location m) (:size m)) (:rloc m)))))
         (map #(update-in % [:rsym] (comp last (partial re-find #"(__regcall3__)?(.*)-0x\p{XDigit}+")))))
   (completing
    (fn [m data]
      (-> m
          (assoc-in [(:sym data) :location] (+ text-offset (:location data)))
          (assoc-in [(:sym data) :size] (:size data))
          (assoc-in [(:sym data) (:rsym data)] (- (:rloc data) (:location data)))
          (update-in [(:sym data) :holes] (fnil conj #{}) (:rsym data)))))
   {}
   symbol-definitions))

(def bin (ByteBuffer/wrap
          (with-open [in (io/input-stream (io/file "/tmp/stencil.o"))
                      o (ByteArrayOutputStream.)]
            (io/copy in o)
            (.toByteArray o))))


(defn patch [stencil-name args]
  (let [s (get stencils stencil-name)
        _ (assert s stencil-name)
        _ (assert (every? args (remove #{"continuation0"} (:holes s)))
                  (pr-str (:holes s)))
        sb (ByteBuffer/allocate
            (- (get s "continuation0") 2))
        _ (.order sb ByteOrder/LITTLE_ENDIAN)
        _ (.put sb (-> bin
                       (.duplicate)
                       (.position (:location s))
                       (.slice)
                       (.limit (- (get s "continuation0") 2))
                       (.slice)))]
    (doseq [[k v] args]
      (if (instance? Long v)
        (.putLong sb (get s k) v)
        (.putInt sb (get s k) (unchecked-int v))))
    (.flip sb)))

(defn downcallhandle [linker symbol-name return args]
  (.downcallHandle
   linker
   (-> linker
       (.defaultLookup)
       (.find (name symbol-name))
       (.orElseThrow))
   (if return
     (FunctionDescriptor/of
      return
      (into-array MemoryLayout args))
     (FunctionDescriptor/ofVoid
      (into-array MemoryLayout args)))
   (into-array Linker$Option [])))

(with-open [a (Arena/ofConfined)] ; arena to allocate in
  (let [linker (Linker/nativeLinker)
        ;; mprotect so we can execute code
        mprotect (downcallhandle
                  linker
                  :mprotect
                  ValueLayout/JAVA_INT
                  [ValueLayout/ADDRESS
                   ValueLayout/JAVA_LONG
                   ValueLayout/JAVA_INT])
        ;; sysconf to lookup the page size
        sysconf (downcallhandle
                 linker
                 :sysconf
                 ValueLayout/JAVA_LONG
                 [ValueLayout/JAVA_INT])
        ;; some magic constants, ffi doesn't help with C preprocessor
        ;; stuff
        _SC_PAGESIZE (unchecked-byte 30)
        PROT_EXEC 0x4
        PROT_READ 0x1
        PROT_WRITE 0x2
        page-size (.invokeWithArguments sysconf [_SC_PAGESIZE])
        ;; some memory for the return value to get stored in
        p0 (.allocate a 8 1)
        code-size 128
        ;; memory for the generated code
        ms (.allocate a code-size page-size)
        ob (.asByteBuffer ms)
        sb (patch "entry" {})
        _ (.put ob sb)
        sb (patch "evaluate_add" {})
        _ (.put ob sb)
        sb (patch "do_escape3" {})
        _ (.put ob sb)
        _ (.put ob (unchecked-byte 0xc3)) ;; 0xc3 is asm ret
        _ (while (pos? (rem (.position ob) 32))
            (.put ob (unchecked-byte 0x90))) ;; 0x90 is asm nop
        _ (.flip ob)
        mh (.downcallHandle
            linker
            ms
            (FunctionDescriptor/ofVoid
             (into-array MemoryLayout
                         [ValueLayout/JAVA_LONG
                          ValueLayout/JAVA_LONG
                          ValueLayout/JAVA_LONG
                          ValueLayout/ADDRESS]))
            (into-array Linker$Option []))]
    (when-not (zero? (.invokeWithArguments
                     mprotect
                     [ms
                      (long code-size)
                      (unchecked-int (bit-or PROT_READ PROT_WRITE PROT_EXEC))]))
      
      (.get
       (.reinterpret
        (-> linker
            (.defaultLookup)
            (.find "errno")
            (.orElseThrow))
        4)
       ValueLayout/JAVA_INT
       0))
    (.invokeWithArguments mh [1 2 0 p0])
    (prn (.get p0 ValueLayout/JAVA_LONG 0))
    (System/exit 0)))


;; % clj -J--enable-native-access=ALL-UNNAMED -M src/copy-and-patch.clj
;; 3
;; % 

Generated At 2024-03-29T15:38:36-0700 original