{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sage Notebook 2 Demonstrating/Testing Ring-BKW Advanced Keying\n", "\n", "Version 1.0, July 10, 2020\n", "\n", "Katherine E. Stange\n", "\n", "This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.\n", "\n", "\n", "This notebook has been prepared as an accompaniment to the paper Algebraic aspects of solving Ring-LWE, including ring-based improvements in the Blum-Kalai-Wasserman algorithm. See the website http://math.colorado.edu/~kstange/ring-bkw.html. Usage is demonstrated below.\n", "\n", "\n", "The purpose of this worksheet is twofold: (1) to demonstrate/verify mathematical correctness of the advanced keying algorithm, and (2) to compare some simple runtimes with and without advanced keying.\n", "\n", "\n", "Evaluate all cells to recreate the runtime experiments in the paper (use seed=1 for the same pseudorandom input samples)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# In case one needs to install sortedcontainers, one can use this command\n", "#import sys\n", "#!{sys.executable} -m pip install sortedcontainers" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# needed python libraries\n", "import sortedcontainers\n", "from sortedcontainers import SortedDict\n", "import random\n", "from sage.misc.sage_timeit import sage_timeit\n", "import time" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Basic functions for bit reversal and rotation of samples\n", "\n", "# turn an integer into a binary string of nn bits\n", "int_to_bin = lambda x, nn: format(x, 'b').zfill(nn)\n", "\n", "# turn a binary string back into an integer\n", "def bin_to_int(b):\n", " out = 0\n", " bp = b[::-1]\n", " for i in range(len(bp)):\n", " out += ZZ(bp[i])*2^i\n", " return out\n", "\n", "# the bit reversal permutation on nn bits (inversion)\n", "def bit_reverse(k,nn):\n", " bp = int_to_bin(k,nn)[::-1]\n", " return bin_to_int(bp)\n", "\n", "# change the vector so the first entry is in the range 1 to (q-1)/2\n", "# vectors have entries mod q\n", "def get_sign(vec):\n", " i = 0\n", " while vec[i] == 0 and i < len(vec)-1:\n", " i += 1\n", " if ZZ(vec[i]) > (q-1)/2:\n", " return -1\n", " else:\n", " return 1\n", " \n", "def sign_assign(vec):\n", " if get_sign(vec) == -1:\n", " return [-vec[_] for _ in range(len(vec))]\n", " else:\n", " return vec\n", " \n", "def is_zero(vec):\n", " i = 0\n", " while vec[i] == 0 and i < len(vec)-1:\n", " i += 1\n", " if vec[i] == Mod(0,q):\n", " return True\n", " else:\n", " return False\n", " \n", "# create a class to store the bit-reversal permutation & zeta action\n", "class bitZeta():\n", " def __init__(self,N,kB,q):\n", " self.N = N\n", " self.nn = 2^N\n", " self.kB = kB\n", " self.B = 2^kB\n", " self.tnn = self.nn*2\n", " self.bit_reversal_lookup = [bit_reverse(ZZ(Mod(i,self.nn)),self.N) for i in range(self.tnn)]\n", " self.sign_lookup = [ [Mod((-1)^((ZZ(Mod(i-h,self.nn))-(i-h))/self.nn),q) for i in range(self.nn)] for h in range(self.tnn)]\n", " def position_reverse(self,vec): #swap power <--> prioritized bases\n", " return [ vec[self.bit_reversal_lookup[i]] for i in range(self.nn) ]\n", " def zeta_pow(self,vec,h): #apply zeta on power basis\n", " h = ZZ(Mod(h,self.tnn))\n", " return [ vec[i-h]*self.sign_lookup[h][i] for i in range(self.nn) ]\n", " def bzeta(self,vec,h,ii): #apply zeta on prioritized basis, but only bother with first block\n", " return [ vec[self.bit_reversal_lookup[self.bit_reversal_lookup[i]-h]]*self.sign_lookup[h][self.bit_reversal_lookup[i]] for i in range(self.B*(ii-1),self.B*ii) ]\n", " def zeta(self,vec,h): #apply zeta on prioritized basis\n", " return [ vec[self.bit_reversal_lookup[self.bit_reversal_lookup[i]-h]]*self.sign_lookup[h][self.bit_reversal_lookup[i]] for i in range(self.nn) ]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Traditional and Advanced Keying BKW Reduction algorithms\n", "\n", "# Parent class for BKW reduction, with reporting functions\n", "class BKW:\n", " \n", " def __init__(self,q,kn,kB,sample_input_list,num_samps,alldiffs=False):\n", " self.q = q # working mod q\n", " self.kn = kn # dimension = n = 2^kn \n", " self.kB = kB # block size = B = 2^kB\n", " self.n = 2^self.kn\n", " self.B = 2^self.kB\n", " self.nB = self.n/self.B\n", " self.sample_input_list = sample_input_list\n", " self.num_samps = num_samps\n", " self.tables = [SortedDict([]) for _ in range(self.n/self.B+1)] # BKW tables\n", " self.passcount = 0 # Counts number of times a difference of samples is passed to next table\n", " self.bit = bitZeta(self.kn,self.kB,self.q) # set up bit reversal / zeta permutations for full samps\n", " if alldiffs == True:\n", " self.table_insert_blind = self.table_insert_blind_alldiffs\n", " else:\n", " self.table_insert_blind = self.table_insert_blind_onediff\n", " return\n", " \n", " def reduce_final_table(self):\n", " \n", " # paring algorithm, removing duplicates from the final table\n", " # for final sample list, include only one per rotation\n", " finaltable = SortedDict([])\n", " for samp in self.tables[self.nB]: # pickle the final table for looping through\n", " finaltable[samp] = self.tables[self.nB][samp]\n", " for samp in finaltable: # loop through it\n", " samp = self.tables[self.nB][samp][0] # pull out the one sample\n", " start = self.n-self.B\n", " end = self.n\n", " # collect list of rotations\n", " rotations = SortedDict([])\n", " for j in range(self.B):\n", " samp1 = self.bit.zeta(samp,j*(self.nB))\n", " samp1abs = sign_assign(samp1)\n", " rotations[repr(samp1abs[start:end])] = samp1abs\n", " # now look for canonical entry among them\n", " mysamp = rotations.peekitem(0)[1]\n", " myrep = rotations.peekitem(0)[0]\n", " # replace the sample with its canonical version\n", " self.tables[self.nB].pop(repr(samp[start:end]))\n", " if myrep not in self.tables[self.nB]:\n", " self.tables[self.nB][myrep] = [mysamp]\n", " return\n", " \n", " def report(self): # Reports the basic facts about the reduction after it has happened\n", " \n", " self.reduce_final_table()\n", "\n", " # report number of passes\n", " print(\"Number of times a sample was passed to another table:\", self.passcount)\n", "\n", " # report table sizes\n", " totalsizes = 0\n", " for i in range(1,len(self.tables)):\n", " print(\"Table\", i, \"has\", len(self.tables[i]), \"entries.\")\n", " if i < len(self.tables)-1:\n", " totalsizes += len(self.tables[i])\n", " print(\"Total stored table rows (not counting final table):\", totalsizes)\n", " \n", " return\n", " \n", " def show_final(self): # print out the final table ()\n", " for i in range(len(self.tables[self.nB])):\n", " print(self.tables[self.nB].peekitem(i)[0])\n", " samplist = self.tables[self.nB].peekitem(i)[1]\n", " for samp in samplist:\n", " print(\" \"+str(samp))\n", "\n", " # table insertion for tradition BKW, passing on one difference only\n", " def table_insert_blind_onediff(self,samp,i):\n", " table = self.tables[i]\n", " samp1 = sign_assign(samp) # multiply by -1 if needed\n", " # if we have a collision and it's not the last table, pass on the difference\n", " rep1 = repr(samp1[self.B*(i-1):self.B*i])\n", " if i < self.nB and rep1 in table:\n", " tabvec = table[rep1][0] # pull out the one sample from its list\n", " diff = list(vector(samp1) - vector(tabvec))\n", " if not is_zero(diff): # if nonzero, send it down\n", " self.passcount += 1\n", " self.table_insert_blind(diff,i+1)\n", " else: # if we have no collision, or it is last table, just store it\n", " table[rep1] = [samp1] # store as a list of one sample\n", " return\n", " \n", " # table insertion for tradition BKW, passing/keeping all differences\n", " def table_insert_blind_alldiffs(self,samp,i):\n", " table = self.tables[i]\n", " samp1 = sign_assign(samp) # multiply by -1 if needed\n", " # if we have a collision and it's not the last table, pass on the difference\n", " rep1 = repr(samp1[self.B*(i-1):self.B*i])\n", " if i < self.nB and rep1 in table:\n", " tabvecs = table[rep1] # pull out the sample list at that rep\n", " for tabvec in tabvecs: # for every difference, pass it down\n", " diff = list(vector(samp1) - vector(tabvec))\n", " if not is_zero(diff): # if nonzero, send it down\n", " self.passcount += 1\n", " self.table_insert_blind(diff,i+1)\n", " table[rep1].append(samp1) # store it also\n", " else: # if we have no collision, or it is last table, just store it\n", " table[rep1] = [samp1] # store as a list of one sample\n", " return\n", "\n", "\n", "# Traditional BKW completely ring blind (using no rotations)\n", "class blind_BKW(BKW):\n", " def __init__(self,q,kn,kB,sample_input_list,num_samps,alldiffs=False):\n", " super().__init__(q,kn,kB,sample_input_list,num_samps,alldiffs) \n", "\n", " def run(self):\n", " for s in range(self.num_samps): # for each sample, pass to first table\n", " samp = self.sample_input_list[s]\n", " self.table_insert_blind(samp,1)\n", " \n", " \n", "# Traditional BKW on samples plus their rotations\n", "class trad_BKW(BKW):\n", " def __init__(self,q,kn,kB,sample_input_list,num_samps,alldiffs=False):\n", " super().__init__(q,kn,kB,sample_input_list,num_samps,alldiffs) \n", "\n", " def run(self):\n", " for s in range(self.num_samps):\n", " samp = self.sample_input_list[s]\n", " self.table_insert_blind(samp,1)\n", " for j in range(1,self.n): # for all rotations of the sample, feed into the first table\n", " samp1 = self.bit.zeta(samp,j)\n", " self.table_insert_blind(samp1,1)\n", " \n", " \n", "# Advanced Keying BKW\n", "class adv_BKW(BKW):\n", " def __init__(self,q,kn,kB,sample_input_list,num_samps,alldiffs=False):\n", " super().__init__(q,kn,kB,sample_input_list,num_samps,alldiffs) \n", " if alldiffs == True:\n", " self.table_insert_adv = self.table_insert_adv_alldiffs\n", " else:\n", " self.table_insert_adv = self.table_insert_adv_onediff\n", " return\n", " \n", " def get_rotations(self,samp,i):\n", " start = self.B*(i-1)\n", " end = self.B*i\n", " # collect list of rotations\n", " rotations = SortedDict([])\n", " samp1abs = sign_assign(samp)\n", " repr1abs = repr(samp1abs[start:end])\n", " if repr1abs not in rotations:\n", " rotations[repr1abs] = [[samp1abs[start:end],0]] # store which rotation it is\n", " else:\n", " rotations[repr1abs].append([samp1abs[start:end],0])\n", " for j in range(1,self.B):\n", " samp1 = self.bit.bzeta(samp,j*(self.nB),i) # only compute the rotation on the block\n", " samp1abs = sign_assign(samp1)\n", " repr1abs = repr(samp1abs)\n", " if repr1abs not in rotations:\n", " rotations[repr1abs] = [[samp1abs,j]]\n", " else:\n", " rotations[repr1abs].append([samp1abs,j])\n", " # now look for canonical entry among them\n", " myrep = rotations.peekitem(0)[0] # the canonical rep\n", " mysamps = rotations.peekitem(0)[1] # all associated canonical samples\n", " # do the full rotations only for the canonical ones\n", " for sampy in mysamps:\n", " samp2 = sign_assign(self.bit.zeta(samp,sampy[1]*self.nB))\n", " sampy.append(samp2) # store full rotation\n", " return myrep, mysamps\n", " \n", " def table_insert_adv_onediff(self,samp,i):\n", " table = self.tables[i]\n", " start = self.B*(i-1)\n", " end = self.B*i\n", " # if we are not in the last table\n", " if i < self.nB:\n", " myrep, mysamps = self.get_rotations(samp,i)\n", " # go through rotations and deal with them\n", " if myrep in table: # if a collision\n", " tabvecs = table[myrep] # list of all vectors already at that row\n", " mysamp = mysamps[0][2] # full rotation\n", " for tabvec in tabvecs: # and each of the old samples\n", " diff = list(vector(mysamp)-vector(tabvec)) # compute the difference\n", " if not is_zero(diff): # if nonzero, send it down\n", " self.passcount += 1\n", " self.table_insert_adv(diff,i+1)\n", " else: # if not already in the table\n", " table[myrep] = [mysamps[0][2]] # store the first canonical sample at that row\n", " else: # if it's just the final table, just store its abs val\n", " samp1abs = sign_assign(samp)\n", " table[repr(samp1abs[start:end])] = [samp1abs]\n", " return\n", " \n", " def table_insert_adv_alldiffs(self,samp,i):\n", " table = self.tables[i]\n", " start = self.B*(i-1)\n", " end = self.B*i\n", " # if we are not in the last table\n", " if i < self.nB:\n", " myrep, mysamps = self.get_rotations(samp,i)\n", " # for the canonical ones, compare and pass on\n", " if myrep not in table: # if it's a new rep, store it\n", " table[myrep] = [mysamps[0][2]] # store the first one\n", " mysamps.pop(0) # drop the first entry\n", " # at this point there is something in the table\n", " for mysamp in mysamps: # for each of the new samples\n", " for tabvec in table[myrep]: # and each of the old samples\n", " diff = list(vector(mysamp[2])-vector(tabvec)) # compute the difference\n", " if not is_zero(diff): # if nonzero, send it down\n", " self.passcount += 1\n", " self.table_insert_adv(diff,i+1)\n", " table[myrep].append(mysamp[2]) # store the new samples in the row\n", " else: # if it's just the final table, just store its abs val\n", " samp1abs = sign_assign(samp)\n", " table[repr(samp1abs[start:end])] = [samp1abs]\n", " return\n", "\n", " def run(self):\n", " for s in range(self.num_samps):\n", " samp = self.sample_input_list[s]\n", " self.table_insert_adv(samp,1)\n", " for j in range(1,self.nB): # for each sample, rotate by 0,1,...,n/B-1 and pass to first table\n", " samp1 = self.bit.zeta(samp,j)\n", " self.table_insert_adv(samp1,1) \n", " return\n", " " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# EXPERIMENT Running\n", "\n", "blindfalse = None\n", "blindtrue = None\n", "tradfalse = None\n", "tradtrue = None\n", "advfalse = None\n", "advtrue = None\n", "\n", "def run_experiment(q,kn,kB,numsamps,alldiffs=False,seed=None,show_final=False):\n", " \n", " global blindfalse\n", " global blindtrue\n", " global tradfalse\n", " global tradtrue\n", " global advfalse\n", " global advtrue\n", "\n", " # Set up parameters\n", "\n", " # entries are mod q\n", "\n", " # kn = dimension of vectors\n", " n = 2^kn #length of vectors\n", "\n", " # kB = length of blocks\n", " B = 2^kB # length of blocks\n", "\n", " # numsamps is number of samples\n", "\n", " # compute initial samples randomly\n", " if seed != None:\n", " random.seed(seed)\n", " sample_number = numsamps*n\n", " sample_input_list = []\n", " for i in range(sample_number):\n", " sample_input_list.append( [Mod(random.randint(0,q),q) for _ in range(n)])\n", "\n", " # # Set up parameters
############### Experiment 1

# entries are mod q
q = 211

# dimension of vectors
kn = 3
n = 2^kn #length of vectors

# length of blocks
kB = 2
B = 2^kB # number of samples
numsamps = 4000

# run experiment
run_experiment(q,kn,kB,numsamps,seed=1,alldiffs=True) # number of samples
numsamps = 2000

# run experiment
run_experiment(q,kn,kB,numsamps,seed=1,alldiffs=True) # number of samples
numsamps = 200

# run experiment
run_experiment(q,kn,kB,numsamps,seed=1,alldiffs=False) # number of samples
numsamps = 250

# run experiment
run_experiment(q,kn,kB,numsamps,seed=1,alldiffs=False)