diff --git a/src/main/java/com/lambdaworks/crypto/SCryptUtil.java b/src/main/java/com/lambdaworks/crypto/SCryptUtil.java index ca29a00..9fb85db 100644 --- a/src/main/java/com/lambdaworks/crypto/SCryptUtil.java +++ b/src/main/java/com/lambdaworks/crypto/SCryptUtil.java @@ -2,9 +2,14 @@ package com.lambdaworks.crypto; -import java.io.UnsupportedEncodingException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import static com.lambdaworks.codec.Base64.*; @@ -28,9 +33,11 @@ * @author Will Glozer */ public class SCryptUtil { + /** * Hash the supplied plaintext password and generate output in the format described - * in {@link SCryptUtil}. + * in {@link SCryptUtil}. This method maybe unsafe as it uses Strings, which are not guaranteed to be + * freed immediately. * * @param passwd Password. * @param N CPU cost parameter. @@ -40,29 +47,90 @@ public class SCryptUtil { * @return The hashed password. */ public static String scrypt(String passwd, int N, int r, int p) { + byte[] bytes = passwd.getBytes(StandardCharsets.UTF_8); + try { + return new String(scrypt(bytes, N, r, p), StandardCharsets.UTF_8); + } finally { + wipeArray(bytes); + } + } + + /** + * Hash the supplied plaintext password and generate output in the format described + * in {@link SCryptUtil}. This call will aggressively clean up password data in memory. + * + * @param passwd Password in UTF-8 encoding. + * @param N CPU cost parameter. + * @param r Memory cost parameter. + * @param p Parallelization parameter. + * + * @return The hashed password. + */ + public static byte[] scrypt(byte[] passwordBytes, int N, int r, int p) { try { byte[] salt = new byte[16]; SecureRandom.getInstance("SHA1PRNG").nextBytes(salt); - byte[] derived = SCrypt.scrypt(passwd.getBytes("UTF-8"), salt, N, r, p, 32); - - String params = Long.toString(log2(N) << 16L | r << 8 | p, 16); - - StringBuilder sb = new StringBuilder((salt.length + derived.length) * 2); - sb.append("$s0$").append(params).append('$'); - sb.append(encode(salt)).append('$'); - sb.append(encode(derived)); - - return sb.toString(); - } catch (UnsupportedEncodingException e) { - throw new IllegalStateException("JVM doesn't support UTF-8?"); + byte[] derived = SCrypt.scrypt(passwordBytes, salt, N, r, p, 32); + + byte[] params = Long.toString(log2(N) << 16L | r << 8 | p, 16).getBytes(StandardCharsets.UTF_8); + + byte[] prefix = "$s0$".getBytes(StandardCharsets.UTF_8); + byte[] dollar = "$".getBytes(StandardCharsets.UTF_8); + final char[] charEncodedSalt = encode(salt); + byte[] byteEncodedSalt = toBytes(charEncodedSalt); + wipeArray(charEncodedSalt); + final char[] charEncodedDerived = encode(derived); + byte[] byteEncodedDerived = toBytes(charEncodedDerived); + wipeArray(charEncodedDerived); + + byte[] result = new byte[prefix.length + + params.length + + dollar.length + + byteEncodedSalt.length + + dollar.length + + byteEncodedDerived.length]; + System.arraycopy(prefix, 0, result, 0, prefix.length); + System.arraycopy(params, 0, result, prefix.length, params.length); + System.arraycopy(dollar, 0, result, prefix.length + + params.length, dollar.length); + System.arraycopy(byteEncodedSalt, 0, result, prefix.length + + params.length + + dollar.length, byteEncodedSalt.length); + System.arraycopy(dollar, 0, result, prefix.length + + params.length + + dollar.length + + byteEncodedSalt.length, dollar.length); + System.arraycopy(byteEncodedDerived, 0, result, prefix.length + + params.length + + dollar.length + + byteEncodedSalt.length + + dollar.length, byteEncodedDerived.length); + wipeArray(salt); + wipeArray(derived); + wipeArray(params); + wipeArray(byteEncodedSalt); + wipeArray(byteEncodedDerived); + return result; } catch (GeneralSecurityException e) { throw new IllegalStateException("JVM doesn't support SHA1PRNG or HMAC_SHA256?"); } } + private static byte[] toBytes(char[] chars) { + CharBuffer charBuffer = CharBuffer.wrap(chars); + ByteBuffer byteBuffer = StandardCharsets.UTF_8.encode(charBuffer); + try { + return Arrays.copyOfRange(byteBuffer.array(), byteBuffer.position(), byteBuffer.limit()); + } finally { + wipeArray(byteBuffer.array()); + } + } + /** * Compare the supplied plaintext password to a hashed password. + * This method maybe unsafe as it uses Strings, which are not guaranteed to be + * freed immediatelly. * * @param passwd Plaintext password. * @param hashed scrypt hashed password. @@ -70,22 +138,39 @@ public static String scrypt(String passwd, int N, int r, int p) { * @return true if passwd matches hashed value. */ public static boolean check(String passwd, String hashed) { + return check(passwd.getBytes(StandardCharsets.UTF_8), hashed.getBytes(StandardCharsets.UTF_8)); + } + + /** + * Compare the supplied plaintext password to a hashed password. + * This call will aggressively clean up password data in memory. + * + * @param passwd Plaintext password encoded in UTF-8. + * @param hashed scrypt hashed password encoded in UTF-8. + * + * @return true if passwd matches hashed value. + */ + public static boolean check(byte[] passwordBytes, byte[] hashed) { try { - String[] parts = hashed.split("\\$"); + byte[][] parts = split(hashed, "$".getBytes(StandardCharsets.UTF_8)); - if (parts.length != 5 || !parts[1].equals("s0")) { + if (parts.length != 5 || !Arrays.equals(parts[1], "s0".getBytes(StandardCharsets.UTF_8))) { throw new IllegalArgumentException("Invalid hashed value"); } - long params = Long.parseLong(parts[2], 16); - byte[] salt = decode(parts[3].toCharArray()); - byte[] derived0 = decode(parts[4].toCharArray()); + long params = Long.parseLong(new String(parts[2], StandardCharsets.UTF_8), 16); + final char[] charEncodedSalt = toCharArray(parts[3]); + byte[] salt = decode(charEncodedSalt); + wipeArray(charEncodedSalt); + final char[] charEncodedDerived = toCharArray(parts[4]); + byte[] derived0 = decode(charEncodedDerived); + wipeArray(charEncodedDerived); int N = (int) Math.pow(2, params >> 16 & 0xffff); int r = (int) params >> 8 & 0xff; int p = (int) params & 0xff; - byte[] derived1 = SCrypt.scrypt(passwd.getBytes("UTF-8"), salt, N, r, p, 32); + byte[] derived1 = SCrypt.scrypt(passwordBytes, salt, N, r, p, 32); if (derived0.length != derived1.length) return false; @@ -93,14 +178,53 @@ public static boolean check(String passwd, String hashed) { for (int i = 0; i < derived0.length; i++) { result |= derived0[i] ^ derived1[i]; } + wipeArray(derived0); + wipeArray(derived1); return result == 0; - } catch (UnsupportedEncodingException e) { - throw new IllegalStateException("JVM doesn't support UTF-8?"); } catch (GeneralSecurityException e) { throw new IllegalStateException("JVM doesn't support SHA1PRNG or HMAC_SHA256?"); } } + private static char[] toCharArray(byte[] bs) { + char[] result = new char[bs.length]; + for(int i = 0; i < bs.length; i++) { + result[i] = (char) bs[i]; + } + return result; + } + + /** + * Splites a byte array into chunks delimited by separatorBytes, similar + * to {@link String#split}. + */ + private static byte[][] split(byte[] array, byte[] separatorBytes) { + List result = new ArrayList(); + int lastSplitIndex = 0; + for(int i = 0; i < array.length - separatorBytes.length; i++) + { + boolean found = false; + int j = 0; + while (j < separatorBytes.length && array[i + j] == separatorBytes[j]) + j++; + found = (j == separatorBytes.length); + if(found) { + byte[] substring = new byte[i - lastSplitIndex]; + System.arraycopy(array, lastSplitIndex, substring, 0, i - lastSplitIndex); + result.add(substring); + lastSplitIndex = i + j; + } else if(array.length == i + separatorBytes.length + 1) { + byte[] substring = new byte[i + separatorBytes.length + 1 - lastSplitIndex]; + System.arraycopy(array, lastSplitIndex, substring, 0, i + separatorBytes.length + 1 - lastSplitIndex); + result.add(substring); + } + } + if(result.isEmpty()) { + result.add(array); + } + return result.toArray(new byte[0][]); + } + private static int log2(int n) { int log = 0; if ((n & 0xffff0000 ) != 0) { n >>>= 16; log = 16; } @@ -109,4 +233,13 @@ private static int log2(int n) { if (n >= 4 ) { n >>>= 2; log += 2; } return log + (n >>> 1); } + + private static void wipeArray(byte[] array) { + Arrays.fill(array, (byte) 0); + } + + private static void wipeArray(char[] array) { + Arrays.fill(array, (char) 0); + } + }