Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrent SNP_GET_REPORT requests result in reports with wrong user data #265

Open
Trundle opened this issue Feb 27, 2025 · 1 comment
Open

Comments

@Trundle
Copy link

Trundle commented Feb 27, 2025

This is due to a bug in the sev-guest device. In commit "virt: sev-guest: Reduce the scope of SNP command mutex", the command mutex was narrowed to snp_send_guest_request. The reasoning was that all shared state is handled in this function. SNP_GET_REPORT requests share a single buffer for the report's user data though: https://github.com/torvalds/linux/blob/1e15510b71c99c6e49134d756df91069f7d18141/drivers/virt/coco/sev-guest/sev-guest.c#L74

Writes to this buffer are now no longer guarded by the mutex: https://github.com/torvalds/linux/blob/1e15510b71c99c6e49134d756df91069f7d18141/drivers/virt/coco/sev-guest/sev-guest.c#L83

This means concurrent SNP_GET_REPORT requests can now overwrite each other's data.

The following Python script demonstrates the issue. It requests attestation reports with either 1s or 2s as user data and only prints something if the report's data doesn't match the requested data.

import ctypes
import fcntl
import os
from threading import Thread

GET_REPORT = 0xffffffffc0205300
DATA_SIZE = 64
REPORT_DATA_IDX = 0x70

def test(n):
    report_req = (ctypes.c_char * (DATA_SIZE + 4 + 28))()
    data = bytes([n] * DATA_SIZE)
    report_req[:DATA_SIZE] = data
    report_resp = (ctypes.c_char * 4096)()
    req = (
        b"\1" + b"\0" * 7
        + ctypes.addressof(report_req).to_bytes(8, byteorder="little")
        + ctypes.addressof(report_resp).to_bytes(8, byteorder="little")
        + b"\0" * 8
    )
    fd = os.open("/dev/sev-guest", os.O_RDONLY)

    for _ in range(128):
        fcntl.ioctl(fd, GET_REPORT, req, True)
        if report_resp[REPORT_DATA_IDX:REPORT_DATA_IDX + DATA_SIZE] != data:
            print(report_resp[REPORT_DATA_IDX:REPORT_DATA_IDX + DATA_SIZE])


def main():
    t1 = Thread(target=test, args=(1, ))
    t1.start()

    t2 = Thread(target=test, args=(2, ))
    t2.start()

    t2.join()

main()

It should never output anything, but running this in a VM with Linux 6.13 prints

b'\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x02\x02\x02\x02\x02\x02\x02\x02\x02\x02\x02\x02\x02\x02\x02\x02'
...

Note in particular the mixed \x01 and \x02.

@tlendacky
Copy link
Collaborator

Adding @nikunjad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants