;;;; -*- Mode: Lisp; Base: 10; Syntax: ANSI-Common-Lisp; Package: kira; Coding: utf-8 -*-
;;;; Copyright © 2022 David Mullen. All Rights Reserved. Origin: <https://cl-pdx.com/kira/>

(in-package :kira)

(deftype octet () '(unsigned-byte 8))
(deftype binary-word () '(unsigned-byte 64))
(deftype binary-fixnum () '(signed-byte 56))

(defmacro %get-binary-byte (vector offset length)
  (check-type length (integer 1 7) "a byte length")
  (loop with (octet byte) = (mapcar #'gensym (mapcar #'symbol-name '(octet byte)))
        with byte-type = `(unsigned-byte ,(* 8 length)) for i fixnum downfrom (1- length) to 0
        collect `(setq ,octet (aref ,vector (the array-index (+ ,offset ,i)))) into decoder-forms
        when (< i (the fixnum (1- length))) collect `(setq ,byte (ash ,byte 8)) into decoder-forms
        collect `(logiorf ,byte ,octet) into decoder-forms
        finally (push `(declare (type ,byte-type ,byte)) decoder-forms)
                (push `(declare (type (unsigned-byte 8) ,octet)) decoder-forms)
                (return `(let ((,octet 0) (,byte 0)) ,.decoder-forms ,byte))))

(defun skip-binary-object (vector offset)
  (with-array-data ((vector vector) (start offset) (end nil))
    (declare (type (simple-array (unsigned-byte 8) (*)) vector))
    (locally (declare (type (and fixnum unsigned-byte) offset))
      (prog ((tag-index start))
        (declare (type fixnum tag-index))
        (macrolet ((step-offset (delta)
                     `(progn (setq start (the fixnum (+ start ,delta)))
                             (when (> start end) (go object-out-of-bounds))
                             (setq offset (the fixnum (+ offset ,delta))))))
          ;; At minimum, an object is at least four octets (chunk or not).
          (let ((chunk-offset (step-offset 4)) (tag (aref vector tag-index)))
            (declare (type octet tag) (type array-index chunk-offset))
            (when (< tag 128) (return (values (step-offset 4) tag)))
            (let* ((length-index (the array-index (1+ tag-index)))
                   (length (%get-binary-byte vector length-index 3)))
              (declare (type array-index length-index) (fixnum length))
              (return (values (step-offset length) tag chunk-offset)))))
       object-out-of-bounds
        (check-vector-bounds
         vector start end)))))

(defun copy-binary-object (vector offset stream)
  (multiple-value-bind (end-of-object tag chunk-offset)
      (skip-binary-object vector offset)
    (declare (type array-index end-of-object))
    (locally (declare (type array-index offset))
      (multiple-value-prog1 (values end-of-object tag chunk-offset)
        (write-sequence vector stream :start offset :end end-of-object)))))

