WebSocketClient.java

package gov.usgs.earthquake.distribution;

import javax.websocket.*;

import gov.usgs.util.Config;
import gov.usgs.util.DefaultConfigurable;

import java.io.IOException;
import java.net.URI;
import java.util.Objects;
import java.util.concurrent.*;
import java.util.logging.Logger;
import java.nio.ByteBuffer;

/**
 * Manages a simple connection to a websocket. Can also be overridden for more
 * complex behavior.
 */
@ClientEndpoint
public class WebSocketClient extends DefaultConfigurable implements Runnable {

  /** Initialzation of logger. For us later in file. */
  public static final Logger LOGGER = Logger.getLogger(WebSocketClient.class.getName());

  private Session session;

  private URI endpoint;
  private WebSocketListener listener;
  private int attempts = WebSocketClient.DEFAULT_ATTEMPTS;
  private long timeoutMillis = WebSocketClient.DEFAULT_TIMEOUT_MILLIS;
  private boolean retryOnClose = WebSocketClient.DEFAULT_RETRY_ON_CLOSE;
  private ScheduledExecutorService scheduledExector;
  private long pingIntervalMillis = WebSocketClient.DEFAULT_PING_INTERVAL_MILLIS;
  private long pingWaitMillis = WebSocketClient.DEFAULT_PING_WAIT_MILLIS;
  private boolean pingSent = false;
  private boolean pongReceived = false;
  private long timePingSentMillis;
  private ScheduledFuture<?> pingTask;
  private String pingFailMessage = "";
  private ScheduledFuture<?> anyMessageTask;
  private long anyMessageIntervalMillis = WebSocketClient.DEFAULT_ANY_MESSAGE_INTERVAL_MILLIS;
  private boolean anyMessageReceived = false;

  // Default Attempts and Timeout have been set to values that essentially mean
  // the client will continue to retry "forever".
  // This is the most common way the client will probably be run, as a continual
  // process meant to be up at all times.
  /** Default number of attempts */
  public static final int DEFAULT_ATTEMPTS = 100000;
  /** Default timeout in ms */
  public static final long DEFAULT_TIMEOUT_MILLIS = 500;
  /** Default for trying to retry on close */
  public static final boolean DEFAULT_RETRY_ON_CLOSE = true;
  /** Default for time in between pings to server */
  public static final long DEFAULT_PING_INTERVAL_MILLIS = 15000;
  /**
   * Default for how long to wait for pong response to ping before closing or
   * restarting connection
   */
  public static final long DEFAULT_PING_WAIT_MILLIS = 4000;
  /** Default for how long to wait for any Message(excluding Pongs) */
  // A value of 0 disables this function
  public static final long DEFAULT_ANY_MESSAGE_INTERVAL_MILLIS = 0;

  /**
   * Property name to configure the pingWaitMillis value for this WebSocketClient
   */
  public static final String PING_WAIT_MILLIS_PROPERTY = "pingWaitMillis";

  /**
   * Property name to configure the anyMessageIntervalMillis value for this
   * WebSocketClient
   */
  public static final String ANY_MESSAGE_INTERVAL_MILLIS_PROPERTY = "anyMessageIntervalMillis";

  /**
   * Property name to configure the pingIntervalMillis value for this
   * WebSocketClient
   */
  public static final String PING_INTERVAL_MILLIS_PROPERTY = "pingIntervalMillis";

  /**
   * Property name to configure the connectRetries value for this WebSocketClient
   */
  public static final String CONNECT_RETRIES_PROPERTY = "connectRetries";

  /** Property name to configure the url value for this WebSocketClient */
  public static final String URI_PROPERTY = "url";

  /**
   * Property name to configure the connectTimeout value for this WebSocketClient
   */
  public static final String CONNECT_TIMEOUT_PROPERTY = "connectTimeout";

  /**
   * Property name to configure the retryOnClose value for this WebSocketClient
   */
  public static final String RETRY_ON_CLOSE_PROPERTY = "retryOnClose";

  /**
   * Default constructor required for configurable interface.
   * 
   * @throws Exception
   */
  public WebSocketClient() throws Exception {
    // Must call configure method to set all necessary attributes
    // If not configured properly, expect crazy exceptions at runtime
  }

