import { useWizardState } from '@invisible/common/components/providers/active-wizard-provider'
import { getUUIDFromNamespace } from '@invisible/common/helpers'
import { SnackbarContext } from '@invisible/common/providers'
import { IStepRunEventTypeEnum } from '@invisible/concorde/gql-client'
import { useLoggedInUser } from '@invisible/hooks/use-logged-in-user'
import { logger } from '@invisible/logger/client'
import { useContext, useQuery } from '@invisible/trpc/client'
import { theme } from '@invisible/ui/mui-theme-v2'
import { ThemeProvider } from '@mui/material/styles'
import { Prisma } from '@prisma/client'
import { InitialValue, Json } from 'libs/common/components/process-base/src/lib/hooks/useBaseRunCreate'
import { baseRunVariableCreationFromBaseRunInput } from 'libs/common/components/process-base/src/lib/hooks/useBaseRunCreateManyWizardAction'
import useBaseRunVariableFindManyByBaseRunId from 'libs/common/components/process-base/src/lib/hooks/useBaseRunVariableFindManyByBaseRunId'
import { flatten, sampleSize, shuffle } from 'lodash/fp'
import pMap from 'p-map'
import pTimes from 'p-times'
import { useContext as useReactContext, useEffect, useMemo, useState } from 'react'
import { useGate } from 'statsig-react'
import { JsonValue } from 'type-fest'

import {
  NEXT_PUBLIC_CONCORDE_URL,
  NEXT_PUBLIC_MTC_RLHF_WITHOUT_CLOUDFARE_URL,
} from '../../../../config/env'
import { useFirstManualStepForBaseRun } from '../../hooks/useFirstManualStepForBaseRun'
import { usePollQueryAndSaveModelTask } from '../../hooks/usePollQueryAndSaveModelTask'
import { useStepRunEventLogger } from '../../hooks/useStepRunEventLogger'
import { TTextDirection, TTextRenderMode } from '../common/types'
import { Layout } from './components/Layout'
import { MainContent } from './components/MainContent'
import { Sidebar } from './components/Sidebar'
import { RLHFContext } from './context'
import { useMutations } from './hooks/useMutations'
import {
  IContent,
  INormalizedConversationData,
  IProps,
  RLHFContextDataType,
  TFindChildBaseRunsData,
  TManyBaseRunData,
  TModelConfig,
} from './types'

