;;;; -*- Mode: Lisp; Base: 10; Syntax: ANSI-Common-Lisp; Package: kira; Coding: utf-8 -*-
;;;; Copyright © 2022 David Mullen. All Rights Reserved. Origin: <https://cl-pdx.com/kira/>

(in-package :kira)

;;; This is the test vector given for Speck96/144 in the original paper:
;;; <https://eprint.iacr.org/2013/404> except we use little-endian format.
;;; It's used for the DIGEST function (whereas CBC-MAC uses a zeroed IV).
(define-symbol-macro +speck-iv+ #(115586905564534 111520595058798))

(eval-always
  (defconstant +speck-rounds+    29)
  (defconstant +speck-word-bits+ 48)
  (defconstant +speck-key-words+  3))

(eval-always
  (defconstant +speck-word-size+ (/ +speck-word-bits+ 8))
  (defconstant +speck-word-mask+ (1- (ash 1 +speck-word-bits+)))
  (deftype speck-word () `(unsigned-byte ,+speck-word-bits+)))

(eval-always
  (defconstant +mac-bits+ (* 2 +speck-word-bits+))
  (defconstant +digest-bits+ (* 4 +speck-word-bits+))
  (defconstant +speck-key-bits+ (* +speck-word-bits+
                                   +speck-key-words+)))

(defmacro %ror (x r)
  (let ((mask (1- (ash 1 r))))
    `(logior (ash ,x (- ,r))
             (the speck-word
               (ash (logand ,x ,mask)
                    (- +speck-word-bits+
                       ,r))))))

(defmacro %rol (x r)
  `(logior (logand (ash ,x ,r) +speck-word-mask+)
           (ash ,x (- ,r +speck-word-bits+))))

(defmacro %round-function (x y k)
  `(setq ,x (%ror ,x 8) ,x (logand (+ ,x ,y) +speck-word-mask+)
         ,x (logxor ,x ,k) ,y (%rol ,y 3) ,y (logxor ,y ,x)))

(defmacro with-speck-key ((key &rest variables) &body body)
  `(with-array-data ((vector ,key) (start 0) (end ,(list-length variables)))
     ,(loop for var in variables for i of-type array-index = 0 then (1+ i)
            collect `(,var (require-type (aref vector (+ start ,i)) 'speck-word)) into bindings
            finally (return `(let (,.bindings) (declare (type speck-word ,@variables)) ,@body)))))

(defun make-key-schedule (key &optional key-schedule)
  (orf key-schedule (make-array +speck-rounds+))
  (with-speck-key (key c b a)
    (loop initially (setf (svref key-schedule 0) c)
          for i of-type (unsigned-byte 16) from 0 below (1- +speck-rounds+)
          do (if (evenp i) (%round-function b c i) (%round-function a c i))
             (setf (svref key-schedule (1+ i)) c) finally (return key-schedule))))

(defmacro with-speck-context ((plaintext &optional key) &body body)
  `(let ((ciphertext.lo 0) (ciphertext.hi 0) (plaintext.lo 0) (plaintext.hi 0))
     (declare (type speck-word ciphertext.lo ciphertext.hi plaintext.lo plaintext.hi))
     (declare (ignorable plaintext.lo plaintext.hi))
     (let ((key-schedule (make-array +speck-rounds+)))
       (declare (optimize (safety 0) (speed 3)))
       (whereas ((plaintext-vector ,plaintext))
         (setq plaintext.lo (svref plaintext-vector 0))
         (setq plaintext.hi (svref plaintext-vector 1)))
       ,@(when key `((make-key-schedule ,key key-schedule)))
       ,@body)))

(defmacro %encrypt-within-speck-context (&optional lo hi new-key)
  `(loop ,@(when new-key `(initially (make-key-schedule ,new-key key-schedule)))
         ,@(when hi `(initially (setq plaintext.lo ,lo plaintext.hi ,hi)))
         with y of-type speck-word = ,(or lo 'plaintext.lo)
         with x of-type speck-word = ,(or hi 'plaintext.hi)
         finally (setq ciphertext.lo y ciphertext.hi x)
         for i of-type (unsigned-byte 16) from 0 below +speck-rounds+
         for round-key of-type speck-word = (svref key-schedule i)
         do (%round-function x y round-key)))

(defun %get-speck-word (vector i start end buffer)
  (declare (type (simple-array octet (*)) vector))
  (declare (type (simple-array octet (*)) buffer))
  (declare (type array-index i start end))
  (declare (optimize (safety 0) (speed 3)))
  (if (>= i end) ; Last word is the 40-bit length.
      (let ((length (ldb (byte 40 0) (- end start))))
        (declare (type (unsigned-byte 48) length))
        (ashf length 8) ; Leave room for padding.
        (if (= i end) (logior length #x80) length))
      ;; MD-compliant padding mechanism. Note that
      ;; %GET-BINARY-BYTE does no bounds checking,
      ;; so the buffer is used for the last block
      ;; of data with less than six bytes remaining.
      (let ((remaining-octets (- end i)))
        (declare (type array-index remaining-octets))
        (when (< remaining-octets +speck-word-size+)
          (replace buffer vector :start2 i :end2 end)
          ;; REMAINING-OCTETS is the start of padding.
          (setf (aref buffer remaining-octets) #x80)
          (fill buffer 0 :start (1+ remaining-octets))
          (setq vector buffer i 0))
        ;; BUFFER is used here for the padded portion.
        (%get-binary-byte vector i #.+speck-word-size+))))

(defmacro put-key (key &rest key-words)
  `(setf ,.(loop for i = 0 then (1+ i)
                 for key-word in key-words
                 collect `(svref ,key ,i)
                 collect key-word)))

(defmacro with-speck-input ((data start end buffer) &body body)
  `(let ((,buffer (make-array +speck-word-size+ :element-type 'octet)))
     (unless (vectorp ,data) (setq ,data (coerce ,data '(vector octet))))
     (when (stringp ,data) (setq ,data (encode-string-to-octets ,data)))
     (with-array-data ((,data ,data) (,start 0) ,end) ,@body)))

(defun digest (data)
  "Hirose compression function."
  (with-speck-context (+speck-iv+)
    (with-speck-input (data start end buffer)
      (loop with key = (make-array +speck-key-words+)
            with end-with-padding of-type array-index = (+ end +speck-word-size+)
            with g.lo of-type speck-word = plaintext.lo with h.lo of-type speck-word = g.lo
            with g.hi of-type speck-word = plaintext.hi with h.hi of-type speck-word = g.hi
            for i of-type array-index from start below end-with-padding by +speck-word-size+
            do (put-key key (%get-speck-word data i start end buffer) h.lo h.hi)
               (%encrypt-within-speck-context g.lo g.hi key)
               (logxorf ciphertext.lo plaintext.lo)
               (logxorf ciphertext.hi plaintext.hi)
               (setq g.lo ciphertext.lo g.hi ciphertext.hi)
               (logxorf plaintext.lo +speck-word-mask+)
               (logxorf plaintext.hi +speck-word-mask+)
               (%encrypt-within-speck-context)
               (logxorf ciphertext.lo plaintext.lo)
               (logxorf ciphertext.hi plaintext.hi)
               (setq h.lo ciphertext.lo h.hi ciphertext.hi)
            ;; Concatenation (H || G) makes a double-block output.
            finally (return (logior (ash g.hi (* 3 +speck-word-bits+))
                                    (ash g.lo (* 2 +speck-word-bits+))
                                    (ash h.hi +speck-word-bits+)
                                    h.lo))))))

(defun make-key (input)
  (loop initially (unless (integerp input) (setq input (digest input)))
        with vector of-type simple-vector = (make-array +speck-key-words+)
        for byte-position fixnum below +speck-key-bits+ by +speck-word-bits+
        for word = (ldb (byte +speck-word-bits+ byte-position) input)
        for i fixnum = 0 then (1+ i) do (setf (svref vector i) word)
        finally (return vector)))

(defmacro %get-cbc-word (vector i start end buffer)
  `(cond ((>= ,i ,end) (if (= ,i ,end) #x80 #x00))
         ((< ,i ,start) (logand (- ,end ,start) +speck-word-mask+))
         (t (%get-speck-word ,vector ,i ,start ,end ,buffer))))

(defun mac (data key)
  "CBC-MAC based on Speck96/144."
  (with-speck-context (nil key)
    (with-speck-input (data start end buffer)
      (loop with i of-type array-index = (- start +speck-word-size+)
            when (> i end) return (logior (ash ciphertext.hi +speck-word-bits+) ciphertext.lo)
            do (setq plaintext.lo (%get-cbc-word data i start end buffer) i (+ i +speck-word-size+))
               (setq plaintext.hi (%get-cbc-word data i start end buffer) i (+ i +speck-word-size+))
               (%encrypt-within-speck-context (logxor plaintext.lo ciphertext.lo)
                                              (logxor plaintext.hi ciphertext.hi))))))

(defun digest-object (object)
  "Compute digest of the binary-encoded OBJECT."
  (digest (with-output-to-vector (binary-stream)
            (write-binary-object object binary-stream))))

(defun derive-key (master-key label)
  (let ((digest (digest-object label)))
    (with-speck-context (nil master-key)
      (loop with i of-type array-index = 0
            with derived-key = (make-array +speck-key-words+)
            for lo-position fixnum from 0 by (* 2 +speck-word-bits+)
            for hi-position fixnum = (+ lo-position +speck-word-bits+)
            do (%encrypt-within-speck-context
                (ldb (byte +speck-word-bits+ lo-position) digest)
                (ldb (byte +speck-word-bits+ hi-position) digest))
               (setf (aref derived-key i) ciphertext.lo i (1+ i))
               (when (= i +speck-key-words+) (return derived-key))
               (setf (aref derived-key i) ciphertext.hi i (1+ i))
               (when (= i +speck-key-words+) (return derived-key))))))

(defun ensure-key (name)
  (symbol-macrolet ((master-key (get 'system 'master-key)))
    (orf (get name 'key) (derive-key master-key name))))

(defun make-timestamp-time (&optional local-time (key (ensure-key 'timestamp-key)))
  (if local-time (normalize-local-time local-time) (setq local-time (get-local-time)))
  (with-accessors ((day local-time-day) (sec local-time-sec) (msec local-time-msec)) local-time
    (let ((time (+ (* day 100000000) (* (the (mod 100000) sec) 1000) (the (mod 1000) msec))))
      (values time (mod (mac (format nil "~D" time) key) 10000000000)))))

(defun make-timestamp (&rest make-timestamp-time-arguments)
  (multiple-value-bind (time code) (apply #'make-timestamp-time make-timestamp-time-arguments)
    (nstring-downcase (write-to-string (+ (* time 10000000000) code) :base 36 :radix nil))))

(defun parse-timestamp (timestamp &key (start 0) end)
  "Parse LOCAL-TIME and MAC components from a TIMESTAMP."
  (with-lexer (timestamp start end)
    (with-lexer-error ("timestamp")
      (let ((timestamp-integer (or (lexer-unsigned 36) (lexer-error))))
        (multiple-value-bind (time code) (floor timestamp-integer 10000000000)
          (multiple-value-bind (day msec-time-of-day) (floor time 100000000)
            (declare (type (integer 0 (100000000)) msec-time-of-day))
            (multiple-value-bind (sec msec) (floor msec-time-of-day 1000)
              (declare (type (unsigned-byte 32) sec) (type (unsigned-byte 16) msec))
              (values (make-local-time :day day :sec sec :msec msec :zone 0) code))))))))

(defun authenticate-timestamp (timestamp &optional (key (ensure-key 'timestamp-key)))
  (multiple-value-bind (local-time provided-code) (parse-timestamp timestamp)
    (when (= provided-code (nth-value 1 (make-timestamp-time local-time key)))
      local-time)))

(defun make-session-cookie (id &optional (code (compute-session-code id)))
  (logior (ash id +mac-bits+) (require-type code '(unsigned-byte #.+mac-bits+))))

(defun compute-session-code (id)
  (mac (with-output-to-vector (binary-stream)
         (write-binary-object id binary-stream))
       (ensure-key 'session-key)))

(defun decode-cookie (authenticated-cookie)
  (values (ash authenticated-cookie (- +mac-bits+))
          (ldb (byte +mac-bits+ 0) authenticated-cookie)))