(def u16 {:encode (fn [dos current-byte byte-pos value]
                    (if (= byte-pos 0)
                      (do
                        (.writeShort dos (unchecked-short value))
                        [0x0 0])
                      (assert nil)))
          :decode (fn [dis current-byte byte-pos]
                    (if (= byte-pos -1)
                      [0x0 -1 (bit-and (long (.readShort dis)) 0xffff)]
                      (assert nil)))})

(def u32 {:encode (fn [dos current-byte byte-pos value]
                    (if (= byte-pos 0)
                      (do
                        (.writeInt dos (unchecked-int value))
                        [0x0 0])
                      (assert nil)))
          :decode (fn [dis current-byte byte-pos]
                    (if (= byte-pos -1)
                      [0x0 -1 (bit-and (long (.readInt dis)) 0xffffffff)]
                      (assert nil)))})

(def u8 {:encode (fn [dos current-byte byte-pos value]
                   (if (= byte-pos 0)
                     (do
                       (.writeByte dos (unchecked-byte value))
                       [0x0 0])
                     (assert nil)))
         :decode (fn [dis current-byte byte-pos]
                   (if (= byte-pos -1)
                     [0x0 byte-pos (bit-and (long (.readByte dis)) 0xff)]
                     (assert nil)))})

(def bool1 {:encode (fn [dos current-byte byte-pos value]
                      (let [n (if value 0x1 0x0)]
                        (if (= byte-pos 0)
                          [n 1]
                          [(bit-or current-byte (bit-shift-left n byte-pos)) (inc byte-pos)])))
            :decode (fn [dis current-byte byte-pos]
                      (if (= byte-pos -1)
                        (let [b (.readByte dis)]
                          [(bit-shift-right b 1) 1 (if (= 1 (bit-and b 1)) true false)])
                        (if (> 8 (+ byte-pos 1))
                          [(bit-shift-right current-byte 1)
                           (+ byte-pos 1)
                           (not (zero? (bit-and current-byte 1)))]
                          [0x0 -1 (not (zero? (bit-and current-byte 1)))])))})

(def u4 {:encode (fn [dos current-byte byte-pos value]
                   (if (= byte-pos 0)
                     [(bit-and value 2r1111) 1]
                     (do
                       (assert (>= (- 8 byte-pos) 4))
                       [(bit-or current-byte (bit-shift-left (bit-and value 2r1111) byte-pos)) (+ 4 byte-pos)])))
         :decode (fn [dis current-byte byte-pos]
                   (if (= -1 byte-pos)
                     (let [current-byte (.readByte dis)]
                       [(bit-shift-right current-byte 4)
                        4
                        (bit-and current-byte 2r1111)])
                     (if (> 8 (+ byte-pos 4))
                       [(bit-shift-right current-byte 4)
                        (+ byte-pos 4)
                        (bit-and current-byte 2r1111)]
                       (if (= 4 byte-pos)
                         [0x0 -1 (bit-and current-byte 2r1111)]
                         (assert false)))))})

(def u3 {:encode (fn [dos current-byte byte-pos value]
                   (if (= byte-pos 0)
                     [(bit-and value 2r111) 1]
                     (do
                       (assert (>= (- 8 byte-pos) 3))
                       [(bit-or current-byte (bit-shift-left (bit-and value 2r111) byte-pos)) (+ 3 byte-pos)])))
         :decode (fn [dis current-byte byte-pos]
                   (if (= -1 byte-pos)
                     (assert false)
                     (if (> 8 (+ byte-pos 3))
                       [(bit-shift-right current-byte 3)
                        (+ byte-pos 3)
                        (bit-and current-byte 2r111)]
                       (assert false))))})

(def labels {:encode (fn [dos current-byte byte-pos value]
                       (if (= byte-pos 0)
                         (do
                           (doseq [part (.split value "\\.")]
                             (.writeByte dos (count part))
                             (.write dos (.getBytes part)))
                           (.writeByte dos 0x0)
                           [current-byte byte-pos])
                         (assert nil)))
             :decode (fn [dis current-byte byte-pos]
                       (if (= byte-pos -1)
                         (loop [parts []]
                           (let [n (.readByte dis)]
                             (if (zero? n)
                               [current-byte byte-pos (apply str (interpose \. parts))]
                               (let [arr (byte-array n)]
                                 (.read dis arr)
                                 (recur (conj parts (String. arr)))))))
                         (assert false)))})

(defn fixed [size data-path]
  (fn [data]
    [[data-path {:encode (fn [dos current-byte byte-pos value]
                           (assert false))
                 :decode (fn [dis current-byte byte-pos]
                           (assert (neg? byte-pos))
                           (let [arr (byte-array size)]
                             (.read dis arr)
                             [current-byte byte-pos (String. arr)]))}]]))

