# Sage Notebook 2 Demonstrating/Testing Ring-BKW Advanced Keying

Version 1.0, July 10, 2020

Katherine E. Stange

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.


This notebook has been prepared as an accompaniment to the paper Algebraic aspects of solving Ring-LWE, including ring-based improvements in the Blum-Kalai-Wasserman algorithm. See the website http://math.colorado.edu/~kstange/ring-bkw.html. Usage is demonstrated below.


The purpose of this worksheet is twofold: (1) to demonstrate/verify mathematical correctness of the advanced keying algorithm, and (2) to compare some simple runtimes with and without advanced keying.


Evaluate all cells to recreate the runtime experiments in the paper (use seed=1 for the same pseudorandom input samples).

In [1]:
# In case one needs to install sortedcontainers, one can use this command
#import sys
#!{sys.executable} -m pip install sortedcontainers

In [2]:
# needed python libraries
import sortedcontainers
from sortedcontainers import SortedDict
import random
from sage.misc.sage_timeit import sage_timeit
import time

In [3]:
# Basic functions for bit reversal and rotation of samples

# turn an integer into a binary string of nn bits
int_to_bin = lambda x, nn: format(x, 'b').zfill(nn)

# turn a binary string back into an integer
def bin_to_int(b):
 out = 0
 bp = b[::-1]
 for i in range(len(bp)):
 out += ZZ(bp[i])*2^i
 return out

# the bit reversal permutation on nn bits (inversion)
def bit_reverse(k,nn):
 bp = int_to_bin(k,nn)[::-1]
 return bin_to_int(bp)

# change the vector so the first entry is in the range 1 to (q-1)/2
# vectors have entries mod q
def get_sign(vec):
 i = 0
 while vec[i] == 0 and i < len(vec)-1:
 i += 1
 if ZZ(vec[i]) > (q-1)/2:
 return -1
 else:
 return 1
 
def sign_assign(vec):
 if get_sign(vec) == -1:
 return [-vec[_] for _ in range(len(vec))]
 else:
 return vec
 
def is_zero(vec):
 i = 0
 while vec[i] == 0 and i < len(vec)-1:
 i += 1
 if vec[i] == Mod(0,q):
 return True
 else:
 return False
 
# create a class to store the bit-reversal permutation & zeta action
class bitZeta():
 def __init__(self,N,kB,q):
 self.N = N
 self.nn = 2^N
 self.kB = kB
 self.B = 2^kB
 self.tnn = self.nn*2
 self.bit_reversal_lookup = [bit_reverse(ZZ(Mod(i,self.nn)),self.N) for i in range(self.tnn)]
 self.sign_lookup = [ [Mod((-1)^((ZZ(Mod(i-h,self.nn))-(i-h))/self.nn),q) for i in range(self.nn)] for h in range(self.tnn)]
 def position_reverse(self,vec): #swap power <--> prioritized bases
 return [ vec[self.bit_reversal_lookup[i]] for i in range(self.nn) ]
 def zeta_pow(self,vec,h): #apply zeta on power basis
 h = ZZ(Mod(h,self.tnn))
 return [ vec[i-h]*self.sign_lookup[h][i] for i in range(self.nn) ]
 def bzeta(self,vec,h,ii): #apply zeta on prioritized basis, but only bother with first block
 return [ vec[self.bit_reversal_lookup[self.bit_reversal_lookup[i]-h]]*self.sign_lookup[h][self.bit_reversal_lookup[i]] for i in range(self.B*(ii-1),self.B*ii) ]
 def zeta(self,vec,h): #apply zeta on prioritized basis
 return [ vec[self.bit_reversal_lookup[self.bit_reversal_lookup[i]-h]]*self.sign_lookup[h][self.bit_reversal_lookup[i]] for i in range(self.nn) ]

In [4]:
# Traditional and Advanced Keying BKW Reduction algorithms