(defmacro constant-octet-vector (&rest elements)
  (let ((initial-contents (mapcar #'macro-eval elements)))
    (coerce initial-contents '(vector (unsigned-byte 8)))))

(defmacro %vector-output-stream-finalize-binary-chunk
    (chunk-start length effective-tag stream)
  (with-gensyms (ioblock origin outbuf buffer i)
    `(let* ((,ioblock (basic-stream-ioblock ,stream))
            (,origin (vector-output-stream-ioblock-displacement
                      ,ioblock)))
       (declare (type array-index ,origin))
       (declare (optimize (safety 0) (speed 3)))
       (let* ((,outbuf (ioblock-outbuf ,ioblock))
              (,buffer (io-buffer-buffer ,outbuf))
              (,i (+ ,origin ,chunk-start)))
         (declare (type array-index ,i))         
         (declare (type (simple-array octet (*)) ,buffer))
         (prog1 (setf (aref ,buffer ,i) ,effective-tag)
           (setf (aref ,buffer (+ ,i 1)) (ldb (byte 8 0) ,length))
           (setf (aref ,buffer (+ ,i 2)) (ldb (byte 8 8) ,length))
           (setf (aref ,buffer (+ ,i 3)) (ldb (byte 8 16) ,length)))))))

(defmacro with-binary-chunk ((stream &optional tag) &body body)
  (with-gensyms (effective-tag chunk-start chunk-end length chunk-offset)
    (let ((placeholder-sequence (constant-octet-vector 255 255 255 255)))
      (rebinding (stream)
        (let ((writer-forms
               `((write-sequence
                  ,placeholder-sequence
                  ,stream) ,@body)))
          ;; The correct tag is written after the chunk size is known.
          ;; For vector streams, this is optimized CCL-specific code,
          ;; namely in %VECTOR-OUTPUT-STREAM-FINALIZE-BINARY-CHUNK.
          `(let* ((,chunk-start (file-position ,stream))
                  (,effective-tag (locally ,@writer-forms))
                  (,chunk-offset (+ ,chunk-start 4))
                  (,chunk-end (file-position ,stream))
                  (,length (- ,chunk-end ,chunk-offset)))
             (declare (type array-index ,chunk-start))
             (declare (type array-index ,chunk-offset))
             (declare (type array-index ,chunk-end))
             (declare (type array-index ,length))
             ,@(when tag `((setq ,effective-tag ,tag)))
             (check-type ,effective-tag octet "a tag")
             (if (typep ,stream 'vector-output-stream)
                 (%vector-output-stream-finalize-binary-chunk
                  ,chunk-start ,length ,effective-tag ,stream)
                 (progn (file-position ,stream ,chunk-start)
                        (write-byte ,effective-tag ,stream)
                        (write-binary-length ,length ,stream)
                        (file-position ,stream ,chunk-end)
                        ,effective-tag))))))))

