import static javax.crypto.Cipher.DECRYPT_MODE;
import static javax.crypto.Cipher.ENCRYPT_MODE;
import static javax.crypto.Cipher.SECRET_KEY;
import static javax.crypto.Cipher.UNWRAP_MODE;
import static javax.crypto.Cipher.WRAP_MODE;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Base64;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.KeyGenerator;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;

import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

public final class Encrypt {
    private static int BITS_PER_BYTE = 8;
    private static int PADDING_BYTES = 8;
    private static String base64Key = "FkWidY82O35T9j9djc1/MH0OtbcoCe+ryBham14WWlU=";
    private static String dataEncryption = "AES/CBC/ISO10126Padding";
    private static String dataKeyType = "AES";
    private static String keyEncryption = "AESWRAP";
    private static String userKeyType = "AES";
    private static int dataKeySize = 256;

    public static void main(final String[] args) {
        boolean encrypt = true;
        final String filename;
        final String element;
        int i = 0;

        if (i < args.length) {
            final String mode = args[i];

            if (mode.startsWith("-")) {
                i++;
                if (mode.equals("-e")) {
                    encrypt = true;
                } else if (mode.equals("-d")) {
                    encrypt = false;
                } else {
                    usage();
                }
            }
        }

        if ((args.length - i) != 2) {
            usage();
        }

        filename = args[i++];
        element = args[i++];

        try {
            process(encrypt, base64Key, filename, element);
        } catch (final Exception e) {
            e.printStackTrace();
        }
    }

    private static void usage() {
        System.out.println("encrypt [-e|-d] <filename> <element>");
        System.out.println("Example: encrypt -d sample.xml wst:TokenType");
        System.exit(0);
    }

    private static void process(final boolean encrypt, final String base64Key,
        final String filename, final String element) throws Exception {
        final byte[] key = Base64.getDecoder().decode(base64Key);
        final File file = new File(filename);
        final DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
        final DocumentBuilder builder = factory.newDocumentBuilder();
        final Document document = builder.parse(file);
        final NodeList nodeList = document.getElementsByTagName(element);
        final int len = nodeList.getLength();

        for (int i = 0; i < len; i++) {
            final Node node = nodeList.item(i);
            final String text = node.getTextContent();
            final String newText;

            if (encrypt) {
                newText = encrypt(text, key);
            } else {
                newText = decrypt(text, key);
            }

            node.setTextContent(newText);
        }

        if (len > 0) {
            write(document, file);
        }
    }

    private static void write(final Document document, final File file)
        throws TransformerException {
        final TransformerFactory factory = TransformerFactory.newInstance();
        final Transformer transformer = factory.newTransformer();
        final DOMSource source = new DOMSource(document);
        final StreamResult result = new StreamResult(file);

        transformer.transform(source, result);
    }

    private static String encrypt(final String text, final byte[] key)
        throws NoSuchAlgorithmException, NoSuchPaddingException, InvalidKeyException,
        IllegalBlockSizeException, BadPaddingException, IOException {
        final byte[] content = text.getBytes();
        final Key userKey = new SecretKeySpec(key, userKeyType);
        final KeyGenerator generator = KeyGenerator.getInstance(dataKeyType);
        final SecureRandom random = new SecureRandom();
        final Key dataKey;
        final Cipher keyCipher = Cipher.getInstance(keyEncryption);
        final Cipher dataCipher = Cipher.getInstance(dataEncryption);
        final byte[] encryptedDataKey;
        final byte[] encryptedData;
        final ByteArrayOutputStream encrypted = new ByteArrayOutputStream();

        generator.init(dataKeySize, random);
        dataKey = generator.generateKey();
        keyCipher.init(WRAP_MODE, userKey);
        encryptedDataKey = keyCipher.wrap(dataKey);

        dataCipher.init(ENCRYPT_MODE, dataKey);
        encryptedData = dataCipher.doFinal(content);

        encrypted.write(encryptedDataKey);
        encrypted.write(dataCipher.getIV());
        encrypted.write(encryptedData);
        return Base64.getEncoder().encodeToString(encrypted.toByteArray());
    }

    private static String decrypt(final String text, final byte[] key)
        throws NoSuchAlgorithmException, NoSuchPaddingException, InvalidKeyException,
        IllegalBlockSizeException, BadPaddingException, InvalidAlgorithmParameterException {
        final byte[] encrypted = Base64.getMimeDecoder().decode(text);
        final Key userKey = new SecretKeySpec(key, userKeyType);
        final Cipher keyCipher = Cipher.getInstance(keyEncryption);
        final Cipher dataCipher = Cipher.getInstance(dataEncryption);
        final int blockSize = dataCipher.getBlockSize();
        final int dataKeyLen = (dataKeySize / BITS_PER_BYTE) + PADDING_BYTES;
        final byte[] encryptedDataKey = new byte[dataKeyLen];
        final byte[] iv = new byte[blockSize];
        final Key dataKey;
        final int dataOffset = dataKeyLen + blockSize;
        final int dataLen = encrypted.length - dataOffset;
        final byte[] content;

        System.arraycopy(encrypted, 0, encryptedDataKey, 0, dataKeyLen);
        System.arraycopy(encrypted, dataKeyLen, iv, 0, blockSize);
        keyCipher.init(UNWRAP_MODE, userKey);
        dataKey = keyCipher.unwrap(encryptedDataKey, dataKeyType, SECRET_KEY);
        dataCipher.init(DECRYPT_MODE, dataKey, new IvParameterSpec(iv));
        content = dataCipher.doFinal(encrypted, dataOffset, dataLen);
        return new String(content);
    }
}