# Parent class for BKW reduction, with reporting functions
class BKW:
 
 def __init__(self,q,kn,kB,sample_input_list,num_samps,alldiffs=False):
 self.q = q # working mod q
 self.kn = kn # dimension = n = 2^kn 
 self.kB = kB # block size = B = 2^kB
 self.n = 2^self.kn
 self.B = 2^self.kB
 self.nB = self.n/self.B
 self.sample_input_list = sample_input_list
 self.num_samps = num_samps
 self.tables = [SortedDict([]) for _ in range(self.n/self.B+1)] # BKW tables
 self.passcount = 0 # Counts number of times a difference of samples is passed to next table
 self.bit = bitZeta(self.kn,self.kB,self.q) # set up bit reversal / zeta permutations for full samps
 if alldiffs == True:
 self.table_insert_blind = self.table_insert_blind_alldiffs
 else:
 self.table_insert_blind = self.table_insert_blind_onediff
 return
 
 def reduce_final_table(self):
 
 # paring algorithm, removing duplicates from the final table
 # for final sample list, include only one per rotation
 finaltable = SortedDict([])
 for samp in self.tables[self.nB]: # pickle the final table for looping through
 finaltable[samp] = self.tables[self.nB][samp]
 for samp in finaltable: # loop through it
 samp = self.tables[self.nB][samp][0] # pull out the one sample
 start = self.n-self.B
 end = self.n
 # collect list of rotations
 rotations = SortedDict([])
 for j in range(self.B):
 samp1 = self.bit.zeta(samp,j*(self.nB))
 samp1abs = sign_assign(samp1)
 rotations[repr(samp1abs[start:end])] = samp1abs
 # now look for canonical entry among them
 mysamp = rotations.peekitem(0)[1]
 myrep = rotations.peekitem(0)[0]
 # replace the sample with its canonical version
 self.tables[self.nB].pop(repr(samp[start:end]))
 if myrep not in self.tables[self.nB]:
 self.tables[self.nB][myrep] = [mysamp]
 return
 
 def report(self): # Reports the basic facts about the reduction after it has happened
 
 self.reduce_final_table()

 # report number of passes
 print("Number of times a sample was passed to another table:", self.passcount)

 # report table sizes
 totalsizes = 0
 for i in range(1,len(self.tables)):
 print("Table", i, "has", len(self.tables[i]), "entries.")
 if i < len(self.tables)-1:
 totalsizes += len(self.tables[i])
 print("Total stored table rows (not counting final table):", totalsizes)
 
 return
 
 def show_final(self): # print out the final table ()
 for i in range(len(self.tables[self.nB])):
 print(self.tables[self.nB].peekitem(i)[0])
 samplist = self.tables[self.nB].peekitem(i)[1]
 for samp in samplist:
 print(" "+str(samp))

 # table insertion for tradition BKW, passing on one difference only
 def table_insert_blind_onediff(self,samp,i):
 table = self.tables[i]
 samp1 = sign_assign(samp) # multiply by -1 if needed
 # if we have a collision and it's not the last table, pass on the difference
 rep1 = repr(samp1[self.B*(i-1):self.B*i])
 if i < self.nB and rep1 in table:
 tabvec = table[rep1][0] # pull out the one sample from its list
 diff = list(vector(samp1) - vector(tabvec))
 if not is_zero(diff): # if nonzero, send it down
 self.passcount += 1
 self.table_insert_blind(diff,i+1)
 else: # if we have no collision, or it is last table, just store it
 table[rep1] = [samp1] # store as a list of one sample
 return
 
 # table insertion for tradition BKW, passing/keeping all differences
 def table_insert_blind_alldiffs(self,samp,i):
 table = self.tables[i]
 samp1 = sign_assign(samp) # multiply by -1 if needed
 # if we have a collision and it's not the last table, pass on the difference
 rep1 = repr(samp1[self.B*(i-1):self.B*i])
 if i < self.nB and rep1 in table:
 tabvecs = table[rep1] # pull out the sample list at that rep
 for tabvec in tabvecs: # for every difference, pass it down
 diff = list(vector(samp1) - vector(tabvec))
 if not is_zero(diff): # if nonzero, send it down
 self.passcount += 1
 self.table_insert_blind(diff,i+1)
 table[rep1].append(samp1) # store it also
 else: # if we have no collision, or it is last table, just store it
 table[rep1] = [samp1] # store as a list of one sample
 return


# Traditional BKW completely ring blind (using no rotations)
class blind_BKW(BKW):
 def __init__(self,q,kn,kB,sample_input_list,num_samps,alldiffs=False):
 super().__init__(q,kn,kB,sample_input_list,num_samps,alldiffs) 

 def run(self):
 for s in range(self.num_samps): # for each sample, pass to first table
 samp = self.sample_input_list[s]
 self.table_insert_blind(samp,1)
 
 
