keys.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright(c) 2018 STMicroelectronics International N.V.
  2. # Copyright 2017 Linaro Limited
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from Cryptodome.Cipher import AES
  17. from Cryptodome.Hash import SHA256
  18. from ecdsa import SigningKey, NIST256p, util
  19. import hashlib
  20. from struct import pack
  21. import translate_key
  22. import numpy
  23. #for AES_CBC lambda pad to 16 bytes by adding the padded value
  24. #(i.e 24 bytes : 0x08 is added 8 times
  25. BS = 16
  26. pad = lambda s: s + (BS - len(s) % BS) * pack("B", 0)
  27. class AES_CBC():
  28. def __init__(self, key):
  29. """Construct an AES_CBC private key with the given key data"""
  30. self.key = key
  31. self.nonce = []
  32. @staticmethod
  33. def generate():
  34. #use random from platform
  35. return AES_CBC(os.urandom(16))
  36. def export_private(self, path):
  37. if "AES_CBC" not in path:
  38. print("path does not contains AES_CBC : AES_CBC key should contain AES_CBC string!!!")
  39. exit(1)
  40. else:
  41. with open(path, 'wb') as f:
  42. f.write(self.key)
  43. def encrypt(self, payload, nonce=[]):
  44. if payload == []:
  45. print("error")
  46. #Fix me AES CBC is possibly 12 bytes
  47. if nonce == []:
  48. nonce = os.urandom(16)
  49. m = hashlib.sha256()
  50. print("block size ="+str(AES.block_size))
  51. encryptor = AES.new(self.key, AES.MODE_CBC, nonce)
  52. encrypted = ""
  53. #check if buffer size is aligned on BS size
  54. if (0 == (len(payload) % BS)):
  55. # we do not need to pad
  56. buffer=payload
  57. encrypted = encryptor.encrypt(buffer)
  58. else:
  59. # Buffer size is not correct (and we do not support ciphertext stealing mode "CBC-CS2" specified in NIST SP 800-38A any more)
  60. raise Exception("AES CBC encryption requires the Firmware Image size to be a multiple of the AES block size (16 bytes)")
  61. #compute sh256 on clear buffer without padding
  62. m.update(payload)
  63. signature = m.digest()
  64. #swap the last two block and truncate if required
  65. return encrypted,signature, nonce
  66. def trans(self,section, name, end, assembly, version):
  67. outcode = translate_key.function(section, name,assembly)
  68. outcode += translate_key.translate(self.key,end,assembly, version)
  69. return outcode
  70. def has_nonce(self):
  71. return True
  72. def has_sign(self):
  73. return False
  74. def has_encrypt(self):
  75. return True
  76. def get_key(self, type):
  77. return self.key
  78. class AES_CTR():
  79. def __init__(self, key):
  80. """Construct an AES_CTR private key with the given key data"""
  81. self.key = key
  82. self.nonce = []
  83. @staticmethod
  84. def generate():
  85. #use random from platform
  86. return AES_CTR(os.urandom(16))
  87. def export_private(self, path):
  88. if "AES_CTR" not in path:
  89. print("path does not contains AES_CTR : AES_CTR key should contain AES_CTR string!!!")
  90. exit(1)
  91. else:
  92. with open(path, 'wb') as f:
  93. f.write(self.key)
  94. def encrypt(self, payload, address, nonce=[]):
  95. if payload == []:
  96. print("error")
  97. m = hashlib.sha256()
  98. #Swap bytes inside 16 bytes block
  99. inarr = numpy.asarray(list(payload), numpy.int8).reshape(-1, 16)
  100. outarr = numpy.fliplr(inarr)
  101. payload = bytearray(outarr)
  102. print("block size ="+str(AES.block_size))
  103. #Encryption
  104. if nonce == []:
  105. encryptor = AES.new(self.key, AES.MODE_CTR, initial_value=address);
  106. else:
  107. encryptor = AES.new(self.key, AES.MODE_CTR, nonce=nonce, initial_value=address);
  108. encrypted = ""
  109. #check if buffer size is aligned on BS size
  110. if (0 == (len(payload) % BS)):
  111. # we do not need to pad
  112. buffer=payload
  113. encrypted = encryptor.encrypt(buffer)
  114. nonce = encryptor.nonce
  115. else:
  116. raise Exception("AES CTR encryption requires the Firmware Image size to be a multiple of the AES block size (16 bytes)")
  117. #Swap bytes inside 16 bytes block
  118. inarr = numpy.asarray(list(encrypted), numpy.int8).reshape(-1, 16)
  119. outarr = numpy.fliplr(inarr)
  120. encrypted = bytearray(outarr)
  121. #compute sh256 on encrypted buffer without padding
  122. m.update(encrypted)
  123. signature = m.digest()
  124. #swap the last two block and truncate if required
  125. return encrypted,signature,nonce
  126. def trans(self,section, name, end, assembly, version):
  127. outcode = translate_key.function(section, name,assembly)
  128. outcode += translate_key.translate(self.key,end,assembly, version)
  129. return outcode
  130. def has_nonce(self):
  131. return True
  132. def has_sign(self):
  133. return False
  134. def has_encrypt(self):
  135. return True
  136. def get_key(self, type):
  137. return self.key
  138. class AES_GCM():
  139. def __init__(self, key):
  140. """Construct an AES_GCM private key with the given key data"""
  141. self.key = key
  142. self.nonce = []
  143. @staticmethod
  144. def generate():
  145. #use random from platform
  146. return AES_GCM(os.urandom(16))
  147. def export_private(self, path):
  148. if "AES_CBC" in path:
  149. print("path contains AES_CBC : AES_GCM key should not contain AES_CBC!!!")
  150. exit(1)
  151. else:
  152. with open(path, 'wb') as f:
  153. f.write(self.key)
  154. def encrypt(self, payload, nonce=[]):
  155. if payload == []:
  156. print("error")
  157. if nonce == []:
  158. nonce = os.urandom(12)
  159. encryptor = AES.new(self.key, AES.MODE_GCM, nonce)
  160. encrypted = encryptor.encrypt(payload)
  161. signature = encryptor.digest()
  162. return encrypted,signature, nonce
  163. def sign(self,payload, nonce):
  164. encryptor = AES.new(self.key, AES.MODE_GCM, nonce)
  165. encryptor.update(payload)
  166. signature = encryptor.digest()
  167. return signature, nonce
  168. def trans(self,section, name, end, assembly, version):
  169. outcode = translate_key.function(section, name,assembly)
  170. outcode += translate_key.translate(self.key,end,assembly, version)
  171. return outcode
  172. def has_nonce(self):
  173. return True
  174. def has_sign(self):
  175. return True
  176. def has_encrypt(self):
  177. return True
  178. def get_key(self, type):
  179. return self.key
  180. class ECDSA256P1():
  181. def __init__(self, key):
  182. """Construct an ECDSA P-256 private key"""
  183. self.key = key
  184. @staticmethod
  185. def generate():
  186. return ECDSA256P1(SigningKey.generate(curve=NIST256p))
  187. def export_private(self, path):
  188. with open(path, 'wb') as f:
  189. f.write(self.key.to_pem())
  190. def trans(self,section, name, end, assembly, version):
  191. vk = self.key.get_verifying_key()
  192. binarykey = vk.to_string()
  193. #generate asm code
  194. outcode = translate_key.function(section, name,assembly)
  195. outcode += translate_key.translate(binarykey,end,assembly, version )
  196. return outcode
  197. def sign(self, payload):
  198. # To make this fixed length, possibly pad with zeros.
  199. sig = self.key.sign(payload, hashfunc=hashlib.sha256)
  200. return sig
  201. def has_nonce(self):
  202. return False
  203. def has_sign(self):
  204. return True
  205. def has_encrypt(self):
  206. return False
  207. def get_key(self, type):
  208. if (type == "public"):
  209. vk = self.key.get_verifying_key()
  210. return vk.to_string()
  211. else:
  212. return self.key.to_pem()
  213. class PAIRING():
  214. def __init__(self, key):
  215. """Construct an PAIRING private key with the given key data"""
  216. self.key = key
  217. self.nonce = []
  218. @staticmethod
  219. def generate():
  220. print("Pairing class unsupported !!")
  221. exit(1)
  222. def export_private(self, path):
  223. print("Pairing class unsupported !!")
  224. exit(1)
  225. def encrypt(self, payload, nonce=[]):
  226. print("Pairing class unsupported !!")
  227. exit(1)
  228. def sign(self,payload, nonce):
  229. print("Pairing class unsupported !!")
  230. exit(1)
  231. def trans(self,section, name, end, assembly, version):
  232. outcode = translate_key.function(section, name,assembly)
  233. outcode += translate_key.translate(self.key,end,assembly, version)
  234. return outcode
  235. def has_nonce(self):
  236. print("Pairing class unsupported !!")
  237. exit(1)
  238. def has_sign(self):
  239. print("Pairing class unsupported !!")
  240. exit(1)
  241. def has_encrypt(self):
  242. print("Pairing class unsupported !!")
  243. exit(1)
  244. def get_key(self, type):
  245. return self.key
  246. def load(path):
  247. with open(path, 'rb') as f:
  248. pem = f.read()
  249. if len(pem) == 16:
  250. if "AES_CBC" in path:
  251. return AES_CBC(pem)
  252. elif "AES_CTR" in path:
  253. return AES_CTR(pem)
  254. else:
  255. return AES_GCM(pem)
  256. elif len(pem) == 32:
  257. return PAIRING(pem)
  258. else:
  259. key = SigningKey.from_pem(pem)
  260. if key.curve.name == 'NIST256p':
  261. return ECDSA256P1(key)
  262. else:
  263. raise Exception("Unsupported")