AwsProductSender.java

package gov.usgs.earthquake.aws;

import gov.usgs.earthquake.distribution.ConfigurationException;
import gov.usgs.earthquake.distribution.InvalidSignatureException;
import gov.usgs.earthquake.distribution.ProductSender;
import gov.usgs.earthquake.product.Content;
import gov.usgs.earthquake.product.Product;
import gov.usgs.earthquake.product.ProductId;
import gov.usgs.earthquake.product.ProductSignature;
import gov.usgs.earthquake.product.URLContent;
import gov.usgs.earthquake.product.io.JsonProduct;
import gov.usgs.util.Config;
import gov.usgs.util.CryptoUtils;
import gov.usgs.util.DefaultConfigurable;
import gov.usgs.util.FileUtils;
import gov.usgs.util.XmlUtils;
import gov.usgs.util.CryptoUtils.Version;

import java.io.File;
import java.io.StringReader;
import java.net.URL;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.Charset;
import java.security.PrivateKey;
import java.time.Duration;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.json.Json;
import javax.json.JsonObject;

/** Send using AWS Hub API. */
public class AwsProductSender extends DefaultConfigurable implements ProductSender {

  /** Initialzation of logger. For us later in file. */
  public static final Logger LOGGER = Logger.getLogger(AwsProductSender.class.getName());

  /** Property key for connect timeout property */
  public static final String CONNECT_TIMEOUT_PROPERTY = "connectTimeout";
  /** Base URL for Hub API. */
  public static final String HUB_URL_PROPERTY = "url";
  /** Max JSON notification size for sending products */
  public static final String MAX_PAYLOAD_BYTES_PROPERTY = "maxPayloadBytes";
  /** Private Key to sign products, if signProducts is true. */
  public static final String PRIVATE_KEY_PROPERTY = "privateKey";
  /** Property key for read timeout property */
  public static final String READ_TIMEOUT_PROPERTY = "readTimeout";
  /** Whether to sign products using private key. */
  public static final String SIGN_PRODUCTS_PROPERTY = "signProducts";
  /** Resolver class to use to get a secret for the private key */
  public static final String SECRET_RESOLVER_PROPERTY = "secretResolver";
  /** Secret key to retrieve */
  public static final String PRIVATE_KEY_SECRET_NAME_PROPERTY = "secretName";

  /** Default value to use for connect timeout property if not configured */
  public static final int DEFAULT_CONNECT_TIMEOUT = 5000;
  /** Default value to use for read timeout property if not configured */
  public static final int DEFAULT_READ_TIMEOUT = 30000;

  /** client used for http requests */
  protected HttpClient httpClient;
  /** url where products are sent */
  protected URL hubUrl;
  /** signing key */
  protected PrivateKey privateKey;
  /** wheter to sign products */
  protected boolean signProducts = false;

  /** Connection timeout. 5s seems excessive, but be cautious for now */
  protected int connectTimeout = AwsProductSender.DEFAULT_CONNECT_TIMEOUT;
  /** Server-side timeout. Called at getInputStream().read() */
  protected int readTimeout = AwsProductSender.DEFAULT_READ_TIMEOUT;

  // API Gateway `PostToConnection` endpoint has a payload maximum of 128 KB
  public static final long DEFAULT_MAX_PAYLOAD_BYTES = 128 * 1024;
  private long maxPayloadBytes = DEFAULT_MAX_PAYLOAD_BYTES;

  /** Empty class constructor */
  public AwsProductSender() {
    this.httpClient = HttpClient.newBuilder()
        .connectTimeout(Duration.ofMillis(this.connectTimeout))
        .followRedirects(HttpClient.Redirect.NORMAL)
        .version(HttpClient.Version.HTTP_2)
        .build();
  }

  /**
   * Setter for product sender
   *
   * @param url URL for sender hub
   */
  public AwsProductSender(URL url) {
    this();
    this.hubUrl = url;
  }