(defun write-binary-length (length stream)
  (require-type length '(unsigned-byte 24))
  (locally (declare (type array-index length))
    (write-byte (ldb (byte 8 0) length) stream)
    (write-byte (ldb (byte 8 8) length) stream)
    (write-byte (ldb (byte 8 16) length) stream)))

(eval-always
  (defconstant +fixnum-tag+          0)
  (defconstant +object-tag+          1)
  (defconstant +character-tag+       2)
  (defconstant +utc-tag+            64)
  (defconstant +text-tag+          128)
  (defconstant +kira-tag+          129)
  (defconstant +keyword-tag+       130)
  (defconstant +bignum-tag+        131)
  (defconstant +ratio-tag+         132)
  (defconstant +local-time-tag+    133)
  (defconstant +proper-list-tag+   134)
  (defconstant +dotted-list-tag+   135)
  (defconstant +object-record-tag+ 136)
  (defconstant +complex-tag+       137)
  (defconstant +simple-vector-tag+ 138)
  (defconstant +octet-vector-tag+  139)
  (defconstant +bit-vector-tag+    140)
  (defconstant +uri-tag+           141)
  (defconstant +array-tag+         142)
  (defconstant +single-float-tag+  143)
  (defconstant +double-float-tag+  144)
  (defconstant +hash-table-tag+    145)
  (defconstant +pathname-tag+      146))

(defun write-binary-octet-vector (vector stream)
  (prog1 (write-byte +octet-vector-tag+ stream)
    (write-binary-length (length vector) stream)
    (write-sequence vector stream)))

(defun kira-symbol-p (symbol &aux (name (symbol-name symbol)))
  (eq (find-symbol name (load-time-package :kira)) symbol))

(defun %write-binary-object (object stream)
  (declare (optimize (safety 0) (speed 3)))
  (etypecase object
    (cons #'write-binary-cons)
    (string #'write-binary-string)
    (keyword #'write-binary-keyword)
    (object (write-binary-word +object-tag+ (object-id object) stream))
    ((vector (unsigned-byte 8)) (write-binary-octet-vector object stream))
    (character (write-binary-word +character-tag+ (char-code object) stream))
    (null (let ((null-seq (constant-octet-vector +proper-list-tag+ 0 0 0)))
            (prog1 +proper-list-tag+ (write-sequence null-seq stream))))
    ((and symbol (satisfies kira-symbol-p)) #'write-binary-symbol)
    (binary-fixnum (locally (declare (type (signed-byte 57) object))
                     (when (minusp object) (incf object (ash 1 56)))
                     (write-binary-word +fixnum-tag+ object stream)))
    (integer #'write-binary-integer)
    (ratio #'write-binary-ratio)
    (single-float #'write-binary-single-float)
    (double-float #'write-binary-double-float)
    (local-time
     (normalize-local-time object)
     (let ((day (local-time-day object))
           (sec (local-time-sec object))
           (msec (local-time-msec object))
           (zone (local-time-zone object)))
       (declare (type fixnum day sec msec))
       (if (and (eql zone 0) (typep day '(unsigned-byte 29)))
           (let* ((utc (logior (the fixnum (ash day 17)) sec))
                  (utc (logior (the fixnum (ash utc 10)) msec)))
             (declare (type (unsigned-byte 56) utc))
             (write-binary-word +utc-tag+ utc stream))
           #'write-binary-time)))
    (bit-vector #'write-binary-bit-vector)
    (hash-table #'write-binary-hash-table)
    (vector #'write-binary-simple-vector)
    (array #'write-binary-array)
    (complex #'write-binary-complex)
    (pathname #'write-binary-pathname)
    (uri #'write-binary-uri)))

(defun write-binary-word (tag integer stream)
  (check-type tag (unsigned-byte 7) "a tag")
  (check-type integer (unsigned-byte 56))
  (locally (declare (type (unsigned-byte 56) integer))
    (write-byte (the (unsigned-byte 7) tag) stream)
    (write-byte (ldb (byte 8 0) integer) stream)
    (write-byte (ldb (byte 8 8) integer) stream)
    (write-byte (ldb (byte 8 16) integer) stream)
    (write-byte (ldb (byte 8 24) integer) stream)
    (write-byte (ldb (byte 8 32) integer) stream)
    (write-byte (ldb (byte 8 40) integer) stream)
    (write-byte (ldb (byte 8 48) integer) stream)
    tag))

(defun write-binary-object (object stream)
  (let ((tag (%write-binary-object object stream)))
    (cond ((typep tag 'fixnum) (values object tag))
          (t (values object (with-binary-chunk (stream)
                              (funcall tag object stream)))))))

(defun get-binary-decoder (tag)
  (check-type tag octet "a tag")
  (ecase (the octet tag)
    (#.+text-tag+ #'get-binary-string)
    (#.+kira-tag+ #'get-binary-symbol)
    (#.+keyword-tag+ #'get-binary-keyword)
    (#.+bignum-tag+ #'get-binary-integer)
    (#.+ratio-tag+ #'get-binary-ratio)
    (#.+local-time-tag+ #'get-binary-time)
    (#.+proper-list-tag+ #'get-binary-proper-list)
    (#.+dotted-list-tag+ #'get-binary-dotted-list)
    (#.+object-record-tag+ #'get-binary-octet-vector)
    (#.+complex-tag+ #'get-binary-complex)
    (#.+simple-vector-tag+ #'get-binary-simple-vector)
    (#.+octet-vector-tag+ #'get-binary-octet-vector)
    (#.+bit-vector-tag+ #'get-binary-bit-vector)
    (#.+uri-tag+ #'get-binary-uri)
    (#.+array-tag+ #'get-binary-array)
    (#.+single-float-tag+ #'get-binary-single-float)
    (#.+double-float-tag+ #'get-binary-double-float)
    (#.+hash-table-tag+ #'get-binary-hash-table)
    (#.+pathname-tag+ #'get-binary-pathname)))

(defun get-binary-object (vector object-offset)
  (multiple-value-bind (new-offset tag chunk-offset)
      (skip-binary-object vector object-offset)
    (declare (type (unsigned-byte 8) tag) (type array-index new-offset))
    (whereas ((decoder (when chunk-offset (get-binary-decoder tag))))
      (let ((object (funcall decoder vector chunk-offset new-offset)))
        (return-from get-binary-object (values object new-offset))))
    (with-array-data ((vector vector) (offset object-offset) end)
      (declare (type (simple-array (unsigned-byte 8) (*)) vector))
      (let* ((immediate-offset (the array-index (1+ offset)))
             (immediate-byte (%get-binary-byte vector immediate-offset 7)))
        (declare (fixnum immediate-offset) (type (unsigned-byte 56) immediate-byte))
        (values
         (cond
           ((= tag +object-tag+)
            (get-object immediate-byte))
           ((= tag +utc-tag+)
            (let* ((day (ldb (byte 29 27) immediate-byte))
                   (sec (ldb (byte 17 10) immediate-byte))
                   (msec (ldb (byte 10 0) immediate-byte))
                   (local-time (make-local-time :zone 0)))
              (declare (type (unsigned-byte 29) day))
              (declare (type (unsigned-byte 17) sec))
              (declare (type (unsigned-byte 10) msec))
              (setf (local-time-day local-time) day)
              (setf (local-time-sec local-time) sec)
              (setf (local-time-msec local-time) msec)
              local-time))
           ((= tag +fixnum-tag+)
            (let ((integer immediate-byte))
              ;; Make (unsigned) IMMEDIATE-BYTE signed.
              (declare (type (signed-byte 57) integer))
              (when (>= integer (expt 2 55))
                (decf integer (expt 2 56)))
              integer))
           ((= tag +character-tag+) (code-char immediate-byte))
           (t (error "Immediate tag undefined: ~D." tag)))
         new-offset)))))

(defstatic +binary-external-format+
  '(:character-encoding :utf-8
    :line-termination :unix))

(defun write-binary-string (string stream)
  (write-string string stream) +text-tag+)

(defun get-binary-string (vector start end)
  "Decode the UTF-8 string encoded in VECTOR."
  (decode-string-from-octets vector
                             :external-format :utf-8
                             :start start :end end))

(defun write-binary-symbol (symbol stream)
  (write-string (symbol-name symbol) stream)
  +kira-tag+)

(defmacro %get-binary-symbol (vector start end package-designator)
  (setq package-designator `(load-time-package ,package-designator))
  `(intern (get-binary-string ,vector ,start ,end) ,package-designator))

(defun get-binary-symbol (vector start end)
  (%get-binary-symbol vector start end :kira))

(defun write-binary-keyword (keyword stream)
  (write-string (symbol-name keyword) stream)
  +keyword-tag+)

(defun get-binary-keyword (vector start end)
  (%get-binary-symbol vector start end :keyword))

(defun write-binary-integer (integer stream)
  (let* ((signed-length (1+ (integer-length integer)))
         (bits (ash (ceiling signed-length 8) 3)))
    ;; BITS has been rounded up to the nearest octet.
    (when (minusp integer) (incf integer (ash 1 bits)))
    (do ((position 0 (+ position 8))) ((= position bits) +bignum-tag+)
      (write-byte (the octet (ldb (byte 8 position) integer)) stream))))

(defun get-binary-integer (vector start end)
  (with-array-data ((vector vector) (start start) (end end))
    (declare (type (simple-array (unsigned-byte 8) (*)) vector))
    (loop with integer-size of-type array-index = (- end start)
          with limit = (ash 1 (the fixnum (ash integer-size 3))) and integer = 0
          for position fixnum by 8 for offset of-type fixnum from start below end
          do (setf (ldb (byte 8 position) integer) (the octet (aref vector offset)))
          finally (return (if (< integer (ash limit -1)) integer (- integer limit))))))

(defmacro with-binary-chunk-iterator ((macro-name) (vector start end) &body body)
  `(let ((,start (prog1 (require-type ,start 'fixnum) (require-type ,end 'fixnum))))
     ,(with-gensyms (object)
        (let* ((get `(setf (values ,object ,start) (get-binary-object ,vector ,start)))
               (g-clause `((< (the fixnum ,start) (the fixnum ,end)) (values ,get t)))
               (cond-form ``(cond ,',g-clause (t ,(when errorp '(error "No more objects."))))))
          `(macrolet ((,macro-name (&optional errorp) ,cond-form))
             (let (,object) (declare (ignorable ,object)) ,@body))))))

(defun write-binary-ratio (ratio stream)
  (write-binary-object (numerator ratio) stream)
  (write-binary-object (denominator ratio) stream)
  +ratio-tag+)

(defun get-binary-ratio (vector start end)
  (with-binary-chunk-iterator (get-ratio-part)
      (vector start end)
    (rationalize (/ (get-ratio-part t)
                    (get-ratio-part t)))))

(defun write-binary-cons (cons stream)
  "Encode both proper and dotted lists."
  (if (proper-list-p cons)
      (dolist (element cons +proper-list-tag+)
        (write-binary-object element stream))
      (prog1 +dotted-list-tag+
        (loop with dotted-list-tail = cons for element = (pop dotted-list-tail)
              do (write-binary-object element stream) until (atom dotted-list-tail)
              finally (write-binary-object dotted-list-tail stream)))))

(defun get-binary-proper-list (vector start end &aux element have-element-p)
  (with-binary-chunk-iterator (get-list-element) (vector start end)
    (loop do (setf (values element have-element-p) (get-list-element))
          while have-element-p collect element)))

(defun get-binary-dotted-list (vector start end &aux element have-element-p)
  (with-binary-chunk-iterator (get-list-element) (vector start end)
    (collecting (dotted-list)
      (loop with tail and pending-element and pending-element-p
            do (setf (values element have-element-p) (get-list-element)) while have-element-p
            when pending-element-p do (setq tail (collect pending-element :into dotted-list))
            do (setq pending-element element pending-element-p t)
            finally (cond ((not pending-element-p) (return nil))
                          (tail (rplacd tail pending-element) (return dotted-list))
                          (t (return (collect pending-element :into dotted-list))))))))

(defun write-binary-time (local-time stream)
  (write-binary-object (local-time-day local-time) stream)
  (write-binary-object (local-time-sec local-time) stream)
  (write-binary-object (local-time-msec local-time) stream)
  (write-binary-object (local-time-zone local-time) stream)
  +local-time-tag+)

(defun get-binary-time (vector start end)
  (with-binary-chunk-iterator (get-local-time-part)
      (vector start end)
    (make-local-time :day (get-local-time-part t)
                     :sec (get-local-time-part t)
                     :msec (get-local-time-part t)
                     :zone (get-local-time-part t))))

(defun write-binary-complex (complex stream)
  (write-binary-object (realpart complex) stream)
  (write-binary-object (imagpart complex) stream)
  +complex-tag+)

(defun get-binary-complex (vector start end)
  (with-binary-chunk-iterator (get-complex-part)
      (vector start end)
    (complex (get-complex-part t)
             (get-complex-part t))))

(defun write-binary-uri (uri stream)
  (let ((uri-string (uri-string uri)))
    (write-string uri-string stream)
    +uri-tag+))

(defun get-binary-uri (vector start end)
  (%make-uri-from-string (get-binary-string
                          vector start end)))

(defun get-binary-octet-vector (vector start end)
  (or (get-mmap-displaced-vector vector start end)
      (subseq vector start end)))

(defun write-binary-simple-vector (vector stream)
  (loop for i of-type array-index below (length vector)
        do (write-binary-object (aref vector i) stream)
        finally (return +simple-vector-tag+)))

(defun get-binary-simple-vector (vector start end)
  (with-binary-chunk-iterator (get-element) (vector start end)
    (loop with estimated-size of-type fixnum = (ash (the fixnum (- end start)) -3)
          with result = (make-buffer estimated-size t) and element and have-element-p
          do (setf (values element have-element-p) (get-element)) while have-element-p
          do (vector-push-extend element result) finally (return result))))

(defun write-binary-bit-vector (bit-vector stream)
  (with-array-data ((bit-vector bit-vector) (start 0) end)
    (loop with odd-bits of-type (mod 8) = (logand (the (unsigned-byte 24) (- end start)) 7)
          initially (write-byte (aref (constant-octet-vector 0 7 6 5 4 3 2 1) odd-bits) stream)
          with octet of-type (unsigned-byte 8) = 0 for position of-type array-index upfrom 0
          for i of-type fixnum from start below end for bit of-type bit = (sbit bit-vector i)
          for octet-position of-type (mod 8) = (logand position 7)
          when (= bit 1) do (logiorf octet (the octet (ash 1 octet-position)))
          when (= octet-position 7) do (progn (write-byte octet stream) (setq octet 0))
          finally (when (plusp odd-bits) (write-byte octet stream)) (return +bit-vector-tag+))))

(defun get-binary-bit-vector (vector start end)
  (with-array-data ((vector vector) (start start) (end end))
    (declare (type (simple-array (unsigned-byte 8) (*)) vector))
    (when (= end start) (error "Missing count of unused bits."))
    (let* ((unused-bits (prog1 (aref vector start) (incf start)))
           (length-in-octets (the (unsigned-byte 24) (- end start)))
           (even-length (the array-index (ash length-in-octets 3)))
           (length (the array-index (- even-length unused-bits)))
           (bit-vector (make-array length :element-type 'bit)))
      (declare (type (unsigned-byte 8) unused-bits))
      (declare (type (unsigned-byte 24) length-in-octets))
      (declare (type array-index even-length length))
      (declare (type simple-bit-vector bit-vector))
      (when (zerop length) (return-from get-binary-bit-vector bit-vector))
      (loop with position of-type (mod 8) = 0 and offset fixnum = start
            for i of-type array-index from 0 below length
            for octet of-type (unsigned-byte 8) = (aref vector offset)
            then (cond ((< position 7) (incf position) (ash octet -1))
                       (t (setq position 0) (aref vector (incf offset))))
            ;; Take the least significant bit of the octet.
            do (setf (sbit bit-vector i) (logand octet 1))
            finally (return bit-vector)))))

(defun make-underlying-vector (array &aux (element-type (array-element-type array)))
  (make-array (array-total-size array) :displaced-to array :element-type element-type))

(defun write-binary-array (array stream)
  (let ((underlying-vector (make-underlying-vector array)))
    (write-binary-object (array-dimensions array) stream)
    (write-binary-object underlying-vector stream)
    +array-tag+))

(defun get-binary-array (vector start end &aux underlying-vector)
  (with-binary-chunk-iterator (get-part) (vector start end)
    (make-array (get-part t)
                :displaced-to (setq underlying-vector (get-part t))
                :element-type (array-element-type underlying-vector))))

(defun write-binary-single-float (single-float stream)
  (let ((single-float-rationalized (rationalize single-float)))
    (write-binary-object single-float-rationalized stream)
    +single-float-tag+))

(defun get-binary-single-float (vector start end)
  (with-array-data ((vector vector) (start start) (end end))
    (coerce (get-binary-object vector start) 'single-float)))

(defun write-binary-double-float (double-float stream)
  (let ((double-float-rationalized (rationalize double-float)))
    (write-binary-object double-float-rationalized stream)
    +double-float-tag+))

(defun get-binary-double-float (vector start end)
  (with-array-data ((vector vector) (start start) (end end))
    (coerce (get-binary-object vector start) 'double-float)))

(defun write-binary-hash-table (hash-table stream)
  (write-binary-object (hash-table-test hash-table) stream)
  (write-binary-object (hash-table-size hash-table) stream)
  (with-hash-table-iterator (get-hash-table-entry hash-table)
    (loop (multiple-value-bind (more k v) (get-hash-table-entry)
            (unless more (return +hash-table-tag+))
            (write-binary-object k stream)
            (write-binary-object v stream)))))

(defun get-binary-hash-table (vector start end)
  (with-binary-chunk-iterator (get-hash-table-part) (vector start end)
    (let ((test (get-hash-table-part t)) (size (get-hash-table-part t)))
      (loop with hash-table of-type hash-table = (make-hash-table :test test :size size)
            with key and present-p do (setf (values key present-p) (get-hash-table-part))
            while present-p do (setf (gethash key hash-table) (get-hash-table-part t))
	          finally (return hash-table)))))

(defun write-binary-pathname (pathname stream)
  (write-string (namestring pathname) stream)
  +pathname-tag+)

(defun get-binary-pathname (vector start end)
  (let ((string (get-binary-string vector start end)))
    (declare (type string string)) (pathname string)))