const MultimodalRLHFWAC = ({
  baseRun,
  multimodalRLHF: config,
  stepRun,
  isReadOnly,
  id: configId,
}: IProps) => {
  const { value: enableMtcRlhfNonCloudfareRoute } = useGate('enable-mtc-rlhf-non-cloudfare-route')
  const { value: enableRLHFConcordeQueryModel } = useGate('enable-rlhf-concorde-query-model')
  const { maybeLogStepRunEvent } = useStepRunEventLogger()
  const [loggedInUser] = useLoggedInUser()

  const reactQueryContext = useContext()
  const [startedPromptForCurrentTurn, setStartedPromptForCurrentTurn] = useState(false)
  const [isResponseRegenerating, setIsResponseRegenerating] = useState(false)
  const [promptText, setPromptText] = useState('')
  const [attachedFiles, setAttachedFiles] = useState<IContent[]>([])

  const [tab, setTab] = useState('1')
  const [showInfoSection, setShowInfoSection] = useState(false)
  const [isLastResponseEmpty, setIsLastResponseEmpty] = useState(false)
  const [isFetchingResponses, setisFetchingResponses] = useState(false)

  const [visibleResponseIndices, setVisibleResponseIndices] = useState<number[]>(() =>
    [1, 2, 3, 4, 5].slice(
      0,
      config.models?.reduce((acc, curr) => acc + (curr.numOfCalls ?? 1), 0) ?? 0
    )
  )
  const [collapseAll, setCollapseAll] = useState(false)
  const [isUpdated, setIsUpdated] = useState(false)
  const [responseMetadataValidationFailures, setResponseMetadataValidationFailures] = useState(
    [] as string[]
  )
  const [promptMetadataValidationFailures, setPromptMetadataValidationFailures] = useState(
    [] as string[]
  )
  const { showSnackbar } = useReactContext(SnackbarContext)

  const { dispatch } = useWizardState()
  const startPolling = usePollQueryAndSaveModelTask({
    interval: 5000, // 5 seconds
    timeout: 300000, // 5 minutes
  })
  const conversationData = useBaseRunVariableFindManyByBaseRunId({
    baseRunIds: [baseRun.id],
  })

  const { data: prompts } = useQuery([
    'baseRun.findChildBaseRuns',
    {
      baseId: config?.promptsBaseId as string,
      parentBaseRunId: baseRun.id,
    },
  ])
  const { data: firstManualStepRun } = useFirstManualStepForBaseRun({
    baseRunId: baseRun.id,
  })
  const firstManualStepRunCreatedAt = useMemo(
    () => firstManualStepRun?.createdAt ?? '',
    [firstManualStepRun]
  )

  const normalizedConversationData: INormalizedConversationData = useMemo(
    () => ({
      maxTurn: conversationData?.find(
        (v) => v.baseVariableId === config?.conversationMaxTurnsBaseVariableId
      )?.value as number,
      minTurn: conversationData?.find(
        (v) => v.baseVariableId === config?.conversationMinTurnsBaseVariableId
      )?.value as number,
      preamble: conversationData?.find(
        (v) => v.baseVariableId === config?.conversationPreambleBaseVariableId
      )?.value as string,
      instruction: conversationData?.find(
        (v) => v.baseVariableId === config?.conversationInstructionBaseVariableId
      )?.value as string,
      modelTemperature: conversationData?.find(
        (v) => v.baseVariableId === config?.conversationModelTempBaseVariableId
      )?.value as number,
    }),
    [conversationData, config]
  )

  // Parses prompts base runs into a typed object with its base run variables
  const normalizedPrompts = useMemo(
    () =>
      (prompts ?? [])
        .map((prompt) => ({
          id: prompt.id,
          text: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptTextBaseVariableId
          )?.value as string,
          index: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptIndexBaseVariableId
          )?.value as number,
          attachedFiles: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptAttachedFilesBaseVariableId
          )?.value as JsonValue,
          acceptedResponse: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptResponseBaseVariableId
          )?.value as string,
          acceptedContent:
            prompt.baseRunVariables.find(
              (variable) =>
                variable.baseVariable.id === config?.promptAcceptedResponseContentBaseVariableId
            )?.value ?? ([] as JsonValue),
          acceptedExtraMetadata:
            prompt.baseRunVariables.find(
              (variable) =>
                variable.baseVariable.id ===
                config?.promptAcceptedResponseExtraMetadataBaseVariableId
            )?.value ?? ({} as JsonValue),
          acceptedModel: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptAcceptedModelBaseVariableId
          )?.value as string,
          responseId: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.responseIdBaseVariableId
          )?.value as string,
          createdAt: prompt.createdAt,
          baseRunVariables: prompt.baseRunVariables,
        }))
        .sort((a, b) => a.index - b.index),
    [prompts, config]
  )
  const tokenCount = useMemo(
    () =>
      (normalizedPrompts ?? [])
        .map((prompt) => (prompt?.acceptedResponse?.length ?? 0) + (prompt?.text?.length ?? 0))
        .reduce((sum, length) => sum + length, 0),
    [normalizedPrompts]
  )

  const unsubmittedPrompt = normalizedPrompts?.find((prompt) => !prompt.acceptedResponse)

  // Uses the Prompt IDs as the parent IDs to fetch all Responses.
  const { data: allResponses } = useQuery([
    'baseRun.findManyByParents',
    {
      parentIds: normalizedPrompts?.map((prompt) => prompt.id) ?? [],
      includeBaseVariableIds: (config?.responseMetadata?.fields ?? [])
        .map((metadata) => metadata.baseVariableId as string)
        .filter(Boolean),
    },
  ])

  // Checks if all responses have valid metadata. Creates an array of all PromptIDs that are failing validation.
  useEffect(() => {
    const validateResponseMetadata = (response: TManyBaseRunData[number]) => {
      if (!response.baseRunVariables) return true

      for (const variable of response.baseRunVariables) {
        if (
          (config.responseMetadata?.fields ?? []).some(
            (metadata) =>
              metadata.required &&
              metadata.baseVariableId === variable.baseVariableId &&
              ((metadata.type === 'multiselect' &&
                typeof variable.value === 'string' &&
                variable.value.split(',').filter((v) => v).length === 0) ||
                variable.value === null)
          )
        ) {
          return false
        }
      }

      return true
    }

    setResponseMetadataValidationFailures([])
    setResponseMetadataValidationFailures((prev) => [
      ...prev,
      ...(allResponses ?? [])
        .filter((response) => !validateResponseMetadata(response))
        .map((r) => r.parentId as string),
    ])
  }, [allResponses, config.responseMetadata?.fields])

  // Checks if all prompts have valid metadata. Creates an array of all PromptIDs that are failing validation.
  useEffect(() => {
    const validatePromptMetadata = (prompt: TFindChildBaseRunsData[number]) => {
      if (!prompt.baseRunVariables) return true

      for (const variable of prompt.baseRunVariables) {
        if (
          (config.promptMetadata?.fields ?? []).some(
            (metadata) =>
              metadata.required &&
              metadata.baseVariableId === variable.baseVariable.id &&
              variable.value === null
          )
        ) {
          return false
        }
      }

      return true
    }

    setPromptMetadataValidationFailures([])
    setPromptMetadataValidationFailures((prev) => [
      ...prev,
      ...(prompts ?? [])
        .filter((prompt) => !validatePromptMetadata(prompt))
        .map((p) => p.id as string),
    ])
  }, [prompts, config.promptMetadata?.fields])

  // If there are 0 validation failures, set the MULTIMODALRLHF-ResponseMetadata in setReadyForSubmit to true, so the Wizard Submit button can be activated.
  useEffect(() => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'MULTIMODALRLHF-ResponseMetadata',
      value: responseMetadataValidationFailures.length === 0,
    })
  }, [responseMetadataValidationFailures, dispatch])

  useEffect(() => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'MULTIMODALRLHF-PromptMetadata',
      value: promptMetadataValidationFailures.length === 0,
    })
  }, [promptMetadataValidationFailures, dispatch])

  // If there are no unsubmitted prompts, set the MULTIMODALRLHF in setReadyForSubmit to true, so the Wizard Submit button can be activated.
  useEffect(() => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'MULTIMODALRLHF',
      value:
        (config.allowEndingWithoutResponse || !unsubmittedPrompt) &&
        (prompts?.length ?? 0) >=
          ((normalizedConversationData.minTurn && normalizedConversationData.maxTurn
            ? normalizedConversationData.minTurn
            : undefined) ?? config.minMaxTurns?.[0] ?? 1),
    })
  }, [unsubmittedPrompt, prompts, dispatch, config, normalizedConversationData])

  const queryModel = async (input: {
    query: string
    attachedFiles: IContent[]
    model: string
    chatHistory: {
      id: string
      text: string
      attachedFiles: JsonValue
      acceptedResponse: string
      acceptedContent: JsonValue
      acceptedExtraMetadata: JsonValue
    }[]
    perModelCallCount?: number
    meta?: Record<string, unknown>
  }) => {
    const count = input?.perModelCallCount ?? 1
    const body = JSON.stringify({
      baseRunId: baseRun.id,
      chatHistory: input.chatHistory.map(chat => ({
        ...chat,
        content: chat.attachedFiles
      })),
      query: input.query,
      content: input.attachedFiles,
      model: input.model,
      meta: input.meta,
      preamble: config?.conversationPreambleBaseVariableId
        ? normalizedConversationData.preamble
        : null,
    })

    const rlhfEndpoint = enableMtcRlhfNonCloudfareRoute
      ? `${NEXT_PUBLIC_MTC_RLHF_WITHOUT_CLOUDFARE_URL}/query-model`
      : '/api/wacs/rlhf/query-model'

    const results = await pTimes(
      count < 1 ? 1 : count,
      async () =>
        await fetch(rlhfEndpoint, {
          method: 'POST',
          credentials: 'include',
          headers: {
            'Content-Type': 'application/json',
          },
          body,
        })
          .then(async (res) => {
            try {
              return res.json()
            } catch (error) {
              const message = await res.text()
              logger.error('Failed to query model.', {
                reqBody: body,
                responseContent: message,
                error,
              })
              return {
                message: message,
              }
            }
          })
          .catch((error) => {
            logger.error('Failed to query model.', {
              reqBody: body,
              error,
            })
            return {
              message: error.message,
            }
          })
    )
    return results.reduce((acc, curr) => {
      if (!Array.isArray(curr)) {
        showSnackbar({
          message: `Unable to fetch response for model: ${input.model}. \n${curr?.message ?? ''}`,
          variant: 'error',
        })
        logger.error(`Error fetching model: ${input.model} response.`, {
          message: curr.message ?? '',
        })
        return acc
      }
      return [...acc, ...curr]
    }, [])
  }

  const queryModelAndSave = async (text: string, attachedFiles: IContent[]) => {
    try {
      const resp = await fetch(`${NEXT_PUBLIC_CONCORDE_URL}/api/rlhf/query-model-and-save`, {
        method: 'POST',
        credentials: 'include',
        headers: {
          'Content-Type': 'application/json',
        },
        body: JSON.stringify({
          query: text,
          wacConfigId: configId,
          stepRunId: stepRun.id,
          attachedFiles: attachedFiles,
        }),
      })
      const { task_id: taskId } = await resp.json()
      return taskId
    } catch (err) {
      logger.error('Failed to enqueue task for query model & save.', {
        error: err,
        body: { text, attachedFiles },
      })
      showSnackbar({
        message: 'Failed to query model.',
        variant: 'error',
      })
    }
  }

  const randomizeModels = (models: TModelConfig[]) => {
    if (config.allowRandomSampling && config.numOfModelsToSample) {
      return sampleSize(config.numOfModelsToSample, models)
    }
    return models
  }

  const randomizeModelResponses = (
    responses: {
      model: string
      index: number
      message: string
      metaFields: baseRunVariableCreationFromBaseRunInput[]
      content: JsonValue
      extraMetadata: JsonValue
    }[]
  ) => {
    if (config.randomizeModelResponseOrder) {
      return shuffle(responses)
    }
    return responses
  }

  const handleInputChange = (e: React.ChangeEvent<HTMLInputElement>) => {
    const newValue = e.target.value

    setPromptText(newValue)

    if (!startedPromptForCurrentTurn) {
      setStartedPromptForCurrentTurn(true)
      maybeLogStepRunEvent({
        name: 'typing_prompt',
        stepRunId: stepRun.id as string,
        spanId: getUUIDFromNamespace([stepRun.id as string, String(nextSpanLocation)]),
        spanType: 'TURN',
        spanLocation: nextSpanLocation,
        type: IStepRunEventTypeEnum.Span,
        timestamp: new Date(),
      })
    }

    if (newValue.length === 0) {
      setStartedPromptForCurrentTurn(false)
    }
  }

  const resubmitPrompt = async (failedPrompt: { id: string; text: string; attachedFiles: IContent[]; index: number }) => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'MULTIMODALRLHF',
      value: config.allowEndingWithoutResponse ?? false,
    })
    setisFetchingResponses(true)
    setIsLastResponseEmpty(false)
    const modelResponses = await pMap(randomizeModels(config.models ?? []), async (model) => {
      const metaParams = model.responseMetaParams ?? []
      const responses = await queryModel({
        query: failedPrompt.text,
        attachedFiles: failedPrompt.attachedFiles,
        model: model.name,
        chatHistory: flatten(
          normalizedPrompts?.filter((prompt: { id: string }) => failedPrompt.id !== prompt.id) // Exclude failed prompt
        ) ?? [],
        perModelCallCount: model?.numOfCalls,
        meta: {
          ...model.params.reduce((acc, param) => ({ ...acc, [param.name]: param.value }), {}),
          ...(config.conversationModelTempBaseVariableId &&
          normalizedConversationData.modelTemperature
            ? {
                temperature: Number.parseFloat(`${normalizedConversationData.modelTemperature}`),
              }
            : {}),
        },
      })
      return responses.map(
        (
          response: {
            text: string
            model?: string
            meta?: JsonValue
            content?: JsonValue
            extraMetadata?: JsonValue
          },
          index: number
        ) => ({
          message: response.text,
          model: model.name,
          index: index + 1,
          metaFields: metaParams
            .map((p) => ({
              value: (typeof response.meta === 'object' && response.meta ? (response.meta as Record<string, JsonValue>)[p.key] : undefined) as JsonValue,
              baseVariableId: p.baseVariableId,
            }))
            .filter((p) => p.baseVariableId),
          content: response.content ?? [],
          extraMetadata: response.extraMetadata ?? {},
        })
      )
    })
    setisFetchingResponses(false)

    if (modelResponses && flatten(modelResponses).every((res) => res.message)) {
      // Create responses
      await mutations.createBaseRuns({
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: failedPrompt.id,
        initialValuesArray: randomizeModelResponses(flatten(modelResponses)).map(
          (response, index) => [
            {
              baseVariableId: config?.responseTextBaseVariableId as string,
              value: response.message,
            },
            {
              baseVariableId: config?.responseOriginalTextBaseVariableId as string,
              value: response.message,
            },
            {
              baseVariableId: config?.responseIndexBaseVariableId as string,
              value: index + 1,
            },
            {
              baseVariableId: config?.responseModelBaseVariableId as string,
              value: response.model,
            },
            {
              baseVariableId: config?.responseContentBaseVariableId as string,
              value: response.content ?? [],
            },
            {
              baseVariableId: config?.responseExtraMetadataBaseVariableId as string,
              value: response.extraMetadata ?? {},
            },
            ...response.metaFields,
          ]
        ),
        sourceStepRunId: stepRun.id,
      })
    } else if (modelResponses) {
      setIsLastResponseEmpty(true)
    }
  }

  const handlePromptSubmission = async (text: string, attachedFiles: IContent[]) => {
    maybeLogStepRunEvent({
      name: 'prompt_submitted',
      stepRunId: stepRun.id as string,
      spanId: getUUIDFromNamespace([stepRun.id as string, String(nextSpanLocation)]),
      spanType: 'TURN',
      spanLocation: nextSpanLocation,
      type: IStepRunEventTypeEnum.Span,
      timestamp: new Date(),
    })

    dispatch({
      type: 'setReadyForSubmit',
      key: 'MULTIMODALRLHF',
      value: config.allowEndingWithoutResponse ?? false,
    })
    setisFetchingResponses(true)
    setIsLastResponseEmpty(false)

    if (enableRLHFConcordeQueryModel) {
      const taskId = await queryModelAndSave(text, attachedFiles)
      startPolling(taskId)
        .then(() => {
          reactQueryContext.invalidateQueries('baseRun.findChildBaseRuns')
          reactQueryContext.invalidateQueries('baseRun.findManyByParents')
          setPromptText('')
          setAttachedFiles([])
          setStartedPromptForCurrentTurn(false)
        })
        .catch((err) => {
          logger.error('Query polling failed: ', { error: err, taskId })
          showSnackbar({
            message: `Failed to process request: ${err}`,
            variant: 'error',
          })
        })
        .finally(() => {
          setisFetchingResponses(false)
        })
      return
    }

    const modelResponses = await pMap(randomizeModels(config.models ?? []), async (model) => {
      const metaParams = model.responseMetaParams ?? []
      const responses = await queryModel({
        query: text,
        attachedFiles: attachedFiles,
        model: model.name,
        chatHistory:
          flatten(
            normalizedPrompts?.map((prompt) => ({
              ...prompt,
              acceptedContent: prompt.acceptedContent as JsonValue,
              acceptedExtraMetadata: prompt.acceptedExtraMetadata as JsonValue,
            }))
          ) ?? [],
        perModelCallCount: model?.numOfCalls,
        meta: {
          ...model.params.reduce((acc, param) => ({ ...acc, [param.name]: param.value }), {}),
          ...(config.conversationModelTempBaseVariableId &&
          normalizedConversationData.modelTemperature
            ? {
                temperature: Number.parseFloat(`${normalizedConversationData.modelTemperature}`),
              }
            : {}),
        },
      })
      return responses.map(
        (
          response: {
            text: string
            model?: string
            meta?: Prisma.JsonObject
            content?: IContent[]
            extraMetadata?: Record<string, unknown>
          },
          index: number
        ) => ({
          message: response.text,
          model: model.name,
          index: index + 1,
          metaFields: metaParams
            .map((p) => ({
              value: (typeof response.meta === 'object' && response.meta ? (response.meta as Record<string, JsonValue>)[p.key] : undefined) as JsonValue,
              baseVariableId: p.baseVariableId,
            }))
            .filter((p) => p.baseVariableId),
          content: response.content ?? [],
          extraMetadata: response.extraMetadata ?? {},
        })
      )
    })
    setisFetchingResponses(false)

    const initialValues: InitialValue[] = [
      {
        baseVariableId: config?.promptTextBaseVariableId,
        value: promptText,
      },
      {
        baseVariableId: config?.promptIndexBaseVariableId,
        value: (prompts?.length ?? 0) + 1,
      },
      {
        baseVariableId: config?.promptAttachedFilesBaseVariableId,
        value: attachedFiles.map((file) => ({
          type: file.type,
          url: file.url,
        })) as Json,
      },
    ].filter(
      (item): item is InitialValue => 
        Boolean(item.baseVariableId)
    )

    if (initialValues.length === 0) {
      showSnackbar({
        message: 'Cannot create prompt: No valid base variable IDs found',
        variant: 'error',
      })
      return
    }

    const prompt = await mutations.createBaseRun({
      baseId: config?.promptsBaseId as string,
      parentBaseRunId: baseRun.id,
      stepRunId: stepRun.id,
      initialValues,
    })

    if (modelResponses && flatten(modelResponses).every((res) => res.message)) {
      // Create responses
      await mutations.createBaseRuns({
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: prompt.id,
        initialValuesArray: randomizeModelResponses(flatten(modelResponses)).map(
          (response, index) => [
            {
              baseVariableId: config?.responseTextBaseVariableId as string,
              value: response.message,
            },
            {
              baseVariableId: config?.responseOriginalTextBaseVariableId as string,
              value: response.message,
            },
            {
              baseVariableId: config?.responseModelBaseVariableId as string,
              value: response.model,
            },
            {
              baseVariableId: config?.responseIndexBaseVariableId as string,
              value: index + 1,
            },
            {
              baseVariableId: config?.responseContentBaseVariableId as string,
              value: response.content ?? [],
            },
            {
              baseVariableId: config?.responseExtraMetadataBaseVariableId as string,
              value: response.extraMetadata ?? {},
            },
            ...response.metaFields,
          ]
        ),
        sourceStepRunId: stepRun.id,
      })
    }

    setPromptText('')
    setAttachedFiles([])
    setStartedPromptForCurrentTurn(false)
  }

  const handleDeletePrompt = async (prompt: { id: string; responseId: string; index: number }) => {
    if (
      !window.confirm(
        'You are about to delete this prompt and its responses. \nThis will also delete subsequent prompts and responses if any. \nAre you sure?'
      )
    )
      return
    const spanLocation = prompt.index + 1

    const currentPromptCreatedAt = prompts?.find((p) => p.id === prompt.id)?.createdAt ?? ''
    const subsequentPrompts = prompts?.filter(
      (p) => p.createdAt > currentPromptCreatedAt && p.id !== prompt.id
    )
    const promptsResponses = (
      await Promise.all(
        [...(subsequentPrompts || []), prompt]?.map((p) =>
          reactQueryContext.fetchQuery([
            'baseRun.findChildBaseRuns',
            {
              baseId: config?.responsesBaseId as string,
              parentBaseRunId: p.id as string,
            },
          ])
        )
      )
    ).flat()
    const baseRunsIdsToDelete = [
      prompt.id,
      ...(subsequentPrompts || []).map((p) => p.id),
      ...promptsResponses.map((r) => r.id),
    ]
    await mutations.deleteBaseRuns({
      baseRunIds: baseRunsIdsToDelete,
      stepRunId: stepRun.id,
    })
    setStartedPromptForCurrentTurn(false)
    reactQueryContext.queryClient.setQueryData<TFindChildBaseRunsData | undefined>(
      [
        'baseRun.findChildBaseRuns',
        {
          baseId: config?.promptsBaseId as string,
          parentBaseRunId: baseRun.id,
        },
      ],
      (prevData) => {
        if (!prevData) return
        return prevData.filter((baseRun) => !baseRunsIdsToDelete.includes(baseRun.id))
      }
    )
    maybeLogStepRunEvent({
      name: 'turn_deleted',
      stepRunId: stepRun.id as string,
      spanId: getUUIDFromNamespace([stepRun.id as string, String(spanLocation)]),
      spanType: 'TURN',
      spanLocation: spanLocation,
      type: IStepRunEventTypeEnum.Span,
      timestamp: new Date(),
    })

    subsequentPrompts?.forEach((_, index) => {
      const subsequentPromptLocation = spanLocation + index + 1
      maybeLogStepRunEvent({
        name: 'turn_deleted',
        stepRunId: stepRun.id as string,
        spanId: getUUIDFromNamespace([stepRun.id as string, String(subsequentPromptLocation)]),
        spanType: 'TURN',
        spanLocation: subsequentPromptLocation,
        type: IStepRunEventTypeEnum.Span,
        timestamp: new Date(),
      })
    })
  }

  const handleRefetchPromptResponses = async (prompt: {
    id: string
    text: string
    attachedFiles: IContent[]
    index: number
  }) => {
    if (
      !window.confirm(
        "You are about to delete this prompt's responses and fetch new ones. \nAre you sure?"
      )
    )
      return
    setisFetchingResponses(true)
    const spanLocation = prompt.index + 1
    maybeLogStepRunEvent({
      name: 'response_regenerate_clicked',
      stepRunId: stepRun.id as string,
      spanId: getUUIDFromNamespace([stepRun.id as string, String(spanLocation)]),
      spanType: 'TURN',
      spanLocation: spanLocation,
      type: IStepRunEventTypeEnum.Span,
      timestamp: new Date(),
    })
    setIsResponseRegenerating(true)

    await mutations.updateBaseRunVariables({
      stepRunId: stepRun.id,
      data: [
        {
          baseRunId: prompt.id,
          baseVariableId: config?.promptResponseBaseVariableId as string,
          value: null,
        },
        {
          baseRunId: prompt.id,
          baseVariableId: config?.promptAcceptedModelBaseVariableId as string,
          value: null,
        },
        {
          baseRunId: prompt.id,
          baseVariableId: config?.promptAcceptedResponseContentBaseVariableId as string,
          value: null,
        },
        {
          baseRunId: prompt.id,
          baseVariableId: config?.promptAcceptedResponseExtraMetadataBaseVariableId as string,
          value: null,
        },
        {
          baseRunId: prompt.id,
          baseVariableId: config?.responseIdBaseVariableId as string,
          value: null,
        },
      ],
    })

    const promptsResponses = await Promise.all(
      [{ id: prompt.id }]?.map((p) =>
        reactQueryContext.fetchQuery([
          'baseRun.findChildBaseRuns',
          {
            baseId: config?.responsesBaseId as string,
            parentBaseRunId: p.id as string,
          },
        ])
      )
    )

    const baseRunsIdsToDelete = [...flatten(promptsResponses).map((r) => r.id)]
    if (baseRunsIdsToDelete.length > 0) {
      await mutations.deleteBaseRuns({
        baseRunIds: baseRunsIdsToDelete,
        stepRunId: stepRun.id,
      })
    }
    await resubmitPrompt(prompt)
  }

  const { mutations, mutationStates } = useMutations({
    stepRun,
    showSnackbar,
    reactQueryContext,
    maybeLogStepRunEvent,
    nextSpanLocation: normalizedPrompts?.length ?? 0,
    config,
    isResponseRegenerating,
    setIsResponseRegenerating,
  })

  const { isCreatingBaseRun, isDeletingBaseRuns, isCreatingManyBaseRuns, isUpdatingBaseVariables } =
    mutationStates

  const rhlfContextValues = {
    stepRunId: stepRun.id,
    baseRunId: baseRun.id,
    config,
    isReadOnly,
    firstManualStepRunCreatedAt,
    numOfPrompts: normalizedPrompts?.length ?? 0,
    visibleResponseIndices,
    responseMetadataValidationFailures,
    promptMetadataValidationFailures,
    setIsUpdated,
    updateBaseRunVariables: (
      data: { baseRunId: string; baseVariableId: string; value: JsonValue }[]
    ) => mutations.updateBaseRunVariables({ stepRunId: stepRun.id, data: [...data]}),
    resubmitPrompt,
    deletePrompt: handleDeletePrompt,
    submitPrompt: handlePromptSubmission,
    refetchPromptResponses: handleRefetchPromptResponses,
    loaders: {
      isFetchingResponses,
      isCreatingManyBaseRuns,
      isCreatingBaseRun,
      isDeletingBaseRuns,
      isUpdatingBaseVariables,
    },
  }

  const isMinMaxTurnsVariablesSet = normalizedConversationData.minTurn ? true : false
  const numOfResponsesPerPrompt =
    config.models?.reduce((acc, curr) => acc + (curr.numOfCalls ?? 1), 0) ?? 0

  const nextSpanLocation = rhlfContextValues.numOfPrompts + 1
  const textDirection = config.promptResponseTextDirection ?? 'auto'

  const maxAllowedTurns = isMinMaxTurnsVariablesSet
    ? normalizedConversationData.maxTurn ?? 1
    : config.minMaxTurns?.[1] ?? 1

  return (
    <ThemeProvider theme={theme}>
      <RLHFContext.Provider value={rhlfContextValues as RLHFContextDataType}>
        <Layout showInfoSection={showInfoSection} isUpdatingBaseVariables={isUpdatingBaseVariables}>
          {showInfoSection && (
            <Sidebar
              tab={tab}
              setTab={setTab}
              instruction={normalizedConversationData.instruction}
              preamble={normalizedConversationData.preamble}
              textDirection={textDirection as TTextDirection}
            />
          )}

          <MainContent
            headerProps={{
              showInfoSection,
              setShowInfoSection,
              tokenCount,
              isMinMaxTurnsVariablesSet,
              normalizedConversationData,
              config,
              numOfResponsesPerPrompt,
              visibleResponseIndices,
              setVisibleResponseIndices,
              isUpdated,
              isUpdatingBaseVariables,
              collapseAll,
              setCollapseAll,
            }}
            conversationProps={{
              prompts: normalizedPrompts,
              isFetchingResponses,
              isCreatingBaseRun,
              isCreatingManyBaseRuns,
              isLastResponseEmpty,
              collapseAll,
              loggedInUser,
              defaultRenderMode: config.defaultRenderMode as TTextRenderMode,
            }}
            promptInputProps={{
              promptText,
              attachedFiles,
              handleInputChange,
              setAttachedFiles,
              handlePromptSubmission,
              isDisabled:
                Boolean(unsubmittedPrompt) ||
                isCreatingBaseRun ||
                isFetchingResponses ||
                isCreatingManyBaseRuns,
              isReadOnly,
              loggedInUser,
              promptBucketName: config.promptBucketName ?? '',
              promptBucketFolderPath: config.promptBucketFolderPath ?? '',
              allowFileUpload: config.allowFileUpload ?? false,
            }}
            config={config}
            prompts={prompts}
            maxAllowedTurns={maxAllowedTurns}
          />
        </Layout>
      </RLHFContext.Provider>
    </ThemeProvider>
  )
}

export { MultimodalRLHFWAC }
