import SockJS from 'sockjs-client';
import { map } from 'rxjs/operators';
import { Subject, Observable, Subscription } from 'rxjs';
import { RxStomp, RxStompConfig, RxStompState } from '@stomp/rx-stomp';
import { RSocketConnector, RSocket, ConnectorConfig } from 'rsocket-core';
import {
  WellKnownMimeType,
  encodeBearerAuthMetadata,
  encodeCompositeMetadata,
  encodeRoute,
} from 'rsocket-composite-metadata';
import { WebsocketClientTransport } from 'rsocket-websocket-client';

export class SocketService {
  public jwtGetter: (() => string) | null = null;

  public makeStompTransport() {
    return new StompSocketTransport({
      beforeConnect: (client) => {
        if (this.jwtGetter) {
          client.configure({
            connectHeaders: {
              Authorization: `Bearer ${this.jwtGetter()}`,
            },
          });
        }
      },
    });
  }

  public makeVanillaTransport() {
    return new VanillaSocketTransport();
  }

  public makeRSocketTransport() {
    return new RSocketTransport(() => {
      if (this.jwtGetter) {
        return this.jwtGetter();
      }
      return null;
    });
  }
}

export class StompSocketTransport {
  constructor(private config: Partial<RxStompConfig> = {}) {}

  private socket?: RxStomp;
  private socketSubs: Subscription[] = [];
  private stateSubject = new Subject<RxStompState>();
  private currentState: RxStompState | null = null;
  private connectingSocket = false;

  private async setupSocket() {
    this.socketSubs.forEach((s) => s.unsubscribe());
    this.socketSubs = [];
    this.currentState = null;
    if (this.socket) {
      this.socket.deactivate();
    }

    const configuration: RxStompConfig = {
      reconnectDelay: 0,
      ...this.config,
    };

    const socket = new RxStomp();

    this.socket = socket;

    this.socket.configure(configuration);

    this.socketSubs.push(
      this.socket.connectionState$.subscribe((state) => {
        // Check for CLOSED or OPEN states
        // As we know the first state is CLOSED, we only check this once currentState is set
        if (this.currentState && (state === RxStompState.CLOSED || state === RxStompState.OPEN)) {
          this.connectingSocket = false;
        }
        this.currentState = state;
      })
    );

    this.socketSubs.push(this.socket.connectionState$.subscribe(this.stateSubject));
  }

  public async connect(url: string, force?: boolean, config: Partial<RxStompConfig> = {}) {
    if (this.socket?.active || this.connectingSocket) {
      if (force) {
        await this.setupSocket();
      } else {
        return;
      }
    } else {
      await this.setupSocket();
    }

    this.socket!.configure({
      ...config,
      webSocketFactory: () => {
        return new SockJS(url, undefined, { transports: 'websocket' });
      },
    });

    this.socket!.activate();

    this.connectingSocket = true;
  }

  public get state() {
    return this.stateSubject.asObservable();
  }

  public connected() {
    return this.socket?.connected() || false;
  }

  public disconnect() {
    this.connectingSocket = false;
    return this.socket?.deactivate();
  }

  /**
   * Watch a channel on the socket
   *
   * This needs to be subscribed directly on the socket, and will be closed when the socket dies
   *
   * As such, this method will likely need to be called each time the socket is connected/reconnected to
   */
  public watchChannel<T = any>(channel: string) {
    return this.socket!.watch(channel).pipe(
      map(function (message) {
        return JSON.parse(message.body) as T;
      })
    );
  }
}

export class VanillaSocketTransport {
  private socket?: WebSocket;
  private socketSubs: Subscription[] = [];
  private connectingSocket = false;

  private async setupSocket(url: string) {
    this.socketSubs.forEach((s) => s.unsubscribe());
    this.socketSubs = [];
    if (this.socket) {
      this.socket.close();
    }

    const socket = new WebSocket(url);

    this.socket = socket;
  }

  public async connect(url: string, force?: boolean) {
    if (this.connectingSocket) {
      if (force) {
        await this.setupSocket(url);
      } else {
        return;
      }
    } else {
      await this.setupSocket(url);
    }

    this.connectingSocket = true;
  }