(defn txt-rdata [size data-path]
  [[data-path {:encode (fn [dos current-byte byte-pos value]
                         (assert false))
               :decode (fn [dis current-byte byte-pos]
                         (assert (neg? byte-pos))
                         (let [arr (byte-array size)]
                           (.read dis arr)
                           [current-byte byte-pos (String. arr)]))}]])

;; ptr?
(defn rr-record [k-count k-data]
  (fn [data]
    (for [i (range (k-count data))
          ;; TODO: labels or ptr
          x [[[k-data i :name] u16]
             [[k-data i :type] u16]
             [[k-data i :class] u16]
             [[k-data i :ttl] u32]
             [[k-data i :rdlength] u16]
             [[k-data i :rdata] (fn [data]
                                  (case (get-in data [k-data i :type])
                                    1 [[[k-data i :rdata] u32]]
                                    16 (txt-rdata (get-in data [k-data i :rdlength])
                                                  [k-data i :rdata])
                                    12 [[[k-data i :rdata] (fixed
                                                            (- (get-in data [k-data i :rdlength])
                                                               2)
                                                            [k-data i :rdata])]
                                        [[k-data i :garbage] u16]]
                                    33 [[[k-data i :rdata :priority] u16]
                                        [[k-data i :rdata :weight] u16]
                                        [[k-data i :rdata :port] u16]
                                        [[k-data i :rdata :target] (fixed
                                                                    (- (get-in data [k-data i :rdlength])
                                                                       2
                                                                       2
                                                                       2
                                                                       2)
                                                                    [k-data i :rdata :target])]
                                        [[k-data i :rdata :garbage] u16]]))]]]
      x)))

(defn dns []
  [[[:id] u16]
   [[:response] bool1]
   [[:opcode] u4]
   [[:aa] bool1]
   [[:tc] bool1]
   [[:rd] bool1]
   [[:ra] bool1]
   [[:z] u3]
   [[:rcode] u4]
   [[:qdcount] u16]
   [[:ancount] u16]
   [[:nscount] u16]
   [[:arcount] u16]
   [[:questions] (fn [data]
                   (for [i (range (:qdcount data))
                         x [[[:questions i :name] labels]
                            [[:questions i :qtype] u16]
                            [[:questions i :qclass] u16]]]
                     x))]
   [[:answers] (rr-record :ancount :answers)]
   [[:authority] (rr-record :nscount :authority)]
   [[:additional] (rr-record :arcount :additional)]])

(defn encode [spec data]
  (let [baos (java.io.ByteArrayOutputStream.)
        ds (java.io.DataOutputStream. baos)]
    (loop [[[path codec] & spec] spec
           current-byte 0x0
           byte-pos 0x0]
      (if path
        (if (fn? codec)
          (recur (concat (codec data) spec)
                 current-byte
                 byte-pos)
          (let [value (get-in data path)
                _ (assert (not (nil? value)) path)
                _ (assert (map? codec))
                [current-byte byte-pos] ((:encode codec) ds current-byte byte-pos value)]
            (if (= byte-pos 8)
              (do
                (.writeByte ds current-byte)
                (recur spec 0x0 0x0))
              (recur spec current-byte byte-pos))))
        (.toByteArray baos)))))

(defn decode [spec bytes]
  (let [bais (java.io.ByteArrayInputStream. bytes)
        ds (java.io.DataInputStream. bais)]
    (loop [[[path codec] & spec] spec
           current-byte 0x0
           byte-pos -1
           data {}]
      (if path
        (if (fn? codec)
          (recur (concat (codec data) spec)
                 current-byte
                 byte-pos
                 data)
          (let [[current-byte byte-pos value] ((:decode codec) ds current-byte byte-pos)
                data (assoc-in data path value)]
            (recur spec current-byte byte-pos data)))
        data))))

(import '(java.net DatagramSocket
                   DatagramPacket
                   InetSocketAddress
                   InetAddress))

(let [packet (encode (dns)
                     {:id (rand-int 100)
                      :response false
                      :opcode 0
                      :aa false
                      :tc false
                      :rd false
                      :ra false
                      :z 0
                      :rcode 0
                      :qdcount 1
                      :ancount 0
                      :nscount 0
                      :arcount 0
                      :questions [{:name "_ssh._tcp.local"
                                   :qtype 255
                                   :qclass 1}]
                      :answers []
                      :authority []
                      :additional []
                      :zero 0})
      add (InetAddress/getByName "224.0.0.251")
      dp (DatagramPacket. packet (count packet) add 5353)
      x (DatagramPacket. (byte-array 1024) 1024)]
  (with-open [sock (doto (DatagramSocket. 35739)
                     (.setBroadcast true)
                     (.setReuseAddress true))]
    (.send sock dp)
    (.receive sock x)
    (decode (dns) (.getData x))))

Generated At 2023-11-27T10:23:17-0800 original