# Traditional BKW on samples plus their rotations
class trad_BKW(BKW):
 def __init__(self,q,kn,kB,sample_input_list,num_samps,alldiffs=False):
 super().__init__(q,kn,kB,sample_input_list,num_samps,alldiffs) 

 def run(self):
 for s in range(self.num_samps):
 samp = self.sample_input_list[s]
 self.table_insert_blind(samp,1)
 for j in range(1,self.n): # for all rotations of the sample, feed into the first table
 samp1 = self.bit.zeta(samp,j)
 self.table_insert_blind(samp1,1)
 
 
# Advanced Keying BKW
class adv_BKW(BKW):
 def __init__(self,q,kn,kB,sample_input_list,num_samps,alldiffs=False):
 super().__init__(q,kn,kB,sample_input_list,num_samps,alldiffs) 
 if alldiffs == True:
 self.table_insert_adv = self.table_insert_adv_alldiffs
 else:
 self.table_insert_adv = self.table_insert_adv_onediff
 return
 
 def get_rotations(self,samp,i):
 start = self.B*(i-1)
 end = self.B*i
 # collect list of rotations
 rotations = SortedDict([])
 samp1abs = sign_assign(samp)
 repr1abs = repr(samp1abs[start:end])
 if repr1abs not in rotations:
 rotations[repr1abs] = [[samp1abs[start:end],0]] # store which rotation it is
 else:
 rotations[repr1abs].append([samp1abs[start:end],0])
 for j in range(1,self.B):
 samp1 = self.bit.bzeta(samp,j*(self.nB),i) # only compute the rotation on the block
 samp1abs = sign_assign(samp1)
 repr1abs = repr(samp1abs)
 if repr1abs not in rotations:
 rotations[repr1abs] = [[samp1abs,j]]
 else:
 rotations[repr1abs].append([samp1abs,j])
 # now look for canonical entry among them
 myrep = rotations.peekitem(0)[0] # the canonical rep
 mysamps = rotations.peekitem(0)[1] # all associated canonical samples
 # do the full rotations only for the canonical ones
 for sampy in mysamps:
 samp2 = sign_assign(self.bit.zeta(samp,sampy[1]*self.nB))
 sampy.append(samp2) # store full rotation
 return myrep, mysamps
 
 def table_insert_adv_onediff(self,samp,i):
 table = self.tables[i]
 start = self.B*(i-1)
 end = self.B*i
 # if we are not in the last table
 if i < self.nB:
 myrep, mysamps = self.get_rotations(samp,i)
 # go through rotations and deal with them
 if myrep in table: # if a collision
 tabvecs = table[myrep] # list of all vectors already at that row
 mysamp = mysamps[0][2] # full rotation
 for tabvec in tabvecs: # and each of the old samples
 diff = list(vector(mysamp)-vector(tabvec)) # compute the difference
 if not is_zero(diff): # if nonzero, send it down
 self.passcount += 1
 self.table_insert_adv(diff,i+1)
 else: # if not already in the table
 table[myrep] = [mysamps[0][2]] # store the first canonical sample at that row
 else: # if it's just the final table, just store its abs val
 samp1abs = sign_assign(samp)
 table[repr(samp1abs[start:end])] = [samp1abs]
 return
 
 def table_insert_adv_alldiffs(self,samp,i):
 table = self.tables[i]
 start = self.B*(i-1)
 end = self.B*i
 # if we are not in the last table
 if i < self.nB:
 myrep, mysamps = self.get_rotations(samp,i)
 # for the canonical ones, compare and pass on
 if myrep not in table: # if it's a new rep, store it
 table[myrep] = [mysamps[0][2]] # store the first one
 mysamps.pop(0) # drop the first entry
 # at this point there is something in the table
 for mysamp in mysamps: # for each of the new samples
 for tabvec in table[myrep]: # and each of the old samples
 diff = list(vector(mysamp[2])-vector(tabvec)) # compute the difference
 if not is_zero(diff): # if nonzero, send it down
 self.passcount += 1
 self.table_insert_adv(diff,i+1)
 table[myrep].append(mysamp[2]) # store the new samples in the row
 else: # if it's just the final table, just store its abs val
 samp1abs = sign_assign(samp)
 table[repr(samp1abs[start:end])] = [samp1abs]
 return

 def run(self):
 for s in range(self.num_samps):
 samp = self.sample_input_list[s]
 self.table_insert_adv(samp,1)
 for j in range(1,self.nB): # for each sample, rotate by 0,1,...,n/B-1 and pass to first table
 samp1 = self.bit.zeta(samp,j)
 self.table_insert_adv(samp1,1) 
 return
 

