0%

Crypto-剪枝

简单记录一下目前接触到的 Crypto 剪枝算法。

剪枝算法

剪枝算法,形象地说,就是通过某些判断,剪掉一些错误的枝条,从而加快搜索速度的算法。在数据结构中学过,搜索算法一般是基于两种方法 深度优先 DFS 和广度优先 BFS )来进行的,本文主要介绍的是基于深度搜索的剪枝算法。

搜索方式

一般可以采取首尾剪枝。

剪枝条件

需要明确剪枝的条件,把错误的树枝剪掉,把可能正确的树枝留下。

结束条件

知道找到我们需要的正确答案之后,搜索就可以结束了。

具体例子

ps:以下例子都是基于大整数分解问题和异或运算的。

(1)已知 p⊕q

1
2
3
4
5
6
7
8
9
10
from Crypto.Util.number import *
p = getPrime(128)
q = getPrime(128)
n = p*q
xor = p^q
print(f"n = {n}")
print(f"xor = {xor}")

#n = 81273634095521392491945168120330007101677085824054016784875224305683560308213
#xor = 55012774068906519160740720236510369652

已知条件:

搜索方式:

  • 从高位向低位搜索
  • 若xor当前位为1,则可能为两种情况:p为1,q为0 或者 p为0,q为1;反之xor当前位为0,则p为1,q为1 或者 p为0,q为0.

剪枝条件:

  • 将p和q剩下位全部填充为1,需要满足 p*q > n
  • 将p和q剩下位全部填充为0,需要满足 p*q < n

结束条件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
n = 81273634095521392491945168120330007101677085824054016784875224305683560308213
xor = 55012774068906519160740720236510369652
pbits = 128
ph = ''
qh = ''
xor = str(bin(xor)[2:]).zfill(pbits)

def find(ph,qh):
l0 = len(ph)
l1 = len(qh)
tmp0 = ph + '0' * (pbits-l0)
tmp1 = ph + '1' * (pbits-l0)
tmq0 = qh + '0' * (pbits-l1)
tmq1 = qh + '1' * (pbits-l1)
if int(tmp0,2) * int(tmq0,2) > n:#剪枝条件1
return
if int(tmp1,2) * int(tmq1,2) < n:#剪枝条件2
return

if l0 == pbits:#结束条件
if int(ph,2) * int(qh,2) == n:
print(f'p = {int(ph,2)}')
print(f'q = {int(qh,2)}')
return

else:
if xor[l1] == '1':
find(ph+'0',qh+'1')
find(ph + '1',qh+'0')
else:
find(ph+'1',qh+'1')
find(ph + '0',qh+'0')

find(ph,qh)


#运行结果
'''
p = 270451921611135557038833183249275131423
q = 300510470073047693263940829088190906731
p = 300510470073047693263940829088190906731
q = 270451921611135557038833183249275131423
'''

ps:基于深度搜索的思想去理解这个代码。

(2)已知 p⊕(q>>kbits)

1
2
3
4
5
6
7
8
9
10
11
12
from Crypto.Util.number import *
p = getPrime(128)
q = getPrime(128)
n = p*q
kbits = 16
_q = q>>kbits
xor = p^_q
print(f"n = {n}")
print(f"xor = {xor}")

#n = 64562232639256893416069755621246602817297999377249269503641314167726888737493
#xor = 309280967555048700343199196922406211930

已知条件:

搜索方式:

  • 从高位向低位搜索
  • 这种情况,p的高kbits位已知,与xor的高kbits位相同。那搜索就从xor的kbits位开始,即p的第kbits位,q的第1位。
  • 若xor当前位为1,则可能为两种情况:p为1,q为0 或者 p为0,q为1;反之xor当前位为0,则p为1,q为1 或者 p为0,q为0.(这里的p或者q为1指的都是xor当前位对应的p和q的位置)

剪枝条件:

  • 将p和q剩下位全部填充为1,需要满足 p*q > n
  • 将p和q剩下位全部填充为0,需要满足 p*q < n
  • 这里要注意把p的已知的高kbits位加上

结束条件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
n = 63643664048526756484794345795699229288597054138478072698289292748237001669839
xor = 188681898001072579822021026951939419279
kbits = 16
pbits = 128
xor = str(bin(xor)[2:])
ph = xor[:kbits]
qh = ''
xor = xor[kbits:]
qh = ''
def find(ph,qh):
l0 = len(ph)
l1 = len(qh)
tmp0 = ph + '0' * (pbits-l0)
tmp1 = ph + '1' * (pbits-l0)
tmq0 = qh + '0' * (pbits-l1)
tmq1 = qh + '1' * (pbits-l1)
if int(tmp0,2) * int(tmq0,2) > n:#剪枝条件1
return
if int(tmp1,2) * int(tmq1,2) < n:#剪枝条件2
return

if l0 == pbits:#结束条件
if n % int(ph,2) == 0:
print(f'p = {int(ph,2)}')
return

else:
if xor[l1] == '1':
find(ph+'0',qh+'1')
find(ph + '1',qh+'0')
else:
find(ph+'1',qh+'1')
find(ph + '0',qh+'0')
find(ph,qh)


#运行结果
'''
p = 188678698133681304596906537936293804297
'''

(3)已知 p⊕(q>>kbits) 但前kbits位未知

1
2
3
4
5
6
7
8
9
10
11
12
13
from Crypto.Util.number import *
p = getPrime(128)
q = getPrime(128)
n = p*q
kbits = 16
_q = q>>kbits
xor = p^_q
xor = int(bin(xor)[2:][kbits:],2)
print(f"n = {n}")
print(f"xor = {xor}")

#n = 67993063298729224384929426280013061841686812014564261424956366758045322225691
#xor = 4614553526345165212539618883307366

和第(2)种情况相比,只是少了p的高kbits,直接爆破,就和第二种情况一样了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
n = 67993063298729224384929426280013061841686812014564261424956366758045322225691
xor = 4614553526345165212539618883307366
kbits = 16
pbits = 128
xor = str(bin(xor)[2:])

def find(ph,qh):
l0 = len(ph)
l1 = len(qh)
tmp0 = ph + '0' * (pbits-l0)
tmp1 = ph + '1' * (pbits-l0)
tmq0 = qh + '0' * (pbits-l1)
tmq1 = qh + '1' * (pbits-l1)
if int(tmp0,2) * int(tmq0,2) > n:#剪枝条件1
return
if int(tmp1,2) * int(tmq1,2) < n:#剪枝条件2
return

if l0 == pbits:#结束条件
if n % int(ph,2) == 0:
print(f'p = {int(ph,2)}')
return

else:
if xor[l1] == '1':
find(ph+'0',qh+'1')
find(ph + '1',qh+'0')
else:
find(ph+'1',qh+'1')
find(ph + '0',qh+'0')

for i in range(2**kbits):
ph = bin(i)[2:].zfill(kbits)
qh = ''
find(ph,qh)

#运行结果
'''
p = 252330886747767934000827792522692714557
'''