;;;; -*- Mode: Lisp; Base: 10; Syntax: ANSI-Common-Lisp; Package: kira; Coding: utf-8 -*-
;;;; Copyright © 2021 David Mullen. All Rights Reserved. Origin: <https://cl-pdx.com/kira/>

(in-package :kira)

(define-symbol-macro +speck-iv+
  ;; This is the test vector given for
  ;; Speck96/144 in the original paper:
  ;; <https://eprint.iacr.org/2013/404>
  ;; except we use little-endian format.
  #(115586905564534 111520595058798))

  (defconstant +speck-rounds+     29)
  (defconstant +speck-word-bits+  48)
  (defconstant +speck-key-words+   3)
  (defconstant +digest-bits+     192)
  (defconstant +mac-bits+         96))

  (defconstant +speck-word-size+ (/ +speck-word-bits+ 8))
  (defconstant +speck-word-mask+ (1- (ash 1 +speck-word-bits+)))
  (deftype speck-word () `(unsigned-byte ,+speck-word-bits+)))

(defmacro %ror (x r)
  (let ((mask (1- (ash 1 r))))
    `(logior (ash ,x (- ,r))
             (the speck-word
               (ash (logand ,x ,mask)
                    (- +speck-word-bits+

(defmacro %rol (x r)
  `(logior (logand (ash ,x ,r) +speck-word-mask+)
           (ash ,x (- ,r +speck-word-bits+))))

(defmacro %round-function (x y k)
  `(setq ,x (%ror ,x 8) ,x (logand (+ ,x ,y) +speck-word-mask+)
         ,x (logxor ,x ,k) ,y (%rol ,y 3) ,y (logxor ,y ,x)))

(defmacro with-speck-key ((key &rest variables) &body body)
  `(with-array-data ((vector ,key) (start 0) (end ,(list-length variables)))
     ,(loop for var in variables for i of-type array-index = 0 then (1+ i)
            collect `(,var (require-type (aref vector (+ start ,i)) 'speck-word)) into bindings
            finally (return `(let (,.bindings) (declare (type speck-word ,@variables)) ,@body)))))

(defun make-key-schedule (key &optional key-schedule)
  (orf key-schedule (make-array +speck-rounds+))
  (with-speck-key (key c b a)
    (loop initially (setf (svref key-schedule 0) c)
          for i of-type (unsigned-byte 16) from 0 below (1- +speck-rounds+)
          do (if (evenp i) (%round-function b c i) (%round-function a c i))
             (setf (svref key-schedule (1+ i)) c) finally (return key-schedule))))

(defmacro with-speck-context ((plaintext &optional key) &body body)
  `(let ((ciphertext.lo 0) (ciphertext.hi 0) (plaintext.lo 0) (plaintext.hi 0))
     (declare (type speck-word ciphertext.lo ciphertext.hi plaintext.lo plaintext.hi))
     (let ((key-schedule (make-array +speck-rounds+)))
       (declare (optimize (safety 0) (speed 3)))
       (let ((plaintext-vector ,plaintext))
         (setq plaintext.lo (aref plaintext-vector 0))
         (setq plaintext.hi (aref plaintext-vector 1)))
       ,@(when key `((make-key-schedule ,key key-schedule)))

(defmacro %encrypt-within-speck-context (&optional lo hi new-key)
  `(loop ,@(when new-key `(initially (make-key-schedule ,new-key key-schedule)))
         ,@(when hi `(initially (setq plaintext.lo ,lo plaintext.hi ,hi)))
         with y of-type speck-word = ,(or lo 'plaintext.lo)
         with x of-type speck-word = ,(or hi 'plaintext.hi)
         for i of-type (unsigned-byte 16) from 0 below +speck-rounds+
         for round-key of-type speck-word = (svref key-schedule i)
         do (%round-function x y round-key)
         finally (setq ciphertext.lo y)
                 (setq ciphertext.hi x)))

(defun %get-speck-word (vector i start end buffer)
  (declare (type (simple-array octet (*)) vector))
  (declare (type (simple-array octet (*)) buffer))
  (declare (type array-index i start end))
  (declare (optimize (safety 0) (speed 3)))
  (if (>= i end) ; Final word is the length.
      (logand (- end start) +speck-word-mask+)
      (let ((remaining-octets (- end i)))
        (declare (type array-index remaining-octets))
        (when (< remaining-octets +speck-word-size+)
          (replace buffer vector :start2 i :end2 end)
          ;; REMAINING-OCTETS is the start of padding.
          (setf (aref buffer remaining-octets) #x80)
          (fill buffer 0 :start (1+ remaining-octets))
          (setq vector buffer i 0))
        ;; BUFFER is used here for the padded portion.
        (%get-binary-byte vector i #.+speck-word-size+))))

(defmacro put-key (key &rest key-words)
  `(setf ,.(loop for i = 0 then (1+ i)
                 for key-word in key-words
                 collect `(svref ,key ,i)
                 collect key-word)))

(defmacro with-speck-input ((data start end buffer) &body body)
  `(let ((,buffer (make-array +speck-word-size+ :element-type 'octet)))
     (unless (vectorp ,data) (setq ,data (coerce ,data '(vector octet))))
     (when (stringp ,data) (setq ,data (encode-string-to-octets ,data)))
     (with-array-data ((,data ,data) (,start 0) ,end) ,@body)))

(defun digest (data)
  "Hirose compression function."
  (with-speck-context (+speck-iv+)
    (with-speck-input (data start end buffer)
      (loop with key = (make-array +speck-key-words+)
            with g.lo of-type speck-word = plaintext.lo
            with g.hi of-type speck-word = plaintext.hi
            with h.lo of-type speck-word = plaintext.lo
            with h.hi of-type speck-word = plaintext.hi
            with virtual-end of-type array-index = (+ end +speck-word-size+)
            for i of-type array-index from start below virtual-end by +speck-word-size+
            do (put-key key (%get-speck-word data i start end buffer) h.lo h.hi)
               (%encrypt-within-speck-context g.lo g.hi key)
               (logxorf ciphertext.lo plaintext.lo)
               (logxorf ciphertext.hi plaintext.hi)
               (setq g.lo ciphertext.lo g.hi ciphertext.hi)
               (logxorf plaintext.lo +speck-word-mask+)
               (logxorf plaintext.hi +speck-word-mask+)
               (logxorf ciphertext.lo plaintext.lo)
               (logxorf ciphertext.hi plaintext.hi)
               (setq h.lo ciphertext.lo h.hi ciphertext.hi)
            ;; Concatenation (H || G) produces a 192-bit integer.
            finally (return (logior (ash g.hi (* 3 +speck-word-bits+))
                                    (ash g.lo (* 2 +speck-word-bits+))
                                    (ash h.hi +speck-word-bits+)

(defun make-key (password &aux (digest (digest password)))
  (let ((byte-position (- +speck-word-bits+)) (byte-size +speck-word-bits+))
    (vector (ldb (byte byte-size (incf byte-position byte-size)) digest)
            (ldb (byte byte-size (incf byte-position byte-size)) digest)
            (ldb (byte byte-size (incf byte-position byte-size)) digest))))

(defmacro %get-cbc-word (vector i start end buffer)
  `(cond ((>= ,i ,end) (if (= ,i ,end) #x80 #x00))
         ((< ,i ,start) (logand (- ,end ,start) +speck-word-mask+))
         (t (%get-speck-word ,vector ,i ,start ,end ,buffer))))

(defun mac (data key)
  "CBC-MAC based on Speck96/144."
  (with-speck-context (#(0 0) key)
    (with-speck-input (data start end buffer)
      (loop with i of-type fixnum = (- start +speck-word-size+)
            when (>= i end) return (logior (ash ciphertext.hi +speck-word-bits+) ciphertext.lo)
            do (setq plaintext.lo (%get-cbc-word data i start end buffer) i (+ i +speck-word-size+))
               (setq plaintext.hi (%get-cbc-word data i start end buffer) i (+ i +speck-word-size+))
               (%encrypt-within-speck-context (logxor plaintext.lo ciphertext.lo)
                                              (logxor plaintext.hi ciphertext.hi))))))