  /**
   * Constructs the client. Also connects to the server.
   *
   * @param endpoint      the URI to connect to
   * @param listener      a WebSocketListener to handle incoming messages
   * @param attempts      an integer number of times to try the connection
   * @param timeoutMillis a long for the wait time between attempts
   * @param retryOnClose  boolean for if the connection should retry when closed
   * @throws Exception on thread interrupt or connection failure
   * @deprecated use
   *             {@link #WebSocketClient(URI, WebSocketListener, int, long, boolean, long, long, long)},
   *             includes ping configuration
   */
  @Deprecated
  public WebSocketClient(URI endpoint, WebSocketListener listener, int attempts, long timeoutMillis,
      boolean retryOnClose) throws Exception {
    this(endpoint, listener, attempts, timeoutMillis, retryOnClose, DEFAULT_PING_INTERVAL_MILLIS,
        DEFAULT_PING_WAIT_MILLIS, DEFAULT_ANY_MESSAGE_INTERVAL_MILLIS);
  }

  /**
   * Creates a Websocket Client Default values for attempts ant timeoutMillis
   * create an instance which is designed to be up and running at all times
   *
   * @param endpoint                 the URI to connect to
   * @param listener                 a WebSocketListener to handle incoming
   *                                 messages
   * @param attempts                 an integer number of times to try the
   *                                 connection
   * @param timeoutMillis            a long for the wait time between attempts
   * @param retryOnClose             boolean for if the connection should retry
   *                                 when closed
   * @param pingIntervalMillis       how often to send ping in milliseconds. If
   *                                 you don't want to send pings set to 0 or
   *                                 negative value.
   * @param pingWaitMillis           how long to wait, in milliseconds, before
   *                                 declaring socket down and closing and
   *                                 retrying if retryOnClose is set.
   * @param anyMessageIntervalMillis how often to check if any message has
   *                                 arrived(excluding Pongs) restart if one has
   *                                 not.
   * @throws Exception on thread interrupt or connection failure
   */
  public WebSocketClient(URI endpoint, WebSocketListener listener, int attempts, long timeoutMillis,
      boolean retryOnClose, long pingIntervalMillis, long pingWaitMillis, long anyMessageIntervalMillis)
      throws Exception {
    this.setName("websocket_client");

    this.pingWaitMillis = pingWaitMillis;
    LOGGER.config(() -> String.format("[%s] pingWaitMillis=%s", getName(), this.pingWaitMillis));

    this.setAnyMessageIntervalMillis(anyMessageIntervalMillis);
    LOGGER
        .config(() -> String.format("[%s] anyMessageIntervalMillis=%s", getName(), this.getAnyMessageIntervalMillis()));

    this.setPingIntervalMillis(pingIntervalMillis);
    LOGGER
        .config(() -> String.format("[%s] pingIntervalMillis=%s", getName(), this.getPingIntervalMillis()));

    this.listener = listener;

    this.endpoint = endpoint;
    LOGGER.config(() -> String.format("[%s] url=%s", getName(), endpoint));

    this.attempts = attempts;
    LOGGER.config(() -> String.format("[%s] attempts=%s", getName(), this.attempts));

    this.timeoutMillis = timeoutMillis;
    LOGGER.config(() -> String.format("[%s] timeoutMillis=%s", getName(), this.timeoutMillis));

    this.retryOnClose = retryOnClose;
    LOGGER.config(() -> String.format("[%s] retryOnClose=%s", getName(), this.retryOnClose));
  }

  /**
   * Constructs the client
   *
   * @param endpoint the URI to connect to
   * @param listener a WebSocketListener to handle incoming messages
   * @throws Exception thread interrupt or connection failure
   */
  public WebSocketClient(URI endpoint, WebSocketListener listener) throws Exception {
    this(endpoint, listener, DEFAULT_ATTEMPTS, DEFAULT_TIMEOUT_MILLIS, DEFAULT_RETRY_ON_CLOSE,
        DEFAULT_PING_INTERVAL_MILLIS, DEFAULT_PING_WAIT_MILLIS, DEFAULT_ANY_MESSAGE_INTERVAL_MILLIS);
  }

