longchute

about

A Symmetric Somewhat Homomorphic Encryption Algorithm

05 Apr 2012

NOTE: I originally posted this on Snipplr. The comments have not been copied over.

This is an implementation of a symmetric SWHE from section 3.2 of "Computing Arbitrary Functions of Encrypted Data" by Craig Gentry. It contains a small modification (namely, the addition of a modulus parameter to allow a greater-than-2-element plaintext space). Examples provided illustrate the encryption/decryption of a value, addition and multiplication, the basic AND and XOR gates, and complex gates (circuits) for NOT, OR, NAND, NOR, IF, and RIGHT ROTATE. Note that I'm not a cryptographer, so I can't vouch for the correctness of this. If you find a bug, PLEASE post a comment below. Also, note that this is a toy, not production code: performing too many consecutive operations can easily cause values to exceed machine word size, and it's probably vulnerable to any number of attacks.

NOTE: Remember, this is only a TOY. Do not use in production.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    #!/usr/bin/env python
    
    import random
    
    def keygen(noise, modulus=2):
        a_key = random.getrandbits((noise ** 2))
    
        while ((a_key % 2) != 1) and (a_key < (modulus ** (noise ** 2) - 1)):
            a_key = a_key + 1
    
        return a_key
    
    def encrypt(noise, a_key, a_bit, modulus=2):
        a_mask          = random.getrandbits(noise)
        a_bit_remainder = a_bit % modulus
    
        while ((a_mask % modulus) != a_bit_remainder):
            a_mask = random.getrandbits(noise)
    
        return a_mask + (a_key * random.getrandbits(noise ** 5))
    
    def decrypt(a_key, a_bit, modulus=2):
        return (a_bit % a_key) % modulus
    
    def simple_example():
        modulus         = 32
        noise           = 6
        a_key           = keygen(noise, modulus=modulus)
        a_random_bit    = random.getrandbits(5)
        a_cipher_bit    = encrypt(noise, a_key, a_random_bit, modulus=modulus)
        a_decrypted_bit = decrypt(a_key, a_cipher_bit, modulus=modulus)
    
        print("simple_example()\n----------------")
        print("key: %d\nplaintext: %d\nencrypted: %d\ndecrypted: %d\n\n" % (a_key, a_random_bit, a_cipher_bit, a_decrypted_bit))
    
    def multiplication_example():
        modulus = 16
        noise   = 5
        a_key   = keygen(noise, modulus=modulus)
        a_p     = random.getrandbits(2)
        b_p     = random.getrandbits(2)
        a_c     = encrypt(noise, a_key, a_p, modulus=modulus)
        b_c     = encrypt(noise, a_key, b_p, modulus=modulus)
        c       = a_c * b_c
        d       = decrypt(a_key, c, modulus=modulus)
        print("multiplication_example()\n-------------------------")
        print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
        print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
        print("c: %d\nd: %d\n\n" % (c, d))
    
    def addition_example():
        modulus = 8
        noise   = 4
        a_key   = keygen(noise, modulus=modulus)
        a_p     = random.getrandbits(2)
        b_p     = random.getrandbits(2)
        a_c     = encrypt(noise, a_key, a_p, modulus=modulus)
        b_c     = encrypt(noise, a_key, b_p, modulus=modulus)
        c       = a_c + b_c
        d       = decrypt(a_key, c, modulus=modulus)
        print("addition_example()\n------------------")
        print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
        print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
        print("c: %d\nd: %d\n\n" % (c, d))
        
    def xor_gate():
        noise   = 4
        a_key   = keygen(noise)
        a_p     = random.getrandbits(1)
        b_p     = random.getrandbits(1)
        a_c     = encrypt(noise, a_key, a_p)
        b_c     = encrypt(noise, a_key, b_p)
        c       = a_c + b_c
        d       = decrypt(a_key, c)
        print("xor_gate() (XOR)\n----------------")
        print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
        print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
        print("c: %d\nd: %d\n\n" % (c, d))
        
    def and_gate():
        noise   = 4
        a_key   = keygen(noise)
        a_p     = random.getrandbits(1)
        b_p     = random.getrandbits(1)
        a_c     = encrypt(noise, a_key, a_p)
        b_c     = encrypt(noise, a_key, b_p)
        c       = a_c * b_c
        d       = decrypt(a_key, c)
        print("and_gate() (AND)\n----------------")
        print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
        print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
        print("c: %d\nd: %d\n\n" % (c, d))
        
    def or_gate():
        noise   = 4
        a_key   = keygen(noise)
        a_p     = random.getrandbits(1)
        b_p     = random.getrandbits(1)
        a_c     = encrypt(noise, a_key, a_p)
        b_c     = encrypt(noise, a_key, b_p)
        c       = (a_c * b_c) + (a_c + b_c)
        d       = decrypt(a_key, c)
        print("or_gate() (OR)\n--------------")
        print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
        print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
        print("c: %d\nd: %d\n\n" % (c, d))
    
    def not_gate():
        noise   = 4
        a_key   = keygen(noise)
        a_p     = random.getrandbits(1)
        a_c     = encrypt(noise, a_key, a_p)
        c       = 1 + a_c
        d       = decrypt(a_key, c)
        print("not_gate() (NOT)\n----------------")
        print("a (p): %d\n" % (a_p,))
        print("a (c): %d\n" % (a_c,))
        print("c: %d\nd: %d\n\n" % (c, d))
        
    def nand_gate():
        noise   = 4
        a_key   = keygen(noise)
        a_p     = random.getrandbits(1)
        b_p     = random.getrandbits(1)
        a_c     = encrypt(noise, a_key, a_p)
        b_c     = encrypt(noise, a_key, b_p)
        c       = 1 + (a_c * b_c)
        d       = decrypt(a_key, c)
        print("nand_gate() (NAND)\n------------------")
        print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
        print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
        print("c: %d\nd: %d\n\n" % (c, d))
        
    def nor_gate():
        noise   = 4
        a_key   = keygen(noise)
        a_p     = random.getrandbits(1)
        b_p     = random.getrandbits(1)
        a_c     = encrypt(noise, a_key, a_p)
        b_c     = encrypt(noise, a_key, b_p)
        c       = 1 + ((a_c * b_c) + (a_c + b_c))
        d       = decrypt(a_key, c)
        print("nor_gate() (NOR)\n----------------")
        print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
        print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
        print("c: %d\nd: %d\n\n" % (c, d))
        
    def if_gate():
        noise   = 4
        a_key   = keygen(noise)
        a_p     = random.getrandbits(1)
        a_c     = encrypt(noise, a_key, a_p)
        c       = 1 * a_c
        d       = decrypt(a_key, c)
        print("if_gate() (IF)\n--------------")
        print("a (p): %d\n" % (a_p,))
        print("a (c): %d\n" % (a_c,))
        print("c: %d\nd: %d\n\n" % (c, d))
        
    def right_rotate():
        noise   = 4
        a_key   = keygen(noise)
        
        a_p     = random.getrandbits(1)
        b_p     = random.getrandbits(1)
        c_p     = random.getrandbits(1)
        d_p     = random.getrandbits(1)
    
        a_c     = encrypt(noise, a_key, a_p)
        b_c     = encrypt(noise, a_key, b_p)
        c_c     = encrypt(noise, a_key, c_p)
        d_c     = encrypt(noise, a_key, d_p)
    
        a       = a_c + d_c + a_c
        b       = b_c + a_c + b_c
        c       = c_c + b_c + c_c
        d       = d_c + c_c + d_c
    
        de_a    = decrypt(a_key, a)
        de_b    = decrypt(a_key, b)
        de_c    = decrypt(a_key, c)
        de_d    = decrypt(a_key, d)
    
        print("right_rotate() (division mod 2)\n-------------------------------")
        print("a (p): %d\nb (p): %d\nc (p): %d\nd (p): %d\n" % (a_p, b_p, c_p, d_p))
        print("a (c): %d\nb (c): %d\nc (c): %d\nd (c): %d\n" % (a_c, b_c, c_c, d_c))
        print("a' (c): %d\nb' (c): %d\nc' (c): %d\nd' (c): %d\n" % (a, b, c, d))
        print("a: %d\nb: %d\nc: %d\nd: %d\n\n" % (de_a, de_b, de_c, de_d))
    
    simple_example()
    multiplication_example()
    addition_example()
    xor_gate()
    and_gate()
    or_gate()
    not_gate()
    nand_gate()
    nor_gate()
    if_gate()
    right_rotate()