In [5]:
# EXPERIMENT Running

blindfalse = None
blindtrue = None
tradfalse = None
tradtrue = None
advfalse = None
advtrue = None

def run_experiment(q,kn,kB,numsamps,alldiffs=False,seed=None,show_final=False):
 
 global blindfalse
 global blindtrue
 global tradfalse
 global tradtrue
 global advfalse
 global advtrue

 # Set up parameters

 # entries are mod q

 # kn = dimension of vectors
 n = 2^kn #length of vectors

 # kB = length of blocks
 B = 2^kB # length of blocks

 # numsamps is number of samples

 # compute initial samples randomly
 if seed != None:
 random.seed(seed)
 sample_number = numsamps*n
 sample_input_list = []
 for i in range(sample_number):
 sample_input_list.append( [Mod(random.randint(0,q),q) for _ in range(n)])

 # Run each test in turn, reporting time it took and results
 
 #Blind False
 print("*************************************")
 print("Running Blind BKW with One Diff")
 
 blindfalse = blind_BKW(q,kn,kB,sample_input_list,numsamps*n,alldiffs=False)
 timed_blindfalse = timeit("blindfalse.run()",number=1,repeat=1)
 print(timed_blindfalse)
 blindfalse.report()
 if show_final == True:
 blindfalse.show_final()

 #Trad False
 print("*************************************")
 print("Running Traditional BKW with One Diff")
 tradfalse = trad_BKW(q,kn,kB,sample_input_list,numsamps,alldiffs=False)
 timed_tradfalse = timeit("tradfalse.run()",number=1,repeat=1)
 print(timed_tradfalse)
 tradfalse.report()
 if show_final == True:
 tradfalse.show_final()
 
 #Advanced False
 print("*************************************")
 print("Running Advanced Keying BKW with One Diff")
 advfalse = adv_BKW(q,kn,kB,sample_input_list,numsamps,alldiffs=False)
 timed_advfalse = timeit("advfalse.run()",number=1,repeat=1)
 print(timed_advfalse)
 advfalse.report()
 if show_final == True:
 advfalse.show_final()
 
 if alldiffs:

 #Blind True
 print("*************************************")
 print("Running Blind BKW with All Diffs")
 blindtrue = blind_BKW(q,kn,kB,sample_input_list,numsamps*n,alldiffs=True)
 timed_blindtrue = timeit("blindtrue.run()",number=1,repeat=1)
 print(timed_blindtrue)
 blindtrue.report()
 if show_final == True:
 blindtrue.show_final()

 #Trad False
 print("*************************************")
 print("Running Traditional BKW with All Diff")
 tradtrue = trad_BKW(q,kn,kB,sample_input_list,numsamps,alldiffs=True)
 timed_tradtrue = timeit("tradtrue.run()",number=1,repeat=1)
 print(timed_tradtrue)
 tradtrue.report()
 if show_final == True:
 tradtrue.show_final()

 #Advanced False
 print("*************************************")
 print("Running Advanced Keying BKW with All Diffs")
 advtrue = adv_BKW(q,kn,kB,sample_input_list,numsamps,alldiffs=True)
 timed_advtrue = timeit("advtrue.run()",number=1,repeat=1)
 print(timed_advtrue)
 advtrue.report()
 if show_final == True:
 advtrue.show_final()

In [6]:
# Set up parameters
############### Experiment 1

# entries are mod q
q = 211

# dimension of vectors
kn = 3
n = 2^kn #length of vectors

# length of blocks
kB = 2
B = 2^kB # length of blocks

# number of samples
numsamps = 4000

# run experiment
run_experiment(q,kn,kB,numsamps,seed=1,alldiffs=True)

*************************************
Running Blind BKW with One Diff
1 loop, best of 1: 1.64 s per loop
Number of times a sample was passed to another table: 1
Table 1 has 31999 entries.
Table 2 has 1 entries.
Total stored table rows (not counting final table): 31999
*************************************
Running Traditional BKW with One Diff
1 loop, best of 1: 1.94 s per loop
Number of times a sample was passed to another table: 4
Table 1 has 31996 entries.
Table 2 has 1 entries.
Total stored table rows (not counting final table): 31996
*************************************
Running Advanced Keying BKW with One Diff
1 loop, best of 1: 2.38 s per loop
Number of times a sample was passed to another table: 1
Table 1 has 7999 entries.
Table 2 has 1 entries.
Total stored table rows (not counting final table): 7999
*************************************
Running Blind BKW with All Diffs
1 loop, best of 1: 1.3 s per loop
Number of times a sample was passed to another table: 1
Table 1 has 31999 