  private void anyMessageClose() {
    try {
      if (this.session.isOpen()) {
        session.close(new CloseReason(CloseReason.CloseCodes.NO_STATUS_CODE, "Any Message Failed!"));
      } else {
        onClose(this.session, new CloseReason(CloseReason.CloseCodes.NO_STATUS_CODE, "Any Message Failed!"));
      }
    } catch (Exception e) {
    }
  }

  /**
   * Called by Websocket interface when a Pong is received in response to a ping
   * that was sent.
   *
   * @param pongMessage Message that was populated from ping.
   * @param session     The websocket session.
   */
  @OnMessage
  public void catchPong(PongMessage pongMessage, Session session) {
    this.pongReceived = true;
  }

  /**
   * Connect to server
   *
   * @throws Exception if error occurs
   */
  public void connect() throws Exception {
    // try to connect to server
    WebSocketContainer container = ContainerProvider.getWebSocketContainer();
    int failedAttempts = 0;
    Exception lastExcept = null;
    for (int i = 0; i < attempts; i++) {
      try {
        int connectAttempt = i + 1;
        LOGGER
            .info(() -> String.format("[%s] attempt %s out of %s connecting to %s", getName(), connectAttempt, attempts,
                this.endpoint.toString()));
        container.connectToServer(this, endpoint);
        break;
      } catch (Exception e) {
        // increment failed attempts, sleep
        failedAttempts++;
        lastExcept = e;
        // Sleep longer and longer between attempts up to a max of one minute
        long sleepInterval = Math.min(60000, (failedAttempts * timeoutMillis));
        LOGGER.info(() -> String.format("[%s] failed to connect to %s, retrying in %s", getName(),
            this.endpoint.toString(), sleepInterval));
        Thread.sleep(sleepInterval);
      }
    }

    // throw connect exception if all attempts fail
    if (failedAttempts == attempts) {
      this.listener.onConnectFail();
      throw lastExcept;
    }
  }

  /**
   * If any message at all, which would include heartbeats, has been received
   * since the last anymessage interval this value will be true.
   *
   * @return
   */
  private boolean hasReceivedAnyMessage() {
    return this.anyMessageReceived;
  }

  /**
   * Checks if there is an open session
   *
   * @return boolean
   * @throws IOException if IO error occurs
   */
  public boolean isConnected() throws IOException {
    return this.session != null && this.session.isOpen();
  }

  /**
   * Sets the session and listener
   *
   * @param session Session
   * @throws IOException if IO error occurs
   */
  @OnOpen
  public void onOpen(Session session) throws IOException {
    this.session = session;
    if (pingIntervalMillis > 0) {
      pingTask = this.scheduledExector.schedule(this, pingIntervalMillis, TimeUnit.MILLISECONDS);
    }
    if (anyMessageIntervalMillis > 0) {
      startAnyMessageReceived();
    }
    this.listener.onOpen(session);
    this.anyMessageReceived = false;
  }

  /**
   * Closes the session on the listener, sets constructor session to null Check if
   * should be retryed
   *
   * @param session Session
   * @param reason  for close
   * @throws IOException if IO error occurs
   */
  @OnClose
  public void onClose(Session session, CloseReason reason) throws IOException {
    this.listener.onClose(session, reason);
    this.session = null;
    pingSent = false;
    pongReceived = false;
    if (pingTask != null) {
      pingTask.cancel(true);
    }
    if (anyMessageTask != null) {
      anyMessageTask.cancel(true);
    }
    if (retryOnClose) {
      try {
        this.connect();
      } catch (Exception e) {
        // failed to reconnect
        this.listener.onReconnectFail();
        // propagate this failure
        throw new IOException(e);
      }
    }
  }

  /**
   * Gives listener the message
   *
   * @param message String
   * @throws IOException if IO error occurs
   */
  @OnMessage
  public void onMessage(String message) throws IOException {
    this.anyMessageReceived = true;
    if (this.anyMessageIntervalMillis > 0) {
      startAnyMessageReceived();
    }
    this.listener.onMessage(message);
  }