  public connected() {
    return this.socket?.readyState === 1 || false;
  }

  public disconnect() {
    this.connectingSocket = false;
    return this.socket?.close();
  }

  /**
   * Watch a channel on the socket
   *
   * This needs to be subscribed directly on the socket, and will be closed when the socket dies
   *
   * As such, this method will likely need to be called each time the socket is connected/reconnected to
   */
  public watchChannel<T = any>() {
    return new Observable<T>((sub) => {
      this.socket!.addEventListener('message', (message) => {
        const response = JSON.parse(message.data) as T;
        sub.next(response);
      });

      this.socket!.addEventListener('error', (error: any) => {
        sub.error(error);
      });

      this.socket!.addEventListener('close', () => {
        sub.complete();
      });
    });
  }
}

export class RSocketTransport {
  private socket?: RSocket;
  private socketSubs: Subscription[] = [];
  private connectingSocket?: Promise<void>;
  private url: string = '';

  private _connected = false;

  constructor(private jwtGetter: () => string | null) {}

  private async setupSocket(url: string) {
    this.socketSubs.forEach((s) => s.unsubscribe());
    this.socketSubs = [];
    if (this.socket) {
      this.socket.close();
    }
    this.url = url;
    this.connectingSocket = undefined;

    const config: ConnectorConfig = {
      transport: new WebsocketClientTransport({
        url: url,
        wsCreator: (url) => new WebSocket(url) as any,
      }),
    };

    const jwt = this.jwtGetter();
    if (jwt) {
      config.setup = {
        /**
         * Spring's implementation of the RSocket defaults to assuming metadata is sent as composite, and will not parse it corretly unless it is
         */
        metadataMimeType: WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.string,
        payload: {
          data: undefined,
          metadata: encodeCompositeMetadata([
            [WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION, encodeBearerAuthMetadata(jwt)],
          ]),
        },
      };
    }

    const connector = new RSocketConnector(config);

    this.connectingSocket = connector.connect().then((rsocket) => {
      rsocket.onClose(() => {
        this._connected = false;
        this.socket = undefined;
      });

      this.socket = rsocket;
      this._connected = true;
      this.connectingSocket = undefined;
    });

    return this.connectingSocket!;
  }

  public async connect(url: string, force?: boolean) {
    force = force || url !== this.url;
    if (!force && this.connectingSocket) {
      return this.connectingSocket;
    } else if (!force && this.socket && this._connected) {
      return Promise.resolve(this.socket);
    } else {
      return this.setupSocket(url);
    }
  }

  public connected() {
    return this._connected;
  }

  public disconnect() {
    this.connectingSocket = undefined;
    return this.socket?.close();
  }

  /**
   * Watch a channel on the socket
   *
   * This needs to be subscribed directly on the socket, and will be closed when the socket dies
   *
   * As such, this method will likely need to be called each time the socket is connected/reconnected to
   */
  public watchChannel<T = any>(channel: string = 'user/queue/notifications') {
    return new Observable<T>((sub) => {
      let payloadsReceived = 0;

      const requester = this.socket!.requestStream(
        {
          data: undefined,
          /**
           * Spring's implementation of the RSocket defaults to assuming metadata is sent as composite, and will not parse it corretly unless it is
           */
          metadata: encodeCompositeMetadata([[WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, encodeRoute(channel)]]),
        },
        11,
        {
          onError: (e) => sub.error(e),
          onNext: (payload, isComplete) => {
            try {
              if (payload.data) {
                sub.next(JSON.parse(payload.data.toString()));
              }
            } catch (e) {
              sub.error(e);
            }

            payloadsReceived++;

            // request 5 more payloads event 5th payload, until a max total payloads received
            if (payloadsReceived % 10 == 0) {
              requester.request(10);
            }

            if (isComplete) {
              sub.complete();
            }
          },
          onComplete: () => {
            sub.complete();
          },
          onExtension: () => {},
        }
      );

      return () => {
        requester.cancel();
      };
    });
  }
}