In [7]:
# Set up parameters
############### Experiment 2

# entries are mod q
q = 17

# dimension of vectors
kn = 4
n = 2^kn #length of vectors

# length of blocks
kB = 2
B = 2^kB # length of blocks

# number of samples
numsamps = 2000

# run experiment
run_experiment(q,kn,kB,numsamps,seed=1,alldiffs=True)

*************************************
Running Blind BKW with One Diff
1 loop, best of 1: 7.41 s per loop
Number of times a sample was passed to another table: 11644
Table 1 has 21538 entries.
Table 2 has 9295 entries.
Table 3 has 1152 entries.
Table 4 has 15 entries.
Total stored table rows (not counting final table): 31985
*************************************
Running Traditional BKW with One Diff
1 loop, best of 1: 8.09 s per loop
Number of times a sample was passed to another table: 11811
Table 1 has 21441 entries.
Table 2 has 9319 entries.
Table 3 has 1228 entries.
Table 4 has 3 entries.
Total stored table rows (not counting final table): 31988
*************************************
Running Advanced Keying BKW with One Diff
1 loop, best of 1: 5.26 s per loop
Number of times a sample was passed to another table: 2952
Table 1 has 5361 entries.
Table 2 has 2329 entries.
Table 3 has 307 entries.
Table 4 has 3 entries.
Total stored table rows (not counting final table): 7997
************

In [8]:
# Set up parameters
############### Experiment 3

# entries are mod q
q = 7

# dimension of vectors
kn = 5
n = 2^kn #length of vectors

# length of blocks
kB = 2
B = 2^kB # length of blocks

# number of samples
numsamps = 200

# run experiment
run_experiment(q,kn,kB,numsamps,seed=1,alldiffs=False)

*************************************
Running Blind BKW with One Diff
1 loop, best of 1: 8.74 s per loop
Number of times a sample was passed to another table: 14955
Table 1 has 1181 entries.
Table 2 has 1196 entries.
Table 3 has 1166 entries.
Table 4 has 1105 entries.
Table 5 has 922 entries.
Table 6 has 588 entries.
Table 7 has 210 entries.
Table 8 has 31 entries.
Total stored table rows (not counting final table): 6368
*************************************
Running Traditional BKW with One Diff
1 loop, best of 1: 9.33 s per loop
Number of times a sample was passed to another table: 14989
Table 1 has 1185 entries.
Table 2 has 1184 entries.
Table 3 has 1163 entries.
Table 4 has 1089 entries.
Table 5 has 920 entries.
Table 6 has 636 entries.
Table 7 has 209 entries.
Table 8 has 13 entries.
Total stored table rows (not counting final table): 6386
*************************************
Running Advanced Keying BKW with One Diff
1 loop, best of 1: 4.52 s per loop
Number of times a sample was 

In [9]:
# Set up parameters
############### Experiment 4

# entries are mod q
q = 3

# dimension of vectors
kn = 6
n = 2^kn #length of vectors

# length of blocks
kB = 3
B = 2^kB # length of blocks

# number of samples
numsamps = 250

# run experiment
run_experiment(q,kn,kB,numsamps,seed=1,alldiffs=False)

*************************************
Running Blind BKW with One Diff
1 loop, best of 1: 31.5 s per loop
Number of times a sample was passed to another table: 35782
Table 1 has 2874 entries.
Table 2 has 3201 entries.
Table 3 has 3135 entries.
Table 4 has 2837 entries.
Table 5 has 2324 entries.
Table 6 has 1282 entries.
Table 7 has 335 entries.
Table 8 has 12 entries.
Total stored table rows (not counting final table): 15988
*************************************
Running Traditional BKW with One Diff
1 loop, best of 1: 33.2 s per loop
Number of times a sample was passed to another table: 35888
Table 1 has 2889 entries.
Table 2 has 3182 entries.
Table 3 has 3059 entries.
Table 4 has 2881 entries.
Table 5 has 2346 entries.
Table 6 has 1304 entries.
Table 7 has 332 entries.
Table 8 has 7 entries.
Total stored table rows (not counting final table): 15993
*************************************
Running Advanced Keying BKW with One Diff
1 loop, best of 1: 9.98 s per loop
Number of times a sample