|
| 1 | +import os |
| 2 | +import boto3 |
| 3 | +import json |
| 4 | +import time |
| 5 | +from boto3.dynamodb.conditions import Key, Attr |
| 6 | +import StringIO |
| 7 | +import gzip |
| 8 | +import random |
| 9 | +import socket,struct |
| 10 | + |
| 11 | +# 337902548806_CloudTrail_us-east-1_20190322T0500Z_auiiBXjcCAI9OYst.json.gz |
| 12 | + |
| 13 | +# Set up table resources from the env vars. |
| 14 | +assumedRoleStateTableName = os.environ['assumedRoleStateTableName'] |
| 15 | +roleExceptionsTableName = os.environ['roleExceptionsTableName'] |
| 16 | +exfilAlertLogGroup = os.environ['exfilAlertLogGroup'] |
| 17 | +dydbResource = boto3.resource("dynamodb") |
| 18 | +sessionsTable = dydbResource.Table(assumedRoleStateTableName) |
| 19 | +exceptionsTable = dydbResource.Table(roleExceptionsTableName) |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | +""" Thanks StackExchange for helping me avoid creating a deployment package!""" |
| 24 | +""" https://stackoverflow.com/questions/819355/how-can-i-check-if-an-ip-is-in-a-network-in-python""" |
| 25 | +def addressInNetwork(ip,net): |
| 26 | + ipaddr = struct.unpack('>L',socket.inet_aton(ip))[0] |
| 27 | + netaddr,bits = net.split('/') |
| 28 | + netmask = struct.unpack('>L',socket.inet_aton(netaddr))[0] |
| 29 | + ipaddr_masked = ipaddr & (4294967295<<(32-int(bits))) # Logical AND of IP address and mask will equal the network address if it matches |
| 30 | + if netmask == netmask & (4294967295<<(32-int(bits))): # Validate network address is valid for mask |
| 31 | + return ipaddr_masked == netmask |
| 32 | + else: |
| 33 | + print "***WARNING*** Network",netaddr,"not valid with mask /"+bits |
| 34 | + return ipaddr_masked == netmask |
| 35 | + |
| 36 | + |
| 37 | +def isWhitelisted(RoleArn, SourceIp): |
| 38 | + roleWhitelistResponse = exceptionsTable.get_item( |
| 39 | + Key={ |
| 40 | + 'roleArn': RoleArn, |
| 41 | + } |
| 42 | + ) |
| 43 | + if "Item" not in roleWhitelistResponse.keys(): |
| 44 | + # No whitelist configured for this role |
| 45 | + return False |
| 46 | + whitelistResponse = roleWhitelistResponse['Item'] |
| 47 | + whitelist = whitelistResponse['whitelist'] |
| 48 | + for net in whitelist: |
| 49 | + if addressInNetwork(str(SourceIp), str(net)): |
| 50 | + # print "%s would have fired an alert, but is in the whitelist %s" % (SourceIp, net) |
| 51 | + return True |
| 52 | + return False |
| 53 | + |
| 54 | +def retrieveJsonBodyFromS3(obj): |
| 55 | + """ |
| 56 | + Given a tuple (bucketname, key) |
| 57 | + Create an S3 client, retrieve the contents |
| 58 | + Unzip the body, and return as a python object |
| 59 | + """ |
| 60 | + s3Client = boto3.client("s3") |
| 61 | + s3_file = s3Client.get_object( |
| 62 | + Bucket=obj[0], |
| 63 | + Key=obj[1] |
| 64 | + ) |
| 65 | + body = s3_file['Body'] |
| 66 | + compressedFile = StringIO.StringIO(body.read()) |
| 67 | + decompressedFile = gzip.GzipFile(fileobj=compressedFile) |
| 68 | + jsonBody = json.loads(decompressedFile.read()) |
| 69 | + return jsonBody |
| 70 | + |
| 71 | + |
| 72 | + |
| 73 | +def createSessionState(RoleArn, SessionId, SourceIp="0"): |
| 74 | + # Set ttl 6 hours in the future |
| 75 | + print "Creating initial session state for %s" % RoleArn |
| 76 | + ttl = int(time.time() + 6 * 60 * 60) |
| 77 | + item = { |
| 78 | + "sessionId": SessionId, |
| 79 | + "sourceIp": SourceIp, |
| 80 | + "roleArn": RoleArn, |
| 81 | + "ttl": ttl |
| 82 | + } |
| 83 | + resp = sessionsTable.put_item( |
| 84 | + Item=item |
| 85 | + ) |
| 86 | + |
| 87 | +def recordSuspiciousEvent(event): |
| 88 | + cwClient = boto3.client("logs") |
| 89 | + logStreamName = ''.join(random.choice('0123456789ABCDEF') for i in range(16)) |
| 90 | + |
| 91 | + t = int(round(time.time() * 1000)) |
| 92 | + logEvent = { |
| 93 | + "timestamp": t, |
| 94 | + "message": event |
| 95 | + } |
| 96 | + cwClient.create_log_stream( |
| 97 | + logGroupName = exfilAlertLogGroup, |
| 98 | + logStreamName = logStreamName |
| 99 | + ) |
| 100 | + cwClient.put_log_events( |
| 101 | + logGroupName = exfilAlertLogGroup, |
| 102 | + logStreamName = logStreamName, |
| 103 | + logEvents = [logEvent] |
| 104 | + ) |
| 105 | + print "Wrote suspicious event to %s" % logStreamName |
| 106 | + |
| 107 | +def analyzeNonAssumeRecord(SessionId, SourceIp): |
| 108 | + """ |
| 109 | + Check CloudTrail events originating from assumed roles. |
| 110 | + If no SourceIP has been previously recorded for the session, record it. |
| 111 | + If the SourceIP doesn't match a previously recorded one... |
| 112 | + We may have a credential exfil, return info on the session to be used in an alert body. |
| 113 | + """ |
| 114 | + if not SessionId.startswith("i-"): |
| 115 | + # Not an EC2 assumed role session |
| 116 | + return |
| 117 | + |
| 118 | + sessionResponse = sessionsTable.get_item( |
| 119 | + Key={ |
| 120 | + 'sessionId': SessionId, |
| 121 | + } |
| 122 | + ) |
| 123 | + if "Item" not in sessionResponse.keys(): |
| 124 | + # No recorded session found for this session. |
| 125 | + # Was Exfil detections started less than 6 hours ago? |
| 126 | + print "No session found for %s (%s)" % (SessionId, SourceIp) |
| 127 | + pass |
| 128 | + else: |
| 129 | + session = sessionResponse['Item'] |
| 130 | + roleArn = session['roleArn'] |
| 131 | + previousSourceIp = session['sourceIp'] |
| 132 | + if previousSourceIp == "0": |
| 133 | + # First time seeing this session used since created. |
| 134 | + # Record the source IP |
| 135 | + print "Recording IP for existing session" |
| 136 | + createSessionState(roleArn, SessionId, SourceIp) |
| 137 | + return |
| 138 | + elif previousSourceIp == SourceIp: |
| 139 | + print "Identified activity from AssumedRole with the same as previously identified IP (%s)" % previousSourceIp |
| 140 | + return |
| 141 | + else: |
| 142 | + print "Suspicious behavior here. Send back the original session info" |
| 143 | + return {"roleArn": roleArn, "sourceIp": previousSourceIp} |
| 144 | + |
| 145 | + |
| 146 | +def assessCloudtrailEventRecord(event): |
| 147 | + # Identify EC2 AssumeRoles and record new sessions or inspect calls made by AssumedRoles |
| 148 | + # TODO: Store VPC Endpoint ID if relevant |
| 149 | + |
| 150 | + if (event['eventName'] == "AssumeRole" and |
| 151 | + event['sourceIPAddress'] == "ec2.amazonaws.com" and |
| 152 | + event['eventSource'] == "sts.amazonaws.com"): |
| 153 | + # Fresh EC2 AssumeRole |
| 154 | + # Record the session |
| 155 | + sessionId = event['requestParameters']['roleSessionName'] |
| 156 | + roleArn = event['requestParameters']['roleArn'] |
| 157 | + createSessionState(roleArn, sessionId) |
| 158 | + elif event['userIdentity']['type'] == "AssumedRole": |
| 159 | + sessionId = event['userIdentity']['arn'].split('/')[-1] |
| 160 | + sourceIp = event['sourceIPAddress'] |
| 161 | + violation = analyzeNonAssumeRecord(sessionId, sourceIp) |
| 162 | + if violation is not None: |
| 163 | + # TODO: Check the exceptions table |
| 164 | + if not isWhitelisted(violation['roleArn'], sourceIp): |
| 165 | + alert = {} |
| 166 | + alert['originalSessionInfo'] = violation |
| 167 | + alert['potentialImposterSourceIp'] = sourceIp |
| 168 | + alert['alertMessage'] = "EC2 credentials previously associated with an IP have been used from a source other than the original. This is indicative of instance compromise and credential exfiltration." |
| 169 | + message = json.dumps(alert) |
| 170 | + recordSuspiciousEvent(message) |
| 171 | + |
| 172 | + |
| 173 | + |
| 174 | +def extract_s3file_from_sns_event(event): |
| 175 | + """ |
| 176 | + Unwrap S3 Records from SNS Records |
| 177 | + Return a list of tuples (bucketname, key) |
| 178 | + """ |
| 179 | + |
| 180 | + s3Files = [] |
| 181 | + for snsRecord in event['Records']: |
| 182 | + s3Body = json.loads(snsRecord['Sns']['Message']) |
| 183 | + for s3Record in s3Body['Records']: |
| 184 | + bucketName = s3Record['s3']['bucket']['name'] |
| 185 | + key = s3Record['s3']['object']['key'] |
| 186 | + s3Files.append((bucketName, key)) |
| 187 | + return s3Files |
| 188 | + |
| 189 | +def lambda_handler(event, context): |
| 190 | + cloudtrailFiles = extract_s3file_from_sns_event(event) |
| 191 | + for cloudtrailFile in cloudtrailFiles: |
| 192 | + if not "CloudTrail-Digest" in cloudtrailFile[1]: |
| 193 | + # Don't run digests through the assessment. |
| 194 | + print "Collecting and assessing %s" % cloudtrailFile[1] |
| 195 | + cloudtrailBody = retrieveJsonBodyFromS3(cloudtrailFile) |
| 196 | + for cloudtrailEvent in cloudtrailBody['Records']: |
| 197 | + alerts = assessCloudtrailEventRecord( |
| 198 | + cloudtrailEvent |
| 199 | + ) |
0 commit comments