import { map } from 'rxjs/operators';
import { ObservableState } from '@ardoq/rxbeach';
import { Observable, combineLatest } from 'rxjs';
import { ByIdItem } from '../types';

type FilterFlatTreeParams = {
  flattenTree: Record<string, string[]>;
  filteredIds: string[] | null;
  rootIds: string[];
};

const filterFlatTree = ({
  flattenTree,
  filteredIds,
  rootIds,
}: FilterFlatTreeParams) => {
  if (!filteredIds || filteredIds.length === 0) return { flattenTree, rootIds };
  const filteredIdsSet = new Set(filteredIds);
  const filteredRootIds = rootIds?.filter(id => filteredIdsSet.has(id));
  const filteredFlattenTree = Object.fromEntries(
    filteredIds.map(id => [
      id,
      flattenTree[id]?.filter(childId => filteredIdsSet.has(childId)),
    ])
  );
  return { flattenTree: filteredFlattenTree, rootIds: filteredRootIds };
};

type FlattenTree$<Data> = Observable<{
  flattenTree: Record<string, string[]>;
  rootIds: string[];
  byId: Record<string, Data>;
}>;
type FilteredIds$ = ObservableState<string[] | null>;

export const withFiltering$ = <T extends ByIdItem>(
  flattenTree$: FlattenTree$<T>,
  filteredIds$: FilteredIds$
) =>
  combineLatest([flattenTree$, filteredIds$]).pipe(
    map(([{ flattenTree, rootIds, byId }, filteredIds]) => ({
      ...filterFlatTree({
        flattenTree,
        filteredIds,
        rootIds,
      }),
      byId,
    }))
  );