  /**
   * If a ping has failed this method will close the session in the proper manner.
   *
   * @param session
   * @param pingFailMessage
   */
  private void pingFailClose(Session session, String pingFailMessage) {
    if (pingFailMessage.isBlank()) {
      pingFailMessage = "Pong not received!";
    }
    // String for CloseReason has to be lest than 123 bytes.
    if (pingFailMessage.length() > 90) {
      pingFailMessage = pingFailMessage.substring(0, 90);
    }
    try {
      if (session.isOpen()) {
        session.close(new CloseReason(CloseReason.CloseCodes.NO_STATUS_CODE, "Ping/Pong Failed! " + pingFailMessage));
      } else {
        onClose(session,
            new CloseReason(CloseReason.CloseCodes.NO_STATUS_CODE, "Ping/Pong Failed! " + pingFailMessage));
      }
    } catch (Exception e) {
    }
    pingFailMessage = "";
  }

  @Override
  public void run() {
    if (!pingSent) {
      // Add timestamp to ping message that could be evaluated when pong is received
      this.timePingSentMillis = System.currentTimeMillis();
      byte[] data = ("" + this.timePingSentMillis).getBytes();
      try {
        pongReceived = false;
        // Set pingSent=true here because the pingTask will always be started and the
        // failure
        // will occur when the pong has not been received even if the ping fails to
        // send.
        pingSent = true;
        this.session.getAsyncRemote().sendPing(ByteBuffer.wrap(data));

      } catch (Exception e) {
        // This message is used to indicate that the failure was actually in the sending
        // of the ping.
        pingFailMessage = "Sending Ping failed.  E=" + e.getLocalizedMessage();
      }

      pingTask = this.scheduledExector.schedule(this, pingWaitMillis, TimeUnit.MILLISECONDS);
    } else {
      if (!pongReceived) {
        pingFailClose(session, pingFailMessage);
      } else {
        pingSent = false;
        pongReceived = false;
        pingTask = this.scheduledExector.schedule(this, pingIntervalMillis, TimeUnit.MILLISECONDS);
      }
    }
  }

  @Override
  public void configure(Config config) throws Exception {
    this.pingWaitMillis = Long.parseLong(
        config.getProperty(WebSocketClient.PING_WAIT_MILLIS_PROPERTY,
            String.valueOf(WebSocketClient.DEFAULT_PING_WAIT_MILLIS)));
    LOGGER.config(() -> String.format("[%s] pingWaitMillis=%s", getName(), this.pingWaitMillis));

    long anyMessageIntervalMillis = Long.parseLong(
        config.getProperty(WebSocketClient.ANY_MESSAGE_INTERVAL_MILLIS_PROPERTY,
            String.valueOf(WebSocketClient.DEFAULT_ANY_MESSAGE_INTERVAL_MILLIS)));
    this.setAnyMessageIntervalMillis(anyMessageIntervalMillis);

    LOGGER
        .config(() -> String.format("[%s] anyMessageIntervalMillis=%s", getName(), this.getAnyMessageIntervalMillis()));

    long pingIntervalMillis = Long.parseLong(
        config.getProperty(WebSocketClient.PING_INTERVAL_MILLIS_PROPERTY,
            String.valueOf(WebSocketClient.DEFAULT_PING_INTERVAL_MILLIS)));
    this.setPingIntervalMillis(pingIntervalMillis);
    LOGGER
        .config(() -> String.format("[%s] pingIntervalMillis=%s", getName(), this.getPingIntervalMillis()));

    String uri = config.getProperty(WebSocketClient.URI_PROPERTY);
    if (Objects.isNull(uri)) {
      throw new ConfigurationException(
          String.format("[%s] missing required property %s", this.getName(), WebSocketClient.URI_PROPERTY));
    }
    try {
      this.endpoint = new URI(uri);
    } catch (Exception e) {
      throw new ConfigurationException(
          String.format("[%s] invalid url given for %s", this.getName(), WebSocketClient.URI_PROPERTY));
    }

    LOGGER.config(() -> String.format("[%s] url=%s", getName(), uri));

    this.attempts = Integer.parseInt(
        config.getProperty(WebSocketClient.CONNECT_RETRIES_PROPERTY,
            String.valueOf(WebSocketClient.DEFAULT_ATTEMPTS)));

    LOGGER.config(() -> String.format("[%s] attempts=%s", getName(), this.attempts));

    this.timeoutMillis = Long.parseLong(config.getProperty(WebSocketClient.CONNECT_TIMEOUT_PROPERTY,
        String.valueOf(WebSocketClient.DEFAULT_TIMEOUT_MILLIS)));
    LOGGER.config(() -> String.format("[%s] timeoutMillis=%s", getName(), this.timeoutMillis));

    this.retryOnClose = Boolean.parseBoolean(config.getProperty(WebSocketClient.RETRY_ON_CLOSE_PROPERTY,
        String.valueOf(WebSocketClient.DEFAULT_RETRY_ON_CLOSE)));
    LOGGER.config(() -> String.format("[%s] retryOnClose=%s", getName(), this.retryOnClose));
  }

