diff --git a/src/main/java/com/lambdaworks/crypto/SCryptUtil.java b/src/main/java/com/lambdaworks/crypto/SCryptUtil.java
index ca29a00..9fb85db 100644
--- a/src/main/java/com/lambdaworks/crypto/SCryptUtil.java
+++ b/src/main/java/com/lambdaworks/crypto/SCryptUtil.java
@@ -2,9 +2,14 @@
package com.lambdaworks.crypto;
-import java.io.UnsupportedEncodingException;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
import static com.lambdaworks.codec.Base64.*;
@@ -28,9 +33,11 @@
* @author Will Glozer
*/
public class SCryptUtil {
+
/**
* Hash the supplied plaintext password and generate output in the format described
- * in {@link SCryptUtil}.
+ * in {@link SCryptUtil}. This method maybe unsafe as it uses Strings, which are not guaranteed to be
+ * freed immediately.
*
* @param passwd Password.
* @param N CPU cost parameter.
@@ -40,29 +47,90 @@ public class SCryptUtil {
* @return The hashed password.
*/
public static String scrypt(String passwd, int N, int r, int p) {
+ byte[] bytes = passwd.getBytes(StandardCharsets.UTF_8);
+ try {
+ return new String(scrypt(bytes, N, r, p), StandardCharsets.UTF_8);
+ } finally {
+ wipeArray(bytes);
+ }
+ }
+
+ /**
+ * Hash the supplied plaintext password and generate output in the format described
+ * in {@link SCryptUtil}. This call will aggressively clean up password data in memory.
+ *
+ * @param passwd Password in UTF-8 encoding.
+ * @param N CPU cost parameter.
+ * @param r Memory cost parameter.
+ * @param p Parallelization parameter.
+ *
+ * @return The hashed password.
+ */
+ public static byte[] scrypt(byte[] passwordBytes, int N, int r, int p) {
try {
byte[] salt = new byte[16];
SecureRandom.getInstance("SHA1PRNG").nextBytes(salt);
- byte[] derived = SCrypt.scrypt(passwd.getBytes("UTF-8"), salt, N, r, p, 32);
-
- String params = Long.toString(log2(N) << 16L | r << 8 | p, 16);
-
- StringBuilder sb = new StringBuilder((salt.length + derived.length) * 2);
- sb.append("$s0$").append(params).append('$');
- sb.append(encode(salt)).append('$');
- sb.append(encode(derived));
-
- return sb.toString();
- } catch (UnsupportedEncodingException e) {
- throw new IllegalStateException("JVM doesn't support UTF-8?");
+ byte[] derived = SCrypt.scrypt(passwordBytes, salt, N, r, p, 32);
+
+ byte[] params = Long.toString(log2(N) << 16L | r << 8 | p, 16).getBytes(StandardCharsets.UTF_8);
+
+ byte[] prefix = "$s0$".getBytes(StandardCharsets.UTF_8);
+ byte[] dollar = "$".getBytes(StandardCharsets.UTF_8);
+ final char[] charEncodedSalt = encode(salt);
+ byte[] byteEncodedSalt = toBytes(charEncodedSalt);
+ wipeArray(charEncodedSalt);
+ final char[] charEncodedDerived = encode(derived);
+ byte[] byteEncodedDerived = toBytes(charEncodedDerived);
+ wipeArray(charEncodedDerived);
+
+ byte[] result = new byte[prefix.length
+ + params.length
+ + dollar.length
+ + byteEncodedSalt.length
+ + dollar.length
+ + byteEncodedDerived.length];
+ System.arraycopy(prefix, 0, result, 0, prefix.length);
+ System.arraycopy(params, 0, result, prefix.length, params.length);
+ System.arraycopy(dollar, 0, result, prefix.length
+ + params.length, dollar.length);
+ System.arraycopy(byteEncodedSalt, 0, result, prefix.length
+ + params.length
+ + dollar.length, byteEncodedSalt.length);
+ System.arraycopy(dollar, 0, result, prefix.length
+ + params.length
+ + dollar.length
+ + byteEncodedSalt.length, dollar.length);
+ System.arraycopy(byteEncodedDerived, 0, result, prefix.length
+ + params.length
+ + dollar.length
+ + byteEncodedSalt.length
+ + dollar.length, byteEncodedDerived.length);
+ wipeArray(salt);
+ wipeArray(derived);
+ wipeArray(params);
+ wipeArray(byteEncodedSalt);
+ wipeArray(byteEncodedDerived);
+ return result;
} catch (GeneralSecurityException e) {
throw new IllegalStateException("JVM doesn't support SHA1PRNG or HMAC_SHA256?");
}
}
+ private static byte[] toBytes(char[] chars) {
+ CharBuffer charBuffer = CharBuffer.wrap(chars);
+ ByteBuffer byteBuffer = StandardCharsets.UTF_8.encode(charBuffer);
+ try {
+ return Arrays.copyOfRange(byteBuffer.array(), byteBuffer.position(), byteBuffer.limit());
+ } finally {
+ wipeArray(byteBuffer.array());
+ }
+ }
+
/**
* Compare the supplied plaintext password to a hashed password.
+ * This method maybe unsafe as it uses Strings, which are not guaranteed to be
+ * freed immediatelly.
*
* @param passwd Plaintext password.
* @param hashed scrypt hashed password.
@@ -70,22 +138,39 @@ public static String scrypt(String passwd, int N, int r, int p) {
* @return true if passwd matches hashed value.
*/
public static boolean check(String passwd, String hashed) {
+ return check(passwd.getBytes(StandardCharsets.UTF_8), hashed.getBytes(StandardCharsets.UTF_8));
+ }
+
+ /**
+ * Compare the supplied plaintext password to a hashed password.
+ * This call will aggressively clean up password data in memory.
+ *
+ * @param passwd Plaintext password encoded in UTF-8.
+ * @param hashed scrypt hashed password encoded in UTF-8.
+ *
+ * @return true if passwd matches hashed value.
+ */
+ public static boolean check(byte[] passwordBytes, byte[] hashed) {
try {
- String[] parts = hashed.split("\\$");
+ byte[][] parts = split(hashed, "$".getBytes(StandardCharsets.UTF_8));
- if (parts.length != 5 || !parts[1].equals("s0")) {
+ if (parts.length != 5 || !Arrays.equals(parts[1], "s0".getBytes(StandardCharsets.UTF_8))) {
throw new IllegalArgumentException("Invalid hashed value");
}
- long params = Long.parseLong(parts[2], 16);
- byte[] salt = decode(parts[3].toCharArray());
- byte[] derived0 = decode(parts[4].toCharArray());
+ long params = Long.parseLong(new String(parts[2], StandardCharsets.UTF_8), 16);
+ final char[] charEncodedSalt = toCharArray(parts[3]);
+ byte[] salt = decode(charEncodedSalt);
+ wipeArray(charEncodedSalt);
+ final char[] charEncodedDerived = toCharArray(parts[4]);
+ byte[] derived0 = decode(charEncodedDerived);
+ wipeArray(charEncodedDerived);
int N = (int) Math.pow(2, params >> 16 & 0xffff);
int r = (int) params >> 8 & 0xff;
int p = (int) params & 0xff;
- byte[] derived1 = SCrypt.scrypt(passwd.getBytes("UTF-8"), salt, N, r, p, 32);
+ byte[] derived1 = SCrypt.scrypt(passwordBytes, salt, N, r, p, 32);
if (derived0.length != derived1.length) return false;
@@ -93,14 +178,53 @@ public static boolean check(String passwd, String hashed) {
for (int i = 0; i < derived0.length; i++) {
result |= derived0[i] ^ derived1[i];
}
+ wipeArray(derived0);
+ wipeArray(derived1);
return result == 0;
- } catch (UnsupportedEncodingException e) {
- throw new IllegalStateException("JVM doesn't support UTF-8?");
} catch (GeneralSecurityException e) {
throw new IllegalStateException("JVM doesn't support SHA1PRNG or HMAC_SHA256?");
}
}
+ private static char[] toCharArray(byte[] bs) {
+ char[] result = new char[bs.length];
+ for(int i = 0; i < bs.length; i++) {
+ result[i] = (char) bs[i];
+ }
+ return result;
+ }
+
+ /**
+ * Splites a byte array into chunks delimited by separatorBytes, similar
+ * to {@link String#split}.
+ */
+ private static byte[][] split(byte[] array, byte[] separatorBytes) {
+ List result = new ArrayList();
+ int lastSplitIndex = 0;
+ for(int i = 0; i < array.length - separatorBytes.length; i++)
+ {
+ boolean found = false;
+ int j = 0;
+ while (j < separatorBytes.length && array[i + j] == separatorBytes[j])
+ j++;
+ found = (j == separatorBytes.length);
+ if(found) {
+ byte[] substring = new byte[i - lastSplitIndex];
+ System.arraycopy(array, lastSplitIndex, substring, 0, i - lastSplitIndex);
+ result.add(substring);
+ lastSplitIndex = i + j;
+ } else if(array.length == i + separatorBytes.length + 1) {
+ byte[] substring = new byte[i + separatorBytes.length + 1 - lastSplitIndex];
+ System.arraycopy(array, lastSplitIndex, substring, 0, i + separatorBytes.length + 1 - lastSplitIndex);
+ result.add(substring);
+ }
+ }
+ if(result.isEmpty()) {
+ result.add(array);
+ }
+ return result.toArray(new byte[0][]);
+ }
+
private static int log2(int n) {
int log = 0;
if ((n & 0xffff0000 ) != 0) { n >>>= 16; log = 16; }
@@ -109,4 +233,13 @@ private static int log2(int n) {
if (n >= 4 ) { n >>>= 2; log += 2; }
return log + (n >>> 1);
}
+
+ private static void wipeArray(byte[] array) {
+ Arrays.fill(array, (byte) 0);
+ }
+
+ private static void wipeArray(char[] array) {
+ Arrays.fill(array, (char) 0);
+ }
+
}