Iris CTF 2024
Baby Charge [50 pts]
For your first assignment, break ChaCha20!
nc babycha.chal.irisc.tf 10100
babycha.tar.gz
We’re given an encryption oracle with two options – 1. encrypt any plaintext we want and 2. encrypt the flag.
We are also given the source code for the service. Here it is:
# https://en.wikipedia.org/wiki/Salsa20#ChaCha20_adoption
from Crypto.Util.number import long_to_bytes, bytes_to_long
import secrets
def ROTL(a, b):
return (((a) << (b)) | ((a % 2**32) >> (32 - (b)))) % 2**32
def qr(x, a, b, c, d):
x[a] += x[b]; x[d] ^= x[a]; x[d] = ROTL(x[d],16)
x[c] += x[d]; x[b] ^= x[c]; x[b] = ROTL(x[b],12)
x[a] += x[b]; x[d] ^= x[a]; x[d] = ROTL(x[d], 8)
x[c] += x[d]; x[b] ^= x[c]; x[b] = ROTL(x[b], 7)
ROUNDS = 20
def chacha_block(inp):
x = list(inp)
for i in range(0, ROUNDS, 2):
qr(x, 0, 4, 8, 12)
qr(x, 1, 5, 9, 13)
qr(x, 2, 6, 10, 14)
qr(x, 3, 7, 11, 15)
qr(x, 0, 5, 10, 15)
qr(x, 1, 6, 11, 12)
qr(x, 2, 7, 8, 13)
qr(x, 3, 4, 9, 14)
return [(a+b) % 2**32 for a, b in zip(x, inp)]
def chacha_init(key, nonce, counter):
assert len(key) == 32
assert len(nonce) == 8
state = [0 for _ in range(16)]
state[0] = bytes_to_long(b"expa"[::-1])
state[1] = bytes_to_long(b"nd 3"[::-1])
state[2] = bytes_to_long(b"2-by"[::-1])
state[3] = bytes_to_long(b"te k"[::-1])
key = bytes_to_long(key)
nonce = bytes_to_long(nonce)
for i in range(8):
state[i+4] = key & 0xffffffff
key >>= 32
state[12] = (counter >> 32) & 0xffffffff
state[13] = counter & 0xffffffff
state[14] = (nonce >> 32) & 0xffffffff
state[15] = nonce & 0xffffffff
return state
state = chacha_init(secrets.token_bytes(32), secrets.token_bytes(8), 0)
buffer = b""
def encrypt(data):
global state, buffer
output = []
for b in data:
if len(buffer) == 0:
buffer = b"".join(long_to_bytes(x).rjust(4, b"\x00") for x in state)
state = chacha_block(state)
output.append(b ^ buffer[0])
buffer = buffer[1:]
return bytes(output)
flag = b"fake_flag{FAKE_FLAG}"
if __name__ == "__main__":
print("""This cipher is approved by Disk Jockey B.
1. Encrypt input
2. Encrypt flag
""")
while True:
inp = input("> ")
match inp:
case '1':
print(encrypt(input("? ").encode()).hex())
case '2':
print(encrypt(flag).hex())
case _:
print("Bye!")
exit()
Most of the program involves carrying out the actual encryption.
After researching ChaCha20 and testing some of the program’s functions in my own local file, I realized a few key things.
- Similar to AES, ChaCha20 shuffles the bytes to make it practically impossible to determine the
state
- ChaCha20 is a stream cipher, so if we can find the
state
, we can decrypt easily - The initialization of ChaCha20 in this program is not random in certain parts
- ChaCha20’s
state
changes in blocks of 64 bytes - ChaCha20’s next
state
is determined solely by the previousstate
.
After a while of trying to figure out how to exploit 3, I realized it was somewhat of a red herring. In fact, this problem is trivial if you focus on the other aspects of this program.
Essentially, since we are provided an encryption oracle, if we send in 64 null bytes to be encrypted, the oracle will return the current state
. Knowing the current state
, it is easy to find the next state by copying the provided code, and thus we can easily XOR it with the ciphertext of the flag and get the flag!
Here is the implementation:
from pwn import *
from Crypto.Util.number import long_to_bytes
import binascii
def ROTL(a, b):
return (((a) << (b)) | ((a % 2**32) >> (32 - (b)))) % 2**32
def qr(x, a, b, c, d):
x[a] += x[b]; x[d] ^= x[a]; x[d] = ROTL(x[d],16)
x[c] += x[d]; x[b] ^= x[c]; x[b] = ROTL(x[b],12)
x[a] += x[b]; x[d] ^= x[a]; x[d] = ROTL(x[d], 8)
x[c] += x[d]; x[b] ^= x[c]; x[b] = ROTL(x[b], 7)
ROUNDS = 20
def chacha_block(inp):
x = list(inp)
for i in range(0, ROUNDS, 2):
qr(x, 0, 4, 8, 12)
qr(x, 1, 5, 9, 13)
qr(x, 2, 6, 10, 14)
qr(x, 3, 7, 11, 15)
qr(x, 0, 5, 10, 15)
qr(x, 1, 6, 11, 12)
qr(x, 2, 7, 8, 13)
qr(x, 3, 4, 9, 14)
return [(a+b) % 2**32 for a, b in zip(x, inp)]
p = remote('babycha.chal.irisc.tf', 10100)
p.sendlineafter(b'> ', b'1')
p.sendlineafter(b'? ', b'\x00'*64)
state_hex = p.recvuntil(b'\n')[:-1]
# print(state_hex)
state = []
for i in range(0, len(state_hex), 8):
state.append(int(state_hex[i:i+8], 16))
state = chacha_block(state)
# print(state)
buffer = b"".join(long_to_bytes(x).rjust(4, b"\x00") for x in state)
p.sendlineafter(b'> ', b'2')
ct = p.recvuntil(b'\n')[:-1]
# print(ct)
ct = binascii.unhexlify(ct)
flag = ''
for i in range(len(ct)):
flag += chr(ct[i] ^ buffer[i])
print(flag)
Running the script gets the flag!
irisctf{initialization_is_no_problem}