  /**
   * Connect the client
   * 
   * @throws Exception
   */
  @Override
  public void startup() throws Exception {
    this.connect();
  }

  /**
   * Sets retry to false, then closes session
   *
   * @throws Exception if error occurs
   */
  @Override
  public void shutdown() throws Exception {
    this.retryOnClose = false;
    this.session.close();
    if (this.scheduledExector != null) {
      this.scheduledExector.shutdownNow();
    }
  }

  /** @param listener set WebSocketListener */
  public void setListener(WebSocketListener listener) {
    this.listener = listener;
  }

  /**
   * Start a task to check if any message has been received during the
   * anyMessageIntervalMillis time period.
   */
  private void startAnyMessageReceived() {
    anyMessageReceived = false;
    if (anyMessageTask != null) {
      anyMessageTask.cancel(false);
    }
    anyMessageTask = this.scheduledExector.schedule(new AnyMessageRunner(this), anyMessageIntervalMillis,
        TimeUnit.MILLISECONDS);
  }

  /**
   * This class is used to asyncronously handle tracking if any message has been
   * received.
   */
  private class AnyMessageRunner implements Runnable {
    private WebSocketClient webSocketClient;

    public AnyMessageRunner(WebSocketClient webSocketClient) {
      this.webSocketClient = webSocketClient;
    }

    @Override
    public void run() {
      if (!webSocketClient.hasReceivedAnyMessage()) {
        webSocketClient.anyMessageClose();
      } else {
        webSocketClient.startAnyMessageReceived();
      }
    }
  }

  public long getAnyMessageIntervalMillis() {
    return anyMessageIntervalMillis;
  }

  public int getAttempts() {
    return attempts;
  }

  public URI getEndpoint() {
    return this.endpoint;
  }

  public long getPingIntervalMillis() {
    return this.pingIntervalMillis;
  }

  public long getPingWaitMillis() {
    return pingWaitMillis;
  }

  public ScheduledExecutorService getScheduledExector() {
    return scheduledExector;
  }

  public long getTimeoutMillis() {
    return timeoutMillis;
  }

  public boolean isRetryOnClose() {
    return retryOnClose;
  }

  public void setPingIntervalMillis(long pingIntervalMillis) throws Exception {
    if (this.isConnected()) {
      throw new ConfigurationException("Can not change pingIntervalMillis after client is connected");
    }
    this.pingIntervalMillis = pingIntervalMillis;

    if (this.pingIntervalMillis > 0 && this.scheduledExector == null) {
      this.scheduledExector = Executors.newSingleThreadScheduledExecutor();
    }
  }

  public void setAnyMessageIntervalMillis(long anyMessageIntervalMillis) throws Exception {
    if (this.isConnected()) {
      throw new ConfigurationException("Can not change anyMessageIntervalMillis after client is connected");
    }
    this.anyMessageIntervalMillis = anyMessageIntervalMillis;

    if (this.anyMessageIntervalMillis > 0 && this.scheduledExector == null) {
      this.scheduledExector = Executors.newSingleThreadScheduledExecutor();
    }
  }

}