  @Override
  public void configure(Config config) throws Exception {
    super.configure(config);

    hubUrl = new URL(config.getProperty(HUB_URL_PROPERTY));
    LOGGER.config("[" + getName() + "] url=" + hubUrl.toString());

    final String sign = config.getProperty(SIGN_PRODUCTS_PROPERTY);
    if (sign != null) {
      signProducts = Boolean.valueOf(sign);
    }
    LOGGER.config("[" + getName() + "] sign products=" + signProducts);

    if (signProducts) {
      // If signProducts is turned on and both a private key file AND secret are
      // given, precedence is given to the private key file over the SecretResolver
      final String key = config.getProperty(PRIVATE_KEY_PROPERTY);
      final String secretName = config.getProperty(PRIVATE_KEY_SECRET_NAME_PROPERTY);

      if (Objects.nonNull(key)) {
        privateKey = CryptoUtils.readOpenSSHPrivateKey(FileUtils.readFile(new File(key)), null);
        LOGGER.config(() -> String.format("[%s] privateKey=%s", getName(), key));
      } else if (Objects.nonNull(secretName)) {
        LOGGER.config(() -> String.format("[%s] private key secretName=%s", getName(), secretName));

        final String secretResolverConfigSection = config.getProperty(SECRET_RESOLVER_PROPERTY);
        if (Objects.isNull(secretResolverConfigSection)) {
          throw new ConfigurationException(String.format("[%s] requires a %s when %s is given", getName(),
              SECRET_RESOLVER_PROPERTY, PRIVATE_KEY_SECRET_NAME_PROPERTY));
        }

        SecretResolver secretResolver = (SecretResolver) Config.getConfig().getObject(secretResolverConfigSection);

        String secretValue = secretResolver.getPlaintextSecret(secretName);
        privateKey = CryptoUtils.readOpenSSHPrivateKey(secretValue.getBytes(Charset.defaultCharset()), null);
      }

      if (Objects.isNull(privateKey)) {
        // no key sucessfully configured
        throw new ConfigurationException(
            String.format("[%s] %s requires a private key configured for signing", getName(), SIGN_PRODUCTS_PROPERTY));
      }
    }

    final String connectTimeout = config.getProperty(AwsProductSender.CONNECT_TIMEOUT_PROPERTY);
    if (connectTimeout != null) {
      this.connectTimeout = Integer.valueOf(connectTimeout);
      LOGGER.config(() -> String.format("[%s] connectTimeout = %d",
          this.getName(), this.connectTimeout));
    }

    final String maxPayloadConfiguration = config.getProperty(MAX_PAYLOAD_BYTES_PROPERTY);
    if (maxPayloadConfiguration != null) {
      long maxPayload = Long.parseLong(maxPayloadConfiguration);
      if (maxPayload > DEFAULT_MAX_PAYLOAD_BYTES) {
        throw new ConfigurationException("[" + getName() + "] " + MAX_PAYLOAD_BYTES_PROPERTY + " exceeds the limit of "
            + DEFAULT_MAX_PAYLOAD_BYTES + " KB");
      }
      this.maxPayloadBytes = maxPayload;
    }

    // Do this last so all newly configured values may be used
    this.httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofMillis(this.connectTimeout))
        .followRedirects(HttpClient.Redirect.NORMAL).version(HttpClient.Version.HTTP_2).build();
  }

  public void checkProductJsonSize(final JsonObject json) throws Exception {
    long jsonSize = json.toString().getBytes().length;
    if (jsonSize > this.maxPayloadBytes) {
      throw new MaxPayloadExceeded(
          "Maximum payload (" + this.maxPayloadBytes + " bytes) exceeded. Actual size: " + jsonSize + " bytes.");
    }
  }

  /**
   * Send a product to the hub.
   */
  @Override
  public void sendProduct(final Product product) throws Exception {
    final ProductId id = product.getId();

    // re-sign if configured
    if (signProducts) {
      if (!Objects.isNull(product.getSignature())) {
        // add original signature to history
        ProductSignature originalSignature = new ProductSignature(product.getProperties().get("original-signature"),
            Version.fromString(product.getProperties().get("original-signature-version")));
        product.addToSignatureHistory(originalSignature);
      }
      // add current signature to history
      ProductSignature currentSignature = new ProductSignature(product.getSignature(),
          product.getSignatureVersion());
      product.addToSignatureHistory(currentSignature);

      product.sign(privateKey, CryptoUtils.Version.SIGNATURE_V2);
      LOGGER.fine(() -> String.format("[%s] Resigned product from original-signature %s to new signature %s",
          getName(),
          product.getProperties().get("original-signature"),
          product.getSignature()));
      // add new signature to history
      ProductSignature resignedSignature = new ProductSignature(product.getSignature(), product.getSignatureVersion());
      product.addToSignatureHistory(resignedSignature);
    }
    // convert to json
    String urn = product.getId().toString();
    JsonObject json = new JsonProduct().getJsonObject(product);

    final long start = new Date().getTime();
    final long afterUploadContent;
    try {
      this.checkProductJsonSize(json);
      // upload contents
      if (
      // has contents
      product.getContents().size() > 0
          // and not only inline content
          && !(product.getContents().size() == 1 && product.getContents().get("") != null)) {
        LOGGER.fine("Getting upload urls for " + json.toString());
        // get upload urls, response is product with signed content urls for
        // upload
        Product uploadProduct;
        try {
          uploadProduct = getUploadUrls(json, urn);
        } catch (HttpException e) {
          int responseCode = e.getResponse().statusCode();
          // check for server error
          if (responseCode >= 500) {
            LOGGER.log(Level.FINE, "[" + getName() + "] get upload urls exception, trying again", e);
            // try again after random back off (1-5 s)
            Thread.sleep(1000 + Math.round(4000 * Math.random()));
            uploadProduct = getUploadUrls(json, urn);
          } else {
            // otherwise propagate exception as usual
            throw e;
          }
        }

        final long afterGetUploadUrls = new Date().getTime();
        LOGGER.fine(
            "[" + getName() + "] get upload urls " + id.toString() + " (" + (afterGetUploadUrls - start) + " ms) ");

        // upload contents
        try {
          uploadContents(product, uploadProduct);
        } catch (HttpException hex) {
          HttpResponse<String> response = hex.getResponse();
          // check for S3 "503 Slow Down" error
          if (503 == response.statusCode() && "Slow Down".equals(response.body())) {
            LOGGER.fine(() -> String.format("[%s] 503 slow down exception, trying again", this.getName()));
            // try again after random back off (1-5 s)
            Thread.sleep(1000 + Math.round(4000 * Math.random()));
            uploadContents(product, uploadProduct);
          } else {
            // otherwise propagate exception as usual
            throw hex;
          }
        }

        afterUploadContent = new Date().getTime();
        LOGGER.fine("[" + getName() + "] upload contents " + id.toString() + " ("
            + (afterUploadContent - afterGetUploadUrls) + " ms) ");
      } else {
        afterUploadContent = new Date().getTime();
      }

      try {
        // send product
        sendProduct(json, urn);
      } catch (HttpException hex) {
        // check for server error
        if (hex.getResponse().statusCode() >= 500) {
          LOGGER.log(Level.FINE, "[" + getName() + "] send product exception, trying again", hex);
          // try again after random back off (1-5 s)
          Thread.sleep(1000 + Math.round(4000 * Math.random()));
          sendProduct(json, urn);
        } else {
          // otherwise propagate exception as usual
          throw hex;
        }
      }

      final long afterSendProduct = new Date().getTime();
      LOGGER.fine("[" + getName() + "] send product " + id.toString() + " (" + (afterSendProduct - afterUploadContent)
          + " ms) ");
    } catch (ProductAlreadySentException pase) {
      // hub already has product
      LOGGER.info("[" + getName() + "] hub already has product");
    } catch (Exception e) {
      LOGGER.log(Level.WARNING, "Exception sending product " + id.toString(), e);
      throw e;
    } finally {
      final long end = new Date().getTime();
      LOGGER.info("[" + getName() + "] send product total " + id.toString() + " (" + (end - start) + " ms) ");
    }
  }

  /**
   * Get content upload urls.
   *
   * @param json product in json format.
   * @return product with content urls set to upload URLs.
   * @throws Exception Exception
   */
  protected Product getUploadUrls(final JsonObject json, final String urn) throws Exception {
    final URL url = new URL(hubUrl, String.format("products/%s/uploads", urn));
    final HttpResponse<String> response = postProductJson(url, json);
    // final int responseCode = result.connection.getResponseCode();
    final int responseCode = response.statusCode();

    // check for errors
    if (responseCode == 401) {
      throw new InvalidSignatureException("Invalid product signature");
    } else if (responseCode == 409) {
      throw new ProductAlreadySentException();
    } else if (responseCode != 200) {
      throw new HttpException(response, String.format("Error [%d] getting upload urls", responseCode));
    }

    // successful response is json object with "products" property
    // that is product with upload urls for contents.
    final JsonObject getUploadUrlsResponse = Json.createReader(new StringReader(response.body())).readObject();
    final Product product = new JsonProduct().getProduct(getUploadUrlsResponse);
    return product;
  }

  /**
   * Post product json to a hub url.
   *
   * This is a HTTP POST method, with a JSON content body with a "product"
   * property with the product.
   *
   * @param url     url of connection
   * @param product product in json format
   * @return new HTTP POST response
   * @throws Exception Exception
   */
  protected HttpResponse<String> postProductJson(final URL url, final JsonObject product) throws Exception {
    final HttpRequest request = HttpRequest.newBuilder()
        .POST(HttpRequest.BodyPublishers.ofString(product.toString()))
        .header("Content-Type", "application/json")
        .timeout(Duration.ofMillis(this.readTimeout))
        .uri(url.toURI())
        .build();
    final HttpResponse<String> response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofString());
    return response;
  }

  /**
   * Send product after content has been uploaded.
   *
   * @param json product in json format.
   * @return product with content urls pointing to hub.
   * @throws Exception Exception
   */
  protected Product sendProduct(final JsonObject json, String urn) throws Exception {
    // send request
    final URL url = new URL(hubUrl, String.format("/products/%s", urn));
    final HttpResponse<String> response = postProductJson(url, json);
    int responseCode = response.statusCode();

    // check for errors
    if (responseCode == 401) {
      throw new InvalidSignatureException("Invalid product signature");
    } else if (responseCode == 409) {
      throw new ProductAlreadySentException();
    } else if (responseCode == 413) {
      throw new MaxPayloadExceeded("Maximum payload (" + DEFAULT_MAX_PAYLOAD_BYTES + " bytes) exceeded.");
    } else if (responseCode == 422) {
      throw new HttpException(response, "Content validation errors: " + response.body());
    } else if (responseCode != 200) {
      throw new HttpException(response, String.format("Error [%d] sending product", responseCode));
    }

    // successful response is json object with "notification" property
    // that has "created" and "product" properties with hub urls for contents.
    final JsonObject sendProductResponse = Json.createReader(new StringReader(response.body())).readObject();
    final JsonObject notification = sendProductResponse.getJsonObject("notification");
    final Product product = new JsonProduct().getProduct(notification.getJsonObject("product"));
    // json response also has "notification_id" property of broadcast that was
    // sent.
    String notificationId = null;
    if (!sendProductResponse.isNull("notification_id")) {
      notificationId = sendProductResponse.getString("notification_id");
    }
    LOGGER.fine("[" + getName() + "] notification id " + notificationId + " " + product.getId().toString());
    return product;
  }

  /**
   * Upload content to a signed url.
   *
   * @param path      content path.
   * @param content   content to upload.
   * @param signedUrl url where content should be uploaded.
   * @return HTTP result
   * @throws Exception Exception
   */
  protected HttpResponse<String> uploadContent(final String path, final Content content, final URL signedUrl)
      throws Exception {
    final long start = new Date().getTime();

    final HttpRequest request = HttpRequest.newBuilder()
        .uri(signedUrl.toURI())
        .timeout(Duration.ofMillis(this.readTimeout))
        .header("Content-Type", content.getContentType())
        .header("Content-Encoding", "aws-chunked")
        // these values are part of signed url and are required
        .header("x-amz-meta-modified", XmlUtils.formatDate(content.getLastModified()))
        .header("x-amz-meta-sha256", content.getSha256())
        .PUT(new ContentPublisher(content))
        .build();

    // send content
    HttpResponse<String> response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofString());

    final long elapsed = (new Date().getTime() - start);
    if (response.statusCode() != 200) {
      throw new HttpException(response,
          String.format("Error [%d] uploading content %s (%d ms)", response.statusCode(), path, elapsed));
    }
    LOGGER.finer(() -> String.format("[%s]  uploaded content %s (size= %d bytes) (time= %d ms)", this.getName(), path,
        content.getLength(), elapsed));
    return response;
  }

  /**
   * Upload product contents.
   *
   * Runs uploads in parallel using a parallel stream.
   *
   * This can be called within a custom ForkJoinPool to use a non-default pool,
   * the default pool is shared by the process and based on number of available
   * cores.
   *
   * @param product       product to upload.
   * @param uploadProduct product with signed upload urls.
   * @return upload results
   * @throws Exception if any upload errors occur
   */
  protected Map<String, HttpResponse<String>> uploadContents(final Product product, final Product uploadProduct)
      throws Exception {
    // collect results
    final ConcurrentHashMap<String, HttpResponse<String>> uploadResults = new ConcurrentHashMap<String, HttpResponse<String>>();
    final ConcurrentHashMap<String, Exception> uploadExceptions = new ConcurrentHashMap<String, Exception>();
    // upload contents in parallel
    uploadProduct.getContents().keySet().parallelStream().filter(path -> !"".equals(path)).forEach(path -> {
      try {
        Content uploadContent = uploadProduct.getContents().get(path);
        if (!(uploadContent instanceof URLContent)) {
          throw new IllegalStateException("Expected URLContent for " + product.getId().toString() + " path '" + path
              + "' but got " + uploadContent);
        }
        uploadResults.put(path,
            uploadContent(path, product.getContents().get(path), ((URLContent) uploadContent).getURL()));
      } catch (Exception e) {
        uploadExceptions.put(path, e);
      }
    });
    if (uploadExceptions.size() > 0) {
      Exception e = null;
      // log all
      for (final String path : uploadExceptions.keySet()) {
        e = uploadExceptions.get(path);
        LOGGER.log(Level.WARNING, "Exception uploading content " + path, e);
      }
      // throw last
      throw e;
    }
    return uploadResults;
  }

  /** Static exception class for when a command times out */
  public static class MaxPayloadExceeded extends Exception {

    public MaxPayloadExceeded(String message) {
      super(message);
    }
  }

  /**
   * Getter for signProducts
   *
   * @return boolean
   */
  public boolean getSignProducts() {
    return signProducts;
  }

  /**
   * Setter for signProducts
   *
   * @param sign boolean
   */
  public void setSignProducts(final boolean sign) {
    this.signProducts = sign;
  }

  /**
   * getter for privateKey
   *
   * @return privateKey
   */
  public PrivateKey getPrivateKey() {
    return privateKey;
  }

  /**
   * setting for privateKey
   *
   * @param key PrivateKey
   */
  public void setPrivateKey(final PrivateKey key) {
    this.privateKey = key